Changeset 377

Show
Ignore:
Timestamp:
06/19/08 18:01:38 (5 months ago)
Author:
apdavison
Message:

Major reorganisation of neuron2 to better encapsulate global simulator state information and reduce problems with multiple imports.

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/src/neuron2/__init__.py

    r376 r377  
    88 
    99from pyNN.random import * 
    10 from pyNN.neuron2.utility import * 
     10from pyNN.neuron2 import simulator 
    1111from pyNN import common, utility 
    1212from pyNN.neuron2.cells import * 
     
    2121 
    2222# Global variables 
    23 gid_counter = 0 
    24 initialised = False 
    25 running = False 
    2623quit_on_end = True 
    2724recorder_list = [] 
     
    3734    simulator but not by others. 
    3835    """ 
    39     global initialised, quit_on_end, running, parallel_context, initializer 
    40     if not initialised: 
    41         h('min_delay = 0') 
    42         h('tstop = 0') 
    43         parallel_context = neuron.ParallelContext() 
    44         parallel_context.spike_compress(1,0) 
    45         cvode = neuron.CVode() 
    46         initializer = Initializer() 
     36    global quit_on_end 
     37    if not simulator.state.initialized: 
    4738        utility.init_logging("neuron2.log.%d" % rank(), debug) 
    4839        logging.info("Initialization of NEURON (use setup(.., debug=True) to see a full logfile)") 
    49     h.dt = timestep 
    50     h.tstop = 0 
    51     h.min_delay = min_delay 
    52     running = False 
     40        simulator.state.initialized = True 
     41    simulator.state.dt = timestep 
     42    simulator.state.min_delay = min_delay 
     43    simulator.reset() 
    5344    if 'quit_on_end' in extra_params: 
    5445        quit_on_end = extra_params['quit_on_end'] 
    5546    if extra_params.has_key('use_cvode'): 
    56         cvode.active(int(extra_params['use_cvode'])) 
     47        simulator.state.cvode.active(int(extra_params['use_cvode'])) 
    5748    return rank() 
    5849 
     
    6152    for recorder in recorder_list: 
    6253        recorder.write(gather=False, compatible_output=compatible_output) 
    63     parallel_context.runworker() 
    64     parallel_context.done() 
    65     if quit_on_end: 
    66         logging.info("Finishing up with NEURON.") 
    67         h.quit() 
     54    simulator.finalize(quit_on_end) 
    6855         
    6956def run(simtime): 
    7057    """Run the simulation for simtime ms.""" 
    71     global running 
    72     if not running: 
    73         running = True 
    74         local_minimum_delay = parallel_context.set_maxstep(10) 
    75         h.finitialize() 
    76         h.tstop = 0 
    77         logging.debug("local_minimum_delay on host #%d = %g" % (rank(), local_minimum_delay)) 
    78         if num_processes() > 1: 
    79             assert local_minimum_delay >= get_min_delay(),\ 
    80                    "There are connections with delays (%g) shorter than the minimum delay (%g)" % (local_minimum_delay, get_min_delay()) 
    81     h.tstop = simtime 
    82     logging.info("Running the simulation for %d ms" % simtime) 
    83     parallel_context.psolve(h.tstop) 
    84     return get_current_time() 
     58    simulator.run(simtime) 
    8559 
    8660# ============================================================================== 
     
    9064def get_current_time(): 
    9165    """Return the current time in the simulation.""" 
    92     return h.t 
     66    return simulator.state.t 
    9367common.get_current_time = get_current_time 
    9468 
    9569def get_time_step(): 
    96     return h.dt 
     70    return simulator.state.dt 
    9771common.get_time_step = get_time_step 
    9872 
    9973def get_min_delay(): 
    100     return h.min_delay 
     74    return simulator.state.min_delay 
    10175common.get_min_delay = get_min_delay 
    10276 
    10377def num_processes(): 
    104     return int(parallel_context.nhost()) 
     78    return simulator.state.num_processes 
    10579 
    10680def rank(): 
    10781    """Return the MPI rank.""" 
    108     return int(parallel_context.id()) 
     82    return simulator.state.mpi_rank 
    10983 
    11084def list_standard_models(): 
     
    12094        gid = int(self) 
    12195        self._cell = cell_model(**cell_parameters)          # create the cell object 
    122         parallel_context.set_gid2node(gid, rank())          # assign the gid to this node 
    123         nc = neuron.NetCon(self._cell.source, None)         # } associate the cell spike source 
    124         parallel_context.cell(gid, nc.hoc_obj)              # } with the gid (using a temporary NetCon) 
     96        simulator.register_gid(gid, self._cell.source) 
    12597        self.parent = parent 
    12698     
     
    144116    Function used by both `create()` and `Population.__init__()` 
    145117    """ 
    146     global gid_counter 
    147118    assert n > 0, 'n must be a positive integer' 
    148119    if isinstance(cellclass, basestring): # cell defined in hoc template 
    149120        try: 
    150             cell_model = getattr(h, cellclass) 
     121            cell_model = getattr(simulator.h, cellclass) 
    151122        except AttributeError: 
    152123            raise common.InvalidModelError("There is no hoc template called %s" % cellclass) 
     
    159130        cell_model = cellclass 
    160131        cell_parameters = param_dict 
    161     first_id = gid_counter 
    162     last_id = gid_counter + n 
     132    first_id = simulator.state.gid_counter 
     133    last_id = simulator.state.gid_counter + n 
    163134    all_ids = numpy.array([id for id in range(first_id, last_id)], ID) 
    164135    # mask_local is used to extract those elements from arrays that apply to the cells on the current node 
     
    168139            all_ids[i] = ID(id) 
    169140            all_ids[i]._build_cell(cell_model, cell_parameters, parent=parent) 
    170     gid_counter += n 
     141    simulator.state.gid_counter += n 
    171142    return all_ids, mask_local, first_id, last_id 
    172143 
     
    180151    for id in all_ids[mask_local]: 
    181152        id.cellclass = cellclass 
    182     initializer.register(*all_ids[mask_local]) 
     153    simulator.initializer.register(*all_ids[mask_local]) 
    183154    all_ids = all_ids.tolist() # not sure this is desirable, but it is consistent with the other modules 
    184155    if len(all_ids) == 1: 
    185156        all_ids = all_ids[0] 
    186157    return all_ids 
    187  
    188 def _single_connect(source, target, weight, delay, synapse_type): 
    189     """ 
    190     Private function to connect two neurons. 
    191     Used by `connect()` and the `Connector` classes. 
    192     """ 
    193     global gid_counter 
    194     if not isinstance(source, int) or source > gid_counter or source < 0: 
    195         errmsg = "Invalid source ID: %s (gid_counter=%d)" % (source, gid_counter) 
    196         raise common.ConnectionError(errmsg) 
    197     if not isinstance(target, ID): 
    198         raise common.ConnectionError("Invalid target ID: %s" % target) 
    199     if synapse_type is None: 
    200         synapse_type = weight>=0 and 'excitatory' or 'inhibitory' 
    201     if weight is None: 
    202         weight = common.DEFAULT_WEIGHT 
    203     if "cond" in target.cellclass.__name__: 
    204         weight = abs(weight) # weights must be positive for conductance-based synapses 
    205     elif synapse_type == 'inhibitory' and weight > 0: 
    206         weight *= -1         # and negative for inhibitory, current-based synapses 
    207     if delay is None: 
    208         delay = get_min_delay() 
    209     elif delay < get_min_delay(): 
    210         raise common.ConnectionError("delay (%s) is too small (< %s)" % (delay, get_min_delay())) 
    211     synapse_object = getattr(target._cell, synapse_type).hoc_obj 
    212     nc = parallel_context.gid_connect(int(source), synapse_object) 
    213     nc.weight[0] = weight 
    214     nc.delay  = delay 
    215     return nc 
    216158 
    217159def connect(source, target, weight=None, delay=None, synapse_type=None, p=1, rng=None): 
     
    235177            sources = sources[rarr<p] 
    236178        for src in sources: 
    237             nc = _single_connect(src, tgt, weight, delay, synapse_type) 
     179            nc = simulator.single_connect(src, tgt, weight, delay, synapse_type) 
    238180            connection_list.append(nc) 
    239181    return connection_list 
     
    260202    if not hasattr(source, '__len__'): 
    261203        source = [source] 
    262     recorder = Recorder('spikes', file=filename) 
     204    recorder = simulator.Recorder('spikes', file=filename) 
    263205    recorder.record(source) 
    264206    recorder_list.append(recorder) 
     
    272214    if not hasattr(source, '__len__'): 
    273215        source = [source] 
    274     recorder = Recorder('v', file=filename) 
     216    recorder = simulator.Recorder('v', file=filename) 
    275217    recorder.record(source) 
    276218    recorder_list.append(recorder) 
     
    305247        """ 
    306248        common.Population.__init__(self, dims, cellclass, cellparams, label) 
    307         self.recorders = {'spikes': Recorder('spikes', population=self), 
    308                           'v': Recorder('v', population=self)} 
     249        self.recorders = {'spikes': simulator.Recorder('spikes', population=self), 
     250                          'v': simulator.Recorder('v', population=self)} 
    309251        self.label = self.label or 'population%d' % Population.nPop 
    310252        if isinstance(cellclass, type) and issubclass(cellclass, common.StandardCellType): 
     
    322264        self._mask_local = self._mask_local.reshape(self.dim) 
    323265         
    324         initializer.register(self) 
     266        simulator.initializer.register(self) 
    325267        Population.nPop += 1 
    326268        logging.info(self.describe('Creating Population "%(label)s" of shape %(dim)s, '+ 
     
    456398        # provided that the same rng with the same seed is used on each node. 
    457399        if isinstance(rand_distr.rng, NativeRNG): 
    458             rng = h.Random(rand_distr.rng.seed or 0) 
     400            rng = simulator.h.Random(rand_distr.rng.seed or 0) 
    459401            native_rand_distr = getattr(rng, rand_distr.name) 
    460402            rarr = [native_rand_distr(*rand_distr.parameters)] + [rng.repick() for i in range(self._all_ids.size-1)] 
     
    477419        Private method called by record() and record_v(). 
    478420        """ 
    479         global myid 
    480421        fixed_list=False 
    481422        if isinstance(record_from, list): #record from the fixed list specified by user 
     
    666607            connection_method(method_parameters) 
    667608        elif isinstance(method, common.Connector): 
    668             print "gid_counter = ", gid_counter 
     609            print "simulator.gid_counter = ", simulator.State.gid_counter 
    669610            method.connect(self) 
    670611             
  • trunk/src/neuron2/cells.py

    r376 r377  
    164164            for name in 'start', 'interval', 'number': 
    165165                setattr(self.source, name, locals()[name]) 
    166             self.spiketimes = neuron.Vector() 
     166            self.source.noise = 1 
     167            self.spiketimes = neuron.Vector(0) 
    167168            self.do_not_record = False 
    168169 
     
    174175        if not self.do_not_record: # for VecStims, etc, recording doesn't make sense as we already have the spike times 
    175176            if active: 
    176                 rec = neuron.NetCon(self.source, None) 
    177                 rec.record(self.spiketimes.hoc_obj) 
     177                self.spiketimes.hoc_obj.printf() 
     178                self.rec = neuron.NetCon(self.source, None) 
     179                self.rec.record(self.spiketimes.hoc_obj) 
    178180             
    179181 
  • trunk/src/neuron2/connectors.py

    r376 r377  
    66from pyNN import common 
    77from pyNN.random import RandomDistribution, NativeRNG 
    8 from pyNN.neuron2.__init__ import get_min_delay, _single_connect 
     8from pyNN.neuron2 import simulator 
    99import numpy 
    1010from math import * 
    11  
    12 common.get_min_delay = get_min_delay 
    1311 
    1412# ============================================================================== 
     
    4240                delays = d.__iter__() 
    4341            else: 
    44                 delays = ConstIter(max((d, get_min_delay()))) 
     42                delays = ConstIter(max((d, simulator.state.min_delay))) 
    4543        else: 
    46             delays = ConstIter(get_min_delay()
     44            delays = ConstIter(simulator.state.min_delay
    4745        return delays 
    4846 
     
    8280                if create[j]: 
    8381                    projection.connections.append( 
    84                             _single_connect(src, tgt, 
    85                                             weights.next(), delays.next(), 
    86                                             projection.synapse_type)) 
     82                            simulator.single_connect(src, tgt, 
     83                                                     weights.next(), delays.next(), 
     84                                                     projection.synapse_type)) 
    8785 
    8886 
     
    102100                src = tgt - projection.post.first_id + projection.pre.first_id 
    103101                projection.connections.append( 
    104                     _single_connect(src, tgt, weights.next(), delays.next(), projection.synapse_type)) 
     102                    simulator.single_connect(src, tgt, weights.next(), delays.next(), projection.synapse_type)) 
    105103        else: 
    106104            raise Exception("OneToOneConnector does not support presynaptic and postsynaptic Populations of different sizes.") 
  • trunk/src/neuron2/simulator.py

    r376 r377  
    6969            for id in self.recorded: 
    7070                spikes = id._cell.spiketimes.toarray() 
    71                 print "t = ", common.get_current_time() 
    72                 spikes = spikes[spikes<=common.get_current_time()+1e-9] 
     71                spikes = spikes[spikes<=state.t+1e-9] 
    7372                if len(spikes) > 0: 
    7473                    new_data = numpy.array([spikes, numpy.ones(spikes.shape)*id]).T 
     
    9190                                              self.population, common.get_time_step()) 
    9291         
    93 class Initializer(object): 
     92class _Initializer(object): 
    9493     
    9594    def __init__(self): 
    9695        self.cell_list = [] 
    9796        self.population_list = [] 
    98         neuron.h('objref initializer') 
     97        h('objref initializer') 
    9998        neuron.h.initializer = self 
    10099        self.fih = h.FInitializeHandler("initializer.initialize()") 
     100     
     101    def __call__(self): 
     102        """This is to make the Initializer a Singleton.""" 
     103        return self 
    101104     
    102105    def register(self, *items): 
     
    118121                cell._cell.memb_init() 
    119122 
    120  
    121 load_mechanisms() 
     123def h_property(name): 
     124    def _get(self): 
     125        return getattr(h,name) 
     126    def _set(self, val): 
     127        setattr(h, name, val) 
     128    return property(fget=_get, fset=_set) 
     129 
     130class _State(object): 
     131    """Represent the simulator state.""" 
     132     
     133    def __init__(self): 
     134        self.gid_counter = 0 
     135        self.running = False 
     136        self.initialized = False 
     137        h('min_delay = 0') 
     138        h('tstop = 0') 
     139        self.parallel_context = neuron.ParallelContext() 
     140        self.parallel_context.spike_compress(1,0) 
     141        self.num_processes = int(self.parallel_context.nhost()) 
     142        self.mpi_rank = int(self.parallel_context.id()) 
     143        self.cvode = neuron.CVode() 
     144     
     145    t = h_property('t') 
     146    dt = h_property('dt') 
     147    tstop = h_property('tstop')         # } do these really need to be stored in hoc? 
     148    min_delay = h_property('min_delay') # } 
     149     
     150     
     151    def __call__(self): 
     152        """This is to make the State a Singleton.""" 
     153        return self 
     154     
     155def reset(): 
     156    state.running = False 
     157    state.t = 0 
     158    state.tstop = 0 
     159 
     160def run(simtime): 
     161    if not state.running: 
     162        state.running = True 
     163        local_minimum_delay = state.parallel_context.set_maxstep(10) 
     164        h.finitialize() 
     165        state.tstop = 0 
     166        logging.debug("local_minimum_delay on host #%d = %g" % (state.mpi_rank, local_minimum_delay)) 
     167        if state.num_processes > 1: 
     168            assert local_minimum_delay >= state.min_delay,\ 
     169                   "There are connections with delays (%g) shorter than the minimum delay (%g)" % (local_minimum_delay, state.min_delay) 
     170    state.tstop = simtime 
     171    logging.info("Running the simulation for %d ms" % simtime) 
     172    state.parallel_context.psolve(state.tstop) 
     173    return state.t 
     174 
     175 
     176def finalize(quit=True): 
     177    state.parallel_context.runworker() 
     178    state.parallel_context.done() 
     179    if quit: 
     180        logging.info("Finishing up with NEURON.") 
     181        h.quit() 
     182 
     183def register_gid(gid, source): 
     184    state.parallel_context.set_gid2node(gid, state.mpi_rank)  # assign the gid to this node 
     185    nc = neuron.NetCon(source, None)                          # } associate the cell spike source 
     186    state.parallel_context.cell(gid, nc.hoc_obj)              # } with the gid (using a temporary NetCon) 
     187 
     188def single_connect(source, target, weight, delay, synapse_type): 
     189    """ 
     190    Private function to connect two neurons. 
     191    Used by `connect()` and the `Connector` classes. 
     192    """ 
     193    if not isinstance(source, int) or source > state.gid_counter or source < 0: 
     194        errmsg = "Invalid source ID: %s (gid_counter=%d)" % (source, state.gid_counter) 
     195        raise common.ConnectionError(errmsg) 
     196    if not isinstance(target, common.IDMixin): 
     197        raise common.ConnectionError("Invalid target ID: %s" % target) 
     198    if synapse_type is None: 
     199        synapse_type = weight>=0 and 'excitatory' or 'inhibitory' 
     200    if weight is None: 
     201        weight = common.DEFAULT_WEIGHT 
     202    if "cond" in target.cellclass.__name__: 
     203        weight = abs(weight) # weights must be positive for conductance-based synapses 
     204    elif synapse_type == 'inhibitory' and weight > 0: 
     205        weight *= -1         # and negative for inhibitory, current-based synapses 
     206    if delay is None: 
     207        delay = state.min_delay 
     208    elif delay < state.min_delay: 
     209        raise common.ConnectionError("delay (%s) is too small (< %s)" % (delay, state.min_delay)) 
     210    synapse_object = getattr(target._cell, synapse_type).hoc_obj 
     211    nc = state.parallel_context.gid_connect(int(source), synapse_object) 
     212    nc.weight[0] = weight 
     213    nc.delay  = delay 
     214    return nc 
     215 
     216# The following are executed every time the module is imported. 
     217load_mechanisms() # maintains a list of mechanisms that have already been imported 
     218state = _State()  # a Singleton, so only a single instance ever exists 
     219initializer = _Initializer() 
  • trunk/test/neuron2tests.py

    r376 r377  
    2323     
    2424    def tearDown(self): 
    25         neuron.gid_counter = 0 
     25        neuron.simulator.state.gid_counter = 0 
    2626     
    2727    def testCreateStandardCell(self): 
     
    2929        logging.info('=== CreationTest.testCreateStandardCell() ===') 
    3030        ifcell = neuron.create(neuron.IF_curr_alpha) 
    31         assert ifcell == 0, 'Failed to create standard cell' 
     31        assert ifcell == 0, 'Failed to create standard cell (cell=%s)' % ifcell 
    3232         
    3333    def testCreateStandardCells(self): 
     
    437437        self.pop2 = neuron.Population((3,3), neuron.IF_curr_alpha) 
    438438 
     439    def tearDown(self): 
     440        neuron.simulator.reset() 
     441 
    439442    def testRecordAll(self): 
    440443        """Population.record(): not a full test, just checking there are no Exceptions raised.""" 
     
    463466        self.pop1.record() 
    464467        simtime = 1000.0 
    465         neuron.running = False 
    466468        neuron.run(simtime) 
    467         self.pop1.printSpikes("temp_neuron.ras", gather=True) 
     469        #self.pop1.printSpikes("temp_neuron.ras", gather=True) 
    468470        rate = self.pop1.meanSpikeCount()*1000/simtime 
    469471        if neuron.rank() == 0: # only on master node 
     
    493495        spike_source = neuron.Population(1, neuron.SpikeSourceArray, {'spike_times': spike_times}) 
    494496        spike_source.record() 
    495         neuron.running = False 
    496497        neuron.run(100.0) 
    497         spikes = spike_source.getSpikes()[:,0] 
     498        spikes = spike_source.getSpikes() 
     499        print spikes 
     500        spikes = spikes[:,0] 
    498501        if neuron.rank() == 0: 
    499502            self.assert_( max(spikes) == 100.0, str(spikes) )