Source code for pyNN.neuron.populations

# encoding: utf-8
"""
nrnpython implementation of the PyNN API.

:copyright: Copyright 2006-2024 by the PyNN team, see AUTHORS.
:license: CeCILL, see LICENSE for details.

"""

from collections import defaultdict
import logging

import numpy as np

from .. import common
from ..parameters import ArrayParameter, Sequence, ParameterSpace, simplify, LazyArray
from ..standardmodels import StandardCellType
from ..random import RandomDistribution
from . import simulator
from .recording import Recorder
from .random import NativeRNG


logger = logging.getLogger("PyNN")


class PopulationMixin(object):

    def _set_parameters(self, parameter_space):
        """parameter_space should contain native parameters"""
        parameter_space.evaluate(mask=np.where(self._mask_local)[0])
        for cell, parameters in zip(self, parameter_space):
            for name, val in parameters.items():
                setattr(cell._cell, name, val)

    def _get_parameters(self, *names):
        """
        Return a ParameterSpace containing PyNN parameters

        `names` should be PyNN names
        """
        def _get_component_parameters(component, names, component_label=None):
            if component.computed_parameters_include(names):
                # need all parameters in order to calculate values
                native_names = component.get_native_names()
            else:
                native_names = component.get_native_names(*names)

            native_parameter_space = self._get_native_parameters(*native_names,
                                                                 component_label=component_label)
            ps = component.reverse_translate(native_parameter_space)
            # extract values for this component from any ArrayParameters
            for name, value in ps.items():
                if isinstance(value.base_value, ArrayParameter):
                    index = self.celltype.receptor_types.index(component_label)
                    ps[name] = LazyArray(value.base_value[index])
                    ps[name].operations = value.operations
            return ps

        if isinstance(self.celltype, StandardCellType):
            if any("." in name for name in names):
                names_by_component = defaultdict(list)
                for name in names:
                    parts = name.split(".")
                    if len(parts) == 1:
                        names_by_component["neuron"].append(parts[0])
                    elif len(parts) == 2:
                        names_by_component[parts[0]].append(parts[1])
                    else:
                        raise ValueError("Invalid name: {}".format(name))
                    if "neuron" in names_by_component:
                        parameter_space = _get_component_parameters(
                            self.celltype.neuron,
                            names_by_component.pop("neuron"))
                    else:
                        parameter_space = ParameterSpace({})
                    for component_label in names_by_component:
                        parameter_space[component_label] = _get_component_parameters(
                                self.celltype.post_synaptic_receptors[component_label],
                                names_by_component[component_label],
                                component_label)
            else:
                parameter_space = _get_component_parameters(self.celltype, names)
        else:
            parameter_space = self._get_native_parameters(*names)
        return parameter_space

    def _get_native_parameters(self, *names, component_label=None):
        """
        return a ParameterSpace containing native parameters
        """
        parameter_dict = {}
        for name in names:
            if name == 'spike_times':  # hack
                parameter_dict[name] = [Sequence(getattr(id._cell, name)) for id in self]
            else:
                if component_label:
                    val = np.array([getattr(getattr(id._cell, component_label), name)
                                    for id in self])
                else:
                    val = np.array([getattr(id._cell, name)
                                    for id in self])
                if isinstance(val[0], tuple) or len(val.shape) == 2:
                    val = np.array([ArrayParameter(v) for v in val])
                    val = LazyArray(simplify(val), shape=(self.local_size,), dtype=ArrayParameter)
                    parameter_dict[name] = val
                else:
                    parameter_dict[name] = simplify(val)
                parameter_dict[name] = simplify(val)
        return ParameterSpace(parameter_dict, shape=(self.local_size,))

    def _set_initial_value_array(self, variable_name, initial_values):
        # todo: support different initial values in different segments
        if hasattr(self.celltype, "variable_map"):
            variable_name = self.celltype.variable_map[variable_name]

        if "." in variable_name:
            mech_name, state_name = variable_name.split(".")
        else:
            mech_name, state_name = None, variable_name

        if initial_values.is_homogeneous:
            value = initial_values.evaluate(simplify=True)
            for cell in self:  # only on local node
                if mech_name:
                    cell._cell.initial_values[mech_name][state_name] = value
                else:
                    cell._cell.initial_values[state_name] = value
        else:
            if (
                isinstance(initial_values.base_value, RandomDistribution)
                and initial_values.base_value.rng.parallel_safe
            ):
                local_values = initial_values.evaluate()[self._mask_local]
            else:
                local_values = initial_values[self._mask_local]
            for cell, value in zip(self, local_values):
                if mech_name:
                    cell._cell.initial_values[mech_name][state_name] = value
                else:
                    cell._cell.initial_values[state_name] = value


[docs] class Assembly(common.Assembly): __doc__ = common.Assembly.__doc__ _simulator = simulator
[docs] class PopulationView(common.PopulationView, PopulationMixin): __doc__ = common.PopulationView.__doc__ _simulator = simulator _assembly_class = Assembly def _get_view(self, selector, label=None): return PopulationView(self, selector, label)
[docs] class Population(common.Population, PopulationMixin): __doc__ = common.Population.__doc__ _simulator = simulator _recorder_class = Recorder _assembly_class = Assembly def __init__(self, size, cellclass, cellparams=None, structure=None, initial_values={}, label=None): common.Population.__init__(self, size, cellclass, cellparams, structure, initial_values, label) simulator.initializer.register(self) def _get_view(self, selector, label=None): return PopulationView(self, selector, label) def _create_cells(self): """ Create cells in NEURON using the celltype of the current Population. """ # this method should never be called more than once # perhaps should check for that self.first_id = simulator.state.gid_counter self.last_id = simulator.state.gid_counter + self.size - 1 self.all_cells = np.array([id for id in range(self.first_id, self.last_id + 1)], simulator.ID) # mask_local is used to extract those elements from arrays # that apply to the cells on the current node, assuming # round-robin distribution of cells between nodes self._mask_local = self.all_cells % simulator.state.num_processes == simulator.state.mpi_rank # noqa: E501 if isinstance(self.celltype, StandardCellType): parameter_space = self.celltype.native_parameters else: parameter_space = self.celltype.parameter_space parameter_space.shape = (self.size,) parameter_space.evaluate(mask=None, simplify=True) if hasattr(self.celltype, "post_synaptic_receptors"): psrs = {name: psr.model for name, psr in self.celltype.post_synaptic_receptors.items()} else: psrs = None for i, (id, is_local, params) in enumerate( zip(self.all_cells, self._mask_local, parameter_space) ): self.all_cells[i] = simulator.ID(id) self.all_cells[i].parent = self if is_local: if hasattr(self.celltype, "extra_parameters"): params.update(self.celltype.extra_parameters) self.all_cells[i]._build_cell(self.celltype.model, params, psrs) simulator.initializer.register(*self.all_cells[self._mask_local]) simulator.state.gid_counter += self.size def _native_rset(self, parametername, rand_distr): """ 'Random' set. Set the value of parametername to a value taken from rand_distr, which should be a RandomDistribution object. """ assert isinstance(rand_distr.rng, NativeRNG) rng = simulator.h.Random(rand_distr.rng.seed or 0) native_rand_distr = getattr(rng, rand_distr.name) rarr = ([native_rand_distr(*rand_distr.parameters)] + [rng.repick() for i in range(self.all_cells.size - 1)]) self.tset(parametername, rarr)