| 1 | from pyNN import common, errors, random, standardmodels, recording |
|---|
| 2 | from pyNN.common import populations |
|---|
| 3 | from pyNN.parameters import Sequence, ParameterSpace |
|---|
| 4 | from nose.tools import assert_equal, assert_raises |
|---|
| 5 | import numpy |
|---|
| 6 | from mock import Mock, patch |
|---|
| 7 | from pyNN.utility import assert_arrays_equal |
|---|
| 8 | from pyNN import core |
|---|
| 9 | from lazyarray import VectorizedIterable |
|---|
| 10 | |
|---|
| 11 | builtin_open = open |
|---|
| 12 | id_map = {'larry': 0, 'curly': 1, 'moe': 2, 'joe': 3, 'william': 4, 'jack': 5, 'averell': 6} |
|---|
| 13 | |
|---|
| 14 | |
|---|
| 15 | class MockSimulator(object): |
|---|
| 16 | class MockState(object): |
|---|
| 17 | mpi_rank = 1 |
|---|
| 18 | num_processes = 3 |
|---|
| 19 | state = MockState() |
|---|
| 20 | |
|---|
| 21 | class MockStandardCell(standardmodels.StandardCellType): |
|---|
| 22 | recordable = ['v', 'spikes'] |
|---|
| 23 | default_parameters = {'tau_m': 999.9, 'i_offset': 321.0, 'spike_times': Sequence([0,1,2]), 'foo': 33.3} |
|---|
| 24 | translations = {'tau_m': None, 'i_offset': None, 'spike_times': None, 'foo': None} |
|---|
| 25 | @classmethod |
|---|
| 26 | def translate(cls, parameters): |
|---|
| 27 | return parameters |
|---|
| 28 | @classmethod |
|---|
| 29 | def computed_parameters(cls): |
|---|
| 30 | return [] |
|---|
| 31 | |
|---|
| 32 | class MockPopulation(populations.BasePopulation): |
|---|
| 33 | _simulator = MockSimulator |
|---|
| 34 | size = 13 |
|---|
| 35 | all_cells = numpy.arange(100, 113) |
|---|
| 36 | _mask_local = numpy.array([0,1,0,1,0,1,0,1,0,1,0,1,0], bool) |
|---|
| 37 | local_cells = all_cells[_mask_local] |
|---|
| 38 | positions = numpy.arange(39).reshape((13,3)).T |
|---|
| 39 | label = "mock_population" |
|---|
| 40 | celltype = MockStandardCell({}) |
|---|
| 41 | initial_values = {"foo": core.LazyArray(numpy.array((98, 100, 102)), shape=(3,))} |
|---|
| 42 | _assembly_class = populations.Assembly |
|---|
| 43 | |
|---|
| 44 | def _get_view(self, selector, label=None): |
|---|
| 45 | return populations.PopulationView(self, selector, label) |
|---|
| 46 | |
|---|
| 47 | def id_to_index(self, id): |
|---|
| 48 | if id.label in id_map: |
|---|
| 49 | return id_map[id.label] |
|---|
| 50 | else: |
|---|
| 51 | raise Exception("Invalid ID") |
|---|
| 52 | |
|---|
| 53 | def id_to_local_index(self, id): |
|---|
| 54 | if id.label in id_map: |
|---|
| 55 | global_index = id_map[id.label] |
|---|
| 56 | if global_index%2 == 1: |
|---|
| 57 | return global_index/2 |
|---|
| 58 | else: |
|---|
| 59 | raise Exception("ID not on this node") |
|---|
| 60 | else: |
|---|
| 61 | raise Exception("Invalid ID") |
|---|
| 62 | |
|---|
| 63 | class MockID(object): |
|---|
| 64 | def __init__(self, label, parent): |
|---|
| 65 | self.label = label |
|---|
| 66 | self.parent = parent |
|---|
| 67 | |
|---|
| 68 | def test__getitem__int(): |
|---|
| 69 | p = MockPopulation() |
|---|
| 70 | assert_equal(p[0], 100) |
|---|
| 71 | assert_equal(p[12], 112) |
|---|
| 72 | assert_raises(IndexError, p.__getitem__, 13) |
|---|
| 73 | assert_equal(p[-1], 112) |
|---|
| 74 | |
|---|
| 75 | def test__getitem__slice(): |
|---|
| 76 | orig_PV = populations.PopulationView |
|---|
| 77 | populations.PopulationView = Mock() |
|---|
| 78 | p = MockPopulation() |
|---|
| 79 | pv = p[3:9] |
|---|
| 80 | populations.PopulationView.assert_called_with(p, slice(3,9,None), None) |
|---|
| 81 | populations.PopulationView = orig_PV |
|---|
| 82 | |
|---|
| 83 | def test__getitem__list(): |
|---|
| 84 | orig_PV = populations.PopulationView |
|---|
| 85 | populations.PopulationView = Mock() |
|---|
| 86 | p = MockPopulation() |
|---|
| 87 | pv = p[range(3,9)] |
|---|
| 88 | populations.PopulationView.assert_called_with(p, range(3,9), None) |
|---|
| 89 | populations.PopulationView = orig_PV |
|---|
| 90 | |
|---|
| 91 | def test__getitem__tuple(): |
|---|
| 92 | orig_PV = populations.PopulationView |
|---|
| 93 | populations.PopulationView = Mock() |
|---|
| 94 | p = MockPopulation() |
|---|
| 95 | pv = p[(3,5,7)] |
|---|
| 96 | populations.PopulationView.assert_called_with(p, [3,5,7], None) |
|---|
| 97 | populations.PopulationView = orig_PV |
|---|
| 98 | |
|---|
| 99 | def test__getitem__invalid(): |
|---|
| 100 | p = MockPopulation() |
|---|
| 101 | assert_raises(TypeError, p.__getitem__, "foo") |
|---|
| 102 | |
|---|
| 103 | def test_len(): |
|---|
| 104 | p = MockPopulation() |
|---|
| 105 | assert_equal(len(p), MockPopulation.size) |
|---|
| 106 | |
|---|
| 107 | def test_iter(): |
|---|
| 108 | p = MockPopulation() |
|---|
| 109 | itr = p.__iter__() |
|---|
| 110 | assert hasattr(itr, "next") |
|---|
| 111 | assert_equal(len(list(itr)), 6) |
|---|
| 112 | |
|---|
| 113 | def test_is_local(): |
|---|
| 114 | p1 = MockPopulation() |
|---|
| 115 | p2 = MockPopulation() |
|---|
| 116 | id_local = MockID("curly", parent=p1) |
|---|
| 117 | id_nonlocal = MockID("larry", parent=p1) |
|---|
| 118 | assert p1.is_local(id_local) |
|---|
| 119 | assert not p1.is_local(id_nonlocal) |
|---|
| 120 | assert_raises(AssertionError, p2.is_local, id_local) |
|---|
| 121 | |
|---|
| 122 | def test_all(): |
|---|
| 123 | p = MockPopulation() |
|---|
| 124 | itr = p.all() |
|---|
| 125 | assert hasattr(itr, "next") |
|---|
| 126 | assert_equal(len(list(itr)), 13) |
|---|
| 127 | |
|---|
| 128 | def test_add(): |
|---|
| 129 | p1 = MockPopulation() |
|---|
| 130 | p2 = MockPopulation() |
|---|
| 131 | assembly = p1 + p2 |
|---|
| 132 | assert isinstance(assembly, populations.Assembly) |
|---|
| 133 | assert_equal(assembly.populations, [p1, p2]) |
|---|
| 134 | |
|---|
| 135 | def test_get_cell_position(): |
|---|
| 136 | p = MockPopulation() |
|---|
| 137 | id = MockID("larry", parent=p) |
|---|
| 138 | assert_arrays_equal(p._get_cell_position(id), numpy.array([0,1,2])) |
|---|
| 139 | id = MockID("moe", parent=p) |
|---|
| 140 | assert_arrays_equal(p._get_cell_position(id), numpy.array([6,7,8])) |
|---|
| 141 | |
|---|
| 142 | def test_set_cell_position(): |
|---|
| 143 | p = MockPopulation() |
|---|
| 144 | id = MockID("larry", parent=p) |
|---|
| 145 | p._set_cell_position(id, numpy.array([100,101,102])) |
|---|
| 146 | assert_equal(p.positions[0,0], 100) |
|---|
| 147 | assert_equal(p.positions[0,1], 3) |
|---|
| 148 | |
|---|
| 149 | def test_get_cell_initial_value(): |
|---|
| 150 | p = MockPopulation() |
|---|
| 151 | id = MockID("curly", parent=p) |
|---|
| 152 | assert_equal(p._get_cell_initial_value(id, "foo"), 98) |
|---|
| 153 | |
|---|
| 154 | def test_set_cell_initial_value(): |
|---|
| 155 | p = MockPopulation() |
|---|
| 156 | id = MockID("curly", parent=p) |
|---|
| 157 | p._set_cell_initial_value(id, "foo", -1) |
|---|
| 158 | assert_equal(p._get_cell_initial_value(id, "foo"), -1) |
|---|
| 159 | |
|---|
| 160 | def test_nearest(): |
|---|
| 161 | p = MockPopulation() |
|---|
| 162 | p.positions = numpy.arange(39).reshape((13,3)).T |
|---|
| 163 | assert_equal(p.nearest((0.0, 1.0, 2.0)), p[0]) |
|---|
| 164 | assert_equal(p.nearest((3.0, 4.0, 5.0)), p[1]) |
|---|
| 165 | assert_equal(p.nearest((36.0, 37.0, 38.0)), p[12]) |
|---|
| 166 | assert_equal(p.nearest((1.49, 2.49, 3.49)), p[0]) |
|---|
| 167 | assert_equal(p.nearest((1.51, 2.51, 3.51)), p[1]) |
|---|
| 168 | |
|---|
| 169 | def test_sample(): |
|---|
| 170 | orig_pv = populations.PopulationView |
|---|
| 171 | populations.PopulationView = Mock() |
|---|
| 172 | p = MockPopulation() |
|---|
| 173 | rng = Mock() |
|---|
| 174 | rng.permutation = Mock(return_value=numpy.array([7,4,8,12,0,3,9,1,2,11,5,10,6])) |
|---|
| 175 | pv = p.sample(5, rng=rng) |
|---|
| 176 | assert_arrays_equal(populations.PopulationView.call_args[0][1], numpy.array([7,4,8,12,0])) |
|---|
| 177 | populations.PopulationView = orig_pv |
|---|
| 178 | |
|---|
| 179 | def test_get_should_call_get_parameters(): |
|---|
| 180 | p = MockPopulation() |
|---|
| 181 | p._get_parameters = Mock(return_value={'tau_m': Mock()}) |
|---|
| 182 | p.get("tau_m") |
|---|
| 183 | p._get_parameters.assert_called_with("tau_m") |
|---|
| 184 | |
|---|
| 185 | def test_get_with_gather(): |
|---|
| 186 | np_orig = MockPopulation._simulator.state.num_processes |
|---|
| 187 | rank_orig = MockPopulation._simulator.state.mpi_rank |
|---|
| 188 | gd_orig = recording.gather_dict |
|---|
| 189 | MockPopulation._simulator.state.num_processes = 2 |
|---|
| 190 | MockPopulation._simulator.state.mpi_rank = 0 |
|---|
| 191 | def mock_gather_dict(D): # really hacky |
|---|
| 192 | assert isinstance(D[0], (list, numpy.ndarray)) |
|---|
| 193 | D[1] = [i-1 for i in D[0]] + [D[0][-1] + 1] |
|---|
| 194 | return D |
|---|
| 195 | recording.gather_dict = mock_gather_dict |
|---|
| 196 | |
|---|
| 197 | p = MockPopulation() |
|---|
| 198 | ps = Mock() |
|---|
| 199 | ps.evaluate = Mock(return_value=numpy.arange(11.0, 23.0, 2.0)) |
|---|
| 200 | p._get_parameters = Mock(return_value={'tau_m': ps}) |
|---|
| 201 | assert_arrays_equal(p.get("tau_m", gather=True), |
|---|
| 202 | numpy.arange(10.0, 23.0)) |
|---|
| 203 | MockPopulation._simulator.state.num_processes = np_orig |
|---|
| 204 | MockPopulation._simulator.state.mpi_rank = rank_orig |
|---|
| 205 | recording.gather_dict = gd_orig |
|---|
| 206 | |
|---|
| 207 | def test_set(): |
|---|
| 208 | p = MockPopulation() |
|---|
| 209 | p._set_parameters = Mock() |
|---|
| 210 | p.set(tau_m=43.21) |
|---|
| 211 | p._set_parameters.assert_called_with( |
|---|
| 212 | ParameterSpace({'tau_m': 43.21}, p.celltype.get_schema(), size=p.size)) |
|---|
| 213 | |
|---|
| 214 | def test_set_invalid_type(): |
|---|
| 215 | p = MockPopulation() |
|---|
| 216 | assert_raises(errors.InvalidParameterValueError, p.set, foo={}) |
|---|
| 217 | assert_raises(errors.InvalidParameterValueError, p.set, foo='bar') |
|---|
| 218 | |
|---|
| 219 | def test_set_with_list(): |
|---|
| 220 | p = MockPopulation() |
|---|
| 221 | p._set_parameters = Mock() |
|---|
| 222 | p.set(spike_times=range(10)) |
|---|
| 223 | p._set_parameters.assert_called_with( |
|---|
| 224 | ParameterSpace({'spike_times': range(10)}, p.celltype.get_schema(), size=p.size)) |
|---|
| 225 | |
|---|
| 226 | def test_tset_with_numeric_values(): |
|---|
| 227 | p = MockPopulation() |
|---|
| 228 | p._set_parameters = Mock() |
|---|
| 229 | tau_m = numpy.linspace(10.0, 20.0, num=p.size) |
|---|
| 230 | p.tset("tau_m", tau_m) |
|---|
| 231 | assert_arrays_equal(p._set_parameters.call_args[0][0]['tau_m'][p._mask_local], |
|---|
| 232 | tau_m[p._mask_local]) |
|---|
| 233 | |
|---|
| 234 | def test_tset_with_array_values(): |
|---|
| 235 | p = MockPopulation() |
|---|
| 236 | p._set_parameters = Mock() |
|---|
| 237 | spike_times = [Sequence(numpy.linspace(i, 100.0+i, 10)) for i in range(p.size)] |
|---|
| 238 | p.tset("spike_times", spike_times) |
|---|
| 239 | param = p._set_parameters.call_args[0][0]['spike_times'] |
|---|
| 240 | assert_equal(param.shape[0], len(spike_times)) |
|---|
| 241 | assert_arrays_equal(param[p._mask_local], |
|---|
| 242 | numpy.array(spike_times)[p._mask_local]) |
|---|
| 243 | |
|---|
| 244 | def test_tset_invalid_dimensions_2D(): |
|---|
| 245 | """Population.tset(): If the size of the valueArray does not match that of the Population, should raise an InvalidDimensionsError.""" |
|---|
| 246 | p = MockPopulation() |
|---|
| 247 | array_in = numpy.array([[0.1,0.2,0.3],[0.4,0.5,0.6]]) |
|---|
| 248 | assert_raises(errors.InvalidDimensionsError, p.tset, 'i_offset', array_in) |
|---|
| 249 | |
|---|
| 250 | def test_tset_invalid_dimensions_1D(): |
|---|
| 251 | p = MockPopulation() |
|---|
| 252 | tau_m = numpy.linspace(10.0, 20.0, num=p.size+1) |
|---|
| 253 | assert_raises(errors.InvalidDimensionsError, p.tset, "tau_m", tau_m) |
|---|
| 254 | |
|---|
| 255 | |
|---|
| 256 | class MockRandDistr(VectorizedIterable): |
|---|
| 257 | def next(self, n): |
|---|
| 258 | return numpy.arange(n) |
|---|
| 259 | |
|---|
| 260 | def test_rset(): |
|---|
| 261 | """Population.rset()""" |
|---|
| 262 | # test should assume MPI with use of mask_local |
|---|
| 263 | p = MockPopulation() |
|---|
| 264 | p._set_parameters = Mock() |
|---|
| 265 | rd = MockRandDistr() |
|---|
| 266 | rnums = numpy.arange(p.size) |
|---|
| 267 | p.rset("foo", rd) |
|---|
| 268 | call_args = p._set_parameters.call_args |
|---|
| 269 | assert_arrays_equal(call_args[0][0]['foo'].evaluate(), rnums) |
|---|
| 270 | |
|---|
| 271 | def test_rset_with_native_rng(): |
|---|
| 272 | p = MockPopulation() |
|---|
| 273 | p._native_rset = Mock() |
|---|
| 274 | rd = Mock() |
|---|
| 275 | rd.rng = random.NativeRNG() |
|---|
| 276 | p.rset('tau_m', rd) |
|---|
| 277 | p._native_rset.assert_called_with('tau_m', rd) |
|---|
| 278 | |
|---|
| 279 | def test_initialize(): |
|---|
| 280 | p = MockPopulation() |
|---|
| 281 | p.initial_values = {} |
|---|
| 282 | p._set_initial_value_array = Mock() |
|---|
| 283 | p.initialize('v', -65.0) |
|---|
| 284 | assert_equal(p.initial_values['v'].evaluate(simplify=True), -65.0) |
|---|
| 285 | p._set_initial_value_array.assert_called_with('v', -65.0) |
|---|
| 286 | |
|---|
| 287 | def test_initialize_random_distribution(): |
|---|
| 288 | p = MockPopulation() |
|---|
| 289 | p.initial_values = {} |
|---|
| 290 | p._set_initial_value_array = Mock() |
|---|
| 291 | class MockRandomDistribution(random.RandomDistribution): |
|---|
| 292 | def next(self, n, mask_local): |
|---|
| 293 | return 42*numpy.ones(n)[mask_local] |
|---|
| 294 | p.initialize('v', MockRandomDistribution()) |
|---|
| 295 | assert_arrays_equal(p.initial_values['v'].evaluate(simplify=True), 42*numpy.ones(p.local_size)) |
|---|
| 296 | #p._set_initial_value_array.assert_called_with('v', 42*numpy.ones(p.size)) |
|---|
| 297 | |
|---|
| 298 | def test_can_record(): |
|---|
| 299 | p = MockPopulation() |
|---|
| 300 | p.celltype = MockStandardCell({}) |
|---|
| 301 | assert p.can_record('v') |
|---|
| 302 | assert not p.can_record('foo') |
|---|
| 303 | |
|---|
| 304 | def test_record_with_single_variable(): |
|---|
| 305 | p = MockPopulation() |
|---|
| 306 | p.recorder = Mock() |
|---|
| 307 | p.record('v') |
|---|
| 308 | meth, args, kwargs = p.recorder.method_calls[0] |
|---|
| 309 | variables, id_arr = args |
|---|
| 310 | assert_equal(meth, 'record') |
|---|
| 311 | assert_equal(variables, 'v') |
|---|
| 312 | assert_arrays_equal(id_arr, p.all_cells) |
|---|
| 313 | |
|---|
| 314 | def test_record_with_multiple_variables(): |
|---|
| 315 | p = MockPopulation() |
|---|
| 316 | p.recorder = Mock() |
|---|
| 317 | p.record(['v', 'gsyn_exc', 'spikes']) |
|---|
| 318 | meth, args, kwargs = p.recorder.method_calls[0] |
|---|
| 319 | variables, id_arr = args |
|---|
| 320 | assert_equal(meth, 'record') |
|---|
| 321 | assert_equal(variables, ['v', 'gsyn_exc', 'spikes']) |
|---|
| 322 | assert_arrays_equal(id_arr, p.all_cells) |
|---|
| 323 | |
|---|
| 324 | def test_record_v(): |
|---|
| 325 | p = MockPopulation() |
|---|
| 326 | p.record = Mock() |
|---|
| 327 | p.record_v("arg1") |
|---|
| 328 | p.record.assert_called_with('v', "arg1") |
|---|
| 329 | |
|---|
| 330 | def test_record_gsyn(): |
|---|
| 331 | p = MockPopulation() |
|---|
| 332 | p.record = Mock() |
|---|
| 333 | p.record_gsyn("arg1") |
|---|
| 334 | p.record.assert_called_with(['gsyn_exc', 'gsyn_inh'], "arg1") |
|---|
| 335 | |
|---|
| 336 | def test_printSpikes(): |
|---|
| 337 | p = MockPopulation() |
|---|
| 338 | p.recorder = Mock() |
|---|
| 339 | p.record_filter = "filter" |
|---|
| 340 | p.printSpikes("file", "gather", "compatible_output") |
|---|
| 341 | meth, args, kwargs = p.recorder.method_calls[0] |
|---|
| 342 | assert_equal(meth, 'write') |
|---|
| 343 | assert_equal(args, ("spikes", "file", "gather", "filter")) |
|---|
| 344 | |
|---|
| 345 | def test_getSpikes(): |
|---|
| 346 | p = MockPopulation() |
|---|
| 347 | p.recorder = Mock() |
|---|
| 348 | p.record_filter = "filter" |
|---|
| 349 | p.getSpikes("gather", "compatible_output") |
|---|
| 350 | meth, args, kwargs = p.recorder.method_calls[0] |
|---|
| 351 | assert_equal(meth, 'get') |
|---|
| 352 | assert_equal(args, ("spikes", "gather", "filter", False)) |
|---|
| 353 | |
|---|
| 354 | def test_print_v(): |
|---|
| 355 | p = MockPopulation() |
|---|
| 356 | p.recorder = Mock() |
|---|
| 357 | p.record_filter = "filter" |
|---|
| 358 | p.print_v("file", "gather", "compatible_output") |
|---|
| 359 | meth, args, kwargs = p.recorder.method_calls[0] |
|---|
| 360 | assert_equal(meth, 'write') |
|---|
| 361 | assert_equal(args, ("v", "file", "gather", "filter")) |
|---|
| 362 | |
|---|
| 363 | def test_get_v(): |
|---|
| 364 | p = MockPopulation() |
|---|
| 365 | p.recorder = Mock() |
|---|
| 366 | p.record_filter = "filter" |
|---|
| 367 | p.get_v("gather", "compatible_output") |
|---|
| 368 | meth, args, kwargs = p.recorder.method_calls[0] |
|---|
| 369 | assert_equal(meth, 'get') |
|---|
| 370 | assert_equal(args, ("v", "gather", "filter", False)) |
|---|
| 371 | |
|---|
| 372 | def test_print_gsyn(): |
|---|
| 373 | p = MockPopulation() |
|---|
| 374 | p.recorder = Mock() |
|---|
| 375 | p.record_filter = "filter" |
|---|
| 376 | p.print_gsyn("file", "gather", "compatible_output") |
|---|
| 377 | meth, args, kwargs = p.recorder.method_calls[0] |
|---|
| 378 | assert_equal(meth, 'write') |
|---|
| 379 | assert_equal(args, (["gsyn_exc", "gsyn_inh"], "file", "gather", "filter")) |
|---|
| 380 | |
|---|
| 381 | def test_get_gsyn(): |
|---|
| 382 | p = MockPopulation() |
|---|
| 383 | p.recorder = Mock() |
|---|
| 384 | p.record_filter = "filter" |
|---|
| 385 | p.get_gsyn("gather", "compatible_output") |
|---|
| 386 | meth, args, kwargs = p.recorder.method_calls[0] |
|---|
| 387 | assert_equal(meth, 'get') |
|---|
| 388 | assert_equal(args, (["gsyn_exc", "gsyn_inh"], "gather", "filter", False)) |
|---|
| 389 | |
|---|
| 390 | def test_get_spike_counts(): |
|---|
| 391 | p = MockPopulation() |
|---|
| 392 | p.recorder = Mock() |
|---|
| 393 | p.get_spike_counts("gather") |
|---|
| 394 | meth, args, kwargs = p.recorder.method_calls[0] |
|---|
| 395 | assert_equal(meth, 'count') |
|---|
| 396 | assert_equal(args, ("spikes", "gather", None)) |
|---|
| 397 | |
|---|
| 398 | def test_mean_spike_count(): |
|---|
| 399 | orig_rank = MockPopulation._simulator.state.mpi_rank |
|---|
| 400 | MockPopulation._simulator.state.mpi_rank = 0 |
|---|
| 401 | p = MockPopulation() |
|---|
| 402 | p.recorder = Mock() |
|---|
| 403 | p.recorder.count = Mock(return_value={0: 2, 1: 5}) |
|---|
| 404 | assert_equal(p.mean_spike_count(), 3.5) |
|---|
| 405 | MockPopulation._simulator.state.mpi_rank = orig_rank |
|---|
| 406 | |
|---|
| 407 | def test_mean_spike_count_on_slave_node(): |
|---|
| 408 | orig_rank = MockPopulation._simulator.state.mpi_rank |
|---|
| 409 | MockPopulation._simulator.state.mpi_rank = 1 |
|---|
| 410 | p = MockPopulation() |
|---|
| 411 | p.recorder = Mock() |
|---|
| 412 | p.recorder.count = Mock(return_value={0: 2, 1: 5}) |
|---|
| 413 | assert p.mean_spike_count() is numpy.NaN |
|---|
| 414 | MockPopulation._simulator.state.mpi_rank = orig_rank |
|---|
| 415 | |
|---|
| 416 | def test_inject(): |
|---|
| 417 | p = MockPopulation() |
|---|
| 418 | cs = Mock() |
|---|
| 419 | p.inject(cs) |
|---|
| 420 | meth, args, kwargs = cs.method_calls[0] |
|---|
| 421 | assert_equal(meth, "inject_into") |
|---|
| 422 | assert_equal(args, (p,)) |
|---|
| 423 | |
|---|
| 424 | def test_inject_into_invalid_celltype(): |
|---|
| 425 | p = MockPopulation() |
|---|
| 426 | p.celltype.injectable = False |
|---|
| 427 | assert_raises(TypeError, p.inject, Mock()) |
|---|
| 428 | |
|---|
| 429 | def test_save_positions(): |
|---|
| 430 | import os |
|---|
| 431 | orig_rank = MockPopulation._simulator.state.mpi_rank |
|---|
| 432 | MockPopulation._simulator.state.mpi_rank = 0 |
|---|
| 433 | p = MockPopulation() |
|---|
| 434 | p.all_cells = numpy.array([34, 45, 56, 67]) |
|---|
| 435 | p.positions = numpy.arange(12).reshape((4,3)).T |
|---|
| 436 | output_file = Mock() |
|---|
| 437 | p.save_positions(output_file) |
|---|
| 438 | assert_arrays_equal(output_file.write.call_args[0][0], |
|---|
| 439 | numpy.array([[34, 0, 1, 2], [45, 3, 4, 5], [56, 6, 7, 8], [67, 9, 10, 11]])) |
|---|
| 440 | assert_equal(output_file.write.call_args[0][1], {'population': p.label}) |
|---|
| 441 | # arguably, the first column should contain indices, not ids. |
|---|
| 442 | MockPopulation._simulator.state.mpi_rank = orig_rank |
|---|