Changeset 364

Show
Ignore:
Timestamp:
06/13/08 17:07:23 (5 months ago)
Author:
apdavison
Message:

Implemented a Recorder class for the neuron module, although gathering is temporarily broken

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/src/nest2/__init__.py

    r363 r364  
    717717            #rarr = rand_distr.next(n=len(self.cell_local)) 
    718718            rarr = rand_distr.next(n=self.size) 
    719             print rank(), self.cell_local[:5], self.cell_local[-5:], len(rarr), len(self.cell_local) 
    720719            assert len(rarr) >= len(self.cell_local), "The length of rarr (%d) must be greater than that of cell_local (%d)" % (len(rarr), len(self.cell_local)) 
    721720            rarr = rarr[:len(self.cell_local)] 
  • trunk/src/neuron/__init__.py

    r361 r364  
    2828ncid          = 0 
    2929gidlist       = [] 
    30 vfilelist     = {} 
    31 spikefilelist = {} 
     30recorder_list = [] 
    3231running       = False 
    3332initialised   = False 
    3433nrn_dll_loaded = [] 
    3534quit_on_end   = True 
     35RECORDING_VECTOR_NAMES = {'spikes': 'spiketimes', 
     36                          'v': 'vtrace'} 
    3637 
    3738# ============================================================================== 
     
    289290            raise HocError("caused by HocToPy.bool('%s')" % condition) 
    290291        return HocToPy.hocvar 
     292 
     293 
     294class Recorder(object): 
     295    """Encapsulates data and functions related to recording model variables.""" 
     296     
     297    def __init__(self, variable, population=None, file=None): 
     298        """ 
     299        `file` should be one of: 
     300            a file-name, 
     301            `None` (write to a temporary file) 
     302            `False` (write to memory). 
     303        """ 
     304        assert variable in RECORDING_VECTOR_NAMES 
     305        self.variable = variable 
     306        self.filename = file or None 
     307        self.population = population # needed for writing header information 
     308        self.recorded = Set([])         
     309 
     310    def record(self, ids): 
     311        """Add the cells in `ids` to the set of recorded cells.""" 
     312        ids = Set([id for id in ids if id in self.population.gidlist]) 
     313        new_ids = list( ids.difference(self.recorded) ) 
     314        self.recorded = self.recorded.union(ids) 
     315        if self.population is None: 
     316            cell_template = "cell%d" 
     317            id_list = new_ids 
     318        else: 
     319            cell_template = "%s.object(%s)" % (self.population.hoc_label, "%d") 
     320            id_list = self.population.gidlist 
     321        if self.variable == 'spikes': 
     322            template = 'tmp = %s.record(1)' % cell_template 
     323        elif self.variable == 'v':  
     324            template = 'tmp = %s.record_v(1,%g)' % (cell_template, get_time_step()) 
     325        hoc_commands = [] 
     326        for src in new_ids: 
     327            hoc_commands += [template % id_list.index(src)] 
     328        hoc_execute(hoc_commands, "---Recorder.record() ---") 
     329         
     330    def get(self, gather=False): 
     331        """Returns the recorded data.""" 
     332        data = recording.readArray(filename, sepchar=None) 
     333        data = recording.convert_compatible_output(data, self.population, variable) 
     334        return data 
     335     
     336    def write(self, file=None, gather=False, compatible_output=True): 
     337        hoc_execute(['objref gathered_vec_list', 
     338                     'gathered_vec_list =  new List()']) 
     339        vector_operation = '' 
     340        if self.variable == 'spikes': 
     341            vector_operation = '.where("<=", tstop)' 
     342        header = "# dt = %g\\n# n = %d\\n" % (get_time_step(), int(h.tstop/get_time_step())) 
     343        if self.population is None: 
     344            cell_template = "cell%d" 
     345            post_label = "node%d: post cellX.%s" % (myid, self.variable) 
     346            id_list = gidlist 
     347            padding = 0 
     348        else: 
     349            cell_template = "%s.object(%s)" % (self.population.hoc_label, "%d") 
     350            post_label = 'node%d: post_%s.%s' % (myid, self.population.hoc_label, self.variable) 
     351            id_list = self.population.gidlist 
     352            padding = self.population.gid_start 
     353             
     354        def post_data(): 
     355            pack_template = 'tmp = pc.pack(%s.%s%s)' % (cell_template, 
     356                                                        RECORDING_VECTOR_NAME[self.variable], 
     357                                                        vector_operation) 
     358            for cell in self.recorded: 
     359                hoc_commands += ['tmp = pc.pack(%d)' % id_list.index(cell), 
     360                                 pack_template % id_list.index(cell)] 
     361            hoc_commands += ['tmp = pc.post("%s")' % post_label] 
     362            hoc_execute(hoc_commands,"--- Population[%s].__print()__ --- [Post objects to master]" %self.label) 
     363        def take_data(): 
     364            hoc_commands = ['tmp = pc.take(post_label)'] 
     365            for node in range(1, num_processes()): 
     366                hoc_commands += ['gathered_vec_list.append(pc.upkscalar())', 
     367                                 'gathered_vec_list.append(pc.upkvec())'] 
     368        def write_data(): 
     369            if self.population is None: 
     370                header = "# first_id = %d\\n# last_id = %d\\n" % (min(self.recorded), max(self.recorded)) 
     371            else: 
     372                header = "# %d" % self.population.dim[0] 
     373                for dimension in list(self.population.dim)[1:]: 
     374                    header = "%s\t%d" % (header, dimension) 
     375                header += "\\n# first_id = %d\\n# last_id = %d\\n" % (self.population.gid_start, self.population.gid_start+self.population.size-1) 
     376             
     377            if self.variable == 'v': 
     378                header += "# dt = %g\\n# n = %d\\n" % (get_time_step(), int(h.tstop/get_time_step())) 
     379                num_format = "%.6g" 
     380            elif self.variable == 'spikes': 
     381                header += "# dt = %g\\n"% get_time_step() 
     382                num_format = "%.2f" 
     383            filename = file or self.filename 
     384            hoc_commands = ['objref fileobj', 
     385                            'fileobj = new File()', 
     386                            'tmp = fileobj.wopen("%s")' % filename, 
     387                            'tmp = fileobj.printf("%s")' % header, 
     388                            'i = 0'] 
     389            write_template = 'tmp = %s.%s%s.printf(fileobj, fmt)' % (cell_template, 
     390                                                                     RECORDING_VECTOR_NAMES[self.variable], 
     391                                                                     vector_operation) 
     392            for cell in self.recorded: 
     393                hoc_commands += ['fmt = "%s\\t%d\\n"' % (num_format, cell-padding), 
     394                                 write_template % id_list.index(cell)] 
     395            # writing gathered data is currently broken 
     396            #hoc_commands += ['while i < gathered_vec_list.count()-2 { gathered_vec_list.o(i+1).printf(fileobj, "%s broken") ' % num_format] 
     397            hoc_commands += ['tmp = fileobj.close()'] 
     398            hoc_execute(hoc_commands, "Recorder.write()") 
     399             
     400        if gather: 
     401            if myid != 0: # on slave nodes, post data 
     402                post_data() 
     403            else: 
     404                take_data() 
     405                write_data() 
     406        else: 
     407            filename += ".%d" % myid     
     408            write_data() 
    291409                 
    292410# ============================================================================== 
     
    371489    """Do any necessary cleaning up before exiting.""" 
    372490    global logfile, myid #, vfilelist, spikefilelist 
     491     
     492    for recorder in recorder_list: 
     493        recorder.write(gather=False, compatible_output=compatible_output) 
    373494    hoc_commands = [] 
    374     if len(vfilelist) > 0: 
    375         hoc_commands = ['objref fileobj', 
    376                         'fileobj = new File()'] 
    377         while len(vfilelist): 
    378             filename, cell_list = vfilelist.popitem() 
    379             #tstop = HocToPy.get('tstop','float') 
    380             tstop = h.tstop 
    381             header = "# dt = %g\\n# n = %d\\n" % (get_time_step(), int(tstop/get_time_step())) 
    382             header += "# first_id = %d\\n# last_id = %d\\n" % (cell_list[0], cell_list[-1]) 
    383             hoc_commands += ['tmp = fileobj.wopen("%s")' % filename, 
    384                              'tmp = fileobj.printf("%s")' % header] 
    385             for cell in cell_list: 
    386                 hoc_commands += ['fmt = "%s\\t%d\\n"' % ("%.6g", cell), 
    387                                  'tmp = cell%d.vtrace.printf(fileobj, fmt)' % cell] 
    388             hoc_commands += ['tmp = fileobj.close()'] 
    389     if len(spikefilelist) > 0: 
    390         hoc_commands += ['objref fileobj', 
    391                         'fileobj = new File()'] 
    392         header = "# dt = %g\\n"% get_time_step() 
    393         header += "# first_id = %d\\n #last_id = %d\\n" % (cell_list[0], cell_list[-1]) 
    394         while len(spikefilelist): 
    395             filename, cell_list = spikefilelist.popitem() 
    396             hoc_commands += ['tmp = fileobj.wopen("%s")' % filename, 
    397                              'tmp = fileobj.printf("%s")' % header] 
    398             for cell in cell_list: 
    399                 hoc_commands += ['fmt = "%s\\t%d\\n"' % ("%.2f", cell), 
    400                                  #'tmp = fileobj.printf("# cell%d\\n")' % cell, 
    401                                  'tmp = cell%d.spiketimes.where("<=", tstop).printf(fileobj, fmt)' % cell] 
    402             hoc_commands += ['tmp = fileobj.close()'] 
    403495    hoc_commands += ['tmp = pc.runworker()', 
    404496                     'tmp = pc.done()'] 
     
    558650    # would actually like to be able to record to an array and choose later 
    559651    # whether to write to a file. 
    560     global spikefilelist, gidlist 
    561     if type(source) != types.ListType: 
     652    if not hasattr(source, '__len__'): 
    562653        source = [source] 
    563     hoc_commands = [] 
    564     if not spikefilelist.has_key(filename): 
    565         spikefilelist[filename] = [] 
    566     for src in source: 
    567         if src in gidlist: 
    568             hoc_commands += ['tmp = cell%d.record(1)' % src] 
    569             spikefilelist[filename] += [src] # writing to file is done in end() 
    570     hoc_execute(hoc_commands, "---record() ---") 
     654    recorder = Recorder('spikes', file=filename) 
     655    recorder.record(source) 
     656    recorder_list.append(recorder) 
    571657 
    572658def record_v(source, filename): 
     
    576662    # would actually like to be able to record to an array and 
    577663    # choose later whether to write to a file. 
    578     global vfilelist, gidlist 
    579     if type(source) != types.ListType: 
     664    if not hasattr(source, '__len__'): 
    580665        source = [source] 
    581     hoc_commands = [] 
    582     if not vfilelist.has_key(filename): 
    583         vfilelist[filename] = [] 
    584     for src in source: 
    585         if src in gidlist: 
    586             if src.parent: 
    587                 raise Exception("The record_v() function does not work with cells in a Population. Please use the record_v() method of the Population object.") 
    588             else: 
    589                 hoc_commands += ['tmp = cell%d.record_v(1,%g)' % (src, get_time_step())] 
    590             vfilelist[filename] += [src] # writing to file is done in end() 
    591     hoc_execute(hoc_commands, "---record_v() ---") 
     666    recorder = Recorder('spikes', file=filename) 
     667    recorder.record(source) 
     668    recorder_list.append(recorder) 
     669 
    592670 
    593671# ============================================================================== 
     
    647725        self.hoc_label = self.label.replace(" ","_") 
    648726         
    649         self.record_from = { 'spiketimes': Set(), 'vtrace': Set() } 
    650          
     727        self.recorders = {} 
     728        for variable in RECORDING_VECTOR_NAMES: 
     729            self.recorders[variable] = Recorder(variable, population=self)         
    651730         
    652731        # Now the gid and cellclass are stored as instance of the ID class, which will allow a syntax like 
     
    884963        """ 
    885964        global myid 
    886         hoc_commands = [] 
    887965        fixed_list=False 
    888  
    889966        if isinstance(record_from, list): #record from the fixed list specified by user 
    890967            fixed_list=True 
     
    904981            raise Exception("record_from must be either a list of cells or the number of cells to record from") 
    905982        # record_from is now a list or numpy array 
    906  
    907         suffix = ''*(record_what=='spiketimes') + '_v'*(record_what=='vtrace') 
    908         for id in record_from: 
    909             if id in self.gidlist: 
    910                 hoc_commands += ['tmp = %s.object(%d).record%s(1)' % (self.hoc_label, self.gidlist.index(id), suffix)] 
    911  
    912         # note that self.record_from is not the same on all nodes, like self.gidlist, for example. 
    913         self.record_from[record_what].update(Set(record_from)) 
    914         hoc_commands += ['objref record_from'] 
    915         hoc_execute(hoc_commands) 
    916  
    917         # Then we have to send the lists of local recorded objects to the master node, 
    918         # but only if the list has not been specified by the user. 
    919         if fixed_list is False: 
    920             if myid != 0:  # on slave nodes 
    921                 hoc_commands = ['record_from = new Vector()'] 
    922                 for id in self.record_from[record_what]: 
    923                     if id in self.gidlist: 
    924                         hoc_commands += ['record_from = record_from.append(%d)' %id] 
    925                 hoc_commands += ['tmp = pc.post("%s.record_from[%s].node[%d]", record_from)' %(self.hoc_label, record_what, myid)] 
    926                 hoc_execute(hoc_commands, "   (Posting recorded cells)") 
    927             else:          # on the master node 
    928                 for id in range (1, nhost): 
    929                     hoc_commands = ['record_from = new Vector()'] 
    930                     hoc_commands += ['tmp = pc.take("%s.record_from[%s].node[%d]", record_from)' %(self.hoc_label, record_what, id)] 
    931                     hoc_execute(hoc_commands) 
    932                     for j in xrange(int(h.record_from.size())): 
    933                         self.record_from[record_what].add(int(h.record_from.x[j])) 
     983        self.recorders[record_what].record(record_from) 
    934984 
    935985    def record(self, record_from=None, rng=None): 
     
    941991        """ 
    942992        hoc_comment("--- Population[%s].__record()__ ---" %self.label) 
    943         self.__record('spiketimes', record_from, rng) 
     993        self.__record('spikes', record_from, rng) 
    944994 
    945995    def record_v(self, record_from=None, rng=None): 
     
    9521002        """ 
    9531003        hoc_comment("--- Population[%s].__record_v()__ ---" %self.label) 
    954         self.__record('vtrace', record_from, rng) 
    955  
    956     def __print(self, print_what, filename, num_format, gather, header=None): 
    957         """Private method used by printSpikes() and print_v().""" 
    958         global myid 
    959         vector_operation = '' 
    960         if print_what == 'spiketimes': 
    961             vector_operation = '.where("<=", tstop)' 
    962         if gather and myid != 0: # on slave nodes, post data 
    963             hoc_commands = [] 
    964             for id in self.record_from[print_what]: 
    965                 if id in self.gidlist: 
    966                     hoc_commands += ['tmp = pc.post("%s[%d].%s",%s.object(%d).%s%s)' % (self.hoc_label, id, 
    967                                                                                         print_what, 
    968                                                                                         self.hoc_label, 
    969                                                                                         self.gidlist.index(id), 
    970                                                                                         print_what, 
    971                                                                                         vector_operation)] 
    972             hoc_execute(hoc_commands,"--- Population[%s].__print()__ --- [Post objects to master]" %self.label) 
    973  
    974         if not gather: 
    975             filename += ".%d" % myid 
    976              
    977         if myid==0 or not gather: 
    978             hoc_commands = ['objref fileobj', 
    979                             'fileobj = new File()', 
    980                             'tmp = fileobj.wopen("%s")' % filename] 
    981             if header: 
    982                 hoc_commands += ['tmp = fileobj.printf("%s\\n")' % header] 
    983             if gather: 
    984                 hoc_commands += ['objref gatheredvec'] 
    985             padding = self.fullgidlist[0] 
    986             for id in self.record_from[print_what]: 
    987                 addr = self.locate(id) 
    988                 #hoc_commands += ['fmt = "%s\\t%s\\n"' % (num_format, "\\t".join([str(j) for j in addr]))] 
    989                 hoc_commands += ['fmt = "%s\\t%d\\n"' % (num_format, id-padding)] 
    990                 if id in self.gidlist: 
    991                     hoc_commands += ['tmp = %s.object(%d).%s%s.printf(fileobj, fmt)' % (self.hoc_label, 
    992                                                                                        self.gidlist.index(id), 
    993                                                                                        print_what, 
    994                                                                                        vector_operation)] 
    995                 elif gather:  
    996                     hoc_commands += ['gatheredvec = new Vector()'] 
    997                     hoc_commands += ['tmp = pc.take("%s[%d].%s", gatheredvec)' % (self.hoc_label, id, print_what), 
    998                                      'tmp = gatheredvec.printf(fileobj, fmt)'] 
    999             hoc_commands += ['tmp = fileobj.close()'] 
    1000             hoc_execute(hoc_commands,"--- Population[%s].__print()__ ---" %self.label) 
     1004        self.__record('v', record_from, rng) 
    10011005 
    10021006    def printSpikes(self, filename, gather=True, compatible_output=True): 
     
    10221026        """         
    10231027        hoc_comment("--- Population[%s].__printSpikes()__ ---" %self.label) 
    1024         header = "# %d" %self.dim[0] 
    1025         for dimension in list(self.dim)[1:]: 
    1026             header = "%s\t%d" %(header, dimension) 
    1027         header += "\\n# first_id = %d\\n# last_id = %d\\n" % (self.fullgidlist[0], self.fullgidlist[-1]) 
    1028         self.__print('spiketimes', filename,"%.2f", gather, header) 
     1028        self.recorders['spikes'].write(filename, gather, compatible_output) 
    10291029 
    10301030    def print_v(self, filename, gather=True, compatible_output=True): 
     
    10471047        on that node. 
    10481048        """ 
    1049         #tstop = HocToPy.get('tstop','float') 
    1050         tstop = h.tstop 
    1051         header = "# dt = %f\\n# n = %d\\n" % (get_time_step(), int(tstop/get_time_step())) 
    1052         header = "%s# %d" %(header, self.dim[0]) 
    1053         for dimension in list(self.dim)[1:]: 
    1054                 header = "%s\t%d" %(header, dimension) 
    1055         header += "\\n# first_id = %d\\n# last_id = %d\\n" % (self.fullgidlist[0], self.fullgidlist[-1]) 
    10561049        hoc_comment("--- Population[%s].__print_v()__ ---" %self.label) 
    1057         self.__print('vtrace', filename,"%.4g", gather, header
     1050        self.recorders['v'].write(filename, gather, compatible_output
    10581051 
    10591052    def getSpikes(self, gather=True): 
     
    10661059        # This is a bit of a hack implemetation 
    10671060        tmpfile = "neuron_tmpfile" # should really use tempfile module 
    1068         self.__print('spiketimes', tmpfile, "%.2f", gather
     1061        self.recorders['spikes'].write(tmpfile, gather, compatible_output=False
    10691062        if not gather: 
    10701063            tmpfile += '%d' % myid 
    10711064        if myid==0 or not gather: 
    10721065            f = open(tmpfile, 'r') 
    1073             lines = [line for line in f.read().split('\n') if line] # remove blank lines 
     1066            lines = [line for line in f.read().split('\n') if line and line[0]!='#'] # remove blank and comment lines 
    10741067            line2spike = lambda s: (int(s[1]), float(s[0])) 
    10751068            spikes = numpy.array([line2spike(line.split()) for line in lines]) 
     
    10901083            hoc_commands = [] 
    10911084            nspikes = 0;ncells  = 0 
    1092             for id in self.record_from['spiketimes']
     1085            for id in self.recorders['spikes'].recorded
    10931086                if id in self.gidlist: 
    10941087                    #nspikes += HocToPy.get('%s.object(%d).spiketimes.size()' %(self.hoc_label, self.gidlist.index(id)),'int') 
     
    11031096            nspikes = 0.0; ncells = 0.0 
    11041097            hoc_execute(["nspikes = 0", "ncells = 0"]) 
    1105             for id in self.record_from['spiketimes']
     1098            for id in self.recorders['spikes'].recorded
    11061099                if id in self.gidlist: 
    11071100                    nspikes += getattr(h, self.hoc_label).object(self.gidlist.index(id)).spiketimes.size() 
  • trunk/test/neurontests.py

    r361 r364  
    445445     
    446446    def setUp(self): 
     447        neuron.Population.nPop = 0 
    447448        self.pop1 = neuron.Population((3,3), neuron.SpikeSourcePoisson,{'rate': 20}) 
    448449        self.pop2 = neuron.Population((3,3), neuron.IF_curr_alpha) 
     
    455456        """Population.record(n): not a full test, just checking there are no Exceptions raised.""" 
    456457        # Partial record         
    457        self.pop1.record(5) 
     458        self.pop1.record(5) 
    458459         
    459460    def testRecordWithRNG(self):