root/trunk/src/common.py @ 881

Revision 881, 69.6 KB (checked in by pierre, 2 years ago)

Fix a bug in the _print method of the Assembly object. Now, methods is well working also for compatible output mode. Nevertheless, since the id_to_index methods are slow for POpulationView and Assembly, du to the fact that we allow non sorted IDs, writing is slower with Assembly than with simple Populations. Note that I don't think the method will work well in parallel

  • Property svn:eol-style set to native
  • Property svn:keywords set to Revision Id
Line 
1# encoding: utf-8
2"""
3Defines a common implementation of the PyNN API.
4
5Simulator modules are not required to use any of the code herein, provided they
6provide the correct interface, but it is suggested that they use as much as is
7consistent with good performance (optimisations may require overriding some of
8the default definitions given here).
9
10Utility functions and classes:
11    is_conductance()
12    check_weight()
13    check_delay()
14
15Accessing individual neurons:
16    IDMixin
17
18Common API implementation/base classes:
19  1. Simulation set-up and control:
20    setup()
21    end()
22    run()
23    get_time_step()
24    get_current_time()
25    get_min_delay()
26    get_max_delay()
27    rank()
28    num_processes()
29
30  2. Creating, connecting and recording from individual neurons:
31    create()
32    connect()
33    set()
34    build_record()
35
36  3. Creating, connecting and recording from populations of neurons:
37    Population
38    Projection
39
40$Id$
41"""
42
43import numpy, os
44import logging
45from warnings import warn
46import operator
47import tempfile
48from pyNN import random, recording, errors, standardmodels, core, space, descriptions
49from pyNN.recording import files
50from itertools import chain
51if not 'simulator' in locals():
52    simulator = None  # should be set by simulator-specific modules
53
54DEFAULT_WEIGHT = 0.0
55DEFAULT_BUFFER_SIZE = 10000
56DEFAULT_MAX_DELAY = 10.0
57DEFAULT_TIMESTEP = 0.1
58DEFAULT_MIN_DELAY = DEFAULT_TIMESTEP
59
60logger = logging.getLogger("PyNN")
61
62# =============================================================================
63#   Utility functions and classes
64# =============================================================================
65
66
67def is_conductance(target_cell):
68    """
69    Returns True if the target cell uses conductance-based synapses, False if
70    it uses current-based synapses, and None if the synapse-basis cannot be
71    determined.
72    """
73    if hasattr(target_cell, 'local') and target_cell.local and hasattr(target_cell, 'celltype'):
74        is_conductance = target_cell.celltype.conductance_based
75    else:
76        is_conductance = None
77    return is_conductance
78
79
80def check_weight(weight, synapse_type, is_conductance):
81    if weight is None:
82        weight = DEFAULT_WEIGHT
83    if core.is_listlike(weight):
84        weight = numpy.array(weight)
85        nan_filter = (1 - numpy.isnan(weight)).astype(bool)  # weight arrays may contain NaN, which should be ignored
86        filtered_weight = weight[nan_filter]
87        all_negative = (filtered_weight <= 0).all()
88        all_positive = (filtered_weight >= 0).all()
89        if not (all_negative or all_positive):
90            raise errors.InvalidWeightError("Weights must be either all positive or all negative")
91    elif numpy.isreal(weight):
92        all_positive = weight >= 0
93        all_negative = weight < 0
94    else:
95        raise errors.InvalidWeightError("Weight must be a number or a list/array of numbers.")
96    if is_conductance or synapse_type == 'excitatory':
97        if not all_positive:
98            raise errors.InvalidWeightError("Weights must be positive for conductance-based and/or excitatory synapses")
99    elif is_conductance == False and synapse_type == 'inhibitory':
100        if not all_negative:
101            raise errors.InvalidWeightError("Weights must be negative for current-based, inhibitory synapses")
102    else:  # is_conductance is None. This happens if the cell does not exist on the current node.
103        logger.debug("Can't check weight, conductance status unknown.")
104    return weight
105
106
107def check_delay(delay):
108    if delay is None:
109        delay = get_min_delay()
110    # If the delay is too small , we have to throw an error
111    if delay < get_min_delay() or delay > get_max_delay():
112        raise errors.ConnectionError("delay (%s) is out of range [%s,%s]" % \
113                                     (delay, get_min_delay(), get_max_delay()))
114    return delay
115
116
117# =============================================================================
118#   Accessing individual neurons
119# =============================================================================
120
121class IDMixin(object):
122    """
123    Instead of storing ids as integers, we store them as ID objects,
124    which allows a syntax like:
125        p[3,4].tau_m = 20.0
126    where p is a Population object.
127    """
128    # Simulator ID classes should inherit both from the base type of the ID
129    # (e.g., int or long) and from IDMixin.
130
131    def __getattr__(self, name):
132        try:
133            val = self.__getattribute__(name)
134        except AttributeError:
135            if name == "parent":
136                raise Exception("parent is not set")
137            try:
138                val = self.get_parameters()[name]
139            except KeyError:
140                raise errors.NonExistentParameterError(name,
141                                                       self.celltype.__class__.__name__,
142                                                       self.celltype.get_parameter_names())
143        return val
144
145    def __setattr__(self, name, value):
146        if name == "parent":
147            object.__setattr__(self, name, value)
148        elif self.celltype.has_parameter(name):
149            self.set_parameters(**{name: value})
150        else:
151            object.__setattr__(self, name, value)
152
153    def set_parameters(self, **parameters):
154        """
155        Set cell parameters, given as a sequence of parameter=value arguments.
156        """
157        # if some of the parameters are computed from the values of other
158        # parameters, need to get and translate all parameters
159        if self.local:
160            if self.is_standard_cell:
161                computed_parameters = self.celltype.computed_parameters()
162                have_computed_parameters = numpy.any([p_name in computed_parameters
163                                                      for p_name in parameters])
164                if have_computed_parameters:
165                    all_parameters = self.get_parameters()
166                    all_parameters.update(parameters)
167                    parameters = all_parameters
168                parameters = self.celltype.translate(parameters)
169            self.set_native_parameters(parameters)
170        else:
171            raise errors.NotLocalError("Cannot set parameters for a cell that does not exist on this node.")
172
173    def get_parameters(self):
174        """Return a dict of all cell parameters."""
175        if self.local:
176            parameters = self.get_native_parameters()           
177            if self.is_standard_cell:
178                parameters = self.celltype.reverse_translate(parameters)
179            return parameters
180        else:
181            raise errors.NotLocalError("Cannot obtain parameters for a cell that does not exist on this node.")
182
183    @property
184    def celltype(self):
185        return self.parent.celltype
186
187    @property
188    def is_standard_cell(self):
189        return issubclass(self.celltype.__class__, standardmodels.StandardCellType)
190
191    def _set_position(self, pos):
192        """
193        Set the cell position in 3D space.
194
195        Cell positions are stored in an array in the parent Population.
196        """
197        assert isinstance(pos, (tuple, numpy.ndarray))
198        assert len(pos) == 3
199        self.parent._set_cell_position(self, pos)
200
201    def _get_position(self):
202        """
203        Return the cell position in 3D space.
204
205        Cell positions are stored in an array in the parent Population, if any,
206        or within the ID object otherwise. Positions are generated the first
207        time they are requested and then cached.
208        """
209        return self.parent._get_cell_position(self)
210
211    position = property(_get_position, _set_position)
212
213    @property
214    def local(self):
215        return self.parent.is_local(self)
216
217    def inject(self, current_source):
218        """Inject current from a current source object into the cell."""
219        current_source.inject_into([self])
220
221    def get_initial_value(self, variable):
222        """Get the initial value of a state variable of the cell."""
223        return self.parent._get_cell_initial_value(self, variable)
224
225    def set_initial_value(self, variable, value):
226        """Set the initial value of a state variable of the cell."""
227        self.parent._set_cell_initial_value(self, variable, value)
228
229
230# =============================================================================
231#   Functions for simulation set-up and control
232# =============================================================================
233
234
235def setup(timestep=DEFAULT_TIMESTEP, min_delay=DEFAULT_MIN_DELAY,
236          max_delay=DEFAULT_MAX_DELAY, **extra_params):
237    """
238    Should be called at the very beginning of a script.
239    extra_params contains any keyword arguments that are required by a given
240    simulator but not by others.
241    """
242    invalid_extra_params = ('mindelay', 'maxdelay', 'dt')
243    for param in invalid_extra_params:
244        if param in extra_params:
245            raise Exception("%s is not a valid argument for setup()" % param)
246    if min_delay > max_delay:
247        raise Exception("min_delay has to be less than or equal to max_delay.")
248    if min_delay < timestep:
249        raise Exception("min_delay (%g) must be greater than timestep (%g)" % (min_delay, timestep))
250
251def end(compatible_output=True):
252    """Do any necessary cleaning up before exiting."""
253    raise NotImplementedError
254
255def run(simtime):
256    """Run the simulation for simtime ms."""
257    raise NotImplementedError
258
259def reset():
260    """
261    Reset the time to zero, neuron membrane potentials and synaptic weights to
262    their initial values, and delete any recorded data. The network structure
263    is not changed, nor is the specification of which neurons to record from.
264    """
265    simulator.reset()
266
267def initialize(cells, variable, value):
268    assert isinstance(cells, (BasePopulation, Assembly)), type(cells)
269    cells.initialize(variable, value)
270
271def get_current_time():
272    """Return the current time in the simulation."""
273    return simulator.state.t
274
275def get_time_step():
276    """Return the integration time step."""
277    return simulator.state.dt
278
279def get_min_delay():
280    """Return the minimum allowed synaptic delay."""
281    return simulator.state.min_delay
282
283def get_max_delay():
284    """Return the maximum allowed synaptic delay."""
285    return simulator.state.max_delay
286
287def num_processes():
288    """Return the number of MPI processes."""
289    return simulator.state.num_processes
290
291def rank():
292    """Return the MPI rank of the current node."""
293    return simulator.state.mpi_rank
294
295# =============================================================================
296#  Low-level API for creating, connecting and recording from individual neurons
297# =============================================================================
298
299def build_create(population_class):
300    def create(cellclass, cellparams=None, n=1):
301        """
302        Create n cells all of the same type.
303
304        If n > 1, return a list of cell ids/references.
305        If n==1, return just the single id.
306        """
307        return population_class(n, cellclass, cellparams)  # return the Population or Population.all_cells?
308    return create
309
310def build_connect(projection_class, connector_class):
311    def connect(source, target, weight=0.0, delay=None, synapse_type=None,
312                p=1, rng=None):
313        """
314        Connect a source of spikes to a synaptic target.
315
316        source and target can both be individual cells or lists of cells, in
317        which case all possible connections are made with probability p, using
318        either the random number generator supplied, or the default rng
319        otherwise. Weights should be in nA or µS.
320        """
321        if isinstance(source, IDMixin):
322            source = source.parent
323        if isinstance(target, IDMixin):
324            target = target.parent
325        connector = connector_class(p_connect=p, weights=weight, delays=delay)
326        return projection_class(source, target, connector, target=synapse_type, rng=rng)
327    return connect
328
329def set(cells, param, val=None):
330    """
331    Set one or more parameters of an individual cell or list of cells.
332
333    param can be a dict, in which case val should not be supplied, or a string
334    giving the parameter name, in which case val is the parameter value.
335    """
336    assert isinstance(cells, (BasePopulation, Assembly))
337    cells.set(param, val)
338
339def build_record(variable, simulator):
340    def record(source, filename):
341        """
342        Record spikes to a file. source can be an individual cell or a list of
343        cells.
344        """
345        # would actually like to be able to record to an array and choose later
346        # whether to write to a file.
347        assert isinstance(source, (BasePopulation, Assembly))
348        source._record(variable, to_file=filename)
349        if isinstance(source, BasePopulation):
350            simulator.recorder_list.append(source.recorders[variable])  # this is a bit hackish - better to add to Population.__del__?
351        if isinstance(source, Assembly):
352            for population in source.populations:
353                simulator.recorder_list.append(population.recorders[variable])
354    if variable == 'v':
355        record.__doc__ = """
356            Record membrane potential to a file. source can be an individual cell or
357            a list of cells."""
358    elif variable == 'gsyn':
359        record.__doc__ = """
360            Record synaptic conductances to a file. source can be an individual cell
361            or a list of cells."""
362    return record
363
364
365# =============================================================================
366#   High-level API for creating, connecting and recording from populations of
367#   neurons.
368# =============================================================================
369
370class BasePopulation(object):
371    record_filter = None
372
373    def __getitem__(self, index):
374        """
375        Return a representation of the cell with the given index,
376        suitable for being passed to other methods that require a cell id.
377        Note that __getitem__ is called when using [] access, e.g.
378            p = Population(...)
379            p[2] is equivalent to p.__getitem__(2).
380        Also accepts slices, e.g.
381            p[3:6]
382        which returns an array of cells.
383        """
384        if isinstance(index, int):
385            return self.all_cells[index]
386        elif isinstance(index, (slice, list, numpy.ndarray)):
387            return PopulationView(self, index)
388        elif isinstance(index, tuple):
389            return PopulationView(self, list(index))
390        else:
391            raise TypeError("indices must be integers, slices, lists, arrays or tuples, not %s" % type(index).__name__)
392
393    def __len__(self):
394        """Return the total number of cells in the population (all nodes)."""
395        return self.size
396
397    def __iter__(self):
398        """Iterator over cell ids on the local node."""
399        return iter(self.local_cells)
400
401    def is_local(self, id):
402        assert id.parent is self
403        index = self.id_to_index(id)
404        return self._mask_local[index]
405
406    def all(self):
407        """Iterator over cell ids on all nodes."""
408        return iter(self.all_cells)
409
410    def __add__(self, other):
411        assert isinstance(other, BasePopulation)
412        return Assembly(self, other)
413
414    def _get_cell_position(self, id):
415        index = self.id_to_index(id)
416        return self.positions[:, index]
417
418    def _set_cell_position(self, id, pos):
419        index = self.id_to_index(id)
420        self.positions[:, index] = pos
421
422    def _get_cell_initial_value(self, id, variable):
423        assert isinstance(self.initial_values[variable], core.LazyArray)
424        index = self.id_to_index(id)
425        return self.initial_values[variable][index]
426
427    def _set_cell_initial_value(self, id, variable, value):
428        assert isinstance(self.initial_values[variable], core.LazyArray)
429        index = self.id_to_index(id)
430        self.initial_values[variable][index] = value
431
432    def nearest(self, position):
433        """Return the neuron closest to the specified position."""
434        # doesn't always work correctly if a position is equidistant between
435        # two neurons, i.e. 0.5 should be rounded up, but it isn't always.
436        # also doesn't take account of periodic boundary conditions
437        pos = numpy.array([position] * self.positions.shape[1]).transpose()
438        dist_arr = (self.positions - pos)**2
439        distances = dist_arr.sum(axis=0)
440        nearest = distances.argmin()
441        return self[nearest]
442
443    def sample(self, n, rng=None):
444        """
445        Randomly sample n cells from the Population, and return a PopulationView
446        object.
447        """
448        assert isinstance(n, int)
449        if not rng:
450            rng = random.NumpyRNG()
451        indices = rng.permutation(numpy.arange(len(self)))[0:n]
452        logger.debug("The %d cells recorded have indices %s" % (n, indices))
453        logger.debug("%s.sample(%s)", self.label, n)
454        return PopulationView(self, indices)
455
456    def get(self, parameter_name, gather=False):
457        """
458        Get the values of a parameter for every local cell in the population.
459        """
460        # if all the cells have the same value for this parameter, should
461        # we return just the number, rather than an array?
462       
463        if hasattr(self, "_get_array"):
464            values = self._get_array(parameter_name).tolist()
465        else:
466            values = [getattr(cell, parameter_name) for cell in self]  # list or array?
467       
468        if gather == True and num_processes() > 1:
469            all_values  = { rank(): values }
470            all_indices = { rank(): self.local_cells.tolist()}
471            all_values  = recording.gather_dict(all_values)
472            all_indices = recording.gather_dict(all_indices)
473            if rank() == 0:
474                values  = reduce(operator.add, all_values.values())
475                indices = reduce(operator.add, all_indices.values())
476            idx    = numpy.argsort(indices)
477            values = numpy.array(values)[idx]
478        return values
479
480    def set(self, param, val=None):
481        """
482        Set one or more parameters for every cell in the population. param
483        can be a dict, in which case val should not be supplied, or a string
484        giving the parameter name, in which case val is the parameter value.
485        val can be a numeric value, or list of such (e.g. for setting spike
486        times).
487        e.g. p.set("tau_m",20.0).
488             p.set({'tau_m':20,'v_rest':-65})
489        """
490        #"""
491        # -- Proposed change to arguments --
492        #Set one or more parameters for every cell in the population.
493        #
494        #Each value may be a single number or a list/array of numbers of the same
495        #size as the population. If the parameter itself takes lists/arrays as
496        #values (e.g. spike times), then the value provided may be either a
497        #single lists/1D array, a list of lists/1D arrays, or a 2D array.
498        #
499        #e.g. p.set(tau_m=20.0).
500        #     p.set(tau_m=20, v_rest=[-65.0, -65.3, ... , -67.2])
501        #"""
502        if isinstance(param, str):
503            param_dict = {param: val}
504        elif isinstance(param, dict):
505            param_dict = param
506        else:
507            raise errors.InvalidParameterValueError
508        for name, val in param_dict.items():
509            if name not in self.celltype.get_parameter_names():
510                raise errors.NonExistentParameterError(name, self.celltype, self.celltype.get_parameter_names())
511            if isinstance(val, (float, int)):
512                param_dict[name] = float(val)
513            elif isinstance(val, (list, numpy.ndarray)):
514                pass  # ought to check list/array only contains numeric types
515            else:
516                raise errors.InvalidParameterValueError
517        logger.debug("%s.set(%s)", self.label, param_dict)
518        if hasattr(self, "_set_array"):
519            self._set_array(**param_dict)
520        else:
521            for cell in self:
522                cell.set_parameters(**param_dict)
523
524    def tset(self, parametername, value_array):
525        """
526        'Topographic' set. Set the value of parametername to the values in
527        value_array, which must have the same dimensions as the Population.
528        """
529        #"""
530        # -- Proposed change to arguments --
531        #'Topographic' set. Each value in parameters should be a function that
532        #accepts arguments x,y,z and returns a single value.
533        #"""
534        if parametername not in self.celltype.get_parameter_names():
535            raise errors.NonExistentParameterError(parametername, self.celltype, self.celltype.get_parameter_names())
536        if (self.size,) == value_array.shape:  # the values are numbers or non-array objects
537            local_values = value_array[self._mask_local]
538            assert local_values.size == self.local_cells.size, "%d != %d" % (local_values.size, self.local_cells.size)
539        elif len(value_array.shape) == 2:  # the values are themselves 1D arrays
540            if value_array.shape[0] != self.size:
541                raise errors.InvalidDimensionsError("Population: %d, value_array first dimension: %s" % (self.size,
542                                                                                                         value_array.shape[0]))
543            local_values = value_array[self._mask_local]  # not sure this works
544        else:
545            raise errors.InvalidDimensionsError("Population: %d, value_array: %s" % (self.size,
546                                                                                     str(value_array.shape)))
547        assert local_values.shape[0] == self.local_cells.size, "%d != %d" % (local_values.size, self.local_cells.size)
548
549        try:
550            logger.debug("%s.tset('%s', array(shape=%s, min=%s, max=%s))",
551                         self.label, parametername, value_array.shape,
552                         value_array.min(), value_array.max())
553        except TypeError:  # min() and max() won't work for non-numeric values
554            logger.debug("%s.tset('%s', non_numeric_array(shape=%s))",
555                         self.label, parametername, value_array.shape)
556
557        # Set the values for each cell
558        if hasattr(self, "_set_array"):
559            self._set_array(**{parametername: local_values})
560        else:
561            for cell, val in zip(self, local_values):
562                setattr(cell, parametername, val)
563
564    def rset(self, parametername, rand_distr):
565        """
566        'Random' set. Set the value of parametername to a value taken from
567        rand_distr, which should be a RandomDistribution object.
568        """
569        # Note that we generate enough random numbers for all cells on all nodes
570        # but use only those relevant to this node. This ensures that the
571        # sequence of random numbers does not depend on the number of nodes,
572        # provided that the same rng with the same seed is used on each node.
573        logger.debug("%s.rset('%s', %s)", self.label, parametername, rand_distr)
574        if isinstance(rand_distr.rng, random.NativeRNG):
575            self._native_rset(parametername, rand_distr)
576        else:
577            rarr = rand_distr.next(n=self.all_cells.size, mask_local=False)
578            rarr = numpy.array(rarr)  # isn't rarr already an array?
579            assert rarr.size == self.size, "%s != %s" % (rarr.size, self.size)
580            self.tset(parametername, rarr)
581
582    def _call(self, methodname, arguments):
583        """
584        Call the method methodname(arguments) for every cell in the population.
585        e.g. p.call("set_background","0.1") if the cell class has a method
586        set_background().
587        """
588        raise NotImplementedError()
589
590    def _tcall(self, methodname, objarr):
591        """
592        `Topographic' call. Call the method methodname() for every cell in the
593        population. The argument to the method depends on the coordinates of
594        the cell. objarr is an array with the same dimensions as the
595        Population.
596        e.g. p.tcall("memb_init", vinitArray) calls
597        p.cell[i][j].memb_init(vInitArray[i][j]) for all i,j.
598        """
599        raise NotImplementedError()
600
601    def randomInit(self, rand_distr):
602        """
603        Set initial membrane potentials for all the cells in the population to
604        random values.
605        """
606        warn("The randomInit() method is deprecated, and will be removed in a future release. Use initialize('v', rand_distr) instead.")
607        self.initialize('v', rand_distr)
608
609    def initialize(self, variable, value):
610        """
611        Set initial values of state variables, e.g. the membrane potential.
612
613        `value` may either be a numeric value (all neurons set to the same
614                value) or a `RandomDistribution` object (each neuron gets a
615                different value)
616        """
617        if isinstance(value, random.RandomDistribution):
618            initial_value = value.next(n=self.all_cells.size, mask_local=self._mask_local)
619        else:
620            initial_value = value
621        self.initial_values[variable] = core.LazyArray(initial_value, shape=(self.size,))
622        if hasattr(self, "_set_initial_value_array"):
623            self._set_initial_value_array(variable, initial_value)
624        else:
625            if isinstance(value, random.RandomDistribution):
626                for cell, val in zip(self, initial_value):
627                    cell.set_initial_value(variable, val)
628            else:
629                for cell in self:  # only on local node
630                    cell.set_initial_value(variable, initial_value)
631
632    def can_record(self, variable):
633        """Determine whether `variable` can be recorded from this population."""
634        return (variable in self.celltype.recordable)
635
636    def _record(self, variable, to_file=True):
637        """
638        Private method called by record() and record_v().
639        """
640        if not self.can_record(variable):
641            raise errors.RecordingError(variable, self.celltype)       
642        logger.debug("%s.record('%s')", self.label, variable)
643        if self.record_filter is not None:
644            self.recorders[variable].record(self.record_filter)
645        else:
646            self.recorders[variable].record(self.all_cells)
647        if isinstance(to_file, basestring):
648            self.recorders[variable].file = to_file
649
650    def record(self, to_file=True):
651        """
652        Record spikes from all cells in the Population.
653        """
654        self._record('spikes', to_file)
655
656    def record_v(self, to_file=True):
657        """
658        Record the membrane potential for all cells in the Population.
659        """
660        self._record('v', to_file)
661
662    def record_gsyn(self, to_file=True):
663        """
664        Record synaptic conductances for all cells in the Population.
665        """
666        self._record('gsyn', to_file)
667
668    def printSpikes(self, file, gather=True, compatible_output=True):
669        """
670        Write spike times to file.
671
672        file should be either a filename or a PyNN File object.
673
674        If compatible_output is True, the format is "spiketime cell_id",
675        where cell_id is the index of the cell counting along rows and down
676        columns (and the extension of that for 3-D).
677        This allows easy plotting of a `raster' plot of spiketimes, with one
678        line for each cell.
679        The timestep, first id, last id, and number of data points per cell are
680        written in a header, indicated by a '#' at the beginning of the line.
681
682        If compatible_output is False, the raw format produced by the simulator
683        is used. This may be faster, since it avoids any post-processing of the
684        spike files.
685
686        For parallel simulators, if gather is True, all data will be gathered
687        to the master node and a single output file created there. Otherwise, a
688        file will be written on each node, containing only the cells simulated
689        on that node.
690        """
691        self.recorders['spikes'].write(file, gather, compatible_output, self.record_filter)
692
693    def getSpikes(self, gather=True, compatible_output=True):
694        """
695        Return a 2-column numpy array containing cell ids and spike times for
696        recorded cells.
697
698        Useful for small populations, for example for single neuron Monte-Carlo.
699        """
700        return self.recorders['spikes'].get(gather, compatible_output, self.record_filter)
701
702    def print_v(self, file, gather=True, compatible_output=True):
703        """
704        Write membrane potential traces to file.
705
706        file should be either a filename or a PyNN File object.
707
708        If compatible_output is True, the format is "v cell_id",
709        where cell_id is the index of the cell counting along rows and down
710        columns (and the extension of that for 3-D).
711        The timestep, first id, last id, and number of data points per cell are
712        written in a header, indicated by a '#' at the beginning of the line.
713
714        If compatible_output is False, the raw format produced by the simulator
715        is used. This may be faster, since it avoids any post-processing of the
716        voltage files.
717
718        For parallel simulators, if gather is True, all data will be gathered
719        to the master node and a single output file created there. Otherwise, a
720        file will be written on each node, containing only the cells simulated
721        on that node.
722        """
723        self.recorders['v'].write(file, gather, compatible_output, self.record_filter)
724
725    def get_v(self, gather=True, compatible_output=True):
726        """
727        Return a 2-column numpy array containing cell ids and Vm for
728        recorded cells.
729        """
730        return self.recorders['v'].get(gather, compatible_output, self.record_filter)
731
732    def print_gsyn(self, file, gather=True, compatible_output=True):
733        """
734        Write synaptic conductance traces to file.
735
736        file should be either a filename or a PyNN File object.
737
738        If compatible_output is True, the format is "t g cell_id",
739        where cell_id is the index of the cell counting along rows and down
740        columns (and the extension of that for 3-D).
741        The timestep, first id, last id, and number of data points per cell are
742        written in a header, indicated by a '#' at the beginning of the line.
743
744        If compatible_output is False, the raw format produced by the simulator
745        is used. This may be faster, since it avoids any post-processing of the
746        voltage files.
747        """
748        self.recorders['gsyn'].write(file, gather, compatible_output, self.record_filter)
749
750    def get_gsyn(self, gather=True, compatible_output=True):
751        """
752        Return a 3-column numpy array containing cell ids and synaptic
753        conductances for recorded cells.
754        """
755        return self.recorders['gsyn'].get(gather, compatible_output, self.record_filter)
756
757    def get_spike_counts(self, gather=True):
758        """
759        Returns the number of spikes for each neuron.
760        """
761        return self.recorders['spikes'].count(gather, self.record_filter)
762
763    def meanSpikeCount(self, gather=True):
764        """
765        Returns the mean number of spikes per neuron.
766        """
767        spike_counts = self.recorders['spikes'].count(gather, self.record_filter)
768        total_spikes = sum(spike_counts.values())
769        if rank() == 0 or not gather:  # should maybe use allgather, and get the numbers on all nodes
770            if len(spike_counts) > 0:
771                return float(total_spikes)/len(spike_counts)
772            else:
773                return numpy.nan
774        else:
775            return numpy.nan
776       
777    def inject(self, current_source):
778        """
779        Connect a current source to all cells in the Population.
780        """
781        if not self.celltype.injectable:
782            raise TypeError("Can't inject current into a spike source.")
783        current_source.inject_into(self)
784
785    def save_positions(self, file):
786        """
787        Save positions to file. The output format is id x y z
788        """
789        # first column should probably be indices, not ids. This would make it
790        # simulator independent.
791        if isinstance(file, basestring):
792            file = files.StandardTextFile(file, mode='w')
793        cells  = self.all_cells
794        result = numpy.empty((len(cells), 4))
795        result[:,0]   = cells
796        result[:,1:4] = self.positions.T
797        if rank() == 0:
798            file.write(result, {'population' : self.label})
799            file.close()
800
801
802class Population(BasePopulation):
803    """
804    A group of neurons all of the same type.
805    """
806    nPop = 0
807
808    def __init__(self, size, cellclass, cellparams=None, structure=None,
809                 label=None):
810        """
811        Create a population of neurons all of the same type.
812
813        size - number of cells in the Population. For backwards-compatibility,
814               n may also be a tuple giving the dimensions of a grid,
815               e.g. n=(10,10) is equivalent to n=100 with structure=Grid2D()
816        cellclass should either be a standardized cell class (a class inheriting
817        from common.standardmodels.StandardCellType) or a string giving the
818        name of the simulator-specific model that makes up the population.
819        cellparams should be a dict which is passed to the neuron model
820          constructor
821        structure should be a Structure instance.
822        label is an optional name for the population.
823        """
824        if not isinstance(size, int):  # also allow a single integer, for a 1D population
825            assert isinstance(size, tuple), "`size` must be an integer or a tuple of ints. You have supplied a %s" % type(size)
826            # check the things inside are ints
827            for e in size:
828                assert isinstance(e, int), "`size` must be an integer or a tuple of ints. Element '%s' is not an int" % str(e)
829
830            assert structure is None, "If you specify `size` as a tuple you may not specify structure."
831            if len(size) == 1:
832                structure = space.Line()
833            elif len(size) == 2:
834                nx, ny = size
835                structure = space.Grid2D(nx/float(ny))
836            elif len(size) == 3:
837                nx, ny, nz = size
838                structure = space.Grid3D(nx/float(ny), nx/float(nz))
839            else:
840                raise Exception("A maximum of 3 dimensions is allowed. What do you think this is, string theory?")
841            size = reduce(operator.mul, size)
842        self.size = size
843        self.label = label or 'population%d' % Population.nPop
844        self.celltype = cellclass(cellparams)
845        self._structure = structure or space.Line()
846        self._positions = None
847        # Build the arrays of cell ids
848        # Cells on the local node are represented as ID objects, other cells by integers
849        # All are stored in a single numpy array for easy lookup by address
850        # The local cells are also stored in a list, for easy iteration
851        self._create_cells(cellclass, cellparams, size)
852        self.initial_values = {}
853        for variable, value in self.celltype.default_initial_values.items():
854            self.initialize(variable, value)
855        self.recorders = {'spikes': self.recorder_class('spikes', population=self),
856                          'v'     : self.recorder_class('v', population=self),
857                          'gsyn'  : self.recorder_class('gsyn', population=self)}
858        Population.nPop += 1
859
860    @property
861    def local_cells(self):
862        return self.all_cells[self._mask_local]
863
864    @property
865    def cell(self):
866        warn("The `Population.cell` attribute is not an official part of the \
867              API, and its use is deprecated. It will be removed in a future \
868              release. All uses of `cell` may be replaced by `all_cells`")
869        return self.all_cells
870
871    def id_to_index(self, id):
872        """
873        Given the ID(s) of cell(s) in the Population, return its (their) index
874        (order in the Population).
875        >>> assert p.id_to_index(p[5]) == 5
876        >>> assert p.id_to_index(p.index([1,2,3])) == [1,2,3]
877        """
878        if not numpy.iterable(id):
879            if not self.first_id <= id <= self.last_id:
880                raise ValueError("id should be in the range [%d,%d], actually %d" % (self.first_id, self.last_id, id))
881            return int(id - self.first_id)  # this assumes ids are consecutive
882        else:
883            if isinstance(id, PopulationView):
884                id = id.all_cells
885            id = numpy.array(id)
886            if (self.first_id > id.min()) or (self.last_id < id.max()):
887                raise ValueError("ids should be in the range [%d,%d], actually [%d, %d]" % (self.first_id, self.last_id, id.min(), id.max()))
888            return (id - self.first_id).astype(int)  # this assumes ids are consecutive
889
890    def id_to_local_index(self, id):
891        if num_processes() > 1:
892            return self.local_cells.tolist().index(id)  # probably very slow
893        else:
894            return self.id_to_index(id)
895
896    def _get_structure(self):
897        return self._structure
898
899    def _set_structure(self, structure):
900        assert isinstance(structure, space.BaseStructure)
901        if structure != self._structure:
902            self._positions = None  # setting a new structure invalidates previously calculated positions
903            self._structure = structure
904    structure = property(fget=_get_structure, fset=_set_structure)
905    # arguably structure should be read-only, i.e. it is not possible to change it after Population creation
906
907    @property
908    def position_generator(self):
909        def gen(i):
910            return self.positions[:,i]
911        return gen
912
913    def _get_positions(self):
914        """
915        Try to return self._positions. If it does not exist, create it and then
916        return it.
917        """
918        if self._positions is None:
919            self._positions = self.structure.generate_positions(self.size)
920        assert self._positions.shape == (3, self.size)
921        return self._positions
922
923    def _set_positions(self, pos_array):
924        assert isinstance(pos_array, numpy.ndarray)
925        assert pos_array.shape == (3, self.size), "%s != %s" % (pos_array.shape, (3, self.size))
926        self._positions = pos_array.copy()  # take a copy in case pos_array is changed later
927        self._structure = None  # explicitly setting positions destroys any previous structure
928
929    positions = property(_get_positions, _set_positions,
930                         """A 3xN array (where N is the number of neurons in the Population)
931                         giving the x,y,z coordinates of all the neurons (soma, in the
932                         case of non-point models).""")
933
934    def describe(self, template='population_default.txt', engine='default'):
935        """
936        Returns a human-readable description of the population.
937
938        The output may be customized by specifying a different template
939        togther with an associated template engine (see ``pyNN.descriptions``).
940
941        If template is None, then a dictionary containing the template context
942        will be returned.
943        """
944        context = {
945            "label": self.label,
946            "celltype": self.celltype.describe(template=None),
947            "structure": None,
948            "size": self.size,
949            "size_local": len(self.local_cells),
950            "first_id": self.first_id,
951            "last_id": self.last_id,
952        }
953        if len(self.local_cells) > 0:
954            first_id = self.local_cells[0]
955            context.update({
956                "local_first_id": first_id,
957                "cell_parameters": first_id.get_parameters(),
958            })
959        if self.structure:
960            context["structure"] = self.structure.describe(template=None)
961        return descriptions.render(engine, template, context)
962
963
964class PopulationView(BasePopulation):
965
966    def __init__(self, parent, selector, label=None):
967        self.parent = parent
968        self.mask = selector # later we can have fancier selectors, for now we just have numpy masks             
969        self.label  = label or "view of %s with mask %s" % (parent.label, self.mask)
970        # maybe just redefine __getattr__ instead of the following...
971        self.celltype     = self.parent.celltype
972        # If the mask is a slice, IDs will be consecutives without duplication.
973        # If not, then we need to remove duplicated IDs
974        if not isinstance(self.mask, slice):
975            if isinstance(self.mask, list):
976                self.mask = numpy.array(self.mask)
977            if self.mask.dtype is numpy.dtype('bool'):
978                if len(self.mask) != len(self.parent):
979                    raise Exception("Boolean masks should have the size of Parent Population")
980                self.mask = numpy.arange(len(self.parent))[self.mask]     
981            if len(numpy.unique(self.mask)) != len(self.mask):
982                logging.warning("PopulationView can contain only once each ID, duplicated IDs are remove")
983                self.mask = numpy.unique(self.mask)
984        self.all_cells    = self.parent.all_cells[self.mask]  # do we need to ensure this is ordered?       
985        self.size         = len(self.all_cells)
986        self._mask_local  = self.parent._mask_local[self.mask]
987        self.local_cells  = self.all_cells[self._mask_local]
988        self.first_id     = numpy.min(self.all_cells) # only works if we assume all_cells is sorted, otherwise could use min()
989        self.last_id      = numpy.max(self.all_cells)
990        self.recorders    = self.parent.recorders
991        self.record_filter= self.all_cells
992
993    @property
994    def initial_values(self):
995        # this is going to be complex - if we keep initial_values as a dict,
996        # need to return a dict-like object that takes account of self.mask
997        raise NotImplementedError
998
999    @property
1000    def structure(self):
1001        return self.parent.structure
1002    # should we allow setting structure for a PopulationView? Maybe if the
1003    # parent has some kind of CompositeStructure?
1004
1005    @property
1006    def positions(self):
1007        return self.parent.positions.T[self.mask].# make positions N,3 instead of 3,N to avoid all this transposing?
1008
1009    def id_to_index(self, id):
1010        """
1011        Given the ID(s) of cell(s) in the PopulationView, return its/their
1012        index/indices (order in the PopulationView).
1013        >>> assert id_to_index(p.index(5)) == 5
1014        >>> assert id_to_index(p.index([1,2,3])) == [1,2,3]
1015        """
1016        if not numpy.iterable(id):
1017            result = numpy.where(self.all_cells == id)[0]
1018            if len(result) == 0:
1019                raise IndexError("ID %s not present in the View" %id)
1020            elif len(result) > 1:
1021                raise Exception("ID %s is duplicated in the View" %id)
1022            else:
1023                return result
1024        else:
1025            result = numpy.array([])
1026            for item in id:
1027                data = numpy.where(self.all_cells == item)[0]
1028                if len(data) == 0:
1029                    raise IndexError("ID %s not present in the View" %item)
1030                elif len(data) > 1:
1031                    raise Exception("ID %s is duplicated in the View" %item)
1032                else:
1033                    result = numpy.append(result, data)
1034            return result
1035       
1036    def describe(self, template='populationview_default.txt', engine='default'):
1037        """
1038        Returns a human-readable description of the population view.
1039
1040        The output may be customized by specifying a different template
1041        togther with an associated template engine (see ``pyNN.descriptions``).
1042
1043        If template is None, then a dictionary containing the template context
1044        will be returned.
1045        """
1046        context = {"label": self.label,
1047                   "parent": self.parent.label,
1048                   "mask": self.mask,
1049                   "size": self.size}
1050        return descriptions.render(engine, template, context)
1051
1052
1053# =============================================================================
1054
1055class Assembly(object):
1056    """
1057    A group of neurons, may be heterogeneous, in contrast to a Population where
1058    all the neurons are of the same type.
1059    """
1060    count = 0
1061
1062    def __init__(self, *populations, **kwargs):
1063        if kwargs:
1064            assert kwargs.keys() == ['label']
1065        self.populations = []
1066        for p in populations:
1067            self._insert(p)
1068        self.label = kwargs.get('label', 'assembly%d' % Assembly.count)
1069        assert isinstance(self.label, basestring), "label must be a string or unicode"
1070        Assembly.count += 1
1071
1072    def _insert(self, element):
1073        if not isinstance(element, BasePopulation):
1074            raise TypeError("argument is a %s, not a Population." % type(element).__name__)
1075        if isinstance(element, PopulationView):
1076            if not element.parent in self.populations:
1077                double = False
1078                for p in self.populations:
1079                    data = numpy.concatenate((p.all_cells, element.all_cells))
1080                    if len(numpy.unique(data))!= len(p.all_cells) + len(element.all_cells):
1081                        logging.warning('Adding a PopulationView to an Assembly containing elements already present is not posible')
1082                        double = True #Should we automatically remove duplicated IDs ?
1083                        break
1084                if not double:
1085                    self.populations.append(element)
1086            else:
1087                logging.warning('Adding a PopulationView to an Assembly when parent Population is there is not possible')
1088        elif isinstance(element, BasePopulation):
1089            if not element in self.populations:
1090                self.populations.append(element)
1091            else:
1092                logging.warning('Adding a Population twice in an Assembly is not possible')
1093
1094    @property
1095    def local_cells(self):
1096        result = self.populations[0].local_cells
1097        for p in self.populations[1:]:
1098            result = numpy.concatenate((result, p.local_cells))
1099        return result
1100
1101    @property
1102    def all_cells(self):
1103        result = self.populations[0].all_cells
1104        for p in self.populations[1:]:
1105            result = numpy.concatenate((result, p.all_cells))
1106        return result
1107       
1108    @property
1109    def _mask_local(self):
1110        result = self.populations[0]._mask_local
1111        for p in self.populations[1:]:
1112            result = numpy.concatenate((result, p._mask_local))
1113        return result
1114   
1115    @property
1116    def first_id(self):
1117        return numpy.min(self.all_cells)
1118       
1119    @property
1120    def last_id(self):
1121        return numpy.max(self.all_cells)
1122   
1123    def id_to_index(self, id):
1124        """
1125        Given the ID(s) of cell(s) in the Assembly, return its (their) index
1126        (order in the Assembly).
1127        >>> assert p.id_to_index(p[5]) == 5
1128        >>> assert p.id_to_index(p.index([1,2,3])) == [1,2,3]
1129        """
1130        all_cells = self.all_cells
1131        if not numpy.iterable(id):
1132            result = numpy.where(all_cells == id)[0]
1133            if len(result) == 0:
1134                raise IndexError("ID %s not present in the View" %id)
1135            elif len(result) > 1:
1136                raise Exception("ID %s is duplicated in the View" %id)
1137            else:
1138                return result
1139        else:
1140            result = numpy.array([])
1141            for item in id:
1142                data = numpy.where(all_cells == item)[0]
1143                if len(data) == 0:
1144                    raise IndexError("ID %s not present in the View" %item)
1145                elif len(data) > 1:
1146                    raise Exception("ID %s is duplicated in the View" %item)
1147                else:
1148                    result = numpy.append(result, data)
1149            return result
1150               
1151    @property
1152    def positions(self):
1153        result = self.populations[0].positions
1154        for p in self.populations[1:]:
1155            result = numpy.hstack((result, p.positions))
1156        return result
1157       
1158    @property
1159    def size(self):
1160        return sum(p.size for p in self.populations)
1161
1162    def __iter__(self):
1163        return chain(iter(p) for p in self.populations)
1164
1165    def __len__(self):
1166        """Return the total number of cells in the population (all nodes)."""
1167        return self.size
1168
1169    def __getitem__(self, index):
1170        if isinstance(index, int):
1171            return self.populations[index]
1172        elif isinstance(index, (slice, list, numpy.ndarray)):
1173            return Assembly(*self.populations[index])
1174        else:
1175            raise TypeError("indices must be integers, slices, lists, arrays, not %s" % type(index).__name__)
1176
1177    def __add__(self, other):
1178        if isinstance(other, BasePopulation):
1179            return Assembly(*(self.populations + [other]))
1180        elif isinstance(other, Assembly):
1181            return Assembly(*(self.populations + other.populations))
1182        else:
1183            raise TypeError("can only add a Population or another Assembly to an Assembly")
1184
1185    def __iadd__(self, other):
1186        if isinstance(other, BasePopulation):
1187            self._insert(other)
1188        elif isinstance(other, Assembly):
1189            for p in other.populations:
1190                self._insert(p)
1191        else:
1192            raise TypeError("can only add a Population or another Assembly to an Assembly")
1193        return self
1194       
1195    def initialize(self, variable, value):
1196        for p in self.populations:
1197            p.initialize(variable, value)
1198
1199    def _record(self, variable, to_file=True):
1200        # need to think about record_from
1201        for p in self.populations:
1202            p._record(variable, to_file)
1203
1204    def record(self, to_file=True):
1205        self._record('spikes', to_file)
1206
1207    def record_v(self, to_file=True):
1208        self._record('v', to_file)
1209
1210    def record_gsyn(self, to_file=True):
1211        self._record('gsyn', to_file)
1212
1213    def get_population(self, label):
1214        for p in self.populations:
1215            if label == p.label:
1216                return p
1217        raise KeyError("Assembly does not contain a population with the label %s" % label)
1218
1219    def save_positions(self, file):
1220        """
1221        Save positions to file. The output format is id x y z
1222        """
1223        # this should be rewritten to use self.positions and recording.files
1224        if isinstance(file, basestring):
1225            file = files.StandardTextFile(file, mode='w')
1226        cells  = self.all_cells
1227        result = numpy.empty((len(cells), 4))
1228        result[:,0]   = cells
1229        result[:,1:4] = self.positions.T
1230        if rank() == 0:
1231            file.write(result, {'assembly' : self.label})
1232            file.close()
1233
1234    @property
1235    def position_generator(self):
1236        def gen(i):
1237            return self.positions[:,i]
1238        return gen
1239
1240    def meanSpikeCount(self, gather=True):
1241        """
1242        Returns the mean number of spikes per neuron.
1243        """
1244        try:
1245            spike_counts = self[0].recorders['spikes'].count(gather, self[0].record_filter)
1246        except errors.NothingToWriteError:
1247            spike_counts = {}
1248        for p in self.populations[1:]:
1249            try:
1250                spike_counts.update(p.recorders['spikes'].count(gather, p.record_filter))
1251            except errors.NothingToWriteError:
1252                pass
1253        total_spikes = sum(spike_counts.values())
1254        if rank() == 0 or not gather:  # should maybe use allgather, and get the numbers on all nodes
1255            return float(total_spikes)/len(spike_counts)
1256        else:
1257            return numpy.nan
1258
1259    def get_v(self, gather=True, compatible_output=True):
1260        """
1261        Return a 2-column numpy array containing cell ids and Vm for
1262        recorded cells.
1263        """
1264        try:
1265            result = self[0].recorders['v'].get(gather, compatible_output, self[0].record_filter)
1266        except errors.NothingToWriteError:
1267            result = numpy.zeros((0, 3))           
1268        for p in self.populations[1:]:
1269            try:
1270                result = numpy.vstack((result, p.recorders['v'].get(gather, compatible_output, p.record_filter)))
1271            except errors.NothingToWriteError:
1272                pass
1273        return result
1274
1275    def get_gsyn(self, gather=True, compatible_output=True):
1276        """
1277        Return a 3-column numpy array containing cell ids and synaptic
1278        conductances for recorded cells.
1279        """
1280        try:
1281            result = self[0].recorders['gsyn'].get(gather, compatible_output, self[0].record_filter)
1282        except errors.NothingToWriteError:
1283            result = numpy.zeros((0, 4))
1284        for p in self.populations[1:]:
1285            try:
1286                result = numpy.vstack((result, p.recorders['gsyn'].get(gather, compatible_output, p.record_filter)))
1287            except errors.NothingToWriteError:
1288                pass
1289        return result
1290
1291    def get_spike_counts(self, gather=True):
1292        """
1293        Returns the number of spikes for each neuron.
1294        """
1295        try:
1296            spike_counts = self[0].recorders['spikes'].count(gather, self[0].record_filter)     
1297        except errors.NothingToWriteError:
1298            spike_counts = {}
1299        for p in self.populations[1:]:
1300            try:
1301                spike_counts.update(p.recorders['spikes'].count(gather, p.record_filter))
1302            except errors.NothingToWriteError:
1303                pass
1304        return spike_counts
1305
1306    def _print(self, file, variable, format, gather=True, compatible_output=True):
1307       
1308        ## First, we write all the individual data for the heterogeneous populations
1309        ## embedded within the Assembly. To speed things up, we write them in temporary
1310        ## folders as Numpy Binary objects
1311        tempdir   = tempfile.mkdtemp()
1312        filenames = {} 
1313        filename  = '%s/%s.%s' %(tempdir, self[0].label, variable)
1314        p_file    = files.NumpyBinaryFile(filename, mode='w')
1315        try:
1316            self[0].recorders[variable].write(p_file, gather, compatible_output, self[0].record_filter)
1317            filenames[self[0]] = (filename, True)       
1318        except errors.NothingToWriteError:
1319            filenames[self[O]] = (filename, False)       
1320        for p in self.populations[1:]:
1321            filename = '%s/%s.%s' %(tempdir, p.label, variable)
1322            p_file = files.NumpyBinaryFile(filename, mode='w')           
1323            try:
1324                p.recorders[variable].write(p_file, gather, compatible_output, p.record_filter)
1325                filenames[p] = (filename, True)
1326            except errors.NothingToWriteError:
1327                filenames[p] = (filename, False)
1328               
1329        ## Then we need to merge the previsouly written files into a single one, to be consistent
1330        ## with a Population object. Note that the header should be better considered.         
1331        metadata = {'variable'    : variable,
1332                    'size'        : self.size,
1333                    'label'       : self.label,
1334                    'populations' : ", ".join(["%s[%d-%d]" %(p.label, p.first_id, p.last_id) for p in self.populations]),
1335                    'first_id'    : self.first_id,
1336                    'last_id'     : self.last_id}
1337       
1338        metadata['dt'] = simulator.state.dt # note that this has to run on all nodes (at least for NEST)
1339        data = numpy.zeros(format)
1340        for pop in filenames.keys():
1341            if filenames[pop][1] is True:
1342                name     = filenames[pop][0]
1343                p_file   = files.NumpyBinaryFile(name, mode='r') 
1344                tmp_data = p_file.read()                   
1345                if compatible_output:
1346                    tmp_data[:, -1] = self.id_to_index(tmp_data[:,-1] + pop.first_id)
1347                data = numpy.vstack((data, tmp_data))
1348            os.remove(name)
1349        metadata['n'] = data.shape[0]             
1350        os.rmdir(tempdir)
1351       
1352        if isinstance(file, basestring):
1353            file = files.StandardTextFile(file, mode='w')
1354       
1355        if rank() == 0:
1356            file.write(data, metadata)
1357            file.close()
1358
1359
1360    def printSpikes(self, file, gather=True, compatible_output=True):
1361        """
1362        Write spike times to file.
1363
1364        file should be either a filename or a PyNN File object.
1365
1366        If compatible_output is True, the format is "spiketime cell_id",
1367        where cell_id is the index of the cell counting along rows and down
1368        columns (and the extension of that for 3-D).
1369        This allows easy plotting of a `raster' plot of spiketimes, with one
1370        line for each cell.
1371        The timestep, first id, last id, and number of data points per cell are
1372        written in a header, indicated by a '#' at the beginning of the line.
1373
1374        If compatible_output is False, the raw format produced by the simulator
1375        is used. This may be faster, since it avoids any post-processing of the
1376        spike files.
1377
1378        For parallel simulators, if gather is True, all data will be gathered
1379        to the master node and a single output file created there. Otherwise, a
1380        file will be written on each node, containing only the cells simulated
1381        on that node.
1382        """
1383        self._print(file, 'spikes', (0, 2), gather, compatible_output)
1384
1385    def print_v(self, file, gather=True, compatible_output=True):
1386        """
1387        Write membrane potential traces to file.
1388
1389        file should be either a filename or a PyNN File object.
1390
1391        If compatible_output is True, the format is "v cell_id",
1392        where cell_id is the index of the cell counting along rows and down
1393        columns (and the extension of that for 3-D).
1394        The timestep, first id, last id, and number of data points per cell are
1395        written in a header, indicated by a '#' at the beginning of the line.
1396
1397        If compatible_output is False, the raw format produced by the simulator
1398        is used. This may be faster, since it avoids any post-processing of the
1399        voltage files.
1400
1401        For parallel simulators, if gather is True, all data will be gathered
1402        to the master node and a single output file created there. Otherwise, a
1403        file will be written on each node, containing only the cells simulated
1404        on that node.
1405        """
1406        self._print(file, 'v', (0, 2), gather, compatible_output)
1407
1408    def print_gsyn(self, file, gather=True, compatible_output=True):
1409        """
1410        Write synaptic conductance traces to file.
1411
1412        file should be either a filename or a PyNN File object.
1413
1414        If compatible_output is True, the format is "t g cell_id",
1415        where cell_id is the index of the cell counting along rows and down
1416        columns (and the extension of that for 3-D).
1417        The timestep, first id, last id, and number of data points per cell are
1418        written in a header, indicated by a '#' at the beginning of the line.
1419
1420        If compatible_output is False, the raw format produced by the simulator
1421        is used. This may be faster, since it avoids any post-processing of the
1422        voltage files.
1423        """
1424        self._print(file, 'gsyn', (0, 3), gather, compatible_output)
1425
1426    def inject(self, current_source):
1427        """
1428        Connect a current source to all cells in the Population.
1429        """
1430        for p in self.populations:
1431            current_source.inject_into(p)
1432
1433    def describe(self, template='assembly_default.txt', engine='default'):
1434        """
1435        Returns a human-readable description of the assembly.
1436
1437        The output may be customized by specifying a different template
1438        togther with an associated template engine (see ``pyNN.descriptions``).
1439
1440        If template is None, then a dictionary containing the template context
1441        will be returned.
1442        """
1443        context = {"label": self.label,
1444                   "populations": [p.describe(template=None) for p in self.populations]}
1445        return descriptions.render(engine, template, context)
1446
1447# =============================================================================
1448
1449
1450class Projection(object):
1451    """
1452    A container for all the connections of a given type (same synapse type and
1453    plasticity mechanisms) between two populations, together with methods to
1454    set parameters of those connections, including of plasticity mechanisms.
1455    """
1456
1457    def __init__(self, presynaptic_neurons, postsynaptic_neurons, method,
1458                 source=None, target=None, synapse_dynamics=None,
1459                 label=None, rng=None):
1460        """
1461        presynaptic_neurons and postsynaptic_neurons - Population, PopulationView
1462                                                       or Assembly objects.
1463
1464        source - string specifying which attribute of the presynaptic cell
1465                 signals action potentials. This is only needed for
1466                 multicompartmental cells with branching axons or
1467                 dendrodendriticsynapses. All standard cells have a single
1468                 source, and this is the default.
1469
1470        target - string specifying which synapse on the postsynaptic cell to
1471                 connect to. For standard cells, this can be 'excitatory' or
1472                 'inhibitory'. For non-standard cells, it could be 'NMDA', etc.
1473                 If target is not given, the default values of 'excitatory' is
1474                 used.
1475
1476        method - a Connector object, encapsulating the algorithm to use for
1477                 connecting the neurons.
1478
1479        synapse_dynamics - a `standardmodels.SynapseDynamics` object specifying
1480                 which synaptic plasticity mechanisms to use.
1481
1482        rng - specify an RNG object to be used by the Connector.
1483        """
1484        for prefix, pop in zip(("pre", "post"),
1485                               (presynaptic_neurons, postsynaptic_neurons)):
1486            if not isinstance(pop, (BasePopulation, Assembly)):
1487                raise errors.ConnectionError("%ssynaptic_neurons must be a Population, PopulationView or Assembly, not a %s" % (prefix, type(pop)))
1488        self.pre    = presynaptic_neurons  #  } these really
1489        self.source = source               #  } should be
1490        self.post   = postsynaptic_neurons #  } read-only
1491        self.target = target               #  }
1492        self.label  = label
1493        if isinstance(rng, random.AbstractRNG):
1494            self.rng = rng
1495        elif rng is None:
1496            self.rng = random.NumpyRNG(seed=151985012)
1497        else:
1498            raise Exception("rng must be either None, or a subclass of pyNN.random.AbstractRNG")
1499        self._method = method
1500        self.synapse_dynamics = synapse_dynamics
1501        #self.connection = None # access individual connections. To be defined by child, simulator-specific classes
1502        self.weights = []
1503        if label is None:
1504            if self.pre.label and self.post.label:
1505                self.label = "%s→%s" % (self.pre.label, self.post.label)
1506        if self.synapse_dynamics:
1507            assert isinstance(self.synapse_dynamics, standardmodels.SynapseDynamics), \
1508              "The synapse_dynamics argument, if specified, must be a standardmodels.SynapseDynamics object, not a %s" % type(synapse_dynamics)
1509
1510    def __len__(self):
1511        """Return the total number of local connections."""
1512        return len(self.connection_manager)
1513
1514    def size(self, gather=True):
1515        """
1516        Return the total number of connections.
1517            - only local connections, if gather is False,
1518            - all connections, if gather is True (default)
1519        """
1520        if gather:
1521            n = len(self)
1522            return recording.mpi_sum(n)
1523        else:
1524            return len(self)
1525
1526    def __repr__(self):
1527        return 'Projection("%s")' % self.label
1528
1529    def __getitem__(self, i):
1530        return self.connection_manager[i]
1531
1532    # --- Methods for setting connection parameters ---------------------------
1533
1534    def setWeights(self, w):
1535        """
1536        w can be a single number, in which case all weights are set to this
1537        value, or a list/1D array of length equal to the number of connections
1538        in the projection, or a 2D array with the same dimensions as the
1539        connectivity matrix (as returned by `getWeights(format='array')`).
1540        Weights should be in nA for current-based and µS for conductance-based
1541        synapses.
1542        """
1543        # should perhaps add a "distribute" argument, for symmetry with "gather" in getWeights()
1544        # if post is an Assembly, some components might have cond-synapses, others curr, so need a more sophisticated check here
1545        w = check_weight(w, self.synapse_type, is_conductance(self.post.local_cells[0]))
1546        self.connection_manager.set('weight', w)
1547
1548    def randomizeWeights(self, rand_distr):
1549        """
1550        Set weights to random values taken from rand_distr.
1551        """
1552        # Arguably, we could merge this with set_weights just by detecting the
1553        # argument type. It could make for easier-to-read simulation code to
1554        # give it a separate name, though. Comments?
1555        self.setWeights(rand_distr.next(len(self)))
1556
1557    def setDelays(self, d):
1558        """
1559        d can be a single number, in which case all delays are set to this
1560        value, or a list/1D array of length equal to the number of connections
1561        in the projection, or a 2D array with the same dimensions as the
1562        connectivity matrix (as returned by `getDelays(format='array')`).
1563        """
1564        self.connection_manager.set('delay', d)
1565
1566    def randomizeDelays(self, rand_distr):
1567        """
1568        Set delays to random values taken from rand_distr.
1569        """
1570        self.setDelays(rand_distr.next(len(self)))
1571
1572    def setSynapseDynamics(self, param, value):
1573        """
1574        Set parameters of the dynamic synapses for all connections in this
1575        projection.
1576        """
1577        self.connection_manager.set(param, value)
1578
1579    def randomizeSynapseDynamics(self, param, rand_distr):
1580        """
1581        Set parameters of the synapse dynamics to values taken from rand_distr
1582        """
1583        self.setSynapseDynamics(param, rand_distr.next(len(self)))
1584
1585    # --- Methods for writing/reading information to/from file. ---------------
1586
1587    def getWeights(self, format='list', gather=True):
1588        """
1589        Get synaptic weights for all connections in this Projection.
1590
1591        Possible formats are: a list of length equal to the number of connections
1592        in the projection, a 2D weight array (with NaN for non-existent
1593        connections). Note that for the array format, if there is more than
1594        one connection between two cells, the summed weight will be given.
1595        """
1596        if gather:
1597            logger.error("getWeights() with gather=True not yet implemented")
1598        return self.connection_manager.get('weight', format)
1599
1600    def getDelays(self, format='list', gather=True):
1601        """
1602        Get synaptic delays for all connections in this Projection.
1603
1604        Possible formats are: a list of length equal to the number of connections
1605        in the projection, a 2D delay array (with NaN for non-existent
1606        connections).
1607        """
1608        if gather:
1609            logger.error("getDelays() with gather=True not yet implemented")
1610        return self.connection_manager.get('delay', format)
1611
1612    def getSynapseDynamics(self, parameter_name, format='list', gather=True):
1613        """
1614        Get parameters of the dynamic synapses for all connections in this
1615        Projection.
1616        """
1617        if gather:
1618            logger.error("getstandardmodels.SynapseDynamics() with gather=True not yet implemented")
1619        return self.connection_manager.get(parameter_name, format)
1620
1621    def saveConnections(self, file, gather=True, compatible_output=True):
1622        """
1623        Save connections to file in a format suitable for reading in with a
1624        FromFileConnector.
1625        """
1626       
1627        if isinstance(file, basestring):
1628            file = files.StandardTextFile(file, mode='w')
1629       
1630        lines = []
1631        if not compatible_output:
1632            for c in self.connections:
1633                lines.append([c.source, c.target, c.weight, c.delay])
1634        else:
1635            for c in self.connections: 
1636                lines.append([self.pre.id_to_index(c.source), self.post.id_to_index(c.target), c.weight, c.delay])
1637       
1638        if gather == True and num_processes() > 1:
1639            all_lines = { rank(): lines }
1640            all_lines = recording.gather_dict(all_lines)
1641            if rank() == 0:
1642                lines = reduce(operator.add, all_lines.values())
1643        elif num_processes() > 1:
1644            file.rename('%s.%d' % (file.name, rank()))
1645       
1646        logger.debug("--- Projection[%s].__saveConnections__() ---" % self.label)
1647       
1648        if gather == False or rank() == 0:
1649            file.write(lines, {'pre' : self.pre.label, 'post' : self.post.label})
1650            file.close()
1651
1652    def printWeights(self, file, format='list', gather=True):
1653        """
1654        Print synaptic weights to file. In the array format, zeros are printed
1655        for non-existent connections.
1656        """
1657        weights = self.getWeights(format=format, gather=gather)
1658       
1659        if isinstance(file, basestring):
1660            file = files.StandardTextFile(file, mode='w')
1661       
1662        if format == 'array':
1663            weights = numpy.where(numpy.isnan(weights), 0.0, weights)
1664        file.write(weights, {})
1665        file.close()   
1666
1667    def weightHistogram(self, min=None, max=None, nbins=10):
1668        """
1669        Return a histogram of synaptic weights.
1670        If min and max are not given, the minimum and maximum weights are
1671        calculated automatically.
1672        """
1673        # it is arguable whether functions operating on the set of weights
1674        # should be put here or in an external module.
1675        weights = self.getWeights(format='list', gather=True)
1676        if min is None:
1677            min = weights.min()
1678        if max is None:
1679            max = weights.max()
1680        bins = numpy.linspace(min, max, nbins+1)
1681        return numpy.histogram(weights, bins, new=True)  # returns n, bins
1682
1683    def describe(self, template='projection_default.txt', engine='default'):
1684        """
1685        Returns a human-readable description of the projection.
1686
1687        The output may be customized by specifying a different template
1688        togther with an associated template engine (see ``pyNN.descriptions``).
1689
1690        If template is None, then a dictionary containing the template context
1691        will be returned.
1692        """
1693        context = {
1694            "label": self.label,
1695            "pre": self.pre.describe(template=None),
1696            "post": self.post.describe(template=None),
1697            "source": self.source,
1698            "target": self.target,
1699            "size_local": len(self),
1700            "size": self.size(gather=True),
1701            "connector": self._method.describe(template=None),
1702            "plasticity": None,
1703        }
1704        if self.synapse_dynamics:
1705            context.update(plasticity=self.synapse_dynamics.describe(template=None))
1706        return descriptions.render(engine, template, context)
1707
1708
1709# =============================================================================
Note: See TracBrowser for help on using the browser.