| 1 | # encoding: utf-8 |
|---|
| 2 | """ |
|---|
| 3 | nrnpython implementation of the PyNN API. |
|---|
| 4 | |
|---|
| 5 | This is an attempt at a parallel-enabled implementation. |
|---|
| 6 | $Id:__init__.py 188 2008-01-29 10:03:59Z apdavison $ |
|---|
| 7 | """ |
|---|
| 8 | __version__ = "$Rev$" |
|---|
| 9 | |
|---|
| 10 | from neuron import hoc, Vector |
|---|
| 11 | h = hoc.HocObject() |
|---|
| 12 | from pyNN import __path__ as pyNN_path |
|---|
| 13 | from pyNN.random import * |
|---|
| 14 | from math import * |
|---|
| 15 | from pyNN import common |
|---|
| 16 | from pyNN.neuron.cells import * |
|---|
| 17 | from pyNN.neuron.connectors import * |
|---|
| 18 | from pyNN.neuron.synapses import * |
|---|
| 19 | import os.path |
|---|
| 20 | import types |
|---|
| 21 | import sys |
|---|
| 22 | import numpy |
|---|
| 23 | import logging |
|---|
| 24 | import platform |
|---|
| 25 | Set = set |
|---|
| 26 | |
|---|
| 27 | gid = 0 |
|---|
| 28 | ncid = 0 |
|---|
| 29 | gidlist = [] |
|---|
| 30 | vfilelist = {} |
|---|
| 31 | spikefilelist = {} |
|---|
| 32 | running = False |
|---|
| 33 | initialised = False |
|---|
| 34 | nrn_dll_loaded = [] |
|---|
| 35 | quit_on_end = True |
|---|
| 36 | |
|---|
| 37 | # ============================================================================== |
|---|
| 38 | # Utility classes and functions |
|---|
| 39 | # ============================================================================== |
|---|
| 40 | |
|---|
| 41 | class ID(int, common.IDMixin): |
|---|
| 42 | """ |
|---|
| 43 | Instead of storing ids as integers, we store them as ID objects, |
|---|
| 44 | which allows a syntax like: |
|---|
| 45 | p[3,4].tau_m = 20.0 |
|---|
| 46 | where p is a Population object. The question is, how big a memory/performance |
|---|
| 47 | hit is it to replace integers with ID objects? |
|---|
| 48 | """ |
|---|
| 49 | |
|---|
| 50 | #def __init__(self, n): |
|---|
| 51 | # common.ID.__init__(self, n) |
|---|
| 52 | # self.hocname = None |
|---|
| 53 | def __init__(self, n): |
|---|
| 54 | int.__init__(n) |
|---|
| 55 | common.IDMixin.__init__(self) |
|---|
| 56 | |
|---|
| 57 | def __getattr__(self, name): |
|---|
| 58 | # Need to override the version from common due to the problem of not |
|---|
| 59 | # being able to get a list of all the parameters in a native model |
|---|
| 60 | if self.is_standard_cell(): |
|---|
| 61 | return self.get_parameters()[name] |
|---|
| 62 | else: |
|---|
| 63 | cell = self._hoc_cell() |
|---|
| 64 | return self._get_hoc_parameter(cell, name) |
|---|
| 65 | |
|---|
| 66 | def _hoc_cell(self): |
|---|
| 67 | assert self in gidlist, "Cell %d does not exist on this node" % self |
|---|
| 68 | if self.parent: |
|---|
| 69 | hoc_cell_list = getattr(h, self.parent.label) |
|---|
| 70 | try: |
|---|
| 71 | #cell = hoc_cell_list.object(self - self.parent.gid_start) |
|---|
| 72 | list_index = self.parent.gidlist.index(self) |
|---|
| 73 | cell = hoc_cell_list.object(list_index) |
|---|
| 74 | except RuntimeError: |
|---|
| 75 | print "id:", self |
|---|
| 76 | print "parent.gid_start:", self.parent.gid_start |
|---|
| 77 | print "len(parent):", len(self.parent) |
|---|
| 78 | print "hoc_cell_list.count():", hoc_cell_list.count() |
|---|
| 79 | print "parent.gidlist.index(id):", self.parent.gidlist.index(self) |
|---|
| 80 | print "id.hocname:", self.hocname |
|---|
| 81 | raise |
|---|
| 82 | else: |
|---|
| 83 | cell_name = "cell%d" % int(self) |
|---|
| 84 | cell = getattr(h, cell_name) |
|---|
| 85 | return cell |
|---|
| 86 | |
|---|
| 87 | def _get_hoc_parameter(self, cell, name): |
|---|
| 88 | try: |
|---|
| 89 | val = getattr(cell, name) |
|---|
| 90 | except HocError: |
|---|
| 91 | val = getattr(cell.source, name) |
|---|
| 92 | return val |
|---|
| 93 | |
|---|
| 94 | def get_native_parameters(self): |
|---|
| 95 | # Construct the list of hoc parameter names to get |
|---|
| 96 | if self.is_standard_cell(): |
|---|
| 97 | parameter_names = [D['translated_name'] for D in self.cellclass.translations.values()] |
|---|
| 98 | else: |
|---|
| 99 | parameter_names = [] # for native cells, don't have a way to get their list of parameters |
|---|
| 100 | # Obtain the hoc object whose parameters we are going to get |
|---|
| 101 | cell = self._hoc_cell() |
|---|
| 102 | # Get the values from hoc |
|---|
| 103 | parameters = {} |
|---|
| 104 | for name in parameter_names: |
|---|
| 105 | val = self._get_hoc_parameter(cell, name) |
|---|
| 106 | if isinstance(val, hoc.HocObject): |
|---|
| 107 | val = [val.x[i] for i in range(int(val.size()))] |
|---|
| 108 | parameters[name] = val |
|---|
| 109 | return parameters |
|---|
| 110 | |
|---|
| 111 | def set_native_parameters(self, parameters): |
|---|
| 112 | cell = self._hoc_cell() |
|---|
| 113 | for name, val in parameters.items(): |
|---|
| 114 | if hasattr(val, '__len__'): |
|---|
| 115 | setattr(cell, name, Vector(val).hoc_obj) |
|---|
| 116 | else: |
|---|
| 117 | setattr(cell, name, val) |
|---|
| 118 | cell.param_update() |
|---|
| 119 | |
|---|
| 120 | def list_standard_models(): |
|---|
| 121 | return [obj for obj in globals().values() if isinstance(obj, type) and issubclass(obj, common.StandardCellType)] |
|---|
| 122 | |
|---|
| 123 | def load_mechanisms(path=pyNN_path[0]): |
|---|
| 124 | global nrn_dll_loaded |
|---|
| 125 | if path not in nrn_dll_loaded: |
|---|
| 126 | arch_list = [platform.machine(), 'i686', 'x86_64', 'powerpc'] |
|---|
| 127 | # in case NEURON is assuming a different architecture to Python, we try multiple possibilities |
|---|
| 128 | for arch in arch_list: |
|---|
| 129 | lib_path = os.path.join(path, 'hoc', arch, '.libs', 'libnrnmech.so') |
|---|
| 130 | if os.path.exists(lib_path): |
|---|
| 131 | h.nrn_load_dll(lib_path) |
|---|
| 132 | nrn_dll_loaded.append(path) |
|---|
| 133 | return |
|---|
| 134 | raise Exception("NEURON mechanisms not found in %s." % os.path.join(path, 'hoc')) |
|---|
| 135 | |
|---|
| 136 | # ============================================================================== |
|---|
| 137 | # Module-specific functions and classes (not part of the common API) |
|---|
| 138 | # ============================================================================== |
|---|
| 139 | |
|---|
| 140 | class HocError(Exception): pass |
|---|
| 141 | |
|---|
| 142 | def hoc_execute(hoc_commands, comment=None): |
|---|
| 143 | assert isinstance(hoc_commands, list) |
|---|
| 144 | if comment: |
|---|
| 145 | logging.debug(comment) |
|---|
| 146 | for cmd in hoc_commands: |
|---|
| 147 | logging.debug(cmd) |
|---|
| 148 | success = hoc.execute(cmd) |
|---|
| 149 | if not success: |
|---|
| 150 | raise HocError('Error produced by hoc command "%s"' % cmd) |
|---|
| 151 | |
|---|
| 152 | def hoc_comment(comment): |
|---|
| 153 | logging.debug(comment) |
|---|
| 154 | |
|---|
| 155 | def _hoc_arglist(paramlist): |
|---|
| 156 | """Convert a list of Python objects to a list of hoc commands which will |
|---|
| 157 | generate equivalent hoc objects.""" |
|---|
| 158 | hoc_commands = [] |
|---|
| 159 | argstr = "" |
|---|
| 160 | nvec = 0; nstr = 0; nvar = 0; ndict = 0; nmat = 0 |
|---|
| 161 | for item in paramlist: |
|---|
| 162 | if type(item) == types.ListType: |
|---|
| 163 | hoc_commands += ['objref argvec%d' % nvec, |
|---|
| 164 | 'argvec%d = new Vector(%d)' % (nvec, len(item))] |
|---|
| 165 | argstr += 'argvec%d, ' % nvec |
|---|
| 166 | for i in xrange(len(item)): |
|---|
| 167 | hoc_commands.append('argvec%d.x[%d] = %g' % (nvec, i, item[i])) # assume only numerical values |
|---|
| 168 | nvec += 1 |
|---|
| 169 | elif type(item) == types.StringType: |
|---|
| 170 | hoc_commands += ['strdef argstr%d' % nstr, |
|---|
| 171 | 'argstr%d = "%s"' % (nstr, item)] |
|---|
| 172 | argstr += 'argstr%d, ' % nstr |
|---|
| 173 | nstr += 1 |
|---|
| 174 | elif type(item) == types.DictType: |
|---|
| 175 | dict_init_list = [] |
|---|
| 176 | for k, v in item.items(): |
|---|
| 177 | if type(v) == types.StringType: |
|---|
| 178 | dict_init_list += ['"%s", "%s"' % (k, v)] |
|---|
| 179 | elif type(v) == types.ListType: |
|---|
| 180 | hoc_commands += ['objref argvec%d' % nvec, |
|---|
| 181 | 'argvec%d = new Vector(%d)' % (nvec, len(v))] |
|---|
| 182 | dict_init_list += ['"%s", argvec%d' % (k, nvec)] |
|---|
| 183 | for i in xrange(len(v)): |
|---|
| 184 | hoc_commands.append('argvec%d.x[%d] = %g' % (nvec, i, v[i])) # assume only numerical values |
|---|
| 185 | nvec += 1 |
|---|
| 186 | else: # assume number |
|---|
| 187 | dict_init_list += ['"%s", %g' % (k, float(v))] |
|---|
| 188 | hoc_commands += ['objref argdict%d' % ndict, |
|---|
| 189 | 'argdict%d = new Dict(%s)' % (ndict,", ".join(dict_init_list))] |
|---|
| 190 | argstr += 'argdict%d, ' % ndict |
|---|
| 191 | ndict += 1 |
|---|
| 192 | elif isinstance(item, numpy.ndarray): |
|---|
| 193 | ndim = len(item.shape) |
|---|
| 194 | if ndim == 1: # this has not been tested yet |
|---|
| 195 | cmd, argstr1 = _hoc_arglist([list(item)]) # convert to a list and call the current function recursively |
|---|
| 196 | hoc_commands += cmd |
|---|
| 197 | argstr += argstr1 |
|---|
| 198 | elif ndim == 2: |
|---|
| 199 | argstr += 'argmat%s,' % nmat |
|---|
| 200 | hoc_commands += ['objref argmat%d' % nmat, |
|---|
| 201 | 'argmat%d = new Matrix(%d,%d)' % (nmat, item.shape[0], item.shape[1])] |
|---|
| 202 | for i in xrange(item.shape[0]): |
|---|
| 203 | for j in xrange(item.shape[1]): |
|---|
| 204 | try: |
|---|
| 205 | hoc_commands += ['argmat%d.x[%d][%d] = %g' % (nmat, i, j, item[i, j])] |
|---|
| 206 | except TypeError: |
|---|
| 207 | raise common.InvalidParameterValueError |
|---|
| 208 | nmat += 1 |
|---|
| 209 | else: |
|---|
| 210 | raise common.InvalidDimensionsError, 'number of dimensions must be 1 or 2' |
|---|
| 211 | elif item is None: |
|---|
| 212 | pass |
|---|
| 213 | else: |
|---|
| 214 | hoc_commands += ['argvar%d = %f' % (nvar, item)] |
|---|
| 215 | argstr += 'argvar%d, ' % nvar |
|---|
| 216 | nvar += 1 |
|---|
| 217 | return hoc_commands, argstr.strip().strip(',') |
|---|
| 218 | |
|---|
| 219 | def _translate_synapse_type(synapse_type, weight=None, extra_mechanism=None): |
|---|
| 220 | """ |
|---|
| 221 | If synapse_type is given (not None), it is used to determine whether the |
|---|
| 222 | synapse is excitatory or inhibitory. |
|---|
| 223 | Otherwise, the synapse type is inferred from the sign of the weight. |
|---|
| 224 | Much testing needed to check if this behaviour matches nest and pcsim. |
|---|
| 225 | """ |
|---|
| 226 | if synapse_type: |
|---|
| 227 | if synapse_type == 'excitatory': |
|---|
| 228 | syn_objref = "esyn" |
|---|
| 229 | elif synapse_type == 'inhibitory': |
|---|
| 230 | syn_objref = "isyn" |
|---|
| 231 | else: |
|---|
| 232 | # More sophisticated treatment needed once we have more sophisticated synapse |
|---|
| 233 | # models, e.g. NMDA... |
|---|
| 234 | #raise common.InvalidParameterValueError, synapse_type, "valid types are 'excitatory' or 'inhibitory'" |
|---|
| 235 | syn_objref = synapse_type |
|---|
| 236 | else: |
|---|
| 237 | if weight is None or weight >= 0.0: |
|---|
| 238 | syn_objref = "esyn" |
|---|
| 239 | else: |
|---|
| 240 | syn_objref = "isyn" |
|---|
| 241 | if extra_mechanism == 'tsodkys-markram': |
|---|
| 242 | syn_objref += "_tm" |
|---|
| 243 | return syn_objref |
|---|
| 244 | |
|---|
| 245 | def checkParams(param, val=None): |
|---|
| 246 | """Check parameters are of valid types, normalise the different ways of |
|---|
| 247 | specifying parameters and values by putting everything in a dict. |
|---|
| 248 | Called by set() and Population.set().""" |
|---|
| 249 | if isinstance(param, str): |
|---|
| 250 | if isinstance(val, float) or isinstance(val, int): |
|---|
| 251 | param_dict = {param:float(val)} |
|---|
| 252 | elif isinstance(val,(str, list)): |
|---|
| 253 | param_dict = {param:val} |
|---|
| 254 | else: |
|---|
| 255 | raise common.InvalidParameterValueError |
|---|
| 256 | elif isinstance(param, dict): |
|---|
| 257 | param_dict = param |
|---|
| 258 | else: |
|---|
| 259 | raise common.InvalidParameterValueError |
|---|
| 260 | return param_dict |
|---|
| 261 | |
|---|
| 262 | class HocToPy: |
|---|
| 263 | """Static class to simplify getting variables from hoc.""" |
|---|
| 264 | |
|---|
| 265 | fmt_dict = {'int' : '%d', 'integer' : '%d', 'float' : '%f', 'double' : '%f', |
|---|
| 266 | 'string' : '\\"%s\\"', 'str' : '\\"%s\\"'} |
|---|
| 267 | |
|---|
| 268 | @staticmethod |
|---|
| 269 | def get(name, return_type='float'): |
|---|
| 270 | """Return a variable from hoc. |
|---|
| 271 | name can be a hoc variable (int, float, string) or a function/method |
|---|
| 272 | that returns such a variable. |
|---|
| 273 | """ |
|---|
| 274 | # We execute some commands here to avoid too much outputs in the log file |
|---|
| 275 | errorstr = '"raise HocError(\'caused by HocToPy.get(%s, return_type=\\"%s\\")\')"' % (name, return_type) |
|---|
| 276 | hoc_commands = ['success = sprint(cmd,"HocToPy.hocvar = %s",%s)' % (HocToPy.fmt_dict[return_type], name), |
|---|
| 277 | 'if (success) { nrnpython(cmd) } else { nrnpython(%s) }' % errorstr ] |
|---|
| 278 | hoc_execute(hoc_commands) |
|---|
| 279 | return HocToPy.hocvar |
|---|
| 280 | |
|---|
| 281 | @staticmethod |
|---|
| 282 | def bool(condition): |
|---|
| 283 | """Evaluate the condition in hoc and return True or False.""" |
|---|
| 284 | HocToPy.hocvar = None |
|---|
| 285 | hoc.execute('if (%s) { nrnpython("HocToPy.hocvar = True") } \ |
|---|
| 286 | else { nrnpython("HocToPy.hocvar = False") }' % condition) |
|---|
| 287 | if HocToPy.hocvar is None: |
|---|
| 288 | raise HocError("caused by HocToPy.bool('%s')" % condition) |
|---|
| 289 | return HocToPy.hocvar |
|---|
| 290 | |
|---|
| 291 | # ============================================================================== |
|---|
| 292 | # Functions for simulation set-up and control |
|---|
| 293 | # ============================================================================== |
|---|
| 294 | |
|---|
| 295 | def setup(timestep=0.1, min_delay=0.1, max_delay=10.0, debug=False,**extra_params): |
|---|
| 296 | """ |
|---|
| 297 | Should be called at the very beginning of a script. |
|---|
| 298 | extra_params contains any keyword arguments that are required by a given |
|---|
| 299 | simulator but not by others. |
|---|
| 300 | """ |
|---|
| 301 | global nhost, myid, logger, initialised, quit_on_end |
|---|
| 302 | load_mechanisms() |
|---|
| 303 | if 'quit_on_end' in extra_params: |
|---|
| 304 | quit_on_end = extra_params['quit_on_end'] |
|---|
| 305 | # Initialisation of the log module. To write in the logfile, simply enter |
|---|
| 306 | # logging.critical(), logging.debug(), logging.info(), logging.warning() |
|---|
| 307 | if debug: |
|---|
| 308 | logging.basicConfig(level=logging.DEBUG, |
|---|
| 309 | format='%(asctime)s %(levelname)s %(message)s', |
|---|
| 310 | filename='neuron.log', |
|---|
| 311 | filemode='w') |
|---|
| 312 | else: |
|---|
| 313 | logging.basicConfig(level=logging.INFO, |
|---|
| 314 | format='%(asctime)s %(levelname)s %(message)s', |
|---|
| 315 | filename='neuron.log', |
|---|
| 316 | filemode='w') |
|---|
| 317 | |
|---|
| 318 | logging.info("Initialization of NEURON (use setup(.., debug=True) to see a full logfile)") |
|---|
| 319 | |
|---|
| 320 | # All the objects that will be used frequently in the hoc code are declared in the setup |
|---|
| 321 | |
|---|
| 322 | if initialised: |
|---|
| 323 | hoc_commands = ['dt = %f' % timestep, |
|---|
| 324 | 'min_delay = %g' % min_delay] |
|---|
| 325 | else: |
|---|
| 326 | hoc_commands = [ |
|---|
| 327 | 'tmp = xopen("%s")' % os.path.join(pyNN_path[0],'hoc','standardCells.hoc'), |
|---|
| 328 | 'tmp = xopen("%s")' % os.path.join(pyNN_path[0],'hoc','odict.hoc'), |
|---|
| 329 | 'objref pc', |
|---|
| 330 | 'pc = new ParallelContext()', |
|---|
| 331 | 'dt = %f' % timestep, |
|---|
| 332 | 'tstop = 0', |
|---|
| 333 | 'min_delay = %g' % min_delay, |
|---|
| 334 | 'create dummy_section', |
|---|
| 335 | 'access dummy_section', |
|---|
| 336 | 'objref netconlist, nil', |
|---|
| 337 | 'netconlist = new List()', |
|---|
| 338 | 'strdef cmd', |
|---|
| 339 | 'strdef fmt', |
|---|
| 340 | 'objref nc', |
|---|
| 341 | 'objref rng', |
|---|
| 342 | 'objref cell'] |
|---|
| 343 | #---Experimental--- Optimize the simulation time ? / Reduce inter-processors exchanges ? |
|---|
| 344 | hoc_commands += [ |
|---|
| 345 | 'tmp = pc.spike_compress(1,0)'] |
|---|
| 346 | if extra_params.has_key('use_cvode') and extra_params['use_cvode'] == True: |
|---|
| 347 | hoc_commands += [ |
|---|
| 348 | 'objref cvode', |
|---|
| 349 | 'cvode = new CVode()', |
|---|
| 350 | 'cvode.active(1)'] |
|---|
| 351 | |
|---|
| 352 | hoc_execute(hoc_commands,"--- setup() ---") |
|---|
| 353 | |
|---|
| 354 | #nhost = HocToPy.get('pc.nhost()','int') |
|---|
| 355 | nhost = int(h.pc.nhost()) |
|---|
| 356 | if nhost < 2: |
|---|
| 357 | nhost = 1; myid = 0 |
|---|
| 358 | else: |
|---|
| 359 | #myid = HocToPy.get('pc.id()','int') |
|---|
| 360 | myid = int(h.pc.id()) |
|---|
| 361 | print "\nHost #%d of %d" % (myid+1, nhost) |
|---|
| 362 | |
|---|
| 363 | initialised = True |
|---|
| 364 | return int(myid) |
|---|
| 365 | |
|---|
| 366 | def end(compatible_output=True): |
|---|
| 367 | """Do any necessary cleaning up before exiting.""" |
|---|
| 368 | global logfile, myid #, vfilelist, spikefilelist |
|---|
| 369 | hoc_commands = [] |
|---|
| 370 | if len(vfilelist) > 0: |
|---|
| 371 | hoc_commands = ['objref fileobj', |
|---|
| 372 | 'fileobj = new File()'] |
|---|
| 373 | while len(vfilelist): |
|---|
| 374 | filename, cell_list = vfilelist.popitem() |
|---|
| 375 | #tstop = HocToPy.get('tstop','float') |
|---|
| 376 | tstop = h.tstop |
|---|
| 377 | header = "# dt = %g\\n# n = %d\\n" % (get_time_step(), int(tstop/get_time_step())) |
|---|
| 378 | header += "# first_id = %d\\n# last_id = %d\\n" % (cell_list[0], cell_list[-1]) |
|---|
| 379 | hoc_commands += ['tmp = fileobj.wopen("%s")' % filename, |
|---|
| 380 | 'tmp = fileobj.printf("%s")' % header] |
|---|
| 381 | for cell in cell_list: |
|---|
| 382 | hoc_commands += ['fmt = "%s\\t%d\\n"' % ("%.6g", cell), |
|---|
| 383 | 'tmp = cell%d.vtrace.printf(fileobj, fmt)' % cell] |
|---|
| 384 | hoc_commands += ['tmp = fileobj.close()'] |
|---|
| 385 | if len(spikefilelist) > 0: |
|---|
| 386 | hoc_commands += ['objref fileobj', |
|---|
| 387 | 'fileobj = new File()'] |
|---|
| 388 | header = "# dt = %g\\n"% get_time_step() |
|---|
| 389 | header += "# first_id = %d\\n #last_id = %d\\n" % (cell_list[0], cell_list[-1]) |
|---|
| 390 | while len(spikefilelist): |
|---|
| 391 | filename, cell_list = spikefilelist.popitem() |
|---|
| 392 | hoc_commands += ['tmp = fileobj.wopen("%s")' % filename, |
|---|
| 393 | 'tmp = fileobj.printf("%s")' % header] |
|---|
| 394 | for cell in cell_list: |
|---|
| 395 | hoc_commands += ['fmt = "%s\\t%d\\n"' % ("%.2f", cell), |
|---|
| 396 | #'tmp = fileobj.printf("# cell%d\\n")' % cell, |
|---|
| 397 | 'tmp = cell%d.spiketimes.where("<=", tstop).printf(fileobj, fmt)' % cell] |
|---|
| 398 | hoc_commands += ['tmp = fileobj.close()'] |
|---|
| 399 | hoc_commands += ['tmp = pc.runworker()', |
|---|
| 400 | 'tmp = pc.done()'] |
|---|
| 401 | hoc_execute(hoc_commands,"--- end() ---") |
|---|
| 402 | if quit_on_end: |
|---|
| 403 | hoc.execute('tmp = quit()') # sometimes needed, sometimes not wanted. Maybe a 'quit_on_end' kwarg for setup? |
|---|
| 404 | logging.info("Finishing up with NEURON.") |
|---|
| 405 | sys.exit(0) |
|---|
| 406 | |
|---|
| 407 | def run(simtime): |
|---|
| 408 | """Run the simulation for simtime ms.""" |
|---|
| 409 | global running |
|---|
| 410 | hoc_commands = [] |
|---|
| 411 | if not running: |
|---|
| 412 | running = True |
|---|
| 413 | hoc_commands += ['local_minimum_delay = pc.set_maxstep(10)', |
|---|
| 414 | 'tmp = finitialize()', |
|---|
| 415 | 'tstop = 0'] |
|---|
| 416 | hoc_execute(hoc_commands,"--- run() ---") |
|---|
| 417 | logging.debug("local_minimum_delay on host #%d = %g" % (myid, h.local_minimum_delay)) |
|---|
| 418 | if nhost > 1: |
|---|
| 419 | assert h.local_minimum_delay >= get_min_delay(),\ |
|---|
| 420 | "There are connections with delays (%g) shorter than the minimum delay (%g)" % (h.local_minimum_delay, get_min_delay()) |
|---|
| 421 | hoc_commands += ['tstop += %f' % simtime, |
|---|
| 422 | 'tmp = pc.psolve(tstop)'] |
|---|
| 423 | hoc_execute(hoc_commands,"--- run() ---") |
|---|
| 424 | return get_current_time() |
|---|
| 425 | |
|---|
| 426 | def get_current_time(): |
|---|
| 427 | """Return the current time in the simulation.""" |
|---|
| 428 | return h.t |
|---|
| 429 | |
|---|
| 430 | def get_time_step(): |
|---|
| 431 | return h.dt |
|---|
| 432 | common.get_time_step = get_time_step |
|---|
| 433 | |
|---|
| 434 | def get_min_delay(): |
|---|
| 435 | return h.min_delay |
|---|
| 436 | common.get_min_delay = get_min_delay |
|---|
| 437 | |
|---|
| 438 | def num_processes(): |
|---|
| 439 | return int(h.pc.nhost()) |
|---|
| 440 | |
|---|
| 441 | def rank(): |
|---|
| 442 | """Return the MPI rank.""" |
|---|
| 443 | myid = int(h.pc.id()) |
|---|
| 444 | return myid |
|---|
| 445 | |
|---|
| 446 | # ============================================================================== |
|---|
| 447 | # Low-level API for creating, connecting and recording from individual neurons |
|---|
| 448 | # ============================================================================== |
|---|
| 449 | |
|---|
| 450 | def create(cellclass, param_dict=None, n=1): |
|---|
| 451 | """ |
|---|
| 452 | Create n cells all of the same type. |
|---|
| 453 | If n > 1, return a list of cell ids/references. |
|---|
| 454 | If n==1, return just the single id. |
|---|
| 455 | """ |
|---|
| 456 | global gid, gidlist, nhost, myid |
|---|
| 457 | |
|---|
| 458 | assert n > 0, 'n must be a positive integer' |
|---|
| 459 | if isinstance(cellclass, type): |
|---|
| 460 | celltype = cellclass(param_dict) |
|---|
| 461 | hoc_name = celltype.hoc_name |
|---|
| 462 | hoc_commands, argstr = _hoc_arglist([celltype.parameters]) |
|---|
| 463 | elif isinstance(cellclass, str): |
|---|
| 464 | hoc_name = cellclass |
|---|
| 465 | hoc_commands, argstr = _hoc_arglist([param_dict]) |
|---|
| 466 | argstr = argstr.strip().strip(',') |
|---|
| 467 | |
|---|
| 468 | # round-robin partitioning |
|---|
| 469 | newgidlist = [i+myid for i in range(gid, gid+n, nhost) if i < gid+n-myid] |
|---|
| 470 | logging.debug("Creating cells %s on host %d" % (newgidlist, myid)) |
|---|
| 471 | for cell_id in newgidlist: |
|---|
| 472 | hoc_commands += ['tmp = pc.set_gid2node(%d,%d)' % (cell_id, myid), |
|---|
| 473 | 'objref cell%d' % cell_id, |
|---|
| 474 | 'cell%d = new %s(%s)' % (cell_id, hoc_name, argstr), |
|---|
| 475 | 'tmp = cell%d.connect2target(nil, nc)' % cell_id, |
|---|
| 476 | #'nc = new NetCon(cell%d.source, nil)' % cell_id, |
|---|
| 477 | 'tmp = pc.cell(%d, nc)' % cell_id] |
|---|
| 478 | hoc_execute(hoc_commands, "--- create() ---") |
|---|
| 479 | |
|---|
| 480 | gidlist.extend(newgidlist) |
|---|
| 481 | cell_list = [ID(i) for i in range(gid, gid+n)] |
|---|
| 482 | for id in cell_list: |
|---|
| 483 | id.cellclass = cellclass |
|---|
| 484 | gid = gid+n |
|---|
| 485 | if n == 1: |
|---|
| 486 | cell_list = cell_list[0] |
|---|
| 487 | return cell_list |
|---|
| 488 | |
|---|
| 489 | def connect(source, target, weight=None, delay=None, synapse_type=None, p=1, rng=None): |
|---|
| 490 | """Connect a source of spikes to a synaptic target. source and target can |
|---|
| 491 | both be individual cells or lists of cells, in which case all possible |
|---|
| 492 | connections are made with probability p, using either the random number |
|---|
| 493 | generator supplied, or the default rng otherwise. |
|---|
| 494 | Weights should be in nA or µS.""" |
|---|
| 495 | global ncid, gid, gidlist, myid |
|---|
| 496 | if type(source) != types.ListType: |
|---|
| 497 | source = [source] |
|---|
| 498 | if type(target) != types.ListType: |
|---|
| 499 | target = [target] |
|---|
| 500 | if weight is None: weight = 0.0 |
|---|
| 501 | if delay is None: delay = get_min_delay() |
|---|
| 502 | syn_objref = _translate_synapse_type(synapse_type, weight) |
|---|
| 503 | nc_start = ncid |
|---|
| 504 | hoc_commands = [] |
|---|
| 505 | logging.debug("connecting %s to %s on host %d" % (source, target, myid)) |
|---|
| 506 | for tgt in target: |
|---|
| 507 | if tgt > gid or tgt < 0 or not isinstance(tgt, int): |
|---|
| 508 | raise common.ConnectionError, "Postsynaptic cell id %s does not exist." % str(tgt) |
|---|
| 509 | if "cond" in tgt.cellclass.__name__: |
|---|
| 510 | weight = abs(weight) # weights must be positive for conductance-based synapses |
|---|
| 511 | else: |
|---|
| 512 | if tgt in gidlist: # only create connections to cells that exist on this machine |
|---|
| 513 | if p < 1: |
|---|
| 514 | if rng: # use the supplied RNG |
|---|
| 515 | rarr = self.rng.uniform(0,1, len(source)) |
|---|
| 516 | else: # use the default RNG |
|---|
| 517 | rarr = numpy.random.uniform(0,1, len(source)) |
|---|
| 518 | for j, src in enumerate(source): |
|---|
| 519 | if src > gid or src < 0 or not isinstance(src, int): |
|---|
| 520 | raise common.ConnectionError, "Presynaptic cell id %s does not exist." % str(src) |
|---|
| 521 | else: |
|---|
| 522 | if p >= 1.0 or rarr[j] < p: # might be more efficient to vectorise the latter comparison |
|---|
| 523 | hoc_commands += [#'nc = pc.gid_connect(%d, pc.gid2cell(%d).%s)' % (src, tgt, syn_objref), |
|---|
| 524 | 'nc = pc.gid_connect(%d, cell%d.%s)' % (src, tgt, syn_objref), |
|---|
| 525 | 'nc.delay = %g' % delay, |
|---|
| 526 | 'nc.weight = %g' % weight, |
|---|
| 527 | 'tmp = netconlist.append(nc)'] |
|---|
| 528 | ncid += 1 |
|---|
| 529 | else: |
|---|
| 530 | for j, src in enumerate(source): |
|---|
| 531 | if src > gid or src < 0 or not isinstance(src, int): |
|---|
| 532 | raise common.ConnectionError, "Presynaptic cell id %s does not exist." % str(src) |
|---|
| 533 | hoc_execute(hoc_commands, "--- connect(%s,%s) ---" % (str(source), str(target))) |
|---|
| 534 | return range(nc_start, ncid) |
|---|
| 535 | |
|---|
| 536 | def set(cells, param, val=None): |
|---|
| 537 | """Set one or more parameters of an individual cell or list of cells. |
|---|
| 538 | param can be a dict, in which case val should not be supplied, or a string |
|---|
| 539 | giving the parameter name, in which case val is the parameter value. |
|---|
| 540 | cellclass must be supplied for doing translation of parameter names.""" |
|---|
| 541 | if val: |
|---|
| 542 | param = {param:val} |
|---|
| 543 | if not hasattr(cells, '__len__'): |
|---|
| 544 | cells = [cells] |
|---|
| 545 | # see comment in Population.set() below about the efficiency of the |
|---|
| 546 | # following |
|---|
| 547 | cells = [cell for cell in cells if cell in gidlist] |
|---|
| 548 | for cell in cells: |
|---|
| 549 | cell.set_parameters(**param) |
|---|
| 550 | |
|---|
| 551 | def record(source, filename): |
|---|
| 552 | """Record spikes to a file. source can be an individual cell or a list of |
|---|
| 553 | cells.""" |
|---|
| 554 | # would actually like to be able to record to an array and choose later |
|---|
| 555 | # whether to write to a file. |
|---|
| 556 | global spikefilelist, gidlist |
|---|
| 557 | if type(source) != types.ListType: |
|---|
| 558 | source = [source] |
|---|
| 559 | hoc_commands = [] |
|---|
| 560 | if not spikefilelist.has_key(filename): |
|---|
| 561 | spikefilelist[filename] = [] |
|---|
| 562 | for src in source: |
|---|
| 563 | if src in gidlist: |
|---|
| 564 | hoc_commands += ['tmp = cell%d.record(1)' % src] |
|---|
| 565 | spikefilelist[filename] += [src] # writing to file is done in end() |
|---|
| 566 | hoc_execute(hoc_commands, "---record() ---") |
|---|
| 567 | |
|---|
| 568 | def record_v(source, filename): |
|---|
| 569 | """ |
|---|
| 570 | Record membrane potential to a file. source can be an individual cell or |
|---|
| 571 | a list of cells.""" |
|---|
| 572 | # would actually like to be able to record to an array and |
|---|
| 573 | # choose later whether to write to a file. |
|---|
| 574 | global vfilelist, gidlist |
|---|
| 575 | if type(source) != types.ListType: |
|---|
| 576 | source = [source] |
|---|
| 577 | hoc_commands = [] |
|---|
| 578 | if not vfilelist.has_key(filename): |
|---|
| 579 | vfilelist[filename] = [] |
|---|
| 580 | for src in source: |
|---|
| 581 | if src in gidlist: |
|---|
| 582 | if src.parent: |
|---|
| 583 | raise Exception("The record_v() function does not work with cells in a Population. Please use the record_v() method of the Population object.") |
|---|
| 584 | else: |
|---|
| 585 | hoc_commands += ['tmp = cell%d.record_v(1,%g)' % (src, get_time_step())] |
|---|
| 586 | vfilelist[filename] += [src] # writing to file is done in end() |
|---|
| 587 | hoc_execute(hoc_commands, "---record_v() ---") |
|---|
| 588 | |
|---|
| 589 | # ============================================================================== |
|---|
| 590 | # High-level API for creating, connecting and recording from populations of |
|---|
| 591 | # neurons. |
|---|
| 592 | # ============================================================================== |
|---|
| 593 | |
|---|
| 594 | class Population(common.Population): |
|---|
| 595 | """ |
|---|
| 596 | An array of neurons all of the same type. `Population' is used as a generic |
|---|
| 597 | term intended to include layers, columns, nuclei, etc., of cells. |
|---|
| 598 | All cells have both an address (a tuple) and an id (an integer). If p is a |
|---|
| 599 | Population object, the address and id can be inter-converted using : |
|---|
| 600 | id = p[address] |
|---|
| 601 | address = p.locate(id) |
|---|
| 602 | """ |
|---|
| 603 | nPop = 0 |
|---|
| 604 | |
|---|
| 605 | def __init__(self, dims, cellclass, cellparams=None, label=None): |
|---|
| 606 | """ |
|---|
| 607 | dims should be a tuple containing the population dimensions, or a single |
|---|
| 608 | integer, for a one-dimensional population. |
|---|
| 609 | e.g., (10,10) will create a two-dimensional population of size 10x10. |
|---|
| 610 | cellclass should either be a standardized cell class (a class inheriting |
|---|
| 611 | from common.StandardCellType) or a string giving the name of the |
|---|
| 612 | simulator-specific model that makes up the population. |
|---|
| 613 | cellparams should be a dict which is passed to the neuron model |
|---|
| 614 | constructor |
|---|
| 615 | label is an optional name for the population. |
|---|
| 616 | """ |
|---|
| 617 | global gid, myid, nhost, gidlist, fullgidlist |
|---|
| 618 | |
|---|
| 619 | common.Population.__init__(self, dims, cellclass, cellparams, label) |
|---|
| 620 | |
|---|
| 621 | # set the steps list, used by the __getitem__() method. |
|---|
| 622 | self.steps = [1]*self.ndim |
|---|
| 623 | for i in xrange(self.ndim-1): |
|---|
| 624 | for j in range(i+1, self.ndim): |
|---|
| 625 | self.steps[i] *= self.dim[j] |
|---|
| 626 | |
|---|
| 627 | if isinstance(cellclass, type): |
|---|
| 628 | self.celltype = cellclass(cellparams) |
|---|
| 629 | self.cellparams = self.celltype.parameters |
|---|
| 630 | hoc_name = self.celltype.hoc_name |
|---|
| 631 | elif isinstance(cellclass, str): # not a standard model |
|---|
| 632 | hoc_name = cellclass |
|---|
| 633 | |
|---|
| 634 | if self.cellparams is not None: |
|---|
| 635 | hoc_commands, argstr = _hoc_arglist([self.cellparams]) |
|---|
| 636 | argstr = argstr.strip().strip(',') |
|---|
| 637 | else: |
|---|
| 638 | hoc_commands = [] |
|---|
| 639 | argstr = '' |
|---|
| 640 | |
|---|
| 641 | if not self.label: |
|---|
| 642 | self.label = 'population%d' % Population.nPop |
|---|
| 643 | self.hoc_label = self.label.replace(" ","_") |
|---|
| 644 | |
|---|
| 645 | self.record_from = { 'spiketimes': Set(), 'vtrace': Set() } |
|---|
| 646 | |
|---|
| 647 | |
|---|
| 648 | # Now the gid and cellclass are stored as instance of the ID class, which will allow a syntax like |
|---|
| 649 | # p[i, j].set(param, val). But we have also to deal with positions : a population needs to know ALL the positions |
|---|
| 650 | # of its cells, and not only those of the cells located on a particular node (i.e in self.gidlist). So |
|---|
| 651 | # each population should store what we call a "fullgidlist" with the ID of all the cells in the populations |
|---|
| 652 | # (and therefore their positions) |
|---|
| 653 | self.fullgidlist = numpy.array([ID(i) for i in range(gid, gid+self.size) if i < gid+self.size], ID) |
|---|
| 654 | self.cell = self.fullgidlist |
|---|
| 655 | |
|---|
| 656 | # self.gidlist is now derived from self.fullgidlist since it contains only the cells of the population located on |
|---|
| 657 | # the node |
|---|
| 658 | self.gidlist = [self.fullgidlist[i+myid] for i in range(0, len(self.fullgidlist), nhost) if i < len(self.fullgidlist)-myid] |
|---|
| 659 | self.gid_start = gid |
|---|
| 660 | |
|---|
| 661 | # Write hoc commands |
|---|
| 662 | hoc_commands += ['objref %s' % self.hoc_label, |
|---|
| 663 | '%s = new List()' % self.hoc_label] |
|---|
| 664 | |
|---|
| 665 | for cell_id in self.gidlist: |
|---|
| 666 | hoc_commands += ['tmp = pc.set_gid2node(%d,%d)' % (cell_id, myid), |
|---|
| 667 | 'cell = new %s(%s)' % (hoc_name, argstr), |
|---|
| 668 | #'nc = new NetCon(cell.source, nil)', |
|---|
| 669 | 'tmp = cell.connect2target(nil, nc)', |
|---|
| 670 | 'tmp = pc.cell(%d, nc)' % cell_id, |
|---|
| 671 | 'tmp = %s.append(cell)' %(self.hoc_label)] |
|---|
| 672 | hoc_execute(hoc_commands, "--- Population[%s].__init__() ---" %self.label) |
|---|
| 673 | Population.nPop += 1 |
|---|
| 674 | gid = gid+self.size |
|---|
| 675 | |
|---|
| 676 | # We add the gidlist of the population to the global gidlist |
|---|
| 677 | gidlist += self.gidlist |
|---|
| 678 | |
|---|
| 679 | # By default, the positions of the cells are their coordinates, given by the locate() |
|---|
| 680 | # method. Note that each node needs to know all the positions of all the cells |
|---|
| 681 | # in the population |
|---|
| 682 | for cell_id in self.fullgidlist: |
|---|
| 683 | cell_id.parent = self |
|---|
| 684 | #cell_id.setPosition(self.locate(cell_id)) |
|---|
| 685 | |
|---|
| 686 | # On the opposite, each node has to know only the precise hocname of its cells, if we |
|---|
| 687 | # want to be able to use the low level set() function |
|---|
| 688 | for cell_id in self.gidlist: |
|---|
| 689 | cell_id.hocname = "%s.o(%d)" % (self.hoc_label, self.gidlist.index(cell_id)) |
|---|
| 690 | |
|---|
| 691 | def __getitem__(self, addr): |
|---|
| 692 | """Return a representation of the cell with coordinates given by addr, |
|---|
| 693 | suitable for being passed to other methods that require a cell id. |
|---|
| 694 | Note that __getitem__ is called when using [] access, e.g. |
|---|
| 695 | p = Population(...) |
|---|
| 696 | p[2,3] is equivalent to p.__getitem__((2,3)). |
|---|
| 697 | """ |
|---|
| 698 | |
|---|
| 699 | global gidlist |
|---|
| 700 | |
|---|
| 701 | # What we actually pass around are gids. |
|---|
| 702 | if isinstance(addr, int): |
|---|
| 703 | addr = (addr,) |
|---|
| 704 | if len(addr) != len(self.dim): |
|---|
| 705 | raise common.InvalidDimensionsError, "Population has %d dimensions. Address was %s" % (self.ndim, str(addr)) |
|---|
| 706 | index = 0 |
|---|
| 707 | for i, s in zip(addr, self.steps): |
|---|
| 708 | index += i*s |
|---|
| 709 | id = index + self.gid_start |
|---|
| 710 | assert addr == self.locate(id), 'index=%s addr=%s id=%s locate(id)=%s' % (index, addr, id, self.locate(id)) |
|---|
| 711 | # We return the gid as an ID object. Note that each instance of Populations |
|---|
| 712 | # distributed on several node can give the ID object, because fullgidlist is duplicated |
|---|
| 713 | # and common to all the node (not the case of global gidlist, or self.gidlist) |
|---|
| 714 | return self.fullgidlist[index] |
|---|
| 715 | |
|---|
| 716 | def __iter__(self): |
|---|
| 717 | return self.__gid_gen() |
|---|
| 718 | |
|---|
| 719 | def __address_gen(self): |
|---|
| 720 | """ |
|---|
| 721 | Generator to produce an iterator over all cells on this node, |
|---|
| 722 | returning addresses. |
|---|
| 723 | """ |
|---|
| 724 | for i in self.gidlist: |
|---|
| 725 | yield self.locate(i) |
|---|
| 726 | |
|---|
| 727 | def __gid_gen(self): |
|---|
| 728 | """ |
|---|
| 729 | Generator to produce an iterator over all cells on this node, |
|---|
| 730 | returning gids. |
|---|
| 731 | """ |
|---|
| 732 | for i in self.gidlist: |
|---|
| 733 | yield i |
|---|
| 734 | |
|---|
| 735 | def addresses(self): |
|---|
| 736 | return self.__address_gen() |
|---|
| 737 | |
|---|
| 738 | def ids(self): |
|---|
| 739 | return self.__gid_gen() |
|---|
| 740 | |
|---|
| 741 | def locate(self, id): |
|---|
| 742 | """Given an element id in a Population, return the coordinates. |
|---|
| 743 | e.g. for 4 6 , element 2 has coordinates (1,0) and value 7 |
|---|
| 744 | 7 9 |
|---|
| 745 | """ |
|---|
| 746 | # id should be a gid |
|---|
| 747 | assert isinstance(id, int), "id is %s, not int" % type(id) |
|---|
| 748 | id -= self.gid_start |
|---|
| 749 | if self.ndim == 3: |
|---|
| 750 | rows = self.dim[1]; cols = self.dim[2] |
|---|
| 751 | i = id/(rows*cols); remainder = id%(rows*cols) |
|---|
| 752 | j = remainder/cols; k = remainder%cols |
|---|
| 753 | coords = (i, j, k) |
|---|
| 754 | elif self.ndim == 2: |
|---|
| 755 | cols = self.dim[1] |
|---|
| 756 | i = id/cols; j = id%cols |
|---|
| 757 | coords = (i, j) |
|---|
| 758 | elif self.ndim == 1: |
|---|
| 759 | coords = (id,) |
|---|
| 760 | else: |
|---|
| 761 | raise common.InvalidDimensionsError |
|---|
| 762 | return coords |
|---|
| 763 | |
|---|
| 764 | def index(self, n): |
|---|
| 765 | """Return the nth cell in the population (Indexing starts at 0).""" |
|---|
| 766 | if hasattr(n, '__len__'): |
|---|
| 767 | n = numpy.array(n) |
|---|
| 768 | return self.fullgidlist[n] |
|---|
| 769 | |
|---|
| 770 | def get(self, parameter_name, as_array=False): |
|---|
| 771 | """ |
|---|
| 772 | Get the values of a parameter for every cell in the population. |
|---|
| 773 | """ |
|---|
| 774 | # Arguably we should reshape to the shape of the Population |
|---|
| 775 | values = [getattr(cell, parameter_name) for cell in self.gidlist] |
|---|
| 776 | if as_array: |
|---|
| 777 | values = numpy.array(values) |
|---|
| 778 | return values |
|---|
| 779 | |
|---|
| 780 | def set(self, param, val=None): |
|---|
| 781 | """ |
|---|
| 782 | Set one or more parameters for every cell in the population. param |
|---|
| 783 | can be a dict, in which case val should not be supplied, or a string |
|---|
| 784 | giving the parameter name, in which case val is the parameter value. |
|---|
| 785 | val can be a numeric value, or list of such (e.g. for setting spike times). |
|---|
| 786 | e.g. p.set("tau_m",20.0). |
|---|
| 787 | p.set({'tau_m':20,'v_rest':-65}) |
|---|
| 788 | """ |
|---|
| 789 | if isinstance(param, str): |
|---|
| 790 | if isinstance(val, (str, float, int)): |
|---|
| 791 | param_dict = {param: float(val)} |
|---|
| 792 | elif isinstance(val, (list, numpy.ndarray)): |
|---|
| 793 | param_dict = {param: val} |
|---|
| 794 | else: |
|---|
| 795 | raise common.InvalidParameterValueError |
|---|
| 796 | elif isinstance(param, dict): |
|---|
| 797 | param_dict = param |
|---|
| 798 | else: |
|---|
| 799 | raise common.InvalidParameterValueError |
|---|
| 800 | for cell in self.gidlist: |
|---|
| 801 | cell.set_parameters(**param_dict) |
|---|
| 802 | |
|---|
| 803 | def tset(self, parametername, value_array): |
|---|
| 804 | """ |
|---|
| 805 | 'Topographic' set. Set the value of parametername to the values in |
|---|
| 806 | value_array, which must have the same dimensions as the Population. |
|---|
| 807 | """ |
|---|
| 808 | # Convert everything to 1D arrays |
|---|
| 809 | if self.dim == value_array.shape: # the values are numbers or non-array objects |
|---|
| 810 | values = value_array.flatten() |
|---|
| 811 | elif len(value_array.shape) == len(self.dim)+1: # the values are themselves 1D arrays |
|---|
| 812 | values = numpy.reshape(value_array, (self.dim, value_array.size/self.cell.size)) |
|---|
| 813 | else: |
|---|
| 814 | raise common.InvalidDimensionsError, "Population: %s, value_array: %s" % (str(self.dim), |
|---|
| 815 | str(value_array.shape)) |
|---|
| 816 | values = values.take(numpy.array(self.gidlist)-self.gid_start) # take just the values for cells on this machine |
|---|
| 817 | assert len(values) == len(self.gidlist) |
|---|
| 818 | |
|---|
| 819 | # Set the values for each cell |
|---|
| 820 | for cell, val in zip(self.gidlist, values): |
|---|
| 821 | if not isinstance(val, str) and hasattr(val, "__len__"): |
|---|
| 822 | # tuples, arrays are all converted to lists, since this is |
|---|
| 823 | # what SpikeSourceArray expects. This is not very robust |
|---|
| 824 | # though - we might want to add things that do accept arrays. |
|---|
| 825 | val = list(val) |
|---|
| 826 | if cell in self.gidlist: # this is not necessary, surely? |
|---|
| 827 | setattr(cell, parametername, val) |
|---|
| 828 | |
|---|
| 829 | def rset(self, parametername, rand_distr): |
|---|
| 830 | """ |
|---|
| 831 | 'Random' set. Set the value of parametername to a value taken from |
|---|
| 832 | rand_distr, which should be a RandomDistribution object. |
|---|
| 833 | """ |
|---|
| 834 | if isinstance(rand_distr.rng, NativeRNG): |
|---|
| 835 | if isinstance(self.celltype, common.StandardCellType): |
|---|
| 836 | parametername = self.celltype.__class__.translations[parametername]['translated_name'] |
|---|
| 837 | if parametername in self.celltype.__class__.computed_parameters(): |
|---|
| 838 | raise Exception("rset() with NativeRNG not (yet) supported for computed parameters.") |
|---|
| 839 | paramfmt = "%g,"*len(rand_distr.parameters); paramfmt = paramfmt.strip(',') |
|---|
| 840 | distr_params = paramfmt % tuple(rand_distr.parameters) |
|---|
| 841 | hoc_commands = ['rng = new Random(%d)' % 0 or distribution.rng.seed, |
|---|
| 842 | 'tmp = rng.%s(%s)' % (rand_distr.name, distr_params)] |
|---|
| 843 | # We do the loop in hoc, to speed up the code |
|---|
| 844 | loop = "for tmp = 0, %d" %(len(self.gidlist)-1) |
|---|
| 845 | cmd = '%s.object(tmp).%s = rng.repick()' % (self.hoc_label, parametername) |
|---|
| 846 | hoc_commands += ['cmd="%s { %s success = %s.object(tmp).param_update()}"' %(loop, cmd, self.hoc_label), |
|---|
| 847 | 'success = execute1(cmd)'] |
|---|
| 848 | hoc_execute(hoc_commands, "--- Population[%s].__rset()__ ---" %self.label) |
|---|
| 849 | else: |
|---|
| 850 | rarr = rand_distr.next(n=self.size) |
|---|
| 851 | hoc_comment("--- Population[%s].__rset()__ --- " %self.label) |
|---|
| 852 | for cell,val in zip(self.gidlist, rarr): |
|---|
| 853 | setattr(cell, parametername, val) |
|---|
| 854 | |
|---|
| 855 | def _call(self, methodname, arguments): |
|---|
| 856 | """ |
|---|
| 857 | Calls the method methodname(arguments) for every cell in the population. |
|---|
| 858 | e.g. p.call("set_background","0.1") if the cell class has a method |
|---|
| 859 | set_background(). |
|---|
| 860 | """ |
|---|
| 861 | raise Exception("Method not yet implemented") |
|---|
| 862 | ## Not sure this belongs in the API, because cell classes only have |
|---|
| 863 | ## parameters/attributes, not methods. |
|---|
| 864 | |
|---|
| 865 | def _tcall(self, methodname, objarr): |
|---|
| 866 | """ |
|---|
| 867 | `Topographic' call. Calls the method methodname() for every cell in the |
|---|
| 868 | population. The argument to the method depends on the coordinates of the |
|---|
| 869 | cell. objarr is an array with the same dimensions as the Population. |
|---|
| 870 | e.g. p.tcall("memb_init", vinitArray) calls |
|---|
| 871 | p.cell[i][j].memb_init(vInitArray[i][j]) for all i, j. |
|---|
| 872 | """ |
|---|
| 873 | raise Exception("Method not yet implemented") |
|---|
| 874 | |
|---|
| 875 | def __record(self, record_what, record_from=None, rng=None): |
|---|
| 876 | """ |
|---|
| 877 | Private method called by record() and record_v(). |
|---|
| 878 | """ |
|---|
| 879 | global myid |
|---|
| 880 | hoc_commands = [] |
|---|
| 881 | fixed_list=False |
|---|
| 882 | |
|---|
| 883 | if isinstance(record_from, list): #record from the fixed list specified by user |
|---|
| 884 | fixed_list=True |
|---|
| 885 | elif record_from is None: # record from all cells: |
|---|
| 886 | record_from = self.gidlist |
|---|
| 887 | elif isinstance(record_from, int): # record from a number of cells, selected at random |
|---|
| 888 | # Each node will record N/nhost cells... |
|---|
| 889 | nrec = int(record_from/nhost) |
|---|
| 890 | if rng: |
|---|
| 891 | record_from = rng.permutation(self.gidlist) |
|---|
| 892 | else: |
|---|
| 893 | record_from = numpy.random.permutation(self.gidlist) |
|---|
| 894 | # Taken as random in self.gidlist |
|---|
| 895 | record_from = record_from[0:nrec] |
|---|
| 896 | record_from = numpy.array(record_from) # is this line necessary? |
|---|
| 897 | else: |
|---|
| 898 | raise Exception("record_from must be either a list of cells or the number of cells to record from") |
|---|
| 899 | # record_from is now a list or numpy array |
|---|
| 900 | |
|---|
| 901 | suffix = ''*(record_what=='spiketimes') + '_v'*(record_what=='vtrace') |
|---|
| 902 | for id in record_from: |
|---|
| 903 | if id in self.gidlist: |
|---|
| 904 | hoc_commands += ['tmp = %s.object(%d).record%s(1)' % (self.hoc_label, self.gidlist.index(id), suffix)] |
|---|
| 905 | |
|---|
| 906 | # note that self.record_from is not the same on all nodes, like self.gidlist, for example. |
|---|
| 907 | self.record_from[record_what].update(Set(record_from)) |
|---|
| 908 | hoc_commands += ['objref record_from'] |
|---|
| 909 | hoc_execute(hoc_commands) |
|---|
| 910 | |
|---|
| 911 | # Then we have to send the lists of local recorded objects to the master node, |
|---|
| 912 | # but only if the list has not been specified by the user. |
|---|
| 913 | if fixed_list is False: |
|---|
| 914 | if myid != 0: # on slave nodes |
|---|
| 915 | hoc_commands = ['record_from = new Vector()'] |
|---|
| 916 | for id in self.record_from[record_what]: |
|---|
| 917 | if id in self.gidlist: |
|---|
| 918 | hoc_commands += ['record_from = record_from.append(%d)' %id] |
|---|
| 919 | hoc_commands += ['tmp = pc.post("%s.record_from[%s].node[%d]", record_from)' %(self.hoc_label, record_what, myid)] |
|---|
| 920 | hoc_execute(hoc_commands, " (Posting recorded cells)") |
|---|
| 921 | else: # on the master node |
|---|
| 922 | for id in range (1, nhost): |
|---|
| 923 | hoc_commands = ['record_from = new Vector()'] |
|---|
| 924 | hoc_commands += ['tmp = pc.take("%s.record_from[%s].node[%d]", record_from)' %(self.hoc_label, record_what, id)] |
|---|
| 925 | hoc_execute(hoc_commands) |
|---|
| 926 | for j in xrange(int(h.record_from.size())): |
|---|
| 927 | self.record_from[record_what].add(int(h.record_from.x[j])) |
|---|
| 928 | |
|---|
| 929 | def record(self, record_from=None, rng=None): |
|---|
| 930 | """ |
|---|
| 931 | If record_from is not given, record spikes from all cells in the Population. |
|---|
| 932 | record_from can be an integer - the number of cells to record from, chosen |
|---|
| 933 | at random (in this case a random number generator can also be supplied) |
|---|
| 934 | - or a list containing the ids of the cells to record. |
|---|
| 935 | """ |
|---|
| 936 | hoc_comment("--- Population[%s].__record()__ ---" %self.label) |
|---|
| 937 | self.__record('spiketimes', record_from, rng) |
|---|
| 938 | |
|---|
| 939 | def record_v(self, record_from=None, rng=None): |
|---|
| 940 | """ |
|---|
| 941 | If record_from is not given, record the membrane potential for all cells in |
|---|
| 942 | the Population. |
|---|
| 943 | record_from can be an integer - the number of cells to record from, chosen |
|---|
| 944 | at random (in this case a random number generator can also be supplied) |
|---|
| 945 | - or a list containing the ids of the cells to record. |
|---|
| 946 | """ |
|---|
| 947 | hoc_comment("--- Population[%s].__record_v()__ ---" %self.label) |
|---|
| 948 | self.__record('vtrace', record_from, rng) |
|---|
| 949 | |
|---|
| 950 | def __print(self, print_what, filename, num_format, gather, header=None): |
|---|
| 951 | """Private method used by printSpikes() and print_v().""" |
|---|
| 952 | global myid |
|---|
| 953 | vector_operation = '' |
|---|
| 954 | if print_what == 'spiketimes': |
|---|
| 955 | vector_operation = '.where("<=", tstop)' |
|---|
| 956 | if gather and myid != 0: # on slave nodes, post data |
|---|
| 957 | hoc_commands = [] |
|---|
| 958 | for id in self.record_from[print_what]: |
|---|
| 959 | if id in self.gidlist: |
|---|
| 960 | hoc_commands += ['tmp = pc.post("%s[%d].%s",%s.object(%d).%s%s)' % (self.hoc_label, id, |
|---|
| 961 | print_what, |
|---|
| 962 | self.hoc_label, |
|---|
| 963 | self.gidlist.index(id), |
|---|
| 964 | print_what, |
|---|
| 965 | vector_operation)] |
|---|
| 966 | hoc_execute(hoc_commands,"--- Population[%s].__print()__ --- [Post objects to master]" %self.label) |
|---|
| 967 | |
|---|
| 968 | if not gather: |
|---|
| 969 | filename += ".%d" % myid |
|---|
| 970 | |
|---|
| 971 | if myid==0 or not gather: |
|---|
| 972 | hoc_commands = ['objref fileobj', |
|---|
| 973 | 'fileobj = new File()', |
|---|
| 974 | 'tmp = fileobj.wopen("%s")' % filename] |
|---|
| 975 | if header: |
|---|
| 976 | hoc_commands += ['tmp = fileobj.printf("%s\\n")' % header] |
|---|
| 977 | if gather: |
|---|
| 978 | hoc_commands += ['objref gatheredvec'] |
|---|
| 979 | padding = self.fullgidlist[0] |
|---|
| 980 | for id in self.record_from[print_what]: |
|---|
| 981 | addr = self.locate(id) |
|---|
| 982 | #hoc_commands += ['fmt = "%s\\t%s\\n"' % (num_format, "\\t".join([str(j) for j in addr]))] |
|---|
| 983 | hoc_commands += ['fmt = "%s\\t%d\\n"' % (num_format, id-padding)] |
|---|
| 984 | if id in self.gidlist: |
|---|
| 985 | hoc_commands += ['tmp = %s.object(%d).%s%s.printf(fileobj, fmt)' % (self.hoc_label, |
|---|
| 986 | self.gidlist.index(id), |
|---|
| 987 | print_what, |
|---|
| 988 | vector_operation)] |
|---|
| 989 | elif gather: |
|---|
| 990 | hoc_commands += ['gatheredvec = new Vector()'] |
|---|
| 991 | hoc_commands += ['tmp = pc.take("%s[%d].%s", gatheredvec)' % (self.hoc_label, id, print_what), |
|---|
| 992 | 'tmp = gatheredvec.printf(fileobj, fmt)'] |
|---|
| 993 | hoc_commands += ['tmp = fileobj.close()'] |
|---|
| 994 | hoc_execute(hoc_commands,"--- Population[%s].__print()__ ---" %self.label) |
|---|
| 995 | |
|---|
| 996 | def printSpikes(self, filename, gather=True, compatible_output=True): |
|---|
| 997 | """ |
|---|
| 998 | Write spike times to file. |
|---|
| 999 | |
|---|
| 1000 | If compatible_output is True, the format is "spiketime cell_id", |
|---|
| 1001 | where cell_id is the index of the cell counting along rows and down |
|---|
| 1002 | columns (and the extension of that for 3-D). |
|---|
| 1003 | This allows easy plotting of a `raster' plot of spiketimes, with one |
|---|
| 1004 | line for each cell. |
|---|
| 1005 | The timestep, first id, last id, and number of data points per cell are |
|---|
| 1006 | written in a header, indicated by a '#' at the beginning of the line. |
|---|
| 1007 | |
|---|
| 1008 | If compatible_output is False, the raw format produced by the simulator |
|---|
| 1009 | is used. This may be faster, since it avoids any post-processing of the |
|---|
| 1010 | spike files. |
|---|
| 1011 | |
|---|
| 1012 | For parallel simulators, if gather is True, all data will be gathered |
|---|
| 1013 | to the master node and a single output file created there. Otherwise, a |
|---|
| 1014 | file will be written on each node, containing only the cells simulated |
|---|
| 1015 | on that node. |
|---|
| 1016 | """ |
|---|
| 1017 | hoc_comment("--- Population[%s].__printSpikes()__ ---" %self.label) |
|---|
| 1018 | header = "# %d" %self.dim[0] |
|---|
| 1019 | for dimension in list(self.dim)[1:]: |
|---|
| 1020 | header = "%s\t%d" %(header, dimension) |
|---|
| 1021 | header += "\\n# first_id = %d\\n# last_id = %d\\n" % (self.fullgidlist[0], self.fullgidlist[-1]) |
|---|
| 1022 | self.__print('spiketimes', filename,"%.2f", gather, header) |
|---|
| 1023 | |
|---|
| 1024 | def print_v(self, filename, gather=True, compatible_output=True): |
|---|
| 1025 | """ |
|---|
| 1026 | Write membrane potential traces to file. |
|---|
| 1027 | |
|---|
| 1028 | If compatible_output is True, the format is "v cell_id", |
|---|
| 1029 | where cell_id is the index of the cell counting along rows and down |
|---|
| 1030 | columns (and the extension of that for 3-D). |
|---|
| 1031 | The timestep, first id, last id, and number of data points per cell are |
|---|
| 1032 | written in a header, indicated by a '#' at the beginning of the line. |
|---|
| 1033 | |
|---|
| 1034 | If compatible_output is False, the raw format produced by the simulator |
|---|
| 1035 | is used. This may be faster, since it avoids any post-processing of the |
|---|
| 1036 | voltage files. |
|---|
| 1037 | |
|---|
| 1038 | For parallel simulators, if gather is True, all data will be gathered |
|---|
| 1039 | to the master node and a single output file created there. Otherwise, a |
|---|
| 1040 | file will be written on each node, containing only the cells simulated |
|---|
| 1041 | on that node. |
|---|
| 1042 | """ |
|---|
| 1043 | #tstop = HocToPy.get('tstop','float') |
|---|
| 1044 | tstop = h.tstop |
|---|
| 1045 | header = "# dt = %f\\n# n = %d\\n" % (get_time_step(), int(tstop/get_time_step())) |
|---|
| 1046 | header = "%s# %d" %(header, self.dim[0]) |
|---|
| 1047 | for dimension in list(self.dim)[1:]: |
|---|
| 1048 | header = "%s\t%d" %(header, dimension) |
|---|
| 1049 | header += "\\n# first_id = %d\\n# last_id = %d\\n" % (self.fullgidlist[0], self.fullgidlist[-1]) |
|---|
| 1050 | hoc_comment("--- Population[%s].__print_v()__ ---" %self.label) |
|---|
| 1051 | self.__print('vtrace', filename,"%.4g", gather, header) |
|---|
| 1052 | |
|---|
| 1053 | def getSpikes(self, gather=True): |
|---|
| 1054 | """ |
|---|
| 1055 | Return a 2-column numpy array containing cell ids and spike times for |
|---|
| 1056 | recorded cells. |
|---|
| 1057 | |
|---|
| 1058 | Useful for small populations, for example for single neuron Monte-Carlo. |
|---|
| 1059 | """ |
|---|
| 1060 | # This is a bit of a hack implemetation |
|---|
| 1061 | tmpfile = "neuron_tmpfile" # should really use tempfile module |
|---|
| 1062 | self.__print('spiketimes', tmpfile, "%.2f", gather) |
|---|
| 1063 | if not gather: |
|---|
| 1064 | tmpfile += '%d' % myid |
|---|
| 1065 | if myid==0 or not gather: |
|---|
| 1066 | f = open(tmpfile, 'r') |
|---|
| 1067 | lines = [line for line in f.read().split('\n') if line] # remove blank lines |
|---|
| 1068 | line2spike = lambda s: (int(s[1]), float(s[0])) |
|---|
| 1069 | spikes = numpy.array([line2spike(line.split()) for line in lines]) |
|---|
| 1070 | f.close() |
|---|
| 1071 | #os.remove(tmpfile) |
|---|
| 1072 | return spikes |
|---|
| 1073 | else: |
|---|
| 1074 | return numpy.empty((0,2)) |
|---|
| 1075 | |
|---|
| 1076 | def meanSpikeCount(self, gather=True): |
|---|
| 1077 | """ |
|---|
| 1078 | Returns the mean number of spikes per neuron. |
|---|
| 1079 | """ |
|---|
| 1080 | global myid |
|---|
| 1081 | # If gathering, each node posts the number of spikes and |
|---|
| 1082 | # the number of cells to the master node (myid == 0) |
|---|
| 1083 | if gather and myid != 0: |
|---|
| 1084 | hoc_commands = [] |
|---|
| 1085 | nspikes = 0;ncells = 0 |
|---|
| 1086 | for id in self.record_from['spiketimes']: |
|---|
| 1087 | if id in self.gidlist: |
|---|
| 1088 | #nspikes += HocToPy.get('%s.object(%d).spiketimes.size()' %(self.hoc_label, self.gidlist.index(id)),'int') |
|---|
| 1089 | nspikes += getattr(h, self.hoc_label).object(self.gidlist.index(id)).spiketimes.size() |
|---|
| 1090 | ncells += 1 |
|---|
| 1091 | hoc_commands += ['tmp = pc.post("%s.node[%d].nspikes",%d)' % (self.hoc_label, myid, nspikes)] |
|---|
| 1092 | hoc_commands += ['tmp = pc.post("%s.node[%d].ncells",%d)' % (self.hoc_label, myid, ncells)] |
|---|
| 1093 | hoc_execute(hoc_commands,"--- Population[%s].__meanSpikeCount()__ --- [Post spike count to master]" %self.label) |
|---|
| 1094 | return 0 |
|---|
| 1095 | |
|---|
| 1096 | if myid==0 or not gather: |
|---|
| 1097 | nspikes = 0.0; ncells = 0.0 |
|---|
| 1098 | hoc_execute(["nspikes = 0", "ncells = 0"]) |
|---|
| 1099 | for id in self.record_from['spiketimes']: |
|---|
| 1100 | if id in self.gidlist: |
|---|
| 1101 | nspikes += getattr(h, self.hoc_label).object(self.gidlist.index(id)).spiketimes.size() |
|---|
| 1102 | ncells += 1 |
|---|
| 1103 | if gather: |
|---|
| 1104 | for id in range(1, nhost): |
|---|
| 1105 | hoc_execute(['tmp = pc.take("%s.node[%d].nspikes",&nspikes)' % (self.hoc_label, id)]) |
|---|
| 1106 | #nspikes += HocToPy.get('nspikes','int') |
|---|
| 1107 | nspikes += int(h.nspikes) |
|---|
| 1108 | hoc_execute(['tmp = pc.take("%s.node[%d].ncells",&ncells)' % (self.hoc_label, id)]) |
|---|
| 1109 | #ncells += HocToPy.get('ncells','int') |
|---|
| 1110 | ncells += int(h.ncells) |
|---|
| 1111 | return float(nspikes)/ncells |
|---|
| 1112 | |
|---|
| 1113 | def randomInit(self, rand_distr): |
|---|
| 1114 | """ |
|---|
| 1115 | Set initial membrane potentials for all the cells in the population to |
|---|
| 1116 | random values. |
|---|
| 1117 | """ |
|---|
| 1118 | hoc_comment("--- Population[%s].__randomInit()__ ---" %self.label) |
|---|
| 1119 | self.rset("v_init", rand_distr) |
|---|
| 1120 | |
|---|
| 1121 | def describe(self): |
|---|
| 1122 | """ |
|---|
| 1123 | Return a human readable description of the population" |
|---|
| 1124 | """ |
|---|
| 1125 | print "\n------- Population description -------" |
|---|
| 1126 | print "Population called %s is made of %d cells [%d being local]" %(self.label, len(self.fullgidlist), len(self.gidlist)) |
|---|
| 1127 | print "-> Cells are aranged on a %dD grid of size %s" %(len(self.dim), self.dim) |
|---|
| 1128 | print "-> Celltype is %s" %self.celltype |
|---|
| 1129 | print "-> Cell Parameters used for cell[0] (during initialization and now) are: " |
|---|
| 1130 | for key, value in self.cellparams.items(): |
|---|
| 1131 | print "\t|", key, "\t: ", "init->", value, "\t now->", getattr(self.cell[0],key) |
|---|
| 1132 | print "--- End of Population description ----" |
|---|
| 1133 | |
|---|
| 1134 | |
|---|
| 1135 | class Projection(common.Projection): |
|---|
| 1136 | """ |
|---|
| 1137 | A container for all the connections of a given type (same synapse type and |
|---|
| 1138 | plasticity mechanisms) between two populations, together with methods to set |
|---|
| 1139 | parameters of those connections, including of plasticity mechanisms. |
|---|
| 1140 | """ |
|---|
| 1141 | |
|---|
| 1142 | nProj = 0 |
|---|
| 1143 | |
|---|
| 1144 | def __init__(self, presynaptic_population, postsynaptic_population, method='allToAll', |
|---|
| 1145 | method_parameters=None, source=None, target=None, |
|---|
| 1146 | synapse_dynamics=None, label=None, rng=None): |
|---|
| 1147 | """ |
|---|
| 1148 | presynaptic_population and postsynaptic_population - Population objects. |
|---|
| 1149 | |
|---|
| 1150 | source - string specifying which attribute of the presynaptic cell |
|---|
| 1151 | signals action potentials |
|---|
| 1152 | |
|---|
| 1153 | target - string specifying which synapse on the postsynaptic cell to |
|---|
| 1154 | connect to |
|---|
| 1155 | |
|---|
| 1156 | If source and/or target are not given, default values are used. |
|---|
| 1157 | |
|---|
| 1158 | method - string indicating which algorithm to use in determining |
|---|
| 1159 | connections. |
|---|
| 1160 | Allowed methods are 'allToAll', 'oneToOne', 'fixedProbability', |
|---|
| 1161 | 'distanceDependentProbability', 'fixedNumberPre', 'fixedNumberPost', |
|---|
| 1162 | 'fromFile', 'fromList'. |
|---|
| 1163 | |
|---|
| 1164 | method_parameters - dict containing parameters needed by the connection |
|---|
| 1165 | method, although we should allow this to be a number or string if there |
|---|
| 1166 | is only one parameter. |
|---|
| 1167 | |
|---|
| 1168 | synapse_dynamics - a `SynapseDynamics` object specifying which |
|---|
| 1169 | synaptic plasticity mechanisms to use. |
|---|
| 1170 | |
|---|
| 1171 | rng - since most of the connection methods need uniform random numbers, |
|---|
| 1172 | it is probably more convenient to specify a RNG object here rather |
|---|
| 1173 | than within method_parameters, particularly since some methods also use |
|---|
| 1174 | random numbers to give variability in the number of connections per cell. |
|---|
| 1175 | """ |
|---|
| 1176 | common.Projection.__init__(self, presynaptic_population, postsynaptic_population, method, |
|---|
| 1177 | method_parameters, source, target, synapse_dynamics, label, rng) |
|---|
| 1178 | self.connections = [] |
|---|
| 1179 | if not label: |
|---|
| 1180 | self.label = 'projection%d' % Projection.nProj |
|---|
| 1181 | self.hoc_label = self.label.replace(" ","_") |
|---|
| 1182 | if not rng: |
|---|
| 1183 | self.rng = NumpyRNG() |
|---|
| 1184 | hoc_commands = ['objref %s' % self.hoc_label, |
|---|
| 1185 | '%s = new List()' % self.hoc_label] |
|---|
| 1186 | self.synapse_type = target |
|---|
| 1187 | |
|---|
| 1188 | ## Deal with short-term synaptic plasticity |
|---|
| 1189 | if self.short_term_plasticity_mechanism: |
|---|
| 1190 | U = self._short_term_plasticity_parameters['U'] |
|---|
| 1191 | tau_rec = self._short_term_plasticity_parameters['tau_rec'] |
|---|
| 1192 | tau_facil = self._short_term_plasticity_parameters['tau_facil'] |
|---|
| 1193 | u0 = self._short_term_plasticity_parameters['u0'] |
|---|
| 1194 | syn_code = {None: 1, |
|---|
| 1195 | 'excitatory': 1, |
|---|
| 1196 | 'inhibitory' :2} |
|---|
| 1197 | for cell in self.post: |
|---|
| 1198 | hoc_cell = cell._hoc_cell() |
|---|
| 1199 | hoc_cell.use_Tsodyks_Markram_synapses(syn_code[self.synapse_type], U, tau_rec, tau_facil, u0) |
|---|
| 1200 | |
|---|
| 1201 | self._syn_objref = _translate_synapse_type(self.synapse_type, extra_mechanism=self.short_term_plasticity_mechanism) |
|---|
| 1202 | |
|---|
| 1203 | ## Create connections |
|---|
| 1204 | if isinstance(method, str): |
|---|
| 1205 | connection_method = getattr(self,'_%s' % method) |
|---|
| 1206 | hoc_commands += connection_method(method_parameters) |
|---|
| 1207 | elif isinstance(method, common.Connector): |
|---|
| 1208 | hoc_commands += method.connect(self) |
|---|
| 1209 | # delays should already be set to min_delay |
|---|
| 1210 | hoc_execute(hoc_commands, "--- Projection[%s].__init__() ---" %self.label) |
|---|
| 1211 | |
|---|
| 1212 | # By defaut, we set all the delays to min_delay, except if |
|---|
| 1213 | # the Projection data have been loaded from a file or a list. |
|---|
| 1214 | # This should already have been done if using a Connector object |
|---|
| 1215 | if isinstance(method, str) and (method != 'fromList') and (method != 'fromFile'): |
|---|
| 1216 | self.setDelays(get_min_delay()) |
|---|
| 1217 | |
|---|
| 1218 | ## Deal with long-term synaptic plasticity |
|---|
| 1219 | if self.long_term_plasticity_mechanism: |
|---|
| 1220 | self._setupSTDP(self.long_term_plasticity_mechanism, self._stdp_parameters) |
|---|
| 1221 | |
|---|
| 1222 | Projection.nProj += 1 |
|---|
| 1223 | |
|---|
| 1224 | def __len__(self): |
|---|
| 1225 | """Return the total number of connections.""" |
|---|
| 1226 | return len(self.connections) |
|---|
| 1227 | |
|---|
| 1228 | # --- Connection methods --------------------------------------------------- |
|---|
| 1229 | |
|---|
| 1230 | def __connect(self, src, tgt): |
|---|
| 1231 | """ |
|---|
| 1232 | Write hoc commands to connect a single pair of neurons. |
|---|
| 1233 | """ |
|---|
| 1234 | cmdlist = ['nc = pc.gid_connect(%d,%s.object(%d).%s)' % (src, |
|---|
| 1235 | self.post.hoc_label, |
|---|
| 1236 | self.post.gidlist.index(tgt), |
|---|
| 1237 | self._syn_objref), |
|---|
| 1238 | 'tmp = %s.append(nc)' % self.hoc_label] |
|---|
| 1239 | self.connections.append((src, tgt)) |
|---|
| 1240 | return cmdlist |
|---|
| 1241 | |
|---|
| 1242 | def _allToAll(self, parameters=None): |
|---|
| 1243 | """ |
|---|
| 1244 | Connect all cells in the presynaptic population to all cells in the |
|---|
| 1245 | postsynaptic population. |
|---|
| 1246 | """ |
|---|
| 1247 | allow_self_connections = True # when pre- and post- are the same population, |
|---|
| 1248 | # is a cell allowed to connect to itself? |
|---|
| 1249 | if parameters and parameters.has_key('allow_self_connections'): |
|---|
| 1250 | allow_self_connections = parameters['allow_self_connections'] |
|---|
| 1251 | c = AllToAllConnector(allow_self_connections) |
|---|
| 1252 | return c.connect(self) |
|---|
| 1253 | |
|---|
| 1254 | def _oneToOne(self, parameters=None): |
|---|
| 1255 | """ |
|---|
| 1256 | Where the pre- and postsynaptic populations have the same size, connect |
|---|
| 1257 | cell i in the presynaptic population to cell i in the postsynaptic |
|---|
| 1258 | population for all i. |
|---|
| 1259 | In fact, despite the name, this should probably be generalised to the |
|---|
| 1260 | case where the pre and post populations have different dimensions, e.g., |
|---|
| 1261 | cell i in a 1D pre population of size n should connect to all cells |
|---|
| 1262 | in row i of a 2D post population of size (n,m). |
|---|
| 1263 | """ |
|---|
| 1264 | c = OneToOneConnector() |
|---|
| 1265 | return c.connect(self) |
|---|
| 1266 | |
|---|
| 1267 | def _fixedProbability(self, parameters): |
|---|
| 1268 | """ |
|---|
| 1269 | For each pair of pre-post cells, the connection probability is constant. |
|---|
| 1270 | """ |
|---|
| 1271 | allow_self_connections = True |
|---|
| 1272 | try: |
|---|
| 1273 | p_connect = float(parameters) |
|---|
| 1274 | except TypeError: |
|---|
| 1275 | p_connect = parameters['p_connect'] |
|---|
| 1276 | if parameters.has_key('allow_self_connections'): |
|---|
| 1277 | allow_self_connections = parameters['allow_self_connections'] |
|---|
| 1278 | |
|---|
| 1279 | c = FixedProbabilityConnector(p_connect=p_connect, |
|---|
| 1280 | allow_self_connections=allow_self_connections) |
|---|
| 1281 | return c.connect(self) |
|---|
| 1282 | |
|---|
| 1283 | def _distanceDependentProbability(self, parameters): |
|---|
| 1284 | """ |
|---|
| 1285 | For each pair of pre-post cells, the connection probability depends on distance. |
|---|
| 1286 | d_expression should be the right-hand side of a valid python expression |
|---|
| 1287 | for probability, involving 'd', e.g. "exp(-abs(d))", or "float(d<3)" |
|---|
| 1288 | """ |
|---|
| 1289 | allow_self_connections = True |
|---|
| 1290 | if type(parameters) == types.StringType: |
|---|
| 1291 | d_expression = parameters |
|---|
| 1292 | else: |
|---|
| 1293 | d_expression = parameters['d_expression'] |
|---|
| 1294 | if parameters.has_key('allow_self_connections'): |
|---|
| 1295 | allow_self_connections = parameters['allow_self_connections'] |
|---|
| 1296 | |
|---|
| 1297 | c = DistanceDependentProbabilityConnector(d_expression=d_expression, |
|---|
| 1298 | allow_self_connections=allow_self_connections) |
|---|
| 1299 | return c.connect(self) |
|---|
| 1300 | |
|---|
| 1301 | def _fixedNumberPre(self, parameters): |
|---|
| 1302 | """Each presynaptic cell makes a fixed number of connections.""" |
|---|
| 1303 | allow_self_connections = True |
|---|
| 1304 | if type(parameters) == types.IntType: |
|---|
| 1305 | n = parameters |
|---|
| 1306 | assert n > 0 |
|---|
| 1307 | fixed = True |
|---|
| 1308 | elif type(parameters) == types.DictType: |
|---|
| 1309 | if parameters.has_key('n'): # all cells have same number of connections |
|---|
| 1310 | n = int(parameters['n']) |
|---|
| 1311 | assert n > 0 |
|---|
| 1312 | fixed = True |
|---|
| 1313 | elif parameters.has_key('rand_distr'): # number of connections per cell follows a distribution |
|---|
| 1314 | rand_distr = parameters['rand_distr'] |
|---|
| 1315 | assert isinstance(rand_distr, RandomDistribution) |
|---|
| 1316 | fixed = False |
|---|
| 1317 | if parameters.has_key('allow_self_connections'): |
|---|
| 1318 | allow_self_connections = parameters['allow_self_connections'] |
|---|
| 1319 | elif isinstance(parameters, RandomDistribution): |
|---|
| 1320 | rand_distr = parameters |
|---|
| 1321 | fixed = False |
|---|
| 1322 | else: |
|---|
| 1323 | raise Exception("Invalid argument type: should be an integer, dictionary or RandomDistribution object.") |
|---|
| 1324 | hoc_commands = [] |
|---|
| 1325 | |
|---|
| 1326 | if self.rng: |
|---|
| 1327 | rng = self.rng |
|---|
| 1328 | else: |
|---|
| 1329 | rng = numpy.random |
|---|
| 1330 | for src in self.pre.gidlist: |
|---|
| 1331 | # pick n neurons at random |
|---|
| 1332 | if not fixed: |
|---|
| 1333 | n = rand_distr.next() |
|---|
| 1334 | for tgt in rng.permutation(self.post.gidlist)[0:n]: |
|---|
| 1335 | if allow_self_connections or (src != tgt): |
|---|
| 1336 | hoc_commands += self.__connect(src, tgt) |
|---|
| 1337 | return hoc_commands |
|---|
| 1338 | |
|---|
| 1339 | def _fixedNumberPost(self, parameters): |
|---|
| 1340 | """Each postsynaptic cell receives a fixed number of connections.""" |
|---|
| 1341 | allow_self_connections = True |
|---|
| 1342 | if type(parameters) == types.IntType: |
|---|
| 1343 | n = parameters |
|---|
| 1344 | assert n > 0 |
|---|
| 1345 | fixed = True |
|---|
| 1346 | elif type(parameters) == types.DictType: |
|---|
| 1347 | if parameters.has_key('n'): # all cells have same number of connections |
|---|
| 1348 | n = int(parameters['n']) |
|---|
| 1349 | assert n > 0 |
|---|
| 1350 | fixed = True |
|---|
| 1351 | elif parameters.has_key('rand_distr'): # number of connections per cell follows a distribution |
|---|
| 1352 | rand_distr = parameters['rand_distr'] |
|---|
| 1353 | assert isinstance(rand_distr, RandomDistribution) |
|---|
| 1354 | fixed = False |
|---|
| 1355 | if parameters.has_key('allow_self_connections'): |
|---|
| 1356 | allow_self_connections = parameters['allow_self_connections'] |
|---|
| 1357 | elif isinstance(parameters, RandomDistribution): |
|---|
| 1358 | rand_distr = parameters |
|---|
| 1359 | fixed = False |
|---|
| 1360 | else: |
|---|
| 1361 | raise Exception("Invalid argument type: should be an integer, dictionary or RandomDistribution object.") |
|---|
| 1362 | hoc_commands = [] |
|---|
| 1363 | |
|---|
| 1364 | if self.rng: |
|---|
| 1365 | rng = self.rng |
|---|
| 1366 | else: |
|---|
| 1367 | rng = numpy.random |
|---|
| 1368 | for tgt in self.post.gidlist: |
|---|
| 1369 | # pick n neurons at random |
|---|
| 1370 | if not fixed: |
|---|
| 1371 | n = rand_distr.next() |
|---|
| 1372 | for src in rng.permutation(self.pre.gidlist)[0:n]: |
|---|
| 1373 | if allow_self_connections or (src != tgt): |
|---|
| 1374 | hoc_commands += self.__connect(src, tgt) |
|---|
| 1375 | return hoc_commands |
|---|
| 1376 | |
|---|
| 1377 | def _fromFile(self, parameters): |
|---|
| 1378 | """ |
|---|
| 1379 | Load connections from a file. |
|---|
| 1380 | """ |
|---|
| 1381 | lines =[] |
|---|
| 1382 | if type(parameters) == types.FileType: |
|---|
| 1383 | fileobj = parameters |
|---|
| 1384 | # should check here that fileobj is already open for reading |
|---|
| 1385 | lines = fileobj.readlines() |
|---|
| 1386 | elif type(parameters) == types.StringType: |
|---|
| 1387 | filename = parameters |
|---|
| 1388 | # now open the file... |
|---|
| 1389 | f = open(filename,'r',10000) |
|---|
| 1390 | lines = f.readlines() |
|---|
| 1391 | elif type(parameters) == types.DictType: |
|---|
| 1392 | # dict could have 'filename' key or 'file' key |
|---|
| 1393 | # implement this... |
|---|
| 1394 | raise "Argument type not yet implemented" |
|---|
| 1395 | |
|---|
| 1396 | # We read the file and gather all the data in a list of tuples (one per line) |
|---|
| 1397 | input_tuples = [] |
|---|
| 1398 | for line in lines: |
|---|
| 1399 | single_line = line.rstrip() |
|---|
| 1400 | src, tgt, w, d = single_line.split("\t", 4) |
|---|
| 1401 | src = "[%s" % src.split("[",1)[1] |
|---|
| 1402 | tgt = "[%s" % tgt.split("[",1)[1] |
|---|
| 1403 | input_tuples.append((eval(src), eval(tgt), float(w), float(d))) |
|---|
| 1404 | f.close() |
|---|
| 1405 | return self._fromList(input_tuples) |
|---|
| 1406 | |
|---|
| 1407 | def _fromList(self, conn_list): |
|---|
| 1408 | """ |
|---|
| 1409 | Read connections from a list of tuples, |
|---|
| 1410 | containing [pre_addr, post_addr, weight, delay] |
|---|
| 1411 | where pre_addr and post_addr are both neuron addresses, i.e. tuples or |
|---|
| 1412 | lists containing the neuron array coordinates. |
|---|
| 1413 | """ |
|---|
| 1414 | hoc_commands = [] |
|---|
| 1415 | |
|---|
| 1416 | # Then we go through those tuple and extract the fields |
|---|
| 1417 | for i in xrange(len(conn_list)): |
|---|
| 1418 | src, tgt, weight, delay = conn_list[i][:] |
|---|
| 1419 | src = self.pre[tuple(src)] |
|---|
| 1420 | tgt = self.post[tuple(tgt)] |
|---|
| 1421 | hoc_commands += self.__connect(src, tgt) |
|---|
| 1422 | hoc_commands += ['%s.object(%d).weight = %f' % (self.hoc_label, i, float(weight)), |
|---|
| 1423 | '%s.object(%d).delay = %f' % (self.hoc_label, i, float(delay))] |
|---|
| 1424 | return hoc_commands |
|---|
| 1425 | |
|---|
| 1426 | # --- Methods for setting connection parameters ---------------------------- |
|---|
| 1427 | |
|---|
| 1428 | def setWeights(self, w): |
|---|
| 1429 | """ |
|---|
| 1430 | w can be a single number, in which case all weights are set to this |
|---|
| 1431 | value, or a list/1D array of length equal to the number of connections |
|---|
| 1432 | in the population. |
|---|
| 1433 | Weights should be in nA for current-based and µS for conductance-based |
|---|
| 1434 | synapses. |
|---|
| 1435 | """ |
|---|
| 1436 | if isinstance(w, float) or isinstance(w, int): |
|---|
| 1437 | loop = ['for tmp = 0, %d {' % (len(self)-1), |
|---|
| 1438 | '%s.object(tmp).weight = %f ' % (self.hoc_label, float(w)), |
|---|
| 1439 | '}'] |
|---|
| 1440 | hoc_code = "".join(loop) |
|---|
| 1441 | hoc_commands = [ 'cmd = "%s"' % hoc_code, |
|---|
| 1442 | 'success = execute1(cmd)'] |
|---|
| 1443 | elif isinstance(w, list) or isinstance(w, numpy.ndarray): |
|---|
| 1444 | hoc_commands = [] |
|---|
| 1445 | assert len(w) == len(self), "List of weights has length %d, Projection %s has length %d" % (len(w), self.label, len(self)) |
|---|
| 1446 | for i, weight in enumerate(w): |
|---|
| 1447 | hoc_commands += ['%s.object(%d).weight = %f' % (self.hoc_label, i, weight)] |
|---|
| 1448 | else: |
|---|
| 1449 | raise TypeError("Argument should be a numeric type (int, float...), a list, or a numpy array.") |
|---|
| 1450 | hoc_execute(hoc_commands, "--- Projection[%s].__setWeights__() ---" % self.label) |
|---|
| 1451 | |
|---|
| 1452 | def randomizeWeights(self, rand_distr): |
|---|
| 1453 | """ |
|---|
| 1454 | Set weights to random values taken from rand_distr. |
|---|
| 1455 | """ |
|---|
| 1456 | # If we have a native rng, we do the loops in hoc. Otherwise, we do the loops in |
|---|
| 1457 | # Python |
|---|
| 1458 | if isinstance(rand_distr.rng, NativeRNG): |
|---|
| 1459 | paramfmt = "%f,"*len(rand_distr.parameters); paramfmt = paramfmt.strip(',') |
|---|
| 1460 | distr_params = paramfmt % tuple(rand_distr.parameters) |
|---|
| 1461 | hoc_commands = ['rng = new Random(%d)' % 0 or distribution.rng.seed, |
|---|
| 1462 | 'tmp = rng.%s(%s)' % (rand_distr.name, distr_params)] |
|---|
| 1463 | |
|---|
| 1464 | loop = ['for tmp = 0, %d {' %(len(self)-1), |
|---|
| 1465 | '%s.object(tmp).weight = rng.repick() ' %(self.hoc_label), |
|---|
| 1466 | '}'] |
|---|
| 1467 | hoc_code = "".join(loop) |
|---|
| 1468 | hoc_commands += ['cmd = "%s"' %hoc_code, |
|---|
| 1469 | 'success = execute1(cmd)'] |
|---|
| 1470 | else: |
|---|
| 1471 | hoc_commands = [] |
|---|
| 1472 | for i in xrange(len(self)): |
|---|
| 1473 | hoc_commands += ['%s.object(%d).weight = %f' % (self.hoc_label, i, float(rand_distr.next()))] |
|---|
| 1474 | hoc_execute(hoc_commands, "--- Projection[%s].__randomizeWeights__() ---" %self.label) |
|---|
| 1475 | |
|---|
| 1476 | def setDelays(self, d): |
|---|
| 1477 | """ |
|---|
| 1478 | d can be a single number, in which case all delays are set to this |
|---|
| 1479 | value, or a list/1D array of length equal to the number of connections |
|---|
| 1480 | in the population. |
|---|
| 1481 | """ |
|---|
| 1482 | if isinstance(d, float) or isinstance(d, int): |
|---|
| 1483 | if d < get_min_delay(): |
|---|
| 1484 | raise Exception("Delays must be greater than or equal to the minimum delay, currently %g ms" % get_min_delay()) |
|---|
| 1485 | loop = ['for tmp = 0, %d {' %(len(self)-1), |
|---|
| 1486 | '%s.object(tmp).delay = %f ' % (self.hoc_label, float(d)), |
|---|
| 1487 | '}'] |
|---|
| 1488 | hoc_code = "".join(loop) |
|---|
| 1489 | hoc_commands = [ 'cmd = "%s"' %hoc_code, |
|---|
| 1490 | 'success = execute1(cmd)'] |
|---|
| 1491 | # if we have STDP, need to update pre2wa and post2wa delays as well |
|---|
| 1492 | if self.synapse_dynamics and self.synapse_dynamics.slow: |
|---|
| 1493 | ddf = self.synapse_dynamics.slow.dendritic_delay_fraction |
|---|
| 1494 | loop = ['for i = 0, %d {' %(len(self)-1), |
|---|
| 1495 | '%s_pre2wa[i].delay = %f ' % (self.hoc_label, float(d)*(1-ddf)), |
|---|
| 1496 | '%s_post2wa[i].delay = %f ' % (self.hoc_label, float(d)*ddf), |
|---|
| 1497 | '}'] |
|---|
| 1498 | hoc_commands = [ 'cmd = "%s"' % "".join(loop), |
|---|
| 1499 | 'success = execute1(cmd)'] |
|---|
| 1500 | elif isinstance(d, list) or isinstance(d, numpy.ndarray): |
|---|
| 1501 | # need check for min_delay here |
|---|
| 1502 | hoc_commands = [] |
|---|
| 1503 | assert len(d) == len(self), "List of delays has length %d, Projection %s has length %d" % (len(d), self.label, len(self)) |
|---|
| 1504 | for i, delay in enumerate(d): |
|---|
| 1505 | hoc_commands += ['%s.object(%d).delay = %f' % (self.hoc_label, i, delay)] |
|---|
| 1506 | # if we have STDP, need to update pre2wa and post2wa delays as well |
|---|
| 1507 | if self.synapse_dynamics and self.synapse_dynamics.slow: |
|---|
| 1508 | ddf = self.synapse_dynamics.slow.dendritic_delay_fraction |
|---|
| 1509 | for i, delay in enumerate(d): |
|---|
| 1510 | hoc_commands += ['%s_pre2wa[%d].delay = %f' % (self.hoc_label, i, delay*(1-ddf)), |
|---|
| 1511 | '%s_post2wa[%d].delay = %f' % (self.hoc_label, i, delay*ddf)] |
|---|
| 1512 | else: |
|---|
| 1513 | raise TypeError("Argument should be a numeric type (int, float...), a list, or a numpy array.") |
|---|
| 1514 | hoc_execute(hoc_commands, "--- Projection[%s].__setDelays__() ---" %self.label) |
|---|
| 1515 | |
|---|
| 1516 | def randomizeDelays(self, rand_distr): |
|---|
| 1517 | """ |
|---|
| 1518 | Set delays to random values taken from rand_distr. |
|---|
| 1519 | """ |
|---|
| 1520 | # If we have a native rng, we do the loops in hoc. Otherwise, we do the loops in |
|---|
| 1521 | # Python |
|---|
| 1522 | # if we have STDP, need to update pre2wa and post2wa delays as well |
|---|
| 1523 | if isinstance(rand_distr.rng, NativeRNG): |
|---|
| 1524 | paramfmt = "%f,"*len(rand_distr.parameters); paramfmt = paramfmt.strip(',') |
|---|
| 1525 | distr_params = paramfmt % tuple(rand_distr.parameters) |
|---|
| 1526 | hoc_commands = ['rng = new Random(%d)' % 0 or distribution.rng.seed, |
|---|
| 1527 | 'tmp = rng.%s(%s)' % (rand_distr.name, distr_params)] |
|---|
| 1528 | if self.synapse_dynamics and self.synapse_dynamics.slow: |
|---|
| 1529 | ddf = self.synapse_dynamics.slow.dendritic_delay_fraction |
|---|
| 1530 | hoc_commands += ['ddf = %g' % ddf] |
|---|
| 1531 | loop = ['for i = 0, %d {' % (len(self)-1), |
|---|
| 1532 | 'rr = rng.repick()', |
|---|
| 1533 | '%s.object(i).delay = rr ' % (self.hoc_label), |
|---|
| 1534 | '%s_pre2wa[i].delay = rr*(1-ddf)' % (self.hoc_label), |
|---|
| 1535 | '%s_post2wa[i].delay = rr*ddf' % (self.hoc_label), |
|---|
| 1536 | '}'] |
|---|
| 1537 | else: |
|---|
| 1538 | loop = ['for tmp = 0, %d {' % (len(self)-1), |
|---|
| 1539 | '%s.object(tmp).delay = rng.repick() ' %(self.hoc_label), |
|---|
| 1540 | '}'] |
|---|
| 1541 | hoc_code = "".join(loop) |
|---|
| 1542 | hoc_commands += ['cmd = "%s"' % hoc_code, |
|---|
| 1543 | 'success = execute1(cmd)'] |
|---|
| 1544 | else: |
|---|
| 1545 | hoc_commands = [] |
|---|
| 1546 | if self.synapse_dynamics and self.synapse_dynamics.slow: |
|---|
| 1547 | ddf = self.synapse_dynamics.slow.dendritic_delay_fraction |
|---|
| 1548 | for i in xrange(len(self)): |
|---|
| 1549 | rr = float(rand_distr.next()) |
|---|
| 1550 | hoc_commands += ['%s.object(%d).delay = %f' % (self.hoc_label, i, rr), |
|---|
| 1551 | '%s_pre2wa[%d].delay = %f' % (self.hoc_label, i, rr*(1-ddf)), |
|---|
| 1552 | '%s_post2wa[%d].delay = %f' % (self.hoc_label, i, rr*ddf)] |
|---|
| 1553 | else: |
|---|
| 1554 | for i in xrange(len(self)): |
|---|
| 1555 | hoc_commands += ['%s.object(%d).delay = %f' % (self.hoc_label, i, float(rand_distr.next()))] |
|---|
| 1556 | hoc_execute(hoc_commands, "--- Projection[%s].__randomizeDelays__() ---" %self.label) |
|---|
| 1557 | |
|---|
| 1558 | def setSynapseDynamics(self, param, value): |
|---|
| 1559 | """ |
|---|
| 1560 | Set parameters of the synapse dynamics linked with the projection |
|---|
| 1561 | """ |
|---|
| 1562 | raise Exception("Method not yet implemented !") |
|---|
| 1563 | |
|---|
| 1564 | def randomizeSynapseDynamics(self, param, rand_distr): |
|---|
| 1565 | """ |
|---|
| 1566 | Set parameters of the synapse dynamics to values taken from rand_distr |
|---|
| 1567 | """ |
|---|
| 1568 | raise Exception("Method not yet implemented !") |
|---|
| 1569 | |
|---|
| 1570 | def setTopographicDelays(self, delay_rule, rand_distr=None, mask=None, scale_factor=1.0): |
|---|
| 1571 | """ |
|---|
| 1572 | Set delays according to a connection rule expressed in delay_rule, based |
|---|
| 1573 | on the delay distance 'd' and an (optional) rng 'rng'. For example, |
|---|
| 1574 | the rule can be "rng*d + 0.5", with "a" extracted from the rng and |
|---|
| 1575 | d being the distance. |
|---|
| 1576 | """ |
|---|
| 1577 | # if we have STDP, need to update pre2wa and post2wa delays as well |
|---|
| 1578 | if self.synapse_dynamics and self.synapse_dynamics.slow: |
|---|
| 1579 | raise Exception("setTopographicDelays() does not currently work with STDP") |
|---|
| 1580 | hoc_commands = [] |
|---|
| 1581 | |
|---|
| 1582 | if rand_distr==None: |
|---|
| 1583 | for i in xrange(len(self)): |
|---|
| 1584 | src = self.connections[i][0] |
|---|
| 1585 | tgt = self.connections[i][1] |
|---|
| 1586 | # calculate the distance between the two cells |
|---|
| 1587 | idx_src = numpy.where(self.pre.fullgidlist == src)[0][0] |
|---|
| 1588 | idx_tgt = numpy.where(self.post.fullgidlist == tgt)[0][0] |
|---|
| 1589 | dist = common.distance(self.pre.fullgidlist[idx_src], self.post.fullgidlist[idx_tgt], |
|---|
| 1590 | mask, scale_factor) |
|---|
| 1591 | # then evaluate the delay according to the delay rule |
|---|
| 1592 | delay = eval(delay_rule.replace('d', '%f' %dist)) |
|---|
| 1593 | hoc_commands += ['%s.object(%d).delay = %f' % (self.hoc_label, i, float(delay))] |
|---|
| 1594 | else: |
|---|
| 1595 | if isinstance(rand_distr.rng, NativeRNG): |
|---|
| 1596 | paramfmt = "%f,"*len(rand_distr.parameters); paramfmt = paramfmt.strip(',') |
|---|
| 1597 | distr_params = paramfmt % tuple(rand_distr.parameters) |
|---|
| 1598 | hoc_commands += ['rng = new Random(%d)' % 0 or distribution.rng.seed, |
|---|
| 1599 | 'tmp = rng.%s(%s)' % (rand_distr.name, distr_params)] |
|---|
| 1600 | for i in xrange(len(self)): |
|---|
| 1601 | src = self.connections[i][0] |
|---|
| 1602 | tgt = self.connections[i][1] |
|---|
| 1603 | # calculate the distance between the two cells |
|---|
| 1604 | idx_src = self.pre.fullgidlist.index(src) |
|---|
| 1605 | idx_tgt = self.post.fullgidlist.index(tgt) |
|---|
| 1606 | dist = common.distance(self.pre.fullgidlist[idx_src], self.post.fullgidlist[idx_tgt], |
|---|
| 1607 | mask, scale_factor) |
|---|
| 1608 | # then evaluate the delay according to the delay rule |
|---|
| 1609 | delay = delay_rule.replace('d', '%f' % dist) |
|---|
| 1610 | #delay = eval(delay.replace('rng', '%f' % HocToPy.get('rng.repick()', 'float'))) |
|---|
| 1611 | delay = eval(delay.replace('rng', '%f' % h.rng.repick())) |
|---|
| 1612 | hoc_commands += ['%s.object(%d).delay = %f' % (self.hoc_label, i, float(delay))] |
|---|
| 1613 | else: |
|---|
| 1614 | for i in xrange(len(self)): |
|---|
| 1615 | src = self.connections[i][0] |
|---|
| 1616 | tgt = self.connections[i][1] |
|---|
| 1617 | # calculate the distance between the 2 cells : |
|---|
| 1618 | idx_src = self.pre.fullgidlist.index(src) |
|---|
| 1619 | idx_tgt = self.post.fullgidlist.index(tgt) |
|---|
| 1620 | dist = common.distance(self.pre.fullgidlist[idx_src], self.post.fullgidlist[idx_tgt], |
|---|
| 1621 | mask, scale_factor) |
|---|
| 1622 | # then evaluate the delay according to the delay rule : |
|---|
| 1623 | delay = delay_rule.replace('d', '%f' %dist) |
|---|
| 1624 | delay = eval(delay.replace('rng', '%f' %rand_distr.next())) |
|---|
| 1625 | hoc_commands += ['%s.object(%d).delay = %f' % (self.hoc_label, i, float(delay))] |
|---|
| 1626 | |
|---|
| 1627 | hoc_execute(hoc_commands, "--- Projection[%s].__setTopographicDelays__() ---" %self.label) |
|---|
| 1628 | |
|---|
| 1629 | # --- Methods relating to synaptic plasticity ------------------------------ |
|---|
| 1630 | |
|---|
| 1631 | def _setupSTDP(self, stdp_model, parameterDict): |
|---|
| 1632 | """Set-up STDP.""" |
|---|
| 1633 | ddf = self.synapse_dynamics.slow.dendritic_delay_fraction |
|---|
| 1634 | if ddf > 0.5 and nhost > 1: |
|---|
| 1635 | # depending on delays, can run into problems with the delay from the |
|---|
| 1636 | # pre-synaptic neuron to the weight-adjuster mechanism being zero. |
|---|
| 1637 | # The best (only?) solution would be to create connections on the |
|---|
| 1638 | # node with the pre-synaptic neurons for ddf>0.5 and on the node |
|---|
| 1639 | # with the post-synaptic neuron (as is done now) for ddf<0.5 |
|---|
| 1640 | raise Exception("STDP with dendritic_delay_fraction > 0.5 is not yet supported for parallel computation.") |
|---|
| 1641 | # Define the objref to handle plasticity |
|---|
| 1642 | hoc_commands = ['objref %s_wa[%d]' %(self.hoc_label, len(self)), |
|---|
| 1643 | 'objref %s_pre2wa[%d]' %(self.hoc_label, len(self)), |
|---|
| 1644 | 'objref %s_post2wa[%d]' %(self.hoc_label, len(self))] |
|---|
| 1645 | # For each connection |
|---|
| 1646 | for i in xrange(len(self)): |
|---|
| 1647 | src = self.connections[i][0] |
|---|
| 1648 | tgt = self.connections[i][1] |
|---|
| 1649 | # we reproduce the structure of STDP that can be found in layerConn.hoc |
|---|
| 1650 | hoc_commands += [ |
|---|
| 1651 | '%s_wa[%d] = new %s(0.5)' %(self.hoc_label, i, stdp_model), |
|---|
| 1652 | '%s_pre2wa[%d] = pc.gid_connect(%d, %s_wa[%d])' % (self.hoc_label, i, src, self.hoc_label, i), |
|---|
| 1653 | '%s_pre2wa[%d].threshold = %s.object(%d).threshold' %(self.hoc_label, i, self.hoc_label, i), |
|---|
| 1654 | '%s_pre2wa[%d].delay = %s.object(%d).delay * %g' % (self.hoc_label, i, self.hoc_label, i, (1-ddf)), |
|---|
| 1655 | '%s_pre2wa[%d].weight = 1' %(self.hoc_label, i), |
|---|
| 1656 | #'%s_post2wa[%d] = pc.gid_connect(%d, %s_wa[%d])' %(self.hoc_label, i, tgt, self.hoc_label, i), |
|---|
| 1657 | # directly create NetCon as wa is on the same machine as the post-synaptic cell |
|---|
| 1658 | '%s_post2wa[%d] = new NetCon(%s.object(%d).source, %s_wa[%d])' % (self.hoc_label, i, self.post.hoc_label, self.post.gidlist.index(tgt), self.hoc_label,i), |
|---|
| 1659 | '%s_post2wa[%d].threshold = 1' %(self.hoc_label, i), |
|---|
| 1660 | '%s_post2wa[%d].delay = %s.object(%d).delay * %g' % (self.hoc_label, i, self.hoc_label, i, ddf), |
|---|
| 1661 | '%s_post2wa[%d].weight = -1' % (self.hoc_label, i), |
|---|
| 1662 | 'setpointer %s_wa[%d].wsyn, %s.object(%d).weight' %(self.hoc_label, i, self.hoc_label, i)] |
|---|
| 1663 | # then update the parameters |
|---|
| 1664 | for param, val in parameterDict.items(): |
|---|
| 1665 | hoc_commands += ['%s_wa[%d].%s = %f' % (self.hoc_label, i, param, val)] |
|---|
| 1666 | hoc_execute(hoc_commands, "--- Projection[%s].__setupSTDP__() ---" %self.label) |
|---|
| 1667 | # debugging |
|---|
| 1668 | #pre2wa_array = getattr(h, "%s_pre2wa" % self.hoc_label) |
|---|
| 1669 | #for i in xrange(len(self)): |
|---|
| 1670 | # print pre2wa_array[i].delay, |
|---|
| 1671 | #print |
|---|
| 1672 | #post2wa_array = getattr(h, "%s_post2wa" % self.hoc_label) |
|---|
| 1673 | #for i in xrange(len(self)): |
|---|
| 1674 | # print post2wa_array[i].delay, |
|---|
| 1675 | |
|---|
| 1676 | # --- Methods for writing/reading information to/from file. ---------------- |
|---|
| 1677 | |
|---|
| 1678 | def getWeights(self, format='list', gather=True): |
|---|
| 1679 | """ |
|---|
| 1680 | Possible formats are: a list of length equal to the number of connections |
|---|
| 1681 | in the projection, a 2D weight array (with zero or None for non-existent |
|---|
| 1682 | connections). |
|---|
| 1683 | """ |
|---|
| 1684 | assert format in ('list', 'array'), "`format` is '%s', should be one of 'list', 'array'" % format |
|---|
| 1685 | if format == 'list': |
|---|
| 1686 | values = [getattr(h, self.hoc_label).object(i).weight[0] for i in range(len(self))] |
|---|
| 1687 | elif format == 'array': |
|---|
| 1688 | values = numpy.zeros((len(self.pre), len(self.post)), 'float') |
|---|
| 1689 | for i in xrange(len(self)): |
|---|
| 1690 | weight = getattr(h, self.hoc_label).object(i).weight[0] |
|---|
| 1691 | values[self.connections[i][0]-self.pre.gid_start, |
|---|
| 1692 | self.connections[i][1]-self.post.gid_start] = weight |
|---|
| 1693 | return values |
|---|
| 1694 | |
|---|
| 1695 | def getDelays(self, format='list', gather=True): |
|---|
| 1696 | """ |
|---|
| 1697 | Possible formats are: a list of length equal to the number of connections |
|---|
| 1698 | in the projection, a 2D delay array (with None or 1e12 for non-existent |
|---|
| 1699 | connections). |
|---|
| 1700 | """ |
|---|
| 1701 | assert format in ('list', 'array'), "`format` is '%s', should be one of 'list', 'array'" % format |
|---|
| 1702 | if format == 'list': |
|---|
| 1703 | values = [getattr(h, self.hoc_label).object(i).delay for i in range(len(self))] |
|---|
| 1704 | elif format == 'array': |
|---|
| 1705 | raise Exception("Not yet implemented") |
|---|
| 1706 | return values |
|---|
| 1707 | |
|---|
| 1708 | def saveConnections(self, filename, gather=False): |
|---|
| 1709 | """Save connections to file in a format suitable for reading in with the |
|---|
| 1710 | 'fromFile' method.""" |
|---|
| 1711 | if gather: |
|---|
| 1712 | raise Exception("saveConnections() with gather=True not yet implemented") |
|---|
| 1713 | elif num_processes() > 1: |
|---|
| 1714 | filename += '.%d' % rank() |
|---|
| 1715 | hoc_comment("--- Projection[%s].__saveConnections__() ---" % self.label) |
|---|
| 1716 | f = open(filename, 'w', 10000) |
|---|
| 1717 | for i in xrange(len(self)): |
|---|
| 1718 | src = self.connections[i][0] |
|---|
| 1719 | tgt = self.connections[i][1] |
|---|
| 1720 | line = "%s%s\t%s%s\t%g\t%g\n" % (self.pre.hoc_label, |
|---|
| 1721 | self.pre.locate(src), |
|---|
| 1722 | self.post.hoc_label, |
|---|
| 1723 | self.post.locate(tgt), |
|---|
| 1724 | getattr(h, self.hoc_label).object(i).weight[0], |
|---|
| 1725 | getattr(h, self.hoc_label).object(i).delay) |
|---|
| 1726 | line = line.replace('(','[').replace(')',']') |
|---|
| 1727 | f.write(line) |
|---|
| 1728 | f.close() |
|---|
| 1729 | |
|---|
| 1730 | def printWeights(self, filename, format='list', gather=True): |
|---|
| 1731 | """Print synaptic weights to file.""" |
|---|
| 1732 | global myid |
|---|
| 1733 | |
|---|
| 1734 | hoc_execute(['objref weight_list']) |
|---|
| 1735 | hoc_commands = [] |
|---|
| 1736 | hoc_comment("--- Projection[%s].__printWeights__() ---" %self.label) |
|---|
| 1737 | |
|---|
| 1738 | # Here we have to deal with the gather options. If we gather, then each |
|---|
| 1739 | # slave node posts its list of weights to the master node. |
|---|
| 1740 | if gather and myid !=0: |
|---|
| 1741 | if format == 'array': raise Exception("Gather not implemented for 'array'.") |
|---|
| 1742 | hoc_commands += ['weight_list = new Vector()'] |
|---|
| 1743 | for i in xrange(len(self)): |
|---|
| 1744 | #weight = HocToPy.get('%s.object(%d).weight' % (self.hoc_label, i),'float') |
|---|
| 1745 | weight = getattr(h, self.hoc_label).object(i).weight[0] |
|---|
| 1746 | hoc_commands += ['weight_list = weight_list.append(%f)' % weight] |
|---|
| 1747 | hoc_commands += ['tmp = pc.post("%s.weight_list.node[%d]", weight_list)' %(self.hoc_label, myid)] |
|---|
| 1748 | hoc_execute(hoc_commands, "--- [Posting weights list to master] ---") |
|---|
| 1749 | |
|---|
| 1750 | if not gather or myid == 0: |
|---|
| 1751 | if hasattr(filename, 'write'): # filename should be renamed to file, to allow open file objects to be used |
|---|
| 1752 | f = filename |
|---|
| 1753 | else: |
|---|
| 1754 | f = open(filename,'w',10000) |
|---|
| 1755 | if format == 'list': |
|---|
| 1756 | for i in xrange(len(self)): |
|---|
| 1757 | #weight = "%f\n" %HocToPy.get('%s.object(%d).weight' % (self.hoc_label, i),'float') |
|---|
| 1758 | weight = getattr(h, self.hoc_label).object(i).weight[0] |
|---|
| 1759 | f.write("%f\n" % weight) |
|---|
| 1760 | elif format == 'array': |
|---|
| 1761 | weights = numpy.zeros((len(self.pre), len(self.post)), 'float') |
|---|
| 1762 | fmt = "%g "*len(self.post) + "\n" |
|---|
| 1763 | for i in xrange(len(self)): |
|---|
| 1764 | weight = getattr(h, self.hoc_label).object(i).weight[0] |
|---|
| 1765 | weights[self.connections[i][0]-self.pre.gid_start, |
|---|
| 1766 | self.connections[i][1]-self.post.gid_start] = weight |
|---|
| 1767 | for row in weights: |
|---|
| 1768 | f.write(fmt % tuple(row)) |
|---|
| 1769 | else: |
|---|
| 1770 | raise Exception("Valid formats are 'list' and 'array'") |
|---|
| 1771 | if gather: |
|---|
| 1772 | if format == 'array' and nhost > 1: raise Exception("Gather not implemented for array format.") |
|---|
| 1773 | for id in range (1, nhost): |
|---|
| 1774 | hoc_commands = ['weight_list = new Vector()'] |
|---|
| 1775 | hoc_commands += ['tmp = pc.take("%s.weight_list.node[%d]", weight_list)' %(self.hoc_label, id)] |
|---|
| 1776 | hoc_execute(hoc_commands) |
|---|
| 1777 | #for j in xrange(HocToPy.get('weight_list.size()', 'int')): |
|---|
| 1778 | for j in xrange(int(h.weight_list.size)): |
|---|
| 1779 | #weight = "%f\n" %HocToPy.get('weight_list.x[%d]' %j, 'float') |
|---|
| 1780 | weight = h.weight_list.x[j] |
|---|
| 1781 | f.write("%f\n" % weight) |
|---|
| 1782 | if not hasattr(filename, 'write'): |
|---|
| 1783 | f.close() |
|---|
| 1784 | |
|---|
| 1785 | def weightHistogram(self, min=None, max=None, nbins=10): |
|---|
| 1786 | """ |
|---|
| 1787 | Return a histogram of synaptic weights. |
|---|
| 1788 | If min and max are not given, the minimum and maximum weights are |
|---|
| 1789 | calculated automatically. |
|---|
| 1790 | """ |
|---|
| 1791 | # it is arguable whether functions operating on the set of weights |
|---|
| 1792 | # should be put here or in an external module. |
|---|
| 1793 | bins = numpy.arange(min, max, (max-min)/nbins) |
|---|
| 1794 | return numpy.histogram(self.getWeights(), bins) # returns n, bins |
|---|
| 1795 | |
|---|
| 1796 | # ============================================================================== |
|---|
| 1797 | # Utility classes |
|---|
| 1798 | # ============================================================================== |
|---|
| 1799 | |
|---|
| 1800 | Timer = common.Timer |
|---|
| 1801 | |
|---|
| 1802 | # ============================================================================== |
|---|