A very simple example of using STDPΒΆ

../_images/simple_stdp_neuron_20170505-150331.png
# encoding: utf8
"""
A very simple example of using STDP.

A single postsynaptic neuron fires at a constant rate. We connect several
presynaptic neurons to it, each of which fires spikes with a fixed time
lag or time advance with respect to the postsynaptic neuron.
The weights of these connections are very small, so they will not
significantly affect the firing times of the post-synaptic neuron.
We plot the amount of potentiation or depression of each synapse as a
function of the time difference.


Usage: python simple_STDP.py [-h] [--plot-figure] [--debug DEBUG] simulator

positional arguments:
  simulator      neuron, nest, brian or another backend simulator

optional arguments:
  -h, --help     show this help message and exit
  --plot-figure  Plot the simulation results to a file
  --fit-curve    Calculate the best-fit curve to the weight-delta_t measurements
  --debug DEBUG  Print debugging information

"""

from math import exp
import numpy as np
import neo
from quantities import ms
from pyNN.utility import get_simulator, init_logging, normalized_filename
from pyNN.utility.plotting import DataTable
from pyNN.parameters import Sequence


# === Parameters ============================================================

firing_period = 100.0    # (ms) interval between spikes
cell_parameters = {
    "tau_m": 10.0,       # (ms)
    "v_thresh": -50.0,   # (mV)
    "v_reset": -60.0,    # (mV)
    "v_rest": -60.0,     # (mV)
    "cm": 1.0,           # (nF)
    "tau_refrac": firing_period / 2,  # (ms) long refractory period to prevent bursting
}
n = 60                   # number of synapses / number of presynaptic neurons
delta_t = 1.0            # (ms) time difference between the firing times of neighbouring neurons
t_stop = 10 * firing_period + n * delta_t
delay = 3.0              # (ms) synaptic time delay


# === Configure the simulator ===============================================

sim, options = get_simulator(("--plot-figure", "Plot the simulation results to a file", {"action": "store_true"}),
                             ("--fit-curve", "Calculate the best-fit curve to the weight-delta_t measurements",
                              {"action": "store_true"}),
                             ("--dendritic-delay-fraction",
                              "What fraction of the total transmission delay is due to dendritic propagation", {"default": 1}),
                             ("--debug", "Print debugging information"))

if options.debug:
    init_logging(None, debug=True)

sim.setup(timestep=0.01, min_delay=delay, max_delay=delay)


# === Build the network =====================================================

def build_spike_sequences(period, duration, n, delta_t):
    """
    Return a spike time generator for `n` neurons (spike sources), where
    all neurons fire with the same period, but neighbouring neurons have a relative
    firing time difference of `delta_t`.
    """
    def spike_time_gen(i):
        """Spike time generator. `i` should be an array of indices."""
        return [Sequence(np.arange(period + j * delta_t, duration, period)) for j in (i - n // 2)]
    return spike_time_gen


spike_sequence_generator = build_spike_sequences(firing_period, t_stop, n, delta_t)
# presynaptic population
p1 = sim.Population(n, sim.SpikeSourceArray(spike_times=spike_sequence_generator),
                    label="presynaptic")
# single postsynaptic neuron
p2 = sim.Population(1, sim.IF_cond_exp(**cell_parameters),
                    initial_values={"v": cell_parameters["v_reset"]}, label="postsynaptic")
# drive to the postsynaptic neuron, ensuring it fires at exact multiples of the firing period
p3 = sim.Population(1, sim.SpikeSourceArray(spike_times=np.arange(firing_period - delay, t_stop, firing_period)),
                    label="driver")

# we set the initial weights to be very small, to avoid perturbing the firing times of the
# postsynaptic neurons
stdp_model = sim.STDPMechanism(
    timing_dependence=sim.SpikePairRule(tau_plus=20.0, tau_minus=20.0,
                                        A_plus=0.01, A_minus=0.012),
    weight_dependence=sim.AdditiveWeightDependence(w_min=0, w_max=0.0000001),
    weight=0.00000005,
    delay=delay,
    dendritic_delay_fraction=float(options.dendritic_delay_fraction))
connections = sim.Projection(p1, p2, sim.AllToAllConnector(), stdp_model)

# the connection weight from the driver neuron is very strong, to ensure the
# postsynaptic neuron fires at the correct times
driver_connection = sim.Projection(p3, p2, sim.OneToOneConnector(),
                                   sim.StaticSynapse(weight=10.0, delay=delay))

# == Instrument the network =================================================

p1.record('spikes')
p2.record(['spikes', 'v'])


class WeightRecorder(object):
    """
    Recording of weights is not yet built in to PyNN, so therefore we need
    to construct a callback object, which reads the current weights from
    the projection at regular intervals.
    """

    def __init__(self, sampling_interval, projection):
        self.interval = sampling_interval
        self.projection = projection
        self._weights = []

    def __call__(self, t):
        self._weights.append(self.projection.get('weight', format='list', with_address=False))
        return t + self.interval

    def get_weights(self):
        signal = neo.AnalogSignal(self._weights, units='nA', sampling_period=self.interval * ms,
                                  name="weight", array_annotations={"channel_index": np.arange(len(self._weights[0]))})
        return signal


weight_recorder = WeightRecorder(sampling_interval=1.0, projection=connections)

# === Run the simulation =====================================================

sim.run(t_stop, callbacks=[weight_recorder])


# === Save the results, optionally plot a figure =============================

filename = normalized_filename("Results", "simple_stdp", "pkl", options.simulator)
p2.write_data(filename, annotations={'script_name': __file__})

presynaptic_data = p1.get_data().segments[0]
postsynaptic_data = p2.get_data().segments[0]
print("Post-synaptic spike times: %s" % postsynaptic_data.spiketrains[0])

weights = weight_recorder.get_weights()
final_weights = np.array(weights[-1])
deltas = delta_t * np.arange(n // 2, -n // 2, -1)
print("Final weights: %s" % final_weights)
plasticity_data = DataTable(deltas, final_weights)


if options.fit_curve:
    def double_exponential(t, t0, w0, wp, wn, tau):
        return w0 + np.where(t >= t0, wp * np.exp(-(t - t0) / tau), wn * np.exp((t - t0) / tau))
    p0 = (-1.0, 5e-8, 1e-8, -1.2e-8, 20.0)
    popt, pcov = plasticity_data.fit_curve(double_exponential, p0, ftol=1e-10)
    print("Best fit parameters: t0={0}, w0={1}, wp={2}, wn={3}, tau={4}".format(*popt))


if options.plot_figure:
    from pyNN.utility.plotting import Figure, Panel, DataTable
    figure_filename = filename.replace("pkl", "png")
    Figure(
        # raster plot of the presynaptic neuron spike times
        Panel(presynaptic_data.spiketrains,
              yticks=True, markersize=0.2, xlim=(0, t_stop)),
        # membrane potential of the postsynaptic neuron
        Panel(postsynaptic_data.filter(name='v')[0],
              ylabel="Membrane potential (mV)",
              data_labels=[p2.label], yticks=True, xlim=(0, t_stop)),
        # evolution of the synaptic weights with time
        Panel(weights, xticks=True, yticks=True, xlabel="Time (ms)",
              legend=False, xlim=(0, t_stop)),
        # scatterplot of the final weight of each synapse against the relative
        # timing of pre- and postsynaptic spikes for that synapse
        Panel(plasticity_data,
              xticks=True, yticks=True, xlim=(-n / 2 * delta_t, n / 2 * delta_t),
              ylim=(0.9 * final_weights.min(), 1.1 * final_weights.max()),
              xlabel="t_post - t_pre (ms)", ylabel="Final weight (nA)",
              show_fit=options.fit_curve),
        title="Simple STDP example",
        annotations="Simulated with %s" % options.simulator.upper()
    ).save(figure_filename)
    print(figure_filename)


# === Clean up and quit ========================================================

sim.end()