| 1 | # encoding: utf-8 |
|---|
| 2 | """ |
|---|
| 3 | Defines a common implementation of the PyNN API. |
|---|
| 4 | |
|---|
| 5 | Simulator modules are not required to use any of the code herein, provided they |
|---|
| 6 | provide the correct interface, but it is suggested that they use as much as is |
|---|
| 7 | consistent with good performance (optimisations may require overriding some of |
|---|
| 8 | the default definitions given here). |
|---|
| 9 | |
|---|
| 10 | Utility functions and classes: |
|---|
| 11 | is_conductance() |
|---|
| 12 | check_weight() |
|---|
| 13 | check_delay() |
|---|
| 14 | |
|---|
| 15 | Accessing individual neurons: |
|---|
| 16 | IDMixin |
|---|
| 17 | |
|---|
| 18 | Common API implementation/base classes: |
|---|
| 19 | 1. Simulation set-up and control: |
|---|
| 20 | setup() |
|---|
| 21 | end() |
|---|
| 22 | run() |
|---|
| 23 | get_time_step() |
|---|
| 24 | get_current_time() |
|---|
| 25 | get_min_delay() |
|---|
| 26 | get_max_delay() |
|---|
| 27 | rank() |
|---|
| 28 | num_processes() |
|---|
| 29 | |
|---|
| 30 | 2. Creating, connecting and recording from individual neurons: |
|---|
| 31 | create() |
|---|
| 32 | connect() |
|---|
| 33 | set() |
|---|
| 34 | build_record() |
|---|
| 35 | |
|---|
| 36 | 3. Creating, connecting and recording from populations of neurons: |
|---|
| 37 | Population |
|---|
| 38 | Projection |
|---|
| 39 | |
|---|
| 40 | $Id$ |
|---|
| 41 | """ |
|---|
| 42 | |
|---|
| 43 | import numpy, os |
|---|
| 44 | import logging |
|---|
| 45 | from warnings import warn |
|---|
| 46 | import operator |
|---|
| 47 | import tempfile |
|---|
| 48 | from pyNN import random, recording, errors, standardmodels, core, space, descriptions |
|---|
| 49 | from pyNN.recording import files |
|---|
| 50 | from itertools import chain |
|---|
| 51 | if not 'simulator' in locals(): |
|---|
| 52 | simulator = None # should be set by simulator-specific modules |
|---|
| 53 | |
|---|
| 54 | DEFAULT_WEIGHT = 0.0 |
|---|
| 55 | DEFAULT_BUFFER_SIZE = 10000 |
|---|
| 56 | DEFAULT_MAX_DELAY = 10.0 |
|---|
| 57 | DEFAULT_TIMESTEP = 0.1 |
|---|
| 58 | DEFAULT_MIN_DELAY = DEFAULT_TIMESTEP |
|---|
| 59 | |
|---|
| 60 | logger = logging.getLogger("PyNN") |
|---|
| 61 | |
|---|
| 62 | # ============================================================================= |
|---|
| 63 | # Utility functions and classes |
|---|
| 64 | # ============================================================================= |
|---|
| 65 | |
|---|
| 66 | |
|---|
| 67 | def is_conductance(target_cell): |
|---|
| 68 | """ |
|---|
| 69 | Returns True if the target cell uses conductance-based synapses, False if |
|---|
| 70 | it uses current-based synapses, and None if the synapse-basis cannot be |
|---|
| 71 | determined. |
|---|
| 72 | """ |
|---|
| 73 | if hasattr(target_cell, 'local') and target_cell.local and hasattr(target_cell, 'celltype'): |
|---|
| 74 | is_conductance = target_cell.celltype.conductance_based |
|---|
| 75 | else: |
|---|
| 76 | is_conductance = None |
|---|
| 77 | return is_conductance |
|---|
| 78 | |
|---|
| 79 | |
|---|
| 80 | def check_weight(weight, synapse_type, is_conductance): |
|---|
| 81 | if weight is None: |
|---|
| 82 | weight = DEFAULT_WEIGHT |
|---|
| 83 | if core.is_listlike(weight): |
|---|
| 84 | weight = numpy.array(weight) |
|---|
| 85 | nan_filter = (1 - numpy.isnan(weight)).astype(bool) # weight arrays may contain NaN, which should be ignored |
|---|
| 86 | filtered_weight = weight[nan_filter] |
|---|
| 87 | all_negative = (filtered_weight <= 0).all() |
|---|
| 88 | all_positive = (filtered_weight >= 0).all() |
|---|
| 89 | if not (all_negative or all_positive): |
|---|
| 90 | raise errors.InvalidWeightError("Weights must be either all positive or all negative") |
|---|
| 91 | elif numpy.isreal(weight): |
|---|
| 92 | all_positive = weight >= 0 |
|---|
| 93 | all_negative = weight < 0 |
|---|
| 94 | else: |
|---|
| 95 | raise errors.InvalidWeightError("Weight must be a number or a list/array of numbers.") |
|---|
| 96 | if is_conductance or synapse_type == 'excitatory': |
|---|
| 97 | if not all_positive: |
|---|
| 98 | raise errors.InvalidWeightError("Weights must be positive for conductance-based and/or excitatory synapses") |
|---|
| 99 | elif is_conductance == False and synapse_type == 'inhibitory': |
|---|
| 100 | if not all_negative: |
|---|
| 101 | raise errors.InvalidWeightError("Weights must be negative for current-based, inhibitory synapses") |
|---|
| 102 | else: # is_conductance is None. This happens if the cell does not exist on the current node. |
|---|
| 103 | logger.debug("Can't check weight, conductance status unknown.") |
|---|
| 104 | return weight |
|---|
| 105 | |
|---|
| 106 | |
|---|
| 107 | def check_delay(delay): |
|---|
| 108 | if delay is None: |
|---|
| 109 | delay = get_min_delay() |
|---|
| 110 | # If the delay is too small , we have to throw an error |
|---|
| 111 | if delay < get_min_delay() or delay > get_max_delay(): |
|---|
| 112 | raise errors.ConnectionError("delay (%s) is out of range [%s,%s]" % \ |
|---|
| 113 | (delay, get_min_delay(), get_max_delay())) |
|---|
| 114 | return delay |
|---|
| 115 | |
|---|
| 116 | |
|---|
| 117 | # ============================================================================= |
|---|
| 118 | # Accessing individual neurons |
|---|
| 119 | # ============================================================================= |
|---|
| 120 | |
|---|
| 121 | class IDMixin(object): |
|---|
| 122 | """ |
|---|
| 123 | Instead of storing ids as integers, we store them as ID objects, |
|---|
| 124 | which allows a syntax like: |
|---|
| 125 | p[3,4].tau_m = 20.0 |
|---|
| 126 | where p is a Population object. |
|---|
| 127 | """ |
|---|
| 128 | # Simulator ID classes should inherit both from the base type of the ID |
|---|
| 129 | # (e.g., int or long) and from IDMixin. |
|---|
| 130 | |
|---|
| 131 | def __getattr__(self, name): |
|---|
| 132 | try: |
|---|
| 133 | val = self.__getattribute__(name) |
|---|
| 134 | except AttributeError: |
|---|
| 135 | if name == "parent": |
|---|
| 136 | raise Exception("parent is not set") |
|---|
| 137 | try: |
|---|
| 138 | val = self.get_parameters()[name] |
|---|
| 139 | except KeyError: |
|---|
| 140 | raise errors.NonExistentParameterError(name, |
|---|
| 141 | self.celltype.__class__.__name__, |
|---|
| 142 | self.celltype.get_parameter_names()) |
|---|
| 143 | return val |
|---|
| 144 | |
|---|
| 145 | def __setattr__(self, name, value): |
|---|
| 146 | if name == "parent": |
|---|
| 147 | object.__setattr__(self, name, value) |
|---|
| 148 | elif self.celltype.has_parameter(name): |
|---|
| 149 | self.set_parameters(**{name: value}) |
|---|
| 150 | else: |
|---|
| 151 | object.__setattr__(self, name, value) |
|---|
| 152 | |
|---|
| 153 | def set_parameters(self, **parameters): |
|---|
| 154 | """ |
|---|
| 155 | Set cell parameters, given as a sequence of parameter=value arguments. |
|---|
| 156 | """ |
|---|
| 157 | # if some of the parameters are computed from the values of other |
|---|
| 158 | # parameters, need to get and translate all parameters |
|---|
| 159 | if self.local: |
|---|
| 160 | if self.is_standard_cell: |
|---|
| 161 | computed_parameters = self.celltype.computed_parameters() |
|---|
| 162 | have_computed_parameters = numpy.any([p_name in computed_parameters |
|---|
| 163 | for p_name in parameters]) |
|---|
| 164 | if have_computed_parameters: |
|---|
| 165 | all_parameters = self.get_parameters() |
|---|
| 166 | all_parameters.update(parameters) |
|---|
| 167 | parameters = all_parameters |
|---|
| 168 | parameters = self.celltype.translate(parameters) |
|---|
| 169 | self.set_native_parameters(parameters) |
|---|
| 170 | else: |
|---|
| 171 | raise errors.NotLocalError("Cannot set parameters for a cell that does not exist on this node.") |
|---|
| 172 | |
|---|
| 173 | def get_parameters(self): |
|---|
| 174 | """Return a dict of all cell parameters.""" |
|---|
| 175 | if self.local: |
|---|
| 176 | parameters = self.get_native_parameters() |
|---|
| 177 | if self.is_standard_cell: |
|---|
| 178 | parameters = self.celltype.reverse_translate(parameters) |
|---|
| 179 | return parameters |
|---|
| 180 | else: |
|---|
| 181 | raise errors.NotLocalError("Cannot obtain parameters for a cell that does not exist on this node.") |
|---|
| 182 | |
|---|
| 183 | @property |
|---|
| 184 | def celltype(self): |
|---|
| 185 | return self.parent.celltype |
|---|
| 186 | |
|---|
| 187 | @property |
|---|
| 188 | def is_standard_cell(self): |
|---|
| 189 | return issubclass(self.celltype.__class__, standardmodels.StandardCellType) |
|---|
| 190 | |
|---|
| 191 | def _set_position(self, pos): |
|---|
| 192 | """ |
|---|
| 193 | Set the cell position in 3D space. |
|---|
| 194 | |
|---|
| 195 | Cell positions are stored in an array in the parent Population. |
|---|
| 196 | """ |
|---|
| 197 | assert isinstance(pos, (tuple, numpy.ndarray)) |
|---|
| 198 | assert len(pos) == 3 |
|---|
| 199 | self.parent._set_cell_position(self, pos) |
|---|
| 200 | |
|---|
| 201 | def _get_position(self): |
|---|
| 202 | """ |
|---|
| 203 | Return the cell position in 3D space. |
|---|
| 204 | |
|---|
| 205 | Cell positions are stored in an array in the parent Population, if any, |
|---|
| 206 | or within the ID object otherwise. Positions are generated the first |
|---|
| 207 | time they are requested and then cached. |
|---|
| 208 | """ |
|---|
| 209 | return self.parent._get_cell_position(self) |
|---|
| 210 | |
|---|
| 211 | position = property(_get_position, _set_position) |
|---|
| 212 | |
|---|
| 213 | @property |
|---|
| 214 | def local(self): |
|---|
| 215 | return self.parent.is_local(self) |
|---|
| 216 | |
|---|
| 217 | def inject(self, current_source): |
|---|
| 218 | """Inject current from a current source object into the cell.""" |
|---|
| 219 | current_source.inject_into([self]) |
|---|
| 220 | |
|---|
| 221 | def get_initial_value(self, variable): |
|---|
| 222 | """Get the initial value of a state variable of the cell.""" |
|---|
| 223 | return self.parent._get_cell_initial_value(self, variable) |
|---|
| 224 | |
|---|
| 225 | def set_initial_value(self, variable, value): |
|---|
| 226 | """Set the initial value of a state variable of the cell.""" |
|---|
| 227 | self.parent._set_cell_initial_value(self, variable, value) |
|---|
| 228 | |
|---|
| 229 | |
|---|
| 230 | # ============================================================================= |
|---|
| 231 | # Functions for simulation set-up and control |
|---|
| 232 | # ============================================================================= |
|---|
| 233 | |
|---|
| 234 | |
|---|
| 235 | def setup(timestep=DEFAULT_TIMESTEP, min_delay=DEFAULT_MIN_DELAY, |
|---|
| 236 | max_delay=DEFAULT_MAX_DELAY, **extra_params): |
|---|
| 237 | """ |
|---|
| 238 | Should be called at the very beginning of a script. |
|---|
| 239 | extra_params contains any keyword arguments that are required by a given |
|---|
| 240 | simulator but not by others. |
|---|
| 241 | """ |
|---|
| 242 | invalid_extra_params = ('mindelay', 'maxdelay', 'dt') |
|---|
| 243 | for param in invalid_extra_params: |
|---|
| 244 | if param in extra_params: |
|---|
| 245 | raise Exception("%s is not a valid argument for setup()" % param) |
|---|
| 246 | if min_delay > max_delay: |
|---|
| 247 | raise Exception("min_delay has to be less than or equal to max_delay.") |
|---|
| 248 | if min_delay < timestep: |
|---|
| 249 | raise Exception("min_delay (%g) must be greater than timestep (%g)" % (min_delay, timestep)) |
|---|
| 250 | |
|---|
| 251 | def end(compatible_output=True): |
|---|
| 252 | """Do any necessary cleaning up before exiting.""" |
|---|
| 253 | raise NotImplementedError |
|---|
| 254 | |
|---|
| 255 | def run(simtime): |
|---|
| 256 | """Run the simulation for simtime ms.""" |
|---|
| 257 | raise NotImplementedError |
|---|
| 258 | |
|---|
| 259 | def reset(): |
|---|
| 260 | """ |
|---|
| 261 | Reset the time to zero, neuron membrane potentials and synaptic weights to |
|---|
| 262 | their initial values, and delete any recorded data. The network structure |
|---|
| 263 | is not changed, nor is the specification of which neurons to record from. |
|---|
| 264 | """ |
|---|
| 265 | simulator.reset() |
|---|
| 266 | |
|---|
| 267 | def initialize(cells, variable, value): |
|---|
| 268 | assert isinstance(cells, (BasePopulation, Assembly)), type(cells) |
|---|
| 269 | cells.initialize(variable, value) |
|---|
| 270 | |
|---|
| 271 | def get_current_time(): |
|---|
| 272 | """Return the current time in the simulation.""" |
|---|
| 273 | return simulator.state.t |
|---|
| 274 | |
|---|
| 275 | def get_time_step(): |
|---|
| 276 | """Return the integration time step.""" |
|---|
| 277 | return simulator.state.dt |
|---|
| 278 | |
|---|
| 279 | def get_min_delay(): |
|---|
| 280 | """Return the minimum allowed synaptic delay.""" |
|---|
| 281 | return simulator.state.min_delay |
|---|
| 282 | |
|---|
| 283 | def get_max_delay(): |
|---|
| 284 | """Return the maximum allowed synaptic delay.""" |
|---|
| 285 | return simulator.state.max_delay |
|---|
| 286 | |
|---|
| 287 | def num_processes(): |
|---|
| 288 | """Return the number of MPI processes.""" |
|---|
| 289 | return simulator.state.num_processes |
|---|
| 290 | |
|---|
| 291 | def rank(): |
|---|
| 292 | """Return the MPI rank of the current node.""" |
|---|
| 293 | return simulator.state.mpi_rank |
|---|
| 294 | |
|---|
| 295 | # ============================================================================= |
|---|
| 296 | # Low-level API for creating, connecting and recording from individual neurons |
|---|
| 297 | # ============================================================================= |
|---|
| 298 | |
|---|
| 299 | def build_create(population_class): |
|---|
| 300 | def create(cellclass, cellparams=None, n=1): |
|---|
| 301 | """ |
|---|
| 302 | Create n cells all of the same type. |
|---|
| 303 | |
|---|
| 304 | If n > 1, return a list of cell ids/references. |
|---|
| 305 | If n==1, return just the single id. |
|---|
| 306 | """ |
|---|
| 307 | return population_class(n, cellclass, cellparams) # return the Population or Population.all_cells? |
|---|
| 308 | return create |
|---|
| 309 | |
|---|
| 310 | def build_connect(projection_class, connector_class): |
|---|
| 311 | def connect(source, target, weight=0.0, delay=None, synapse_type=None, |
|---|
| 312 | p=1, rng=None): |
|---|
| 313 | """ |
|---|
| 314 | Connect a source of spikes to a synaptic target. |
|---|
| 315 | |
|---|
| 316 | source and target can both be individual cells or lists of cells, in |
|---|
| 317 | which case all possible connections are made with probability p, using |
|---|
| 318 | either the random number generator supplied, or the default rng |
|---|
| 319 | otherwise. Weights should be in nA or µS. |
|---|
| 320 | """ |
|---|
| 321 | if isinstance(source, IDMixin): |
|---|
| 322 | source = source.parent |
|---|
| 323 | if isinstance(target, IDMixin): |
|---|
| 324 | target = target.parent |
|---|
| 325 | connector = connector_class(p_connect=p, weights=weight, delays=delay) |
|---|
| 326 | return projection_class(source, target, connector, target=synapse_type, rng=rng) |
|---|
| 327 | return connect |
|---|
| 328 | |
|---|
| 329 | def set(cells, param, val=None): |
|---|
| 330 | """ |
|---|
| 331 | Set one or more parameters of an individual cell or list of cells. |
|---|
| 332 | |
|---|
| 333 | param can be a dict, in which case val should not be supplied, or a string |
|---|
| 334 | giving the parameter name, in which case val is the parameter value. |
|---|
| 335 | """ |
|---|
| 336 | assert isinstance(cells, (BasePopulation, Assembly)) |
|---|
| 337 | cells.set(param, val) |
|---|
| 338 | |
|---|
| 339 | def build_record(variable, simulator): |
|---|
| 340 | def record(source, filename): |
|---|
| 341 | """ |
|---|
| 342 | Record spikes to a file. source can be an individual cell or a list of |
|---|
| 343 | cells. |
|---|
| 344 | """ |
|---|
| 345 | # would actually like to be able to record to an array and choose later |
|---|
| 346 | # whether to write to a file. |
|---|
| 347 | assert isinstance(source, (BasePopulation, Assembly)) |
|---|
| 348 | source._record(variable, to_file=filename) |
|---|
| 349 | if isinstance(source, BasePopulation): |
|---|
| 350 | simulator.recorder_list.append(source.recorders[variable]) # this is a bit hackish - better to add to Population.__del__? |
|---|
| 351 | if isinstance(source, Assembly): |
|---|
| 352 | for population in source.populations: |
|---|
| 353 | simulator.recorder_list.append(population.recorders[variable]) |
|---|
| 354 | if variable == 'v': |
|---|
| 355 | record.__doc__ = """ |
|---|
| 356 | Record membrane potential to a file. source can be an individual cell or |
|---|
| 357 | a list of cells.""" |
|---|
| 358 | elif variable == 'gsyn': |
|---|
| 359 | record.__doc__ = """ |
|---|
| 360 | Record synaptic conductances to a file. source can be an individual cell |
|---|
| 361 | or a list of cells.""" |
|---|
| 362 | return record |
|---|
| 363 | |
|---|
| 364 | |
|---|
| 365 | # ============================================================================= |
|---|
| 366 | # High-level API for creating, connecting and recording from populations of |
|---|
| 367 | # neurons. |
|---|
| 368 | # ============================================================================= |
|---|
| 369 | |
|---|
| 370 | class BasePopulation(object): |
|---|
| 371 | record_filter = None |
|---|
| 372 | |
|---|
| 373 | def __getitem__(self, index): |
|---|
| 374 | """ |
|---|
| 375 | Return a representation of the cell with the given index, |
|---|
| 376 | suitable for being passed to other methods that require a cell id. |
|---|
| 377 | Note that __getitem__ is called when using [] access, e.g. |
|---|
| 378 | p = Population(...) |
|---|
| 379 | p[2] is equivalent to p.__getitem__(2). |
|---|
| 380 | Also accepts slices, e.g. |
|---|
| 381 | p[3:6] |
|---|
| 382 | which returns an array of cells. |
|---|
| 383 | """ |
|---|
| 384 | if isinstance(index, int): |
|---|
| 385 | return self.all_cells[index] |
|---|
| 386 | elif isinstance(index, (slice, list, numpy.ndarray)): |
|---|
| 387 | return PopulationView(self, index) |
|---|
| 388 | elif isinstance(index, tuple): |
|---|
| 389 | return PopulationView(self, list(index)) |
|---|
| 390 | else: |
|---|
| 391 | raise TypeError("indices must be integers, slices, lists, arrays or tuples, not %s" % type(index).__name__) |
|---|
| 392 | |
|---|
| 393 | def __len__(self): |
|---|
| 394 | """Return the total number of cells in the population (all nodes).""" |
|---|
| 395 | return self.size |
|---|
| 396 | |
|---|
| 397 | def __iter__(self): |
|---|
| 398 | """Iterator over cell ids on the local node.""" |
|---|
| 399 | return iter(self.local_cells) |
|---|
| 400 | |
|---|
| 401 | def is_local(self, id): |
|---|
| 402 | assert id.parent is self |
|---|
| 403 | index = self.id_to_index(id) |
|---|
| 404 | return self._mask_local[index] |
|---|
| 405 | |
|---|
| 406 | def all(self): |
|---|
| 407 | """Iterator over cell ids on all nodes.""" |
|---|
| 408 | return iter(self.all_cells) |
|---|
| 409 | |
|---|
| 410 | def __add__(self, other): |
|---|
| 411 | assert isinstance(other, BasePopulation) |
|---|
| 412 | return Assembly(self, other) |
|---|
| 413 | |
|---|
| 414 | def _get_cell_position(self, id): |
|---|
| 415 | index = self.id_to_index(id) |
|---|
| 416 | return self.positions[:, index] |
|---|
| 417 | |
|---|
| 418 | def _set_cell_position(self, id, pos): |
|---|
| 419 | index = self.id_to_index(id) |
|---|
| 420 | self.positions[:, index] = pos |
|---|
| 421 | |
|---|
| 422 | def _get_cell_initial_value(self, id, variable): |
|---|
| 423 | assert isinstance(self.initial_values[variable], core.LazyArray) |
|---|
| 424 | index = self.id_to_index(id) |
|---|
| 425 | return self.initial_values[variable][index] |
|---|
| 426 | |
|---|
| 427 | def _set_cell_initial_value(self, id, variable, value): |
|---|
| 428 | assert isinstance(self.initial_values[variable], core.LazyArray) |
|---|
| 429 | index = self.id_to_index(id) |
|---|
| 430 | self.initial_values[variable][index] = value |
|---|
| 431 | |
|---|
| 432 | def nearest(self, position): |
|---|
| 433 | """Return the neuron closest to the specified position.""" |
|---|
| 434 | # doesn't always work correctly if a position is equidistant between |
|---|
| 435 | # two neurons, i.e. 0.5 should be rounded up, but it isn't always. |
|---|
| 436 | # also doesn't take account of periodic boundary conditions |
|---|
| 437 | pos = numpy.array([position] * self.positions.shape[1]).transpose() |
|---|
| 438 | dist_arr = (self.positions - pos)**2 |
|---|
| 439 | distances = dist_arr.sum(axis=0) |
|---|
| 440 | nearest = distances.argmin() |
|---|
| 441 | return self[nearest] |
|---|
| 442 | |
|---|
| 443 | def sample(self, n, rng=None): |
|---|
| 444 | """ |
|---|
| 445 | Randomly sample n cells from the Population, and return a PopulationView |
|---|
| 446 | object. |
|---|
| 447 | """ |
|---|
| 448 | assert isinstance(n, int) |
|---|
| 449 | if not rng: |
|---|
| 450 | rng = random.NumpyRNG() |
|---|
| 451 | indices = rng.permutation(numpy.arange(len(self)))[0:n] |
|---|
| 452 | logger.debug("The %d cells recorded have indices %s" % (n, indices)) |
|---|
| 453 | logger.debug("%s.sample(%s)", self.label, n) |
|---|
| 454 | return PopulationView(self, indices) |
|---|
| 455 | |
|---|
| 456 | def get(self, parameter_name, gather=False): |
|---|
| 457 | """ |
|---|
| 458 | Get the values of a parameter for every local cell in the population. |
|---|
| 459 | """ |
|---|
| 460 | # if all the cells have the same value for this parameter, should |
|---|
| 461 | # we return just the number, rather than an array? |
|---|
| 462 | |
|---|
| 463 | if hasattr(self, "_get_array"): |
|---|
| 464 | values = self._get_array(parameter_name).tolist() |
|---|
| 465 | else: |
|---|
| 466 | values = [getattr(cell, parameter_name) for cell in self] # list or array? |
|---|
| 467 | |
|---|
| 468 | if gather == True and num_processes() > 1: |
|---|
| 469 | all_values = { rank(): values } |
|---|
| 470 | all_indices = { rank(): self.local_cells.tolist()} |
|---|
| 471 | all_values = recording.gather_dict(all_values) |
|---|
| 472 | all_indices = recording.gather_dict(all_indices) |
|---|
| 473 | if rank() == 0: |
|---|
| 474 | values = reduce(operator.add, all_values.values()) |
|---|
| 475 | indices = reduce(operator.add, all_indices.values()) |
|---|
| 476 | idx = numpy.argsort(indices) |
|---|
| 477 | values = numpy.array(values)[idx] |
|---|
| 478 | return values |
|---|
| 479 | |
|---|
| 480 | def set(self, param, val=None): |
|---|
| 481 | """ |
|---|
| 482 | Set one or more parameters for every cell in the population. param |
|---|
| 483 | can be a dict, in which case val should not be supplied, or a string |
|---|
| 484 | giving the parameter name, in which case val is the parameter value. |
|---|
| 485 | val can be a numeric value, or list of such (e.g. for setting spike |
|---|
| 486 | times). |
|---|
| 487 | e.g. p.set("tau_m",20.0). |
|---|
| 488 | p.set({'tau_m':20,'v_rest':-65}) |
|---|
| 489 | """ |
|---|
| 490 | #""" |
|---|
| 491 | # -- Proposed change to arguments -- |
|---|
| 492 | #Set one or more parameters for every cell in the population. |
|---|
| 493 | # |
|---|
| 494 | #Each value may be a single number or a list/array of numbers of the same |
|---|
| 495 | #size as the population. If the parameter itself takes lists/arrays as |
|---|
| 496 | #values (e.g. spike times), then the value provided may be either a |
|---|
| 497 | #single lists/1D array, a list of lists/1D arrays, or a 2D array. |
|---|
| 498 | # |
|---|
| 499 | #e.g. p.set(tau_m=20.0). |
|---|
| 500 | # p.set(tau_m=20, v_rest=[-65.0, -65.3, ... , -67.2]) |
|---|
| 501 | #""" |
|---|
| 502 | if isinstance(param, str): |
|---|
| 503 | param_dict = {param: val} |
|---|
| 504 | elif isinstance(param, dict): |
|---|
| 505 | param_dict = param |
|---|
| 506 | else: |
|---|
| 507 | raise errors.InvalidParameterValueError |
|---|
| 508 | for name, val in param_dict.items(): |
|---|
| 509 | if name not in self.celltype.get_parameter_names(): |
|---|
| 510 | raise errors.NonExistentParameterError(name, self.celltype, self.celltype.get_parameter_names()) |
|---|
| 511 | if isinstance(val, (float, int)): |
|---|
| 512 | param_dict[name] = float(val) |
|---|
| 513 | elif isinstance(val, (list, numpy.ndarray)): |
|---|
| 514 | pass # ought to check list/array only contains numeric types |
|---|
| 515 | else: |
|---|
| 516 | raise errors.InvalidParameterValueError |
|---|
| 517 | logger.debug("%s.set(%s)", self.label, param_dict) |
|---|
| 518 | if hasattr(self, "_set_array"): |
|---|
| 519 | self._set_array(**param_dict) |
|---|
| 520 | else: |
|---|
| 521 | for cell in self: |
|---|
| 522 | cell.set_parameters(**param_dict) |
|---|
| 523 | |
|---|
| 524 | def tset(self, parametername, value_array): |
|---|
| 525 | """ |
|---|
| 526 | 'Topographic' set. Set the value of parametername to the values in |
|---|
| 527 | value_array, which must have the same dimensions as the Population. |
|---|
| 528 | """ |
|---|
| 529 | #""" |
|---|
| 530 | # -- Proposed change to arguments -- |
|---|
| 531 | #'Topographic' set. Each value in parameters should be a function that |
|---|
| 532 | #accepts arguments x,y,z and returns a single value. |
|---|
| 533 | #""" |
|---|
| 534 | if parametername not in self.celltype.get_parameter_names(): |
|---|
| 535 | raise errors.NonExistentParameterError(parametername, self.celltype, self.celltype.get_parameter_names()) |
|---|
| 536 | if (self.size,) == value_array.shape: # the values are numbers or non-array objects |
|---|
| 537 | local_values = value_array[self._mask_local] |
|---|
| 538 | assert local_values.size == self.local_cells.size, "%d != %d" % (local_values.size, self.local_cells.size) |
|---|
| 539 | elif len(value_array.shape) == 2: # the values are themselves 1D arrays |
|---|
| 540 | if value_array.shape[0] != self.size: |
|---|
| 541 | raise errors.InvalidDimensionsError("Population: %d, value_array first dimension: %s" % (self.size, |
|---|
| 542 | value_array.shape[0])) |
|---|
| 543 | local_values = value_array[self._mask_local] # not sure this works |
|---|
| 544 | else: |
|---|
| 545 | raise errors.InvalidDimensionsError("Population: %d, value_array: %s" % (self.size, |
|---|
| 546 | str(value_array.shape))) |
|---|
| 547 | assert local_values.shape[0] == self.local_cells.size, "%d != %d" % (local_values.size, self.local_cells.size) |
|---|
| 548 | |
|---|
| 549 | try: |
|---|
| 550 | logger.debug("%s.tset('%s', array(shape=%s, min=%s, max=%s))", |
|---|
| 551 | self.label, parametername, value_array.shape, |
|---|
| 552 | value_array.min(), value_array.max()) |
|---|
| 553 | except TypeError: # min() and max() won't work for non-numeric values |
|---|
| 554 | logger.debug("%s.tset('%s', non_numeric_array(shape=%s))", |
|---|
| 555 | self.label, parametername, value_array.shape) |
|---|
| 556 | |
|---|
| 557 | # Set the values for each cell |
|---|
| 558 | if hasattr(self, "_set_array"): |
|---|
| 559 | self._set_array(**{parametername: local_values}) |
|---|
| 560 | else: |
|---|
| 561 | for cell, val in zip(self, local_values): |
|---|
| 562 | setattr(cell, parametername, val) |
|---|
| 563 | |
|---|
| 564 | def rset(self, parametername, rand_distr): |
|---|
| 565 | """ |
|---|
| 566 | 'Random' set. Set the value of parametername to a value taken from |
|---|
| 567 | rand_distr, which should be a RandomDistribution object. |
|---|
| 568 | """ |
|---|
| 569 | # Note that we generate enough random numbers for all cells on all nodes |
|---|
| 570 | # but use only those relevant to this node. This ensures that the |
|---|
| 571 | # sequence of random numbers does not depend on the number of nodes, |
|---|
| 572 | # provided that the same rng with the same seed is used on each node. |
|---|
| 573 | logger.debug("%s.rset('%s', %s)", self.label, parametername, rand_distr) |
|---|
| 574 | if isinstance(rand_distr.rng, random.NativeRNG): |
|---|
| 575 | self._native_rset(parametername, rand_distr) |
|---|
| 576 | else: |
|---|
| 577 | rarr = rand_distr.next(n=self.all_cells.size, mask_local=False) |
|---|
| 578 | rarr = numpy.array(rarr) # isn't rarr already an array? |
|---|
| 579 | assert rarr.size == self.size, "%s != %s" % (rarr.size, self.size) |
|---|
| 580 | self.tset(parametername, rarr) |
|---|
| 581 | |
|---|
| 582 | def _call(self, methodname, arguments): |
|---|
| 583 | """ |
|---|
| 584 | Call the method methodname(arguments) for every cell in the population. |
|---|
| 585 | e.g. p.call("set_background","0.1") if the cell class has a method |
|---|
| 586 | set_background(). |
|---|
| 587 | """ |
|---|
| 588 | raise NotImplementedError() |
|---|
| 589 | |
|---|
| 590 | def _tcall(self, methodname, objarr): |
|---|
| 591 | """ |
|---|
| 592 | `Topographic' call. Call the method methodname() for every cell in the |
|---|
| 593 | population. The argument to the method depends on the coordinates of |
|---|
| 594 | the cell. objarr is an array with the same dimensions as the |
|---|
| 595 | Population. |
|---|
| 596 | e.g. p.tcall("memb_init", vinitArray) calls |
|---|
| 597 | p.cell[i][j].memb_init(vInitArray[i][j]) for all i,j. |
|---|
| 598 | """ |
|---|
| 599 | raise NotImplementedError() |
|---|
| 600 | |
|---|
| 601 | def randomInit(self, rand_distr): |
|---|
| 602 | """ |
|---|
| 603 | Set initial membrane potentials for all the cells in the population to |
|---|
| 604 | random values. |
|---|
| 605 | """ |
|---|
| 606 | warn("The randomInit() method is deprecated, and will be removed in a future release. Use initialize('v', rand_distr) instead.") |
|---|
| 607 | self.initialize('v', rand_distr) |
|---|
| 608 | |
|---|
| 609 | def initialize(self, variable, value): |
|---|
| 610 | """ |
|---|
| 611 | Set initial values of state variables, e.g. the membrane potential. |
|---|
| 612 | |
|---|
| 613 | `value` may either be a numeric value (all neurons set to the same |
|---|
| 614 | value) or a `RandomDistribution` object (each neuron gets a |
|---|
| 615 | different value) |
|---|
| 616 | """ |
|---|
| 617 | if isinstance(value, random.RandomDistribution): |
|---|
| 618 | initial_value = value.next(n=self.all_cells.size, mask_local=self._mask_local) |
|---|
| 619 | else: |
|---|
| 620 | initial_value = value |
|---|
| 621 | self.initial_values[variable] = core.LazyArray(initial_value, shape=(self.size,)) |
|---|
| 622 | if hasattr(self, "_set_initial_value_array"): |
|---|
| 623 | self._set_initial_value_array(variable, initial_value) |
|---|
| 624 | else: |
|---|
| 625 | if isinstance(value, random.RandomDistribution): |
|---|
| 626 | for cell, val in zip(self, initial_value): |
|---|
| 627 | cell.set_initial_value(variable, val) |
|---|
| 628 | else: |
|---|
| 629 | for cell in self: # only on local node |
|---|
| 630 | cell.set_initial_value(variable, initial_value) |
|---|
| 631 | |
|---|
| 632 | def can_record(self, variable): |
|---|
| 633 | """Determine whether `variable` can be recorded from this population.""" |
|---|
| 634 | return (variable in self.celltype.recordable) |
|---|
| 635 | |
|---|
| 636 | def _record(self, variable, to_file=True): |
|---|
| 637 | """ |
|---|
| 638 | Private method called by record() and record_v(). |
|---|
| 639 | """ |
|---|
| 640 | if not self.can_record(variable): |
|---|
| 641 | raise errors.RecordingError(variable, self.celltype) |
|---|
| 642 | logger.debug("%s.record('%s')", self.label, variable) |
|---|
| 643 | if self.record_filter is not None: |
|---|
| 644 | self.recorders[variable].record(self.record_filter) |
|---|
| 645 | else: |
|---|
| 646 | self.recorders[variable].record(self.all_cells) |
|---|
| 647 | if isinstance(to_file, basestring): |
|---|
| 648 | self.recorders[variable].file = to_file |
|---|
| 649 | |
|---|
| 650 | def record(self, to_file=True): |
|---|
| 651 | """ |
|---|
| 652 | Record spikes from all cells in the Population. |
|---|
| 653 | """ |
|---|
| 654 | self._record('spikes', to_file) |
|---|
| 655 | |
|---|
| 656 | def record_v(self, to_file=True): |
|---|
| 657 | """ |
|---|
| 658 | Record the membrane potential for all cells in the Population. |
|---|
| 659 | """ |
|---|
| 660 | self._record('v', to_file) |
|---|
| 661 | |
|---|
| 662 | def record_gsyn(self, to_file=True): |
|---|
| 663 | """ |
|---|
| 664 | Record synaptic conductances for all cells in the Population. |
|---|
| 665 | """ |
|---|
| 666 | self._record('gsyn', to_file) |
|---|
| 667 | |
|---|
| 668 | def printSpikes(self, file, gather=True, compatible_output=True): |
|---|
| 669 | """ |
|---|
| 670 | Write spike times to file. |
|---|
| 671 | |
|---|
| 672 | file should be either a filename or a PyNN File object. |
|---|
| 673 | |
|---|
| 674 | If compatible_output is True, the format is "spiketime cell_id", |
|---|
| 675 | where cell_id is the index of the cell counting along rows and down |
|---|
| 676 | columns (and the extension of that for 3-D). |
|---|
| 677 | This allows easy plotting of a `raster' plot of spiketimes, with one |
|---|
| 678 | line for each cell. |
|---|
| 679 | The timestep, first id, last id, and number of data points per cell are |
|---|
| 680 | written in a header, indicated by a '#' at the beginning of the line. |
|---|
| 681 | |
|---|
| 682 | If compatible_output is False, the raw format produced by the simulator |
|---|
| 683 | is used. This may be faster, since it avoids any post-processing of the |
|---|
| 684 | spike files. |
|---|
| 685 | |
|---|
| 686 | For parallel simulators, if gather is True, all data will be gathered |
|---|
| 687 | to the master node and a single output file created there. Otherwise, a |
|---|
| 688 | file will be written on each node, containing only the cells simulated |
|---|
| 689 | on that node. |
|---|
| 690 | """ |
|---|
| 691 | self.recorders['spikes'].write(file, gather, compatible_output, self.record_filter) |
|---|
| 692 | |
|---|
| 693 | def getSpikes(self, gather=True, compatible_output=True): |
|---|
| 694 | """ |
|---|
| 695 | Return a 2-column numpy array containing cell ids and spike times for |
|---|
| 696 | recorded cells. |
|---|
| 697 | |
|---|
| 698 | Useful for small populations, for example for single neuron Monte-Carlo. |
|---|
| 699 | """ |
|---|
| 700 | return self.recorders['spikes'].get(gather, compatible_output, self.record_filter) |
|---|
| 701 | |
|---|
| 702 | def print_v(self, file, gather=True, compatible_output=True): |
|---|
| 703 | """ |
|---|
| 704 | Write membrane potential traces to file. |
|---|
| 705 | |
|---|
| 706 | file should be either a filename or a PyNN File object. |
|---|
| 707 | |
|---|
| 708 | If compatible_output is True, the format is "v cell_id", |
|---|
| 709 | where cell_id is the index of the cell counting along rows and down |
|---|
| 710 | columns (and the extension of that for 3-D). |
|---|
| 711 | The timestep, first id, last id, and number of data points per cell are |
|---|
| 712 | written in a header, indicated by a '#' at the beginning of the line. |
|---|
| 713 | |
|---|
| 714 | If compatible_output is False, the raw format produced by the simulator |
|---|
| 715 | is used. This may be faster, since it avoids any post-processing of the |
|---|
| 716 | voltage files. |
|---|
| 717 | |
|---|
| 718 | For parallel simulators, if gather is True, all data will be gathered |
|---|
| 719 | to the master node and a single output file created there. Otherwise, a |
|---|
| 720 | file will be written on each node, containing only the cells simulated |
|---|
| 721 | on that node. |
|---|
| 722 | """ |
|---|
| 723 | self.recorders['v'].write(file, gather, compatible_output, self.record_filter) |
|---|
| 724 | |
|---|
| 725 | def get_v(self, gather=True, compatible_output=True): |
|---|
| 726 | """ |
|---|
| 727 | Return a 2-column numpy array containing cell ids and Vm for |
|---|
| 728 | recorded cells. |
|---|
| 729 | """ |
|---|
| 730 | return self.recorders['v'].get(gather, compatible_output, self.record_filter) |
|---|
| 731 | |
|---|
| 732 | def print_gsyn(self, file, gather=True, compatible_output=True): |
|---|
| 733 | """ |
|---|
| 734 | Write synaptic conductance traces to file. |
|---|
| 735 | |
|---|
| 736 | file should be either a filename or a PyNN File object. |
|---|
| 737 | |
|---|
| 738 | If compatible_output is True, the format is "t g cell_id", |
|---|
| 739 | where cell_id is the index of the cell counting along rows and down |
|---|
| 740 | columns (and the extension of that for 3-D). |
|---|
| 741 | The timestep, first id, last id, and number of data points per cell are |
|---|
| 742 | written in a header, indicated by a '#' at the beginning of the line. |
|---|
| 743 | |
|---|
| 744 | If compatible_output is False, the raw format produced by the simulator |
|---|
| 745 | is used. This may be faster, since it avoids any post-processing of the |
|---|
| 746 | voltage files. |
|---|
| 747 | """ |
|---|
| 748 | self.recorders['gsyn'].write(file, gather, compatible_output, self.record_filter) |
|---|
| 749 | |
|---|
| 750 | def get_gsyn(self, gather=True, compatible_output=True): |
|---|
| 751 | """ |
|---|
| 752 | Return a 3-column numpy array containing cell ids and synaptic |
|---|
| 753 | conductances for recorded cells. |
|---|
| 754 | """ |
|---|
| 755 | return self.recorders['gsyn'].get(gather, compatible_output, self.record_filter) |
|---|
| 756 | |
|---|
| 757 | def get_spike_counts(self, gather=True): |
|---|
| 758 | """ |
|---|
| 759 | Returns the number of spikes for each neuron. |
|---|
| 760 | """ |
|---|
| 761 | return self.recorders['spikes'].count(gather, self.record_filter) |
|---|
| 762 | |
|---|
| 763 | def meanSpikeCount(self, gather=True): |
|---|
| 764 | """ |
|---|
| 765 | Returns the mean number of spikes per neuron. |
|---|
| 766 | """ |
|---|
| 767 | spike_counts = self.recorders['spikes'].count(gather, self.record_filter) |
|---|
| 768 | total_spikes = sum(spike_counts.values()) |
|---|
| 769 | if rank() == 0 or not gather: # should maybe use allgather, and get the numbers on all nodes |
|---|
| 770 | if len(spike_counts) > 0: |
|---|
| 771 | return float(total_spikes)/len(spike_counts) |
|---|
| 772 | else: |
|---|
| 773 | return numpy.nan |
|---|
| 774 | else: |
|---|
| 775 | return numpy.nan |
|---|
| 776 | |
|---|
| 777 | def inject(self, current_source): |
|---|
| 778 | """ |
|---|
| 779 | Connect a current source to all cells in the Population. |
|---|
| 780 | """ |
|---|
| 781 | if not self.celltype.injectable: |
|---|
| 782 | raise TypeError("Can't inject current into a spike source.") |
|---|
| 783 | current_source.inject_into(self) |
|---|
| 784 | |
|---|
| 785 | def save_positions(self, file): |
|---|
| 786 | """ |
|---|
| 787 | Save positions to file. The output format is id x y z |
|---|
| 788 | """ |
|---|
| 789 | # first column should probably be indices, not ids. This would make it |
|---|
| 790 | # simulator independent. |
|---|
| 791 | if isinstance(file, basestring): |
|---|
| 792 | file = files.StandardTextFile(file, mode='w') |
|---|
| 793 | cells = self.all_cells |
|---|
| 794 | result = numpy.empty((len(cells), 4)) |
|---|
| 795 | result[:,0] = cells |
|---|
| 796 | result[:,1:4] = self.positions.T |
|---|
| 797 | if rank() == 0: |
|---|
| 798 | file.write(result, {'population' : self.label}) |
|---|
| 799 | file.close() |
|---|
| 800 | |
|---|
| 801 | |
|---|
| 802 | class Population(BasePopulation): |
|---|
| 803 | """ |
|---|
| 804 | A group of neurons all of the same type. |
|---|
| 805 | """ |
|---|
| 806 | nPop = 0 |
|---|
| 807 | |
|---|
| 808 | def __init__(self, size, cellclass, cellparams=None, structure=None, |
|---|
| 809 | label=None): |
|---|
| 810 | """ |
|---|
| 811 | Create a population of neurons all of the same type. |
|---|
| 812 | |
|---|
| 813 | size - number of cells in the Population. For backwards-compatibility, |
|---|
| 814 | n may also be a tuple giving the dimensions of a grid, |
|---|
| 815 | e.g. n=(10,10) is equivalent to n=100 with structure=Grid2D() |
|---|
| 816 | cellclass should either be a standardized cell class (a class inheriting |
|---|
| 817 | from common.standardmodels.StandardCellType) or a string giving the |
|---|
| 818 | name of the simulator-specific model that makes up the population. |
|---|
| 819 | cellparams should be a dict which is passed to the neuron model |
|---|
| 820 | constructor |
|---|
| 821 | structure should be a Structure instance. |
|---|
| 822 | label is an optional name for the population. |
|---|
| 823 | """ |
|---|
| 824 | if not isinstance(size, int): # also allow a single integer, for a 1D population |
|---|
| 825 | assert isinstance(size, tuple), "`size` must be an integer or a tuple of ints. You have supplied a %s" % type(size) |
|---|
| 826 | # check the things inside are ints |
|---|
| 827 | for e in size: |
|---|
| 828 | assert isinstance(e, int), "`size` must be an integer or a tuple of ints. Element '%s' is not an int" % str(e) |
|---|
| 829 | |
|---|
| 830 | assert structure is None, "If you specify `size` as a tuple you may not specify structure." |
|---|
| 831 | if len(size) == 1: |
|---|
| 832 | structure = space.Line() |
|---|
| 833 | elif len(size) == 2: |
|---|
| 834 | nx, ny = size |
|---|
| 835 | structure = space.Grid2D(nx/float(ny)) |
|---|
| 836 | elif len(size) == 3: |
|---|
| 837 | nx, ny, nz = size |
|---|
| 838 | structure = space.Grid3D(nx/float(ny), nx/float(nz)) |
|---|
| 839 | else: |
|---|
| 840 | raise Exception("A maximum of 3 dimensions is allowed. What do you think this is, string theory?") |
|---|
| 841 | size = reduce(operator.mul, size) |
|---|
| 842 | self.size = size |
|---|
| 843 | self.label = label or 'population%d' % Population.nPop |
|---|
| 844 | self.celltype = cellclass(cellparams) |
|---|
| 845 | self._structure = structure or space.Line() |
|---|
| 846 | self._positions = None |
|---|
| 847 | # Build the arrays of cell ids |
|---|
| 848 | # Cells on the local node are represented as ID objects, other cells by integers |
|---|
| 849 | # All are stored in a single numpy array for easy lookup by address |
|---|
| 850 | # The local cells are also stored in a list, for easy iteration |
|---|
| 851 | self._create_cells(cellclass, cellparams, size) |
|---|
| 852 | self.initial_values = {} |
|---|
| 853 | for variable, value in self.celltype.default_initial_values.items(): |
|---|
| 854 | self.initialize(variable, value) |
|---|
| 855 | self.recorders = {'spikes': self.recorder_class('spikes', population=self), |
|---|
| 856 | 'v' : self.recorder_class('v', population=self), |
|---|
| 857 | 'gsyn' : self.recorder_class('gsyn', population=self)} |
|---|
| 858 | Population.nPop += 1 |
|---|
| 859 | |
|---|
| 860 | @property |
|---|
| 861 | def local_cells(self): |
|---|
| 862 | return self.all_cells[self._mask_local] |
|---|
| 863 | |
|---|
| 864 | @property |
|---|
| 865 | def cell(self): |
|---|
| 866 | warn("The `Population.cell` attribute is not an official part of the \ |
|---|
| 867 | API, and its use is deprecated. It will be removed in a future \ |
|---|
| 868 | release. All uses of `cell` may be replaced by `all_cells`") |
|---|
| 869 | return self.all_cells |
|---|
| 870 | |
|---|
| 871 | def id_to_index(self, id): |
|---|
| 872 | """ |
|---|
| 873 | Given the ID(s) of cell(s) in the Population, return its (their) index |
|---|
| 874 | (order in the Population). |
|---|
| 875 | >>> assert p.id_to_index(p[5]) == 5 |
|---|
| 876 | >>> assert p.id_to_index(p.index([1,2,3])) == [1,2,3] |
|---|
| 877 | """ |
|---|
| 878 | if not numpy.iterable(id): |
|---|
| 879 | if not self.first_id <= id <= self.last_id: |
|---|
| 880 | raise ValueError("id should be in the range [%d,%d], actually %d" % (self.first_id, self.last_id, id)) |
|---|
| 881 | return int(id - self.first_id) # this assumes ids are consecutive |
|---|
| 882 | else: |
|---|
| 883 | if isinstance(id, PopulationView): |
|---|
| 884 | id = id.all_cells |
|---|
| 885 | id = numpy.array(id) |
|---|
| 886 | if (self.first_id > id.min()) or (self.last_id < id.max()): |
|---|
| 887 | raise ValueError("ids should be in the range [%d,%d], actually [%d, %d]" % (self.first_id, self.last_id, id.min(), id.max())) |
|---|
| 888 | return (id - self.first_id).astype(int) # this assumes ids are consecutive |
|---|
| 889 | |
|---|
| 890 | def id_to_local_index(self, id): |
|---|
| 891 | if num_processes() > 1: |
|---|
| 892 | return self.local_cells.tolist().index(id) # probably very slow |
|---|
| 893 | else: |
|---|
| 894 | return self.id_to_index(id) |
|---|
| 895 | |
|---|
| 896 | def _get_structure(self): |
|---|
| 897 | return self._structure |
|---|
| 898 | |
|---|
| 899 | def _set_structure(self, structure): |
|---|
| 900 | assert isinstance(structure, space.BaseStructure) |
|---|
| 901 | if structure != self._structure: |
|---|
| 902 | self._positions = None # setting a new structure invalidates previously calculated positions |
|---|
| 903 | self._structure = structure |
|---|
| 904 | structure = property(fget=_get_structure, fset=_set_structure) |
|---|
| 905 | # arguably structure should be read-only, i.e. it is not possible to change it after Population creation |
|---|
| 906 | |
|---|
| 907 | @property |
|---|
| 908 | def position_generator(self): |
|---|
| 909 | def gen(i): |
|---|
| 910 | return self.positions[:,i] |
|---|
| 911 | return gen |
|---|
| 912 | |
|---|
| 913 | def _get_positions(self): |
|---|
| 914 | """ |
|---|
| 915 | Try to return self._positions. If it does not exist, create it and then |
|---|
| 916 | return it. |
|---|
| 917 | """ |
|---|
| 918 | if self._positions is None: |
|---|
| 919 | self._positions = self.structure.generate_positions(self.size) |
|---|
| 920 | assert self._positions.shape == (3, self.size) |
|---|
| 921 | return self._positions |
|---|
| 922 | |
|---|
| 923 | def _set_positions(self, pos_array): |
|---|
| 924 | assert isinstance(pos_array, numpy.ndarray) |
|---|
| 925 | assert pos_array.shape == (3, self.size), "%s != %s" % (pos_array.shape, (3, self.size)) |
|---|
| 926 | self._positions = pos_array.copy() # take a copy in case pos_array is changed later |
|---|
| 927 | self._structure = None # explicitly setting positions destroys any previous structure |
|---|
| 928 | |
|---|
| 929 | positions = property(_get_positions, _set_positions, |
|---|
| 930 | """A 3xN array (where N is the number of neurons in the Population) |
|---|
| 931 | giving the x,y,z coordinates of all the neurons (soma, in the |
|---|
| 932 | case of non-point models).""") |
|---|
| 933 | |
|---|
| 934 | def describe(self, template='population_default.txt', engine='default'): |
|---|
| 935 | """ |
|---|
| 936 | Returns a human-readable description of the population. |
|---|
| 937 | |
|---|
| 938 | The output may be customized by specifying a different template |
|---|
| 939 | togther with an associated template engine (see ``pyNN.descriptions``). |
|---|
| 940 | |
|---|
| 941 | If template is None, then a dictionary containing the template context |
|---|
| 942 | will be returned. |
|---|
| 943 | """ |
|---|
| 944 | context = { |
|---|
| 945 | "label": self.label, |
|---|
| 946 | "celltype": self.celltype.describe(template=None), |
|---|
| 947 | "structure": None, |
|---|
| 948 | "size": self.size, |
|---|
| 949 | "size_local": len(self.local_cells), |
|---|
| 950 | "first_id": self.first_id, |
|---|
| 951 | "last_id": self.last_id, |
|---|
| 952 | } |
|---|
| 953 | if len(self.local_cells) > 0: |
|---|
| 954 | first_id = self.local_cells[0] |
|---|
| 955 | context.update({ |
|---|
| 956 | "local_first_id": first_id, |
|---|
| 957 | "cell_parameters": first_id.get_parameters(), |
|---|
| 958 | }) |
|---|
| 959 | if self.structure: |
|---|
| 960 | context["structure"] = self.structure.describe(template=None) |
|---|
| 961 | return descriptions.render(engine, template, context) |
|---|
| 962 | |
|---|
| 963 | |
|---|
| 964 | class PopulationView(BasePopulation): |
|---|
| 965 | |
|---|
| 966 | def __init__(self, parent, selector, label=None): |
|---|
| 967 | self.parent = parent |
|---|
| 968 | self.mask = selector # later we can have fancier selectors, for now we just have numpy masks |
|---|
| 969 | self.label = label or "view of %s with mask %s" % (parent.label, self.mask) |
|---|
| 970 | # maybe just redefine __getattr__ instead of the following... |
|---|
| 971 | self.celltype = self.parent.celltype |
|---|
| 972 | # If the mask is a slice, IDs will be consecutives without duplication. |
|---|
| 973 | # If not, then we need to remove duplicated IDs |
|---|
| 974 | if not isinstance(self.mask, slice): |
|---|
| 975 | if isinstance(self.mask, list): |
|---|
| 976 | self.mask = numpy.array(self.mask) |
|---|
| 977 | if self.mask.dtype is numpy.dtype('bool'): |
|---|
| 978 | if len(self.mask) != len(self.parent): |
|---|
| 979 | raise Exception("Boolean masks should have the size of Parent Population") |
|---|
| 980 | self.mask = numpy.arange(len(self.parent))[self.mask] |
|---|
| 981 | if len(numpy.unique(self.mask)) != len(self.mask): |
|---|
| 982 | logging.warning("PopulationView can contain only once each ID, duplicated IDs are remove") |
|---|
| 983 | self.mask = numpy.unique(self.mask) |
|---|
| 984 | self.all_cells = self.parent.all_cells[self.mask] # do we need to ensure this is ordered? |
|---|
| 985 | self.size = len(self.all_cells) |
|---|
| 986 | self._mask_local = self.parent._mask_local[self.mask] |
|---|
| 987 | self.local_cells = self.all_cells[self._mask_local] |
|---|
| 988 | self.first_id = numpy.min(self.all_cells) # only works if we assume all_cells is sorted, otherwise could use min() |
|---|
| 989 | self.last_id = numpy.max(self.all_cells) |
|---|
| 990 | self.recorders = self.parent.recorders |
|---|
| 991 | self.record_filter= self.all_cells |
|---|
| 992 | |
|---|
| 993 | @property |
|---|
| 994 | def initial_values(self): |
|---|
| 995 | # this is going to be complex - if we keep initial_values as a dict, |
|---|
| 996 | # need to return a dict-like object that takes account of self.mask |
|---|
| 997 | raise NotImplementedError |
|---|
| 998 | |
|---|
| 999 | @property |
|---|
| 1000 | def structure(self): |
|---|
| 1001 | return self.parent.structure |
|---|
| 1002 | # should we allow setting structure for a PopulationView? Maybe if the |
|---|
| 1003 | # parent has some kind of CompositeStructure? |
|---|
| 1004 | |
|---|
| 1005 | @property |
|---|
| 1006 | def positions(self): |
|---|
| 1007 | return self.parent.positions.T[self.mask].T # make positions N,3 instead of 3,N to avoid all this transposing? |
|---|
| 1008 | |
|---|
| 1009 | def id_to_index(self, id): |
|---|
| 1010 | """ |
|---|
| 1011 | Given the ID(s) of cell(s) in the PopulationView, return its/their |
|---|
| 1012 | index/indices (order in the PopulationView). |
|---|
| 1013 | >>> assert id_to_index(p.index(5)) == 5 |
|---|
| 1014 | >>> assert id_to_index(p.index([1,2,3])) == [1,2,3] |
|---|
| 1015 | """ |
|---|
| 1016 | if not numpy.iterable(id): |
|---|
| 1017 | result = numpy.where(self.all_cells == id)[0] |
|---|
| 1018 | if len(result) == 0: |
|---|
| 1019 | raise IndexError("ID %s not present in the View" %id) |
|---|
| 1020 | elif len(result) > 1: |
|---|
| 1021 | raise Exception("ID %s is duplicated in the View" %id) |
|---|
| 1022 | else: |
|---|
| 1023 | return result |
|---|
| 1024 | else: |
|---|
| 1025 | result = numpy.array([]) |
|---|
| 1026 | for item in id: |
|---|
| 1027 | data = numpy.where(self.all_cells == item)[0] |
|---|
| 1028 | if len(data) == 0: |
|---|
| 1029 | raise IndexError("ID %s not present in the View" %item) |
|---|
| 1030 | elif len(data) > 1: |
|---|
| 1031 | raise Exception("ID %s is duplicated in the View" %item) |
|---|
| 1032 | else: |
|---|
| 1033 | result = numpy.append(result, data) |
|---|
| 1034 | return result |
|---|
| 1035 | |
|---|
| 1036 | def describe(self, template='populationview_default.txt', engine='default'): |
|---|
| 1037 | """ |
|---|
| 1038 | Returns a human-readable description of the population view. |
|---|
| 1039 | |
|---|
| 1040 | The output may be customized by specifying a different template |
|---|
| 1041 | togther with an associated template engine (see ``pyNN.descriptions``). |
|---|
| 1042 | |
|---|
| 1043 | If template is None, then a dictionary containing the template context |
|---|
| 1044 | will be returned. |
|---|
| 1045 | """ |
|---|
| 1046 | context = {"label": self.label, |
|---|
| 1047 | "parent": self.parent.label, |
|---|
| 1048 | "mask": self.mask, |
|---|
| 1049 | "size": self.size} |
|---|
| 1050 | return descriptions.render(engine, template, context) |
|---|
| 1051 | |
|---|
| 1052 | |
|---|
| 1053 | # ============================================================================= |
|---|
| 1054 | |
|---|
| 1055 | class Assembly(object): |
|---|
| 1056 | """ |
|---|
| 1057 | A group of neurons, may be heterogeneous, in contrast to a Population where |
|---|
| 1058 | all the neurons are of the same type. |
|---|
| 1059 | """ |
|---|
| 1060 | count = 0 |
|---|
| 1061 | |
|---|
| 1062 | def __init__(self, *populations, **kwargs): |
|---|
| 1063 | if kwargs: |
|---|
| 1064 | assert kwargs.keys() == ['label'] |
|---|
| 1065 | self.populations = [] |
|---|
| 1066 | for p in populations: |
|---|
| 1067 | self._insert(p) |
|---|
| 1068 | self.label = kwargs.get('label', 'assembly%d' % Assembly.count) |
|---|
| 1069 | assert isinstance(self.label, basestring), "label must be a string or unicode" |
|---|
| 1070 | Assembly.count += 1 |
|---|
| 1071 | |
|---|
| 1072 | def _insert(self, element): |
|---|
| 1073 | if not isinstance(element, BasePopulation): |
|---|
| 1074 | raise TypeError("argument is a %s, not a Population." % type(element).__name__) |
|---|
| 1075 | if isinstance(element, PopulationView): |
|---|
| 1076 | if not element.parent in self.populations: |
|---|
| 1077 | double = False |
|---|
| 1078 | for p in self.populations: |
|---|
| 1079 | data = numpy.concatenate((p.all_cells, element.all_cells)) |
|---|
| 1080 | if len(numpy.unique(data))!= len(p.all_cells) + len(element.all_cells): |
|---|
| 1081 | logging.warning('Adding a PopulationView to an Assembly containing elements already present is not posible') |
|---|
| 1082 | double = True #Should we automatically remove duplicated IDs ? |
|---|
| 1083 | break |
|---|
| 1084 | if not double: |
|---|
| 1085 | self.populations.append(element) |
|---|
| 1086 | else: |
|---|
| 1087 | logging.warning('Adding a PopulationView to an Assembly when parent Population is there is not possible') |
|---|
| 1088 | elif isinstance(element, BasePopulation): |
|---|
| 1089 | if not element in self.populations: |
|---|
| 1090 | self.populations.append(element) |
|---|
| 1091 | else: |
|---|
| 1092 | logging.warning('Adding a Population twice in an Assembly is not possible') |
|---|
| 1093 | |
|---|
| 1094 | @property |
|---|
| 1095 | def local_cells(self): |
|---|
| 1096 | result = self.populations[0].local_cells |
|---|
| 1097 | for p in self.populations[1:]: |
|---|
| 1098 | result = numpy.concatenate((result, p.local_cells)) |
|---|
| 1099 | return result |
|---|
| 1100 | |
|---|
| 1101 | @property |
|---|
| 1102 | def all_cells(self): |
|---|
| 1103 | result = self.populations[0].all_cells |
|---|
| 1104 | for p in self.populations[1:]: |
|---|
| 1105 | result = numpy.concatenate((result, p.all_cells)) |
|---|
| 1106 | return result |
|---|
| 1107 | |
|---|
| 1108 | @property |
|---|
| 1109 | def _mask_local(self): |
|---|
| 1110 | result = self.populations[0]._mask_local |
|---|
| 1111 | for p in self.populations[1:]: |
|---|
| 1112 | result = numpy.concatenate((result, p._mask_local)) |
|---|
| 1113 | return result |
|---|
| 1114 | |
|---|
| 1115 | @property |
|---|
| 1116 | def first_id(self): |
|---|
| 1117 | return numpy.min(self.all_cells) |
|---|
| 1118 | |
|---|
| 1119 | @property |
|---|
| 1120 | def last_id(self): |
|---|
| 1121 | return numpy.max(self.all_cells) |
|---|
| 1122 | |
|---|
| 1123 | def id_to_index(self, id): |
|---|
| 1124 | """ |
|---|
| 1125 | Given the ID(s) of cell(s) in the Assembly, return its (their) index |
|---|
| 1126 | (order in the Assembly). |
|---|
| 1127 | >>> assert p.id_to_index(p[5]) == 5 |
|---|
| 1128 | >>> assert p.id_to_index(p.index([1,2,3])) == [1,2,3] |
|---|
| 1129 | """ |
|---|
| 1130 | all_cells = self.all_cells |
|---|
| 1131 | if not numpy.iterable(id): |
|---|
| 1132 | result = numpy.where(all_cells == id)[0] |
|---|
| 1133 | if len(result) == 0: |
|---|
| 1134 | raise IndexError("ID %s not present in the View" %id) |
|---|
| 1135 | elif len(result) > 1: |
|---|
| 1136 | raise Exception("ID %s is duplicated in the View" %id) |
|---|
| 1137 | else: |
|---|
| 1138 | return result |
|---|
| 1139 | else: |
|---|
| 1140 | result = numpy.array([]) |
|---|
| 1141 | for item in id: |
|---|
| 1142 | data = numpy.where(all_cells == item)[0] |
|---|
| 1143 | if len(data) == 0: |
|---|
| 1144 | raise IndexError("ID %s not present in the View" %item) |
|---|
| 1145 | elif len(data) > 1: |
|---|
| 1146 | raise Exception("ID %s is duplicated in the View" %item) |
|---|
| 1147 | else: |
|---|
| 1148 | result = numpy.append(result, data) |
|---|
| 1149 | return result |
|---|
| 1150 | |
|---|
| 1151 | @property |
|---|
| 1152 | def positions(self): |
|---|
| 1153 | result = self.populations[0].positions |
|---|
| 1154 | for p in self.populations[1:]: |
|---|
| 1155 | result = numpy.hstack((result, p.positions)) |
|---|
| 1156 | return result |
|---|
| 1157 | |
|---|
| 1158 | @property |
|---|
| 1159 | def size(self): |
|---|
| 1160 | return sum(p.size for p in self.populations) |
|---|
| 1161 | |
|---|
| 1162 | def __iter__(self): |
|---|
| 1163 | return chain(iter(p) for p in self.populations) |
|---|
| 1164 | |
|---|
| 1165 | def __len__(self): |
|---|
| 1166 | """Return the total number of cells in the population (all nodes).""" |
|---|
| 1167 | return self.size |
|---|
| 1168 | |
|---|
| 1169 | def __getitem__(self, index): |
|---|
| 1170 | if isinstance(index, int): |
|---|
| 1171 | return self.populations[index] |
|---|
| 1172 | elif isinstance(index, (slice, list, numpy.ndarray)): |
|---|
| 1173 | return Assembly(*self.populations[index]) |
|---|
| 1174 | else: |
|---|
| 1175 | raise TypeError("indices must be integers, slices, lists, arrays, not %s" % type(index).__name__) |
|---|
| 1176 | |
|---|
| 1177 | def __add__(self, other): |
|---|
| 1178 | if isinstance(other, BasePopulation): |
|---|
| 1179 | return Assembly(*(self.populations + [other])) |
|---|
| 1180 | elif isinstance(other, Assembly): |
|---|
| 1181 | return Assembly(*(self.populations + other.populations)) |
|---|
| 1182 | else: |
|---|
| 1183 | raise TypeError("can only add a Population or another Assembly to an Assembly") |
|---|
| 1184 | |
|---|
| 1185 | def __iadd__(self, other): |
|---|
| 1186 | if isinstance(other, BasePopulation): |
|---|
| 1187 | self._insert(other) |
|---|
| 1188 | elif isinstance(other, Assembly): |
|---|
| 1189 | for p in other.populations: |
|---|
| 1190 | self._insert(p) |
|---|
| 1191 | else: |
|---|
| 1192 | raise TypeError("can only add a Population or another Assembly to an Assembly") |
|---|
| 1193 | return self |
|---|
| 1194 | |
|---|
| 1195 | def initialize(self, variable, value): |
|---|
| 1196 | for p in self.populations: |
|---|
| 1197 | p.initialize(variable, value) |
|---|
| 1198 | |
|---|
| 1199 | def _record(self, variable, to_file=True): |
|---|
| 1200 | # need to think about record_from |
|---|
| 1201 | for p in self.populations: |
|---|
| 1202 | p._record(variable, to_file) |
|---|
| 1203 | |
|---|
| 1204 | def record(self, to_file=True): |
|---|
| 1205 | self._record('spikes', to_file) |
|---|
| 1206 | |
|---|
| 1207 | def record_v(self, to_file=True): |
|---|
| 1208 | self._record('v', to_file) |
|---|
| 1209 | |
|---|
| 1210 | def record_gsyn(self, to_file=True): |
|---|
| 1211 | self._record('gsyn', to_file) |
|---|
| 1212 | |
|---|
| 1213 | def get_population(self, label): |
|---|
| 1214 | for p in self.populations: |
|---|
| 1215 | if label == p.label: |
|---|
| 1216 | return p |
|---|
| 1217 | raise KeyError("Assembly does not contain a population with the label %s" % label) |
|---|
| 1218 | |
|---|
| 1219 | def save_positions(self, file): |
|---|
| 1220 | """ |
|---|
| 1221 | Save positions to file. The output format is id x y z |
|---|
| 1222 | """ |
|---|
| 1223 | # this should be rewritten to use self.positions and recording.files |
|---|
| 1224 | if isinstance(file, basestring): |
|---|
| 1225 | file = files.StandardTextFile(file, mode='w') |
|---|
| 1226 | cells = self.all_cells |
|---|
| 1227 | result = numpy.empty((len(cells), 4)) |
|---|
| 1228 | result[:,0] = cells |
|---|
| 1229 | result[:,1:4] = self.positions.T |
|---|
| 1230 | if rank() == 0: |
|---|
| 1231 | file.write(result, {'assembly' : self.label}) |
|---|
| 1232 | file.close() |
|---|
| 1233 | |
|---|
| 1234 | @property |
|---|
| 1235 | def position_generator(self): |
|---|
| 1236 | def gen(i): |
|---|
| 1237 | return self.positions[:,i] |
|---|
| 1238 | return gen |
|---|
| 1239 | |
|---|
| 1240 | def meanSpikeCount(self, gather=True): |
|---|
| 1241 | """ |
|---|
| 1242 | Returns the mean number of spikes per neuron. |
|---|
| 1243 | """ |
|---|
| 1244 | try: |
|---|
| 1245 | spike_counts = self[0].recorders['spikes'].count(gather, self[0].record_filter) |
|---|
| 1246 | except errors.NothingToWriteError: |
|---|
| 1247 | spike_counts = {} |
|---|
| 1248 | for p in self.populations[1:]: |
|---|
| 1249 | try: |
|---|
| 1250 | spike_counts.update(p.recorders['spikes'].count(gather, p.record_filter)) |
|---|
| 1251 | except errors.NothingToWriteError: |
|---|
| 1252 | pass |
|---|
| 1253 | total_spikes = sum(spike_counts.values()) |
|---|
| 1254 | if rank() == 0 or not gather: # should maybe use allgather, and get the numbers on all nodes |
|---|
| 1255 | return float(total_spikes)/len(spike_counts) |
|---|
| 1256 | else: |
|---|
| 1257 | return numpy.nan |
|---|
| 1258 | |
|---|
| 1259 | def get_v(self, gather=True, compatible_output=True): |
|---|
| 1260 | """ |
|---|
| 1261 | Return a 2-column numpy array containing cell ids and Vm for |
|---|
| 1262 | recorded cells. |
|---|
| 1263 | """ |
|---|
| 1264 | try: |
|---|
| 1265 | result = self[0].recorders['v'].get(gather, compatible_output, self[0].record_filter) |
|---|
| 1266 | except errors.NothingToWriteError: |
|---|
| 1267 | result = numpy.zeros((0, 3)) |
|---|
| 1268 | for p in self.populations[1:]: |
|---|
| 1269 | try: |
|---|
| 1270 | result = numpy.vstack((result, p.recorders['v'].get(gather, compatible_output, p.record_filter))) |
|---|
| 1271 | except errors.NothingToWriteError: |
|---|
| 1272 | pass |
|---|
| 1273 | return result |
|---|
| 1274 | |
|---|
| 1275 | def get_gsyn(self, gather=True, compatible_output=True): |
|---|
| 1276 | """ |
|---|
| 1277 | Return a 3-column numpy array containing cell ids and synaptic |
|---|
| 1278 | conductances for recorded cells. |
|---|
| 1279 | """ |
|---|
| 1280 | try: |
|---|
| 1281 | result = self[0].recorders['gsyn'].get(gather, compatible_output, self[0].record_filter) |
|---|
| 1282 | except errors.NothingToWriteError: |
|---|
| 1283 | result = numpy.zeros((0, 4)) |
|---|
| 1284 | for p in self.populations[1:]: |
|---|
| 1285 | try: |
|---|
| 1286 | result = numpy.vstack((result, p.recorders['gsyn'].get(gather, compatible_output, p.record_filter))) |
|---|
| 1287 | except errors.NothingToWriteError: |
|---|
| 1288 | pass |
|---|
| 1289 | return result |
|---|
| 1290 | |
|---|
| 1291 | def get_spike_counts(self, gather=True): |
|---|
| 1292 | """ |
|---|
| 1293 | Returns the number of spikes for each neuron. |
|---|
| 1294 | """ |
|---|
| 1295 | try: |
|---|
| 1296 | spike_counts = self[0].recorders['spikes'].count(gather, self[0].record_filter) |
|---|
| 1297 | except errors.NothingToWriteError: |
|---|
| 1298 | spike_counts = {} |
|---|
| 1299 | for p in self.populations[1:]: |
|---|
| 1300 | try: |
|---|
| 1301 | spike_counts.update(p.recorders['spikes'].count(gather, p.record_filter)) |
|---|
| 1302 | except errors.NothingToWriteError: |
|---|
| 1303 | pass |
|---|
| 1304 | return spike_counts |
|---|
| 1305 | |
|---|
| 1306 | def _print(self, file, variable, format, gather=True, compatible_output=True): |
|---|
| 1307 | |
|---|
| 1308 | ## First, we write all the individual data for the heterogeneous populations |
|---|
| 1309 | ## embedded within the Assembly. To speed things up, we write them in temporary |
|---|
| 1310 | ## folders as Numpy Binary objects |
|---|
| 1311 | tempdir = tempfile.mkdtemp() |
|---|
| 1312 | filenames = {} |
|---|
| 1313 | filename = '%s/%s.%s' %(tempdir, self[0].label, variable) |
|---|
| 1314 | p_file = files.NumpyBinaryFile(filename, mode='w') |
|---|
| 1315 | try: |
|---|
| 1316 | self[0].recorders[variable].write(p_file, gather, compatible_output, self[0].record_filter) |
|---|
| 1317 | filenames[self[0]] = (filename, True) |
|---|
| 1318 | except errors.NothingToWriteError: |
|---|
| 1319 | filenames[self[O]] = (filename, False) |
|---|
| 1320 | for p in self.populations[1:]: |
|---|
| 1321 | filename = '%s/%s.%s' %(tempdir, p.label, variable) |
|---|
| 1322 | p_file = files.NumpyBinaryFile(filename, mode='w') |
|---|
| 1323 | try: |
|---|
| 1324 | p.recorders[variable].write(p_file, gather, compatible_output, p.record_filter) |
|---|
| 1325 | filenames[p] = (filename, True) |
|---|
| 1326 | except errors.NothingToWriteError: |
|---|
| 1327 | filenames[p] = (filename, False) |
|---|
| 1328 | |
|---|
| 1329 | ## Then we need to merge the previsouly written files into a single one, to be consistent |
|---|
| 1330 | ## with a Population object. Note that the header should be better considered. |
|---|
| 1331 | metadata = {'variable' : variable, |
|---|
| 1332 | 'size' : self.size, |
|---|
| 1333 | 'label' : self.label, |
|---|
| 1334 | 'populations' : ", ".join(["%s[%d-%d]" %(p.label, p.first_id, p.last_id) for p in self.populations]), |
|---|
| 1335 | 'first_id' : self.first_id, |
|---|
| 1336 | 'last_id' : self.last_id} |
|---|
| 1337 | |
|---|
| 1338 | metadata['dt'] = simulator.state.dt # note that this has to run on all nodes (at least for NEST) |
|---|
| 1339 | data = numpy.zeros(format) |
|---|
| 1340 | for pop in filenames.keys(): |
|---|
| 1341 | if filenames[pop][1] is True: |
|---|
| 1342 | name = filenames[pop][0] |
|---|
| 1343 | p_file = files.NumpyBinaryFile(name, mode='r') |
|---|
| 1344 | tmp_data = p_file.read() |
|---|
| 1345 | if compatible_output: |
|---|
| 1346 | tmp_data[:, -1] = self.id_to_index(tmp_data[:,-1] + pop.first_id) |
|---|
| 1347 | data = numpy.vstack((data, tmp_data)) |
|---|
| 1348 | os.remove(name) |
|---|
| 1349 | metadata['n'] = data.shape[0] |
|---|
| 1350 | os.rmdir(tempdir) |
|---|
| 1351 | |
|---|
| 1352 | if isinstance(file, basestring): |
|---|
| 1353 | file = files.StandardTextFile(file, mode='w') |
|---|
| 1354 | |
|---|
| 1355 | if rank() == 0: |
|---|
| 1356 | file.write(data, metadata) |
|---|
| 1357 | file.close() |
|---|
| 1358 | |
|---|
| 1359 | |
|---|
| 1360 | def printSpikes(self, file, gather=True, compatible_output=True): |
|---|
| 1361 | """ |
|---|
| 1362 | Write spike times to file. |
|---|
| 1363 | |
|---|
| 1364 | file should be either a filename or a PyNN File object. |
|---|
| 1365 | |
|---|
| 1366 | If compatible_output is True, the format is "spiketime cell_id", |
|---|
| 1367 | where cell_id is the index of the cell counting along rows and down |
|---|
| 1368 | columns (and the extension of that for 3-D). |
|---|
| 1369 | This allows easy plotting of a `raster' plot of spiketimes, with one |
|---|
| 1370 | line for each cell. |
|---|
| 1371 | The timestep, first id, last id, and number of data points per cell are |
|---|
| 1372 | written in a header, indicated by a '#' at the beginning of the line. |
|---|
| 1373 | |
|---|
| 1374 | If compatible_output is False, the raw format produced by the simulator |
|---|
| 1375 | is used. This may be faster, since it avoids any post-processing of the |
|---|
| 1376 | spike files. |
|---|
| 1377 | |
|---|
| 1378 | For parallel simulators, if gather is True, all data will be gathered |
|---|
| 1379 | to the master node and a single output file created there. Otherwise, a |
|---|
| 1380 | file will be written on each node, containing only the cells simulated |
|---|
| 1381 | on that node. |
|---|
| 1382 | """ |
|---|
| 1383 | self._print(file, 'spikes', (0, 2), gather, compatible_output) |
|---|
| 1384 | |
|---|
| 1385 | def print_v(self, file, gather=True, compatible_output=True): |
|---|
| 1386 | """ |
|---|
| 1387 | Write membrane potential traces to file. |
|---|
| 1388 | |
|---|
| 1389 | file should be either a filename or a PyNN File object. |
|---|
| 1390 | |
|---|
| 1391 | If compatible_output is True, the format is "v cell_id", |
|---|
| 1392 | where cell_id is the index of the cell counting along rows and down |
|---|
| 1393 | columns (and the extension of that for 3-D). |
|---|
| 1394 | The timestep, first id, last id, and number of data points per cell are |
|---|
| 1395 | written in a header, indicated by a '#' at the beginning of the line. |
|---|
| 1396 | |
|---|
| 1397 | If compatible_output is False, the raw format produced by the simulator |
|---|
| 1398 | is used. This may be faster, since it avoids any post-processing of the |
|---|
| 1399 | voltage files. |
|---|
| 1400 | |
|---|
| 1401 | For parallel simulators, if gather is True, all data will be gathered |
|---|
| 1402 | to the master node and a single output file created there. Otherwise, a |
|---|
| 1403 | file will be written on each node, containing only the cells simulated |
|---|
| 1404 | on that node. |
|---|
| 1405 | """ |
|---|
| 1406 | self._print(file, 'v', (0, 2), gather, compatible_output) |
|---|
| 1407 | |
|---|
| 1408 | def print_gsyn(self, file, gather=True, compatible_output=True): |
|---|
| 1409 | """ |
|---|
| 1410 | Write synaptic conductance traces to file. |
|---|
| 1411 | |
|---|
| 1412 | file should be either a filename or a PyNN File object. |
|---|
| 1413 | |
|---|
| 1414 | If compatible_output is True, the format is "t g cell_id", |
|---|
| 1415 | where cell_id is the index of the cell counting along rows and down |
|---|
| 1416 | columns (and the extension of that for 3-D). |
|---|
| 1417 | The timestep, first id, last id, and number of data points per cell are |
|---|
| 1418 | written in a header, indicated by a '#' at the beginning of the line. |
|---|
| 1419 | |
|---|
| 1420 | If compatible_output is False, the raw format produced by the simulator |
|---|
| 1421 | is used. This may be faster, since it avoids any post-processing of the |
|---|
| 1422 | voltage files. |
|---|
| 1423 | """ |
|---|
| 1424 | self._print(file, 'gsyn', (0, 3), gather, compatible_output) |
|---|
| 1425 | |
|---|
| 1426 | def inject(self, current_source): |
|---|
| 1427 | """ |
|---|
| 1428 | Connect a current source to all cells in the Population. |
|---|
| 1429 | """ |
|---|
| 1430 | for p in self.populations: |
|---|
| 1431 | current_source.inject_into(p) |
|---|
| 1432 | |
|---|
| 1433 | def describe(self, template='assembly_default.txt', engine='default'): |
|---|
| 1434 | """ |
|---|
| 1435 | Returns a human-readable description of the assembly. |
|---|
| 1436 | |
|---|
| 1437 | The output may be customized by specifying a different template |
|---|
| 1438 | togther with an associated template engine (see ``pyNN.descriptions``). |
|---|
| 1439 | |
|---|
| 1440 | If template is None, then a dictionary containing the template context |
|---|
| 1441 | will be returned. |
|---|
| 1442 | """ |
|---|
| 1443 | context = {"label": self.label, |
|---|
| 1444 | "populations": [p.describe(template=None) for p in self.populations]} |
|---|
| 1445 | return descriptions.render(engine, template, context) |
|---|
| 1446 | |
|---|
| 1447 | # ============================================================================= |
|---|
| 1448 | |
|---|
| 1449 | |
|---|
| 1450 | class Projection(object): |
|---|
| 1451 | """ |
|---|
| 1452 | A container for all the connections of a given type (same synapse type and |
|---|
| 1453 | plasticity mechanisms) between two populations, together with methods to |
|---|
| 1454 | set parameters of those connections, including of plasticity mechanisms. |
|---|
| 1455 | """ |
|---|
| 1456 | |
|---|
| 1457 | def __init__(self, presynaptic_neurons, postsynaptic_neurons, method, |
|---|
| 1458 | source=None, target=None, synapse_dynamics=None, |
|---|
| 1459 | label=None, rng=None): |
|---|
| 1460 | """ |
|---|
| 1461 | presynaptic_neurons and postsynaptic_neurons - Population, PopulationView |
|---|
| 1462 | or Assembly objects. |
|---|
| 1463 | |
|---|
| 1464 | source - string specifying which attribute of the presynaptic cell |
|---|
| 1465 | signals action potentials. This is only needed for |
|---|
| 1466 | multicompartmental cells with branching axons or |
|---|
| 1467 | dendrodendriticsynapses. All standard cells have a single |
|---|
| 1468 | source, and this is the default. |
|---|
| 1469 | |
|---|
| 1470 | target - string specifying which synapse on the postsynaptic cell to |
|---|
| 1471 | connect to. For standard cells, this can be 'excitatory' or |
|---|
| 1472 | 'inhibitory'. For non-standard cells, it could be 'NMDA', etc. |
|---|
| 1473 | If target is not given, the default values of 'excitatory' is |
|---|
| 1474 | used. |
|---|
| 1475 | |
|---|
| 1476 | method - a Connector object, encapsulating the algorithm to use for |
|---|
| 1477 | connecting the neurons. |
|---|
| 1478 | |
|---|
| 1479 | synapse_dynamics - a `standardmodels.SynapseDynamics` object specifying |
|---|
| 1480 | which synaptic plasticity mechanisms to use. |
|---|
| 1481 | |
|---|
| 1482 | rng - specify an RNG object to be used by the Connector. |
|---|
| 1483 | """ |
|---|
| 1484 | for prefix, pop in zip(("pre", "post"), |
|---|
| 1485 | (presynaptic_neurons, postsynaptic_neurons)): |
|---|
| 1486 | if not isinstance(pop, (BasePopulation, Assembly)): |
|---|
| 1487 | raise errors.ConnectionError("%ssynaptic_neurons must be a Population, PopulationView or Assembly, not a %s" % (prefix, type(pop))) |
|---|
| 1488 | self.pre = presynaptic_neurons # } these really |
|---|
| 1489 | self.source = source # } should be |
|---|
| 1490 | self.post = postsynaptic_neurons # } read-only |
|---|
| 1491 | self.target = target # } |
|---|
| 1492 | self.label = label |
|---|
| 1493 | if isinstance(rng, random.AbstractRNG): |
|---|
| 1494 | self.rng = rng |
|---|
| 1495 | elif rng is None: |
|---|
| 1496 | self.rng = random.NumpyRNG(seed=151985012) |
|---|
| 1497 | else: |
|---|
| 1498 | raise Exception("rng must be either None, or a subclass of pyNN.random.AbstractRNG") |
|---|
| 1499 | self._method = method |
|---|
| 1500 | self.synapse_dynamics = synapse_dynamics |
|---|
| 1501 | #self.connection = None # access individual connections. To be defined by child, simulator-specific classes |
|---|
| 1502 | self.weights = [] |
|---|
| 1503 | if label is None: |
|---|
| 1504 | if self.pre.label and self.post.label: |
|---|
| 1505 | self.label = "%sâ%s" % (self.pre.label, self.post.label) |
|---|
| 1506 | if self.synapse_dynamics: |
|---|
| 1507 | assert isinstance(self.synapse_dynamics, standardmodels.SynapseDynamics), \ |
|---|
| 1508 | "The synapse_dynamics argument, if specified, must be a standardmodels.SynapseDynamics object, not a %s" % type(synapse_dynamics) |
|---|
| 1509 | |
|---|
| 1510 | def __len__(self): |
|---|
| 1511 | """Return the total number of local connections.""" |
|---|
| 1512 | return len(self.connection_manager) |
|---|
| 1513 | |
|---|
| 1514 | def size(self, gather=True): |
|---|
| 1515 | """ |
|---|
| 1516 | Return the total number of connections. |
|---|
| 1517 | - only local connections, if gather is False, |
|---|
| 1518 | - all connections, if gather is True (default) |
|---|
| 1519 | """ |
|---|
| 1520 | if gather: |
|---|
| 1521 | n = len(self) |
|---|
| 1522 | return recording.mpi_sum(n) |
|---|
| 1523 | else: |
|---|
| 1524 | return len(self) |
|---|
| 1525 | |
|---|
| 1526 | def __repr__(self): |
|---|
| 1527 | return 'Projection("%s")' % self.label |
|---|
| 1528 | |
|---|
| 1529 | def __getitem__(self, i): |
|---|
| 1530 | return self.connection_manager[i] |
|---|
| 1531 | |
|---|
| 1532 | # --- Methods for setting connection parameters --------------------------- |
|---|
| 1533 | |
|---|
| 1534 | def setWeights(self, w): |
|---|
| 1535 | """ |
|---|
| 1536 | w can be a single number, in which case all weights are set to this |
|---|
| 1537 | value, or a list/1D array of length equal to the number of connections |
|---|
| 1538 | in the projection, or a 2D array with the same dimensions as the |
|---|
| 1539 | connectivity matrix (as returned by `getWeights(format='array')`). |
|---|
| 1540 | Weights should be in nA for current-based and µS for conductance-based |
|---|
| 1541 | synapses. |
|---|
| 1542 | """ |
|---|
| 1543 | # should perhaps add a "distribute" argument, for symmetry with "gather" in getWeights() |
|---|
| 1544 | # if post is an Assembly, some components might have cond-synapses, others curr, so need a more sophisticated check here |
|---|
| 1545 | w = check_weight(w, self.synapse_type, is_conductance(self.post.local_cells[0])) |
|---|
| 1546 | self.connection_manager.set('weight', w) |
|---|
| 1547 | |
|---|
| 1548 | def randomizeWeights(self, rand_distr): |
|---|
| 1549 | """ |
|---|
| 1550 | Set weights to random values taken from rand_distr. |
|---|
| 1551 | """ |
|---|
| 1552 | # Arguably, we could merge this with set_weights just by detecting the |
|---|
| 1553 | # argument type. It could make for easier-to-read simulation code to |
|---|
| 1554 | # give it a separate name, though. Comments? |
|---|
| 1555 | self.setWeights(rand_distr.next(len(self))) |
|---|
| 1556 | |
|---|
| 1557 | def setDelays(self, d): |
|---|
| 1558 | """ |
|---|
| 1559 | d can be a single number, in which case all delays are set to this |
|---|
| 1560 | value, or a list/1D array of length equal to the number of connections |
|---|
| 1561 | in the projection, or a 2D array with the same dimensions as the |
|---|
| 1562 | connectivity matrix (as returned by `getDelays(format='array')`). |
|---|
| 1563 | """ |
|---|
| 1564 | self.connection_manager.set('delay', d) |
|---|
| 1565 | |
|---|
| 1566 | def randomizeDelays(self, rand_distr): |
|---|
| 1567 | """ |
|---|
| 1568 | Set delays to random values taken from rand_distr. |
|---|
| 1569 | """ |
|---|
| 1570 | self.setDelays(rand_distr.next(len(self))) |
|---|
| 1571 | |
|---|
| 1572 | def setSynapseDynamics(self, param, value): |
|---|
| 1573 | """ |
|---|
| 1574 | Set parameters of the dynamic synapses for all connections in this |
|---|
| 1575 | projection. |
|---|
| 1576 | """ |
|---|
| 1577 | self.connection_manager.set(param, value) |
|---|
| 1578 | |
|---|
| 1579 | def randomizeSynapseDynamics(self, param, rand_distr): |
|---|
| 1580 | """ |
|---|
| 1581 | Set parameters of the synapse dynamics to values taken from rand_distr |
|---|
| 1582 | """ |
|---|
| 1583 | self.setSynapseDynamics(param, rand_distr.next(len(self))) |
|---|
| 1584 | |
|---|
| 1585 | # --- Methods for writing/reading information to/from file. --------------- |
|---|
| 1586 | |
|---|
| 1587 | def getWeights(self, format='list', gather=True): |
|---|
| 1588 | """ |
|---|
| 1589 | Get synaptic weights for all connections in this Projection. |
|---|
| 1590 | |
|---|
| 1591 | Possible formats are: a list of length equal to the number of connections |
|---|
| 1592 | in the projection, a 2D weight array (with NaN for non-existent |
|---|
| 1593 | connections). Note that for the array format, if there is more than |
|---|
| 1594 | one connection between two cells, the summed weight will be given. |
|---|
| 1595 | """ |
|---|
| 1596 | if gather: |
|---|
| 1597 | logger.error("getWeights() with gather=True not yet implemented") |
|---|
| 1598 | return self.connection_manager.get('weight', format) |
|---|
| 1599 | |
|---|
| 1600 | def getDelays(self, format='list', gather=True): |
|---|
| 1601 | """ |
|---|
| 1602 | Get synaptic delays for all connections in this Projection. |
|---|
| 1603 | |
|---|
| 1604 | Possible formats are: a list of length equal to the number of connections |
|---|
| 1605 | in the projection, a 2D delay array (with NaN for non-existent |
|---|
| 1606 | connections). |
|---|
| 1607 | """ |
|---|
| 1608 | if gather: |
|---|
| 1609 | logger.error("getDelays() with gather=True not yet implemented") |
|---|
| 1610 | return self.connection_manager.get('delay', format) |
|---|
| 1611 | |
|---|
| 1612 | def getSynapseDynamics(self, parameter_name, format='list', gather=True): |
|---|
| 1613 | """ |
|---|
| 1614 | Get parameters of the dynamic synapses for all connections in this |
|---|
| 1615 | Projection. |
|---|
| 1616 | """ |
|---|
| 1617 | if gather: |
|---|
| 1618 | logger.error("getstandardmodels.SynapseDynamics() with gather=True not yet implemented") |
|---|
| 1619 | return self.connection_manager.get(parameter_name, format) |
|---|
| 1620 | |
|---|
| 1621 | def saveConnections(self, file, gather=True, compatible_output=True): |
|---|
| 1622 | """ |
|---|
| 1623 | Save connections to file in a format suitable for reading in with a |
|---|
| 1624 | FromFileConnector. |
|---|
| 1625 | """ |
|---|
| 1626 | |
|---|
| 1627 | if isinstance(file, basestring): |
|---|
| 1628 | file = files.StandardTextFile(file, mode='w') |
|---|
| 1629 | |
|---|
| 1630 | lines = [] |
|---|
| 1631 | if not compatible_output: |
|---|
| 1632 | for c in self.connections: |
|---|
| 1633 | lines.append([c.source, c.target, c.weight, c.delay]) |
|---|
| 1634 | else: |
|---|
| 1635 | for c in self.connections: |
|---|
| 1636 | lines.append([self.pre.id_to_index(c.source), self.post.id_to_index(c.target), c.weight, c.delay]) |
|---|
| 1637 | |
|---|
| 1638 | if gather == True and num_processes() > 1: |
|---|
| 1639 | all_lines = { rank(): lines } |
|---|
| 1640 | all_lines = recording.gather_dict(all_lines) |
|---|
| 1641 | if rank() == 0: |
|---|
| 1642 | lines = reduce(operator.add, all_lines.values()) |
|---|
| 1643 | elif num_processes() > 1: |
|---|
| 1644 | file.rename('%s.%d' % (file.name, rank())) |
|---|
| 1645 | |
|---|
| 1646 | logger.debug("--- Projection[%s].__saveConnections__() ---" % self.label) |
|---|
| 1647 | |
|---|
| 1648 | if gather == False or rank() == 0: |
|---|
| 1649 | file.write(lines, {'pre' : self.pre.label, 'post' : self.post.label}) |
|---|
| 1650 | file.close() |
|---|
| 1651 | |
|---|
| 1652 | def printWeights(self, file, format='list', gather=True): |
|---|
| 1653 | """ |
|---|
| 1654 | Print synaptic weights to file. In the array format, zeros are printed |
|---|
| 1655 | for non-existent connections. |
|---|
| 1656 | """ |
|---|
| 1657 | weights = self.getWeights(format=format, gather=gather) |
|---|
| 1658 | |
|---|
| 1659 | if isinstance(file, basestring): |
|---|
| 1660 | file = files.StandardTextFile(file, mode='w') |
|---|
| 1661 | |
|---|
| 1662 | if format == 'array': |
|---|
| 1663 | weights = numpy.where(numpy.isnan(weights), 0.0, weights) |
|---|
| 1664 | file.write(weights, {}) |
|---|
| 1665 | file.close() |
|---|
| 1666 | |
|---|
| 1667 | def weightHistogram(self, min=None, max=None, nbins=10): |
|---|
| 1668 | """ |
|---|
| 1669 | Return a histogram of synaptic weights. |
|---|
| 1670 | If min and max are not given, the minimum and maximum weights are |
|---|
| 1671 | calculated automatically. |
|---|
| 1672 | """ |
|---|
| 1673 | # it is arguable whether functions operating on the set of weights |
|---|
| 1674 | # should be put here or in an external module. |
|---|
| 1675 | weights = self.getWeights(format='list', gather=True) |
|---|
| 1676 | if min is None: |
|---|
| 1677 | min = weights.min() |
|---|
| 1678 | if max is None: |
|---|
| 1679 | max = weights.max() |
|---|
| 1680 | bins = numpy.linspace(min, max, nbins+1) |
|---|
| 1681 | return numpy.histogram(weights, bins, new=True) # returns n, bins |
|---|
| 1682 | |
|---|
| 1683 | def describe(self, template='projection_default.txt', engine='default'): |
|---|
| 1684 | """ |
|---|
| 1685 | Returns a human-readable description of the projection. |
|---|
| 1686 | |
|---|
| 1687 | The output may be customized by specifying a different template |
|---|
| 1688 | togther with an associated template engine (see ``pyNN.descriptions``). |
|---|
| 1689 | |
|---|
| 1690 | If template is None, then a dictionary containing the template context |
|---|
| 1691 | will be returned. |
|---|
| 1692 | """ |
|---|
| 1693 | context = { |
|---|
| 1694 | "label": self.label, |
|---|
| 1695 | "pre": self.pre.describe(template=None), |
|---|
| 1696 | "post": self.post.describe(template=None), |
|---|
| 1697 | "source": self.source, |
|---|
| 1698 | "target": self.target, |
|---|
| 1699 | "size_local": len(self), |
|---|
| 1700 | "size": self.size(gather=True), |
|---|
| 1701 | "connector": self._method.describe(template=None), |
|---|
| 1702 | "plasticity": None, |
|---|
| 1703 | } |
|---|
| 1704 | if self.synapse_dynamics: |
|---|
| 1705 | context.update(plasticity=self.synapse_dynamics.describe(template=None)) |
|---|
| 1706 | return descriptions.render(engine, template, context) |
|---|
| 1707 | |
|---|
| 1708 | |
|---|
| 1709 | # ============================================================================= |
|---|