Dear Aditya,
Thanks for writing in! In principle, both of your approaches should work: you should be
able to set the learning rate (``lambda``) to 0 on the STDP synapse, and the weights
should remain unchanged from that point on. Your approach with saving the weights from the
STDP synapses, and then reconstructing them using static synapses, should also work.
However, it is risky to loop over a set of connections while you're changing them, as
the connection IDs might change during the loop iterations. I would recommend instead to
do a ResetKernel() inbetween, and then connect the network from scratch. Then you also
don't have to work with the time offset (= nest.biological_time) any longer.
Hope this helps, please let us know how you fare!
With kind regards,
Charl
On Tue, Dec 10, 2024, at 01:11, Aditya Srivastava wrote:
I have built a nest model with a reverse stdp approach
in which with
time I'm trying to reduce the spikes in training data which is a normal
bearing metal data and then during testing I'll try to test it on a
anomalous data which should produce opposite results to that of
training data so now firstly I made it in brian and it worked fine but
here in nest after training the model when I have to test it then again
the synaptic weights are changing which should not happen during
testing. In brian I had put the learning rate as 0 thats why it was
working fine but in Nest I am not able to find a way to test a model by
freezing the weights. So I tried something which I'm not sure is the
correct approach so can anyone please review this approach or if
suggest a better approach for testing these kind of models.
Also I want to mention one thing that with my current approach the
weights are freezing but as I'm assigning input spikes from anomalous
signal and start testing, the kernel is getting died which usually
happens when there is a some issue in nest network.
nest.ResetKernel()
# Simulation parameters
dt = 0.01 # Simulation time step in ms
nest.SetKernelStatus({"resolution": dt})
N = 16 # Number of neurons
gmax = 3
neuron_params = {
"tau_m": 140.0, # Membrane time constant in ms
"E_L": -70.6, # Resting potential (in mV)
"V_th": -40.4, # Threshold potential (in mV)
"V_reset": -70.6, # Reset potential (in mV)
"V_m":0.0,
}
nest.CopyModel("stdp_synapse_hom", "stdp_synapse_exc",
{"weight":gmax,"alpha":1,"Wmax":gmax, "tau_plus":
20.0,"lambda":
-0.0005, "mu_minus": 0, "mu_plus": 0})
nest.CopyModel("stdp_synapse_hom", "stdp_synapse_inh",
{"weight":-gmax/(N-1),"alpha":1,"Wmax":-gmax,
"tau_plus":
20.0,"lambda": 0.005, "mu_minus": 0, "mu_plus": 0})
G_spike_gen = nest.Create("spike_generator", N)
G_input = nest.Create("parrot_neuron", N)
G_hidden = nest.Create("iaf_psc_delta", N,params=neuron_params)
G_output = nest.Create("iaf_psc_delta", 1,params=neuron_params)
S_in = nest.Connect(G_spike_gen, G_input, "one_to_one")
# Connect input to hidden layer with excitatory STDP synapses
S_exc = nest.Connect(G_input, G_hidden,"one_to_one",syn_spec =
"stdp_synapse_exc")
# Retrieve the neuron IDs for G_input and G_hidden
input_ids = G_input.tolist()
hidden_ids = G_hidden.tolist()
# Get the existing excitatory connections and extract their
source-target ID pairs
exc_connections = nest.GetConnections(G_input, G_hidden)
exc_connection_pairs = set(zip(exc_connections.get("source"),
exc_connections.get("target")))
for i in input_ids:
for h in hidden_ids:
if (i, h) not in exc_connection_pairs:
nc1 = nest.NodeCollection([i])
nc2 = nest.NodeCollection([h])
nest.Connect(nc1, nc2, syn_spec= "stdp_synapse_inh")
S_out = nest.Connect(G_hidden, G_output,"all_to_all",syn_spec =
{"weight": 40.4/N})
def train_net(signals, spike_tol=0):
output_spike_recorder = nest.Create("spike_recorder")
output_membrane_recorder = nest.Create("voltmeter", 1)
nest.Connect(G_output, output_spike_recorder) # Record output layer spikes
nest.Connect(output_membrane_recorder, G_output, "one_to_one")
while True:
recorder_for_epoch = nest.Create("spike_recorder")
nest.Connect(G_output, recorder_for_epoch)
for signal in signals:
indices, times = process_signal(signal) # Get the indices
and times
offset = nest.GetKernelStatus()["biological_time"]
for i in range(N):
times_for_neuron = [t+offset for j,t in
enumerate(times) if indices[j] == i]
G_spike_gen[i].set(spike_times = times_for_neuron)
nest.Simulate(len(signal)*(1/SAMPLE_RATE)*1000)
output_spikes_num =
len(recorder_for_epoch.get()["events"]["times"])
print("# of output spikes: " + str(output_spikes_num))
del recorder_for_epoch
if output_spikes_num <= spike_tol:
print("Training finished")
break
return output_membrane_recorder, output_spike_recorder
def synapse_freeze():
conn = nest.GetConnections()
sources = conn.get("source")
targets = conn.get("target")
weights = conn.get("weight")
# Store as a list of tuples
connections = list(zip(sources, targets, weights))
# Replace dynamic synapses with static synapses
for s,t,w in connections:
nc1 = nest.NodeCollection([s])
nc2 = nest.NodeCollection([t])
conn_t = nest.GetConnections(nc1,nc2)
model=nest.GetConnections(nc1,nc2).get("synapse_model")
if model=="stdp_synapse_exc":
nest.Disconnect(nc1,nc2,syn_spec =
{"synapse_model":"stdp_synapse_exc"})
nest.Connect(nc1,nc2,syn_spec =
{"synapse_model":"static_synapse","weight":w})
if model=="stdp_synapse_inh":
nest.Disconnect(nc1,nc2,syn_spec =
{"synapse_model":"stdp_synapse_inh"})
nest.Connect(nc1,nc2,syn_spec =
{"synapse_model":"static_synapse","weight":w})
print("Switched to static synapses with trained weights.")
def inference(signal):
output_spike_recorder_t = nest.Create("spike_recorder")
output_membrane_recorder_t = nest.Create("voltmeter", 1)
nest.Connect(G_output, output_spike_recorder_t) # Record output
layer spikes
nest.Connect(output_membrane_recorder_t, G_output, "one_to_one")
indices, times = process_signal(signal) # Prepare test signal
offset = nest.GetKernelStatus()["biological_time"]
for i in range(N):
times_for_neuron = [t+offset for j,t in enumerate(times) if
indices[j] == i]
G_spike_gen[i].set(spike_times = times_for_neuron)
nest.Simulate(len(signal)*(1/SAMPLE_RATE)*1000) # Run simulation
output_spikes_num =
len(output_spike_recorder_t.get()["events"]["times"])
print("# of output spikes: " + str(output_spikes_num))
return output_membrane_recorder_t, output_spike_recorder_t
_______________________________________________
NEST Users mailing list -- users(a)nest-simulator.org
To unsubscribe send an email to users-leave(a)nest-simulator.org