| 1 | # -*- coding: utf-8 -*- |
|---|
| 2 | |
|---|
| 3 | |
|---|
| 4 | """ |
|---|
| 5 | |
|---|
| 6 | """ |
|---|
| 7 | |
|---|
| 8 | |
|---|
| 9 | |
|---|
| 10 | from PyQt4.QtCore import * |
|---|
| 11 | from PyQt4.QtGui import * |
|---|
| 12 | import numpy |
|---|
| 13 | |
|---|
| 14 | from guiutil.icons import icons |
|---|
| 15 | #~ from guiutil.globalapplicationdict import * |
|---|
| 16 | from guiutil.paramwidget import ParamWidget, LimitWidget |
|---|
| 17 | from enhancedmatplotlib import * |
|---|
| 18 | import numpy |
|---|
| 19 | from numpy import inf, zeros, unique, mean, std, arange |
|---|
| 20 | |
|---|
| 21 | from ..computing.spikesorting import filtering, detection, extraction, projection, clustering |
|---|
| 22 | |
|---|
| 23 | #~ from ..classes import allclasses, Oscillation |
|---|
| 24 | #~ from ..computing.timefrequency import LineDetector, PlotLineDetector |
|---|
| 25 | from enhancedmatplotlib import SimpleCanvasAndTool |
|---|
| 26 | #~ from queryresultbox import QueryResultBox |
|---|
| 27 | |
|---|
| 28 | #~ from sqlalchemy import and_, or_ |
|---|
| 29 | |
|---|
| 30 | |
|---|
| 31 | from mpl_toolkits.mplot3d import Axes3D |
|---|
| 32 | |
|---|
| 33 | |
|---|
| 34 | |
|---|
| 35 | |
|---|
| 36 | |
|---|
| 37 | colors = [ 'c' , 'g' , 'r' , 'b' , 'k' , 'm' , 'y']*100 |
|---|
| 38 | |
|---|
| 39 | |
|---|
| 40 | class WidgetMultiMethodsParam(QFrame) : |
|---|
| 41 | """ |
|---|
| 42 | Widget for choosing a method and its parameters. |
|---|
| 43 | """ |
|---|
| 44 | def __init__(self, parent = None , |
|---|
| 45 | list_method = [ ], |
|---|
| 46 | method_name = '', |
|---|
| 47 | globalApplicationDict = None, |
|---|
| 48 | ): |
|---|
| 49 | QFrame.__init__(self, parent) |
|---|
| 50 | |
|---|
| 51 | self.list_method = list_method |
|---|
| 52 | self.method_name = method_name |
|---|
| 53 | self.globalApplicationDict = globalApplicationDict |
|---|
| 54 | |
|---|
| 55 | self.setFrameStyle(QFrame.Raised | QFrame.StyledPanel) |
|---|
| 56 | self.v1 = QVBoxLayout() |
|---|
| 57 | v1 = self.v1 |
|---|
| 58 | self.setLayout(v1) |
|---|
| 59 | |
|---|
| 60 | |
|---|
| 61 | v1.addWidget(QLabel(self.method_name)) |
|---|
| 62 | self.comboBox_method = QComboBox() |
|---|
| 63 | v1.addWidget(self.comboBox_method) |
|---|
| 64 | self.comboBox_method.addItems([ method.name for method in list_method ]) |
|---|
| 65 | |
|---|
| 66 | self.connect(self.comboBox_method,SIGNAL('currentIndexChanged( int )') , self.comboBoxChangeMethod ) |
|---|
| 67 | |
|---|
| 68 | self.paramWidget = None |
|---|
| 69 | |
|---|
| 70 | |
|---|
| 71 | self.comboBoxChangeMethod() |
|---|
| 72 | |
|---|
| 73 | def comboBoxChangeMethod(self) : |
|---|
| 74 | pos = self.comboBox_method.currentIndex() |
|---|
| 75 | if self.paramWidget is not None : |
|---|
| 76 | self.paramWidget.setVisible(False) |
|---|
| 77 | self.v1.removeWidget(self.paramWidget) |
|---|
| 78 | del self.paramWidget |
|---|
| 79 | method = self.list_method[pos] |
|---|
| 80 | self.paramWidget = ParamWidget(method.params , |
|---|
| 81 | applicationdict = self.globalApplicationDict, |
|---|
| 82 | keyformemory = 'spikesorting/%s/%s'%(self.method_name,method.name) , |
|---|
| 83 | title = method.name, |
|---|
| 84 | ) |
|---|
| 85 | self.v1.addWidget(self.paramWidget,1) |
|---|
| 86 | |
|---|
| 87 | def get_method(self) : |
|---|
| 88 | pos = self.comboBox_method.currentIndex() |
|---|
| 89 | method = self.list_method[pos]() |
|---|
| 90 | return method |
|---|
| 91 | |
|---|
| 92 | def get_dict(self) : |
|---|
| 93 | return self.paramWidget.get_dict() |
|---|
| 94 | |
|---|
| 95 | |
|---|
| 96 | |
|---|
| 97 | |
|---|
| 98 | class WidgetFiltering(QWidget): |
|---|
| 99 | """ |
|---|
| 100 | widget to plot the filtering |
|---|
| 101 | """ |
|---|
| 102 | def __init__(self , parent=None ,): |
|---|
| 103 | QWidget.__init__(self,parent ) |
|---|
| 104 | |
|---|
| 105 | self.spikeSortingWin = self.parent() |
|---|
| 106 | mainLayout = QVBoxLayout() |
|---|
| 107 | self.setLayout(mainLayout) |
|---|
| 108 | self.canvas = SimpleCanvasAndTool(orientation = Qt.Horizontal ) |
|---|
| 109 | mainLayout.addWidget(self.canvas) |
|---|
| 110 | self.fig = self.canvas.fig |
|---|
| 111 | self.ax1 = self.fig.add_subplot(2,1,1) |
|---|
| 112 | self.ax2 = self.fig.add_subplot(2,1,2, sharex = self.ax1) |
|---|
| 113 | |
|---|
| 114 | self.ax1.clear() |
|---|
| 115 | for anaSig in self.spikeSortingWin.tab.anaSigList: |
|---|
| 116 | self.ax1.plot(anaSig.t(), anaSig.signal) |
|---|
| 117 | self.canvas.draw() |
|---|
| 118 | |
|---|
| 119 | def refresh(self): |
|---|
| 120 | self.ax2.clear() |
|---|
| 121 | for anaSig in self.spikeSortingWin.tab.anaSigFilteredList: |
|---|
| 122 | self.ax2.plot(anaSig.t(), anaSig.signal) |
|---|
| 123 | self.canvas.draw() |
|---|
| 124 | |
|---|
| 125 | |
|---|
| 126 | |
|---|
| 127 | |
|---|
| 128 | class WidgetDetection(QWidget): |
|---|
| 129 | """ |
|---|
| 130 | Widget to plot the detection |
|---|
| 131 | """ |
|---|
| 132 | def __init__(self , parent=None ,): |
|---|
| 133 | QWidget.__init__(self,parent ) |
|---|
| 134 | self.spikeSortingWin = self.parent() |
|---|
| 135 | mainLayout = QVBoxLayout() |
|---|
| 136 | self.setLayout(mainLayout) |
|---|
| 137 | self.canvas = SimpleCanvasAndTool(orientation = Qt.Horizontal ) |
|---|
| 138 | self.fig = self.canvas.fig |
|---|
| 139 | mainLayout.addWidget(self.canvas) |
|---|
| 140 | |
|---|
| 141 | self.axs = None |
|---|
| 142 | self.lines = [ ] |
|---|
| 143 | |
|---|
| 144 | def plotSigs(self): |
|---|
| 145 | n = len(self.spikeSortingWin.tab.anaSigFilteredList) |
|---|
| 146 | self.axs = [ ] |
|---|
| 147 | ax = None |
|---|
| 148 | for i , anaSig in enumerate(self.spikeSortingWin.tab.anaSigFilteredList): |
|---|
| 149 | ax = self.fig.add_subplot(n, 1,i+1 , sharex = ax, sharey = ax) |
|---|
| 150 | self.axs.append(ax) |
|---|
| 151 | ax.plot(anaSig.t(), anaSig.signal , color = 'b') |
|---|
| 152 | self.canvas.draw() |
|---|
| 153 | |
|---|
| 154 | def refresh(self): |
|---|
| 155 | if self.axs is None: |
|---|
| 156 | self.plotSigs() |
|---|
| 157 | |
|---|
| 158 | sorted = self.spikeSortingWin.tab.sorted |
|---|
| 159 | |
|---|
| 160 | #remove old detection |
|---|
| 161 | for i in range(len(self.lines)): |
|---|
| 162 | for l in self.lines[i]: |
|---|
| 163 | self.axs[i].lines.remove(l) |
|---|
| 164 | self.lines = [ ] |
|---|
| 165 | |
|---|
| 166 | for c in unique(sorted): |
|---|
| 167 | |
|---|
| 168 | sp = self.spikeSortingWin.tab.spikePosistion[ c==sorted ] |
|---|
| 169 | for i , anaSig in enumerate(self.spikeSortingWin.tab.anaSigFilteredList): |
|---|
| 170 | l = self.axs[i].plot( anaSig.t()[sp] , anaSig.signal[sp], linestyle = 'None', marker = 'o', color = colors[c]) |
|---|
| 171 | self.lines.append( l ) |
|---|
| 172 | |
|---|
| 173 | self.canvas.draw() |
|---|
| 174 | |
|---|
| 175 | |
|---|
| 176 | class WidgetExtraction(QWidget): |
|---|
| 177 | def __init__(self , parent=None ,): |
|---|
| 178 | QWidget.__init__(self,parent ) |
|---|
| 179 | self.spikeSortingWin = self.parent() |
|---|
| 180 | mainLayout = QHBoxLayout() |
|---|
| 181 | self.setLayout(mainLayout) |
|---|
| 182 | self.canvas = SimpleCanvasAndTool(orientation = Qt.Horizontal ) |
|---|
| 183 | self.fig = self.canvas.fig |
|---|
| 184 | mainLayout.addWidget(self.canvas) |
|---|
| 185 | |
|---|
| 186 | n = len(self.spikeSortingWin.tab.anaSigList) |
|---|
| 187 | self.ax_moy = [ ] |
|---|
| 188 | ax = None |
|---|
| 189 | for i in range(n): |
|---|
| 190 | ax = self.fig.add_subplot(2,n, i+1, sharex = ax, sharey = ax) |
|---|
| 191 | self.ax_moy.append(ax) |
|---|
| 192 | |
|---|
| 193 | self.ax_all = [ ] |
|---|
| 194 | ax = None |
|---|
| 195 | for i in range(n): |
|---|
| 196 | ax = self.fig.add_subplot(2,n, n+i+1, sharex = ax, sharey = ax) |
|---|
| 197 | self.ax_all.append(ax) |
|---|
| 198 | |
|---|
| 199 | def refresh(self): |
|---|
| 200 | sorted = self.spikeSortingWin.tab.sorted |
|---|
| 201 | n = len(self.spikeSortingWin.tab.anaSigList) |
|---|
| 202 | waveforms = self.spikeSortingWin.tab.waveforms |
|---|
| 203 | for i in range(n): |
|---|
| 204 | ax = self.ax_all[i] |
|---|
| 205 | ax.clear() |
|---|
| 206 | for c in unique(sorted): |
|---|
| 207 | ax.plot( waveforms[sorted ==c, i, :].transpose(), color = colors[c]) |
|---|
| 208 | |
|---|
| 209 | ax = self.ax_moy[i] |
|---|
| 210 | ax.clear() |
|---|
| 211 | for c in unique(sorted): |
|---|
| 212 | ind = c==sorted |
|---|
| 213 | m = mean(waveforms[ind,i,:], axis = 0) |
|---|
| 214 | sd = std(waveforms[ind,i,:], axis = 0) |
|---|
| 215 | ax.plot( m, color = colors[ c ] , linewidth=2) |
|---|
| 216 | ax.fill_between(arange(m.size), m-sd, m+sd , color = colors[ c ] , alpha = .3) |
|---|
| 217 | |
|---|
| 218 | self.canvas.draw() |
|---|
| 219 | |
|---|
| 220 | |
|---|
| 221 | |
|---|
| 222 | class Widget3DViewer(QWidget): |
|---|
| 223 | def __init__(self , parent=None ,): |
|---|
| 224 | QWidget.__init__(self,parent ) |
|---|
| 225 | self.spikeSortingWin = self.parent() |
|---|
| 226 | mainLayout = QVBoxLayout() |
|---|
| 227 | self.setLayout(mainLayout) |
|---|
| 228 | |
|---|
| 229 | h = QHBoxLayout() |
|---|
| 230 | mainLayout.addLayout(h) |
|---|
| 231 | h.addWidget(QLabel('Choose dim')) |
|---|
| 232 | self.combos = [ ] |
|---|
| 233 | for i in range(3): |
|---|
| 234 | cb = QComboBox() |
|---|
| 235 | self.combos.append(cb) |
|---|
| 236 | self.connect(cb, SIGNAL('activated(int)'),self.change_dim ) |
|---|
| 237 | h.addWidget(cb) |
|---|
| 238 | |
|---|
| 239 | but = QPushButton(QIcon(':/view-refresh.png'), 'refresh') |
|---|
| 240 | h.addWidget(but) |
|---|
| 241 | self.connect(but, SIGNAL('clicked()'), self.change_dim) |
|---|
| 242 | |
|---|
| 243 | self.canvas1 = SimpleCanvas() |
|---|
| 244 | #~ self.canvas1 = SimpleCanvasAndTool() |
|---|
| 245 | self.ax = Axes3D(self.canvas1.fig) |
|---|
| 246 | mainLayout.addWidget( self.canvas1 ) |
|---|
| 247 | |
|---|
| 248 | self.projected = None |
|---|
| 249 | self.sorted = None |
|---|
| 250 | |
|---|
| 251 | def change_dim(self, index = None): |
|---|
| 252 | if self.projected is None : return |
|---|
| 253 | self.ax.clear() |
|---|
| 254 | vects = [ ] |
|---|
| 255 | for i in range(3): |
|---|
| 256 | ind = self.combos[i].currentIndex() |
|---|
| 257 | vects.append( self.projected[:,ind] ) |
|---|
| 258 | |
|---|
| 259 | for c in unique(self.sorted): |
|---|
| 260 | ind = self.sorted==c |
|---|
| 261 | self.ax.scatter(vects[0][ind], vects[1][ind], vects[2][ind], color = colors[c]) |
|---|
| 262 | self.canvas1.draw() |
|---|
| 263 | |
|---|
| 264 | def refresh(self, projected, sorted): |
|---|
| 265 | ndim = projected.shape[1] |
|---|
| 266 | for i in range(3): |
|---|
| 267 | self.combos[i].clear() |
|---|
| 268 | self.combos[i].addItems( [ str(n) for n in range(ndim) ] ) |
|---|
| 269 | if i<ndim: |
|---|
| 270 | self.combos[i].setCurrentIndex(i) |
|---|
| 271 | |
|---|
| 272 | self.projected = projected |
|---|
| 273 | self.sorted = sorted |
|---|
| 274 | |
|---|
| 275 | self.change_dim() |
|---|
| 276 | |
|---|
| 277 | |
|---|
| 278 | |
|---|
| 279 | |
|---|
| 280 | class WidgetProjection(QWidget): |
|---|
| 281 | def __init__(self , parent=None ,): |
|---|
| 282 | QWidget.__init__(self,parent ) |
|---|
| 283 | |
|---|
| 284 | self.spikeSortingWin = self.parent() |
|---|
| 285 | mainLayout = QVBoxLayout() |
|---|
| 286 | self.setLayout(mainLayout) |
|---|
| 287 | |
|---|
| 288 | h = QHBoxLayout() |
|---|
| 289 | mainLayout.addLayout(h) |
|---|
| 290 | h.addWidget(QLabel('Choose a view for projection')) |
|---|
| 291 | self.comboView = QComboBox() |
|---|
| 292 | h.addWidget(self.comboView) |
|---|
| 293 | self.stacked = QStackedWidget() |
|---|
| 294 | mainLayout.addWidget(self.stacked) |
|---|
| 295 | self.connect(self.comboView, SIGNAL('activated(int)'),self.stacked, SLOT('setCurrentIndex(int)') ) |
|---|
| 296 | |
|---|
| 297 | # flatened 1D view |
|---|
| 298 | self.comboView.addItem('flatened 1D view') |
|---|
| 299 | self.canvas1 = SimpleCanvasAndTool() |
|---|
| 300 | self.stacked.addWidget(self.canvas1) |
|---|
| 301 | self.ax1 = self.canvas1.fig.add_subplot(1,1,1) |
|---|
| 302 | |
|---|
| 303 | # combinated 2D |
|---|
| 304 | self.comboView.addItem('combinated 2D') |
|---|
| 305 | self.canvas2 = SimpleCanvasAndTool() |
|---|
| 306 | self.stacked.addWidget(self.canvas2) |
|---|
| 307 | |
|---|
| 308 | # 3D viewer |
|---|
| 309 | self.comboView.addItem('3D viewer') |
|---|
| 310 | self.widget3Dviewer = Widget3DViewer() |
|---|
| 311 | self.stacked.addWidget(self.widget3Dviewer) |
|---|
| 312 | |
|---|
| 313 | |
|---|
| 314 | |
|---|
| 315 | def refresh(self): |
|---|
| 316 | sorted = self.spikeSortingWin.tab.sorted |
|---|
| 317 | waveforms = self.spikeSortingWin.tab.waveforms |
|---|
| 318 | projected = self.spikeSortingWin.tab.projected |
|---|
| 319 | ndim = projected.shape[1] |
|---|
| 320 | |
|---|
| 321 | # flatened 1D view |
|---|
| 322 | self.ax1.clear() |
|---|
| 323 | for c in unique(sorted): |
|---|
| 324 | ind = c==sorted |
|---|
| 325 | self.ax1.plot( projected[ind,:].transpose() , color = colors[c], marker = '.') |
|---|
| 326 | self.canvas1.draw() |
|---|
| 327 | |
|---|
| 328 | # combinated 2D |
|---|
| 329 | ndim2 = min(ndim, 16) |
|---|
| 330 | print 'yep' |
|---|
| 331 | self.canvas2.fig.clear() |
|---|
| 332 | if projected.shape[1]>1: |
|---|
| 333 | for c in unique(sorted): |
|---|
| 334 | ind = c==sorted |
|---|
| 335 | |
|---|
| 336 | |
|---|
| 337 | for i in range(ndim2): |
|---|
| 338 | for j in range(i+1, ndim2): |
|---|
| 339 | p = (j-1)*(ndim2-1)+i+1 |
|---|
| 340 | ax = self.canvas2.fig.add_subplot(ndim2-1, ndim2-1, p) |
|---|
| 341 | ax.plot(projected[ind,i], projected[ind,j], color = colors[c], marker = '.', linestyle = 'None') |
|---|
| 342 | #ax.set_title('%d %d'%(i,j)) |
|---|
| 343 | if i==0: |
|---|
| 344 | ax.set_ylabel( str(j) ) |
|---|
| 345 | if j==ndim-1: |
|---|
| 346 | ax.set_xlabel( str(i) ) |
|---|
| 347 | ax.set_xticks([ ]) |
|---|
| 348 | ax.set_yticks([ ]) |
|---|
| 349 | self.canvas2.draw() |
|---|
| 350 | |
|---|
| 351 | |
|---|
| 352 | # 3D viewer |
|---|
| 353 | self.widget3Dviewer.refresh( projected, sorted) |
|---|
| 354 | |
|---|
| 355 | |
|---|
| 356 | |
|---|
| 357 | |
|---|
| 358 | |
|---|
| 359 | |
|---|
| 360 | |
|---|
| 361 | class WidgetClustering(QWidget): |
|---|
| 362 | def __init__(self , parent=None ,): |
|---|
| 363 | QWidget.__init__(self,parent ) |
|---|
| 364 | def refresh(self): |
|---|
| 365 | pass |
|---|
| 366 | |
|---|
| 367 | |
|---|
| 368 | |
|---|
| 369 | |
|---|
| 370 | steps = [ |
|---|
| 371 | ['Filtering' , filtering, WidgetFiltering], |
|---|
| 372 | ['Detection' , detection, WidgetDetection], |
|---|
| 373 | ['Extraction' , extraction, WidgetExtraction], |
|---|
| 374 | ['Projection' , projection, WidgetProjection], |
|---|
| 375 | ['Clustering' , clustering, WidgetClustering], |
|---|
| 376 | ] |
|---|
| 377 | |
|---|
| 378 | |
|---|
| 379 | class TabSpikeSorting(QTabWidget) : |
|---|
| 380 | """ |
|---|
| 381 | Widget displaying all tabs and methods options. |
|---|
| 382 | Used in : |
|---|
| 383 | - |
|---|
| 384 | - |
|---|
| 385 | |
|---|
| 386 | |
|---|
| 387 | """ |
|---|
| 388 | |
|---|
| 389 | def __init__(self , parent=None , |
|---|
| 390 | metadata =None, |
|---|
| 391 | session = None, |
|---|
| 392 | globalApplicationDict = None, |
|---|
| 393 | |
|---|
| 394 | # possibilitty 1 |
|---|
| 395 | anaSigList = None, |
|---|
| 396 | |
|---|
| 397 | |
|---|
| 398 | ): |
|---|
| 399 | QTabWidget.__init__(self,parent ) |
|---|
| 400 | self.setTabPosition(QTabWidget.West) |
|---|
| 401 | |
|---|
| 402 | #~ self.setAttribute(Qt.WA_DeleteOnClose) |
|---|
| 403 | |
|---|
| 404 | self.metadata = metadata |
|---|
| 405 | self.session = session |
|---|
| 406 | self.globalApplicationDict = globalApplicationDict |
|---|
| 407 | |
|---|
| 408 | |
|---|
| 409 | # construct all tabs |
|---|
| 410 | |
|---|
| 411 | |
|---|
| 412 | self.hboxes = { } |
|---|
| 413 | self.vboxes = { } |
|---|
| 414 | self.widgetMultimethods = { } |
|---|
| 415 | |
|---|
| 416 | for name, module, plotWidget in steps: |
|---|
| 417 | w = QWidget() |
|---|
| 418 | self.addTab(w,name) |
|---|
| 419 | h= QHBoxLayout() |
|---|
| 420 | self.hboxes[name] = h |
|---|
| 421 | w.setLayout(h) |
|---|
| 422 | |
|---|
| 423 | v = QVBoxLayout( ) |
|---|
| 424 | h.addLayout( v ) |
|---|
| 425 | self.vboxes[name] = v |
|---|
| 426 | wMeth = WidgetMultiMethodsParam( list_method = module.list_method, |
|---|
| 427 | method_name = 'Choose methd for %s:'%name, |
|---|
| 428 | globalApplicationDict = self.globalApplicationDict, |
|---|
| 429 | ) |
|---|
| 430 | self.widgetMultimethods[name] = wMeth |
|---|
| 431 | v.addWidget(wMeth) |
|---|
| 432 | v.addStretch(0) |
|---|
| 433 | |
|---|
| 434 | |
|---|
| 435 | # tab for database options |
|---|
| 436 | w = QWidget() |
|---|
| 437 | self.addTab(w,'Database option') |
|---|
| 438 | h= QHBoxLayout() |
|---|
| 439 | w.setLayout(h) |
|---|
| 440 | v = QVBoxLayout( ) |
|---|
| 441 | h.addLayout( v ) |
|---|
| 442 | |
|---|
| 443 | params = [ |
|---|
| 444 | ( 'save_filtered_waveform' , {'value' : True , 'label' : 'Save filterered waveform' }), |
|---|
| 445 | ] |
|---|
| 446 | self.databaseOptions = ParamWidget(params, |
|---|
| 447 | applicationdict = self.globalApplicationDict, |
|---|
| 448 | keyformemory = 'spikesorting/databaseoptions' , |
|---|
| 449 | title = 'database options', |
|---|
| 450 | ) |
|---|
| 451 | v.addWidget( self.databaseOptions ) |
|---|
| 452 | v.addStretch(0) |
|---|
| 453 | |
|---|
| 454 | |
|---|
| 455 | # variables |
|---|
| 456 | self.anaSigList = None |
|---|
| 457 | self.anaSigFilteredList = None |
|---|
| 458 | self.spikePosistion = None |
|---|
| 459 | self.spikeSign = None |
|---|
| 460 | self.left_sweep = None |
|---|
| 461 | self.right_sweep = None |
|---|
| 462 | self.waveforms = None |
|---|
| 463 | self.projected = None |
|---|
| 464 | self.sorted = None |
|---|
| 465 | |
|---|
| 466 | |
|---|
| 467 | # FIXME : |
|---|
| 468 | self.anaSigList = anaSigList |
|---|
| 469 | |
|---|
| 470 | |
|---|
| 471 | #~ def load_signal(self) : |
|---|
| 472 | #~ if self.id_electrode is not None : |
|---|
| 473 | #~ # mode one electrode |
|---|
| 474 | #~ self.elec = Electrode() |
|---|
| 475 | #~ self.elec.load_from_db(self.id_electrode) |
|---|
| 476 | #~ self.sig = self.elec.signal |
|---|
| 477 | #~ self.fs = self.elec.fs |
|---|
| 478 | #~ self.list_elec = None |
|---|
| 479 | #~ else : |
|---|
| 480 | #~ # mode all electrode on same serie |
|---|
| 481 | #~ query = """ |
|---|
| 482 | #~ SELECT id_electrode |
|---|
| 483 | #~ FROM electrode , trial |
|---|
| 484 | #~ WHERE |
|---|
| 485 | #~ electrode.id_trial = trial.id_trial |
|---|
| 486 | #~ AND trial.id_serie = %s |
|---|
| 487 | #~ AND num_channel = %s |
|---|
| 488 | #~ ORDER BY trial.thedatetime |
|---|
| 489 | #~ """ |
|---|
| 490 | #~ self.list_elec = [ ] |
|---|
| 491 | #~ id_electrodes, = sql(query , (self.id_serie , self.num_channel)) |
|---|
| 492 | #~ self.sig = array([]) |
|---|
| 493 | #~ for id_electrode in id_electrodes : |
|---|
| 494 | #~ elec = Electrode() |
|---|
| 495 | #~ elec.load_from_db(id_electrode) |
|---|
| 496 | #~ self.list_elec.append(elec) |
|---|
| 497 | #~ self.sig = concatenate((self.sig , elec.signal)) |
|---|
| 498 | #~ self.fs = elec.fs |
|---|
| 499 | |
|---|
| 500 | #~ self.t = arange(self.sig.size)/self.fs |
|---|
| 501 | #~ self.pos_spike = [ ] |
|---|
| 502 | #~ self.sig_f = [ ] |
|---|
| 503 | #~ self.waveform = [ ] |
|---|
| 504 | #~ self.waveform_projected = [ ] |
|---|
| 505 | #~ self.cluster = [ ] |
|---|
| 506 | #~ self.waveform_size = None |
|---|
| 507 | #~ self.oversampling = None |
|---|
| 508 | |
|---|
| 509 | |
|---|
| 510 | |
|---|
| 511 | def computeFiltering(self) : |
|---|
| 512 | m = self.widgetMultimethods['Filtering'].get_method() |
|---|
| 513 | kargs = self.widgetMultimethods['Filtering'].get_dict() |
|---|
| 514 | |
|---|
| 515 | self.anaSigFilteredList = [ ] |
|---|
| 516 | for i in range(len( self.anaSigList )): |
|---|
| 517 | self.anaSigFilteredList.append( m.compute( self.anaSigList[i] , **kargs) ) |
|---|
| 518 | |
|---|
| 519 | |
|---|
| 520 | |
|---|
| 521 | def computeDetection(self) : |
|---|
| 522 | m = self.widgetMultimethods['Detection'].get_method() |
|---|
| 523 | kargs = self.widgetMultimethods['Detection'].get_dict() |
|---|
| 524 | |
|---|
| 525 | self.spikeSign = kargs['sign'] |
|---|
| 526 | self.left_sweep = kargs['left_sweep'] |
|---|
| 527 | self.right_sweep = kargs['right_sweep'] |
|---|
| 528 | self.spikePosistion = m.compute(self.anaSigFilteredList, **kargs) |
|---|
| 529 | |
|---|
| 530 | self.sorted = zeros(self.spikePosistion.size, dtype = 'i') |
|---|
| 531 | |
|---|
| 532 | |
|---|
| 533 | def computeExtraction(self) : |
|---|
| 534 | m = self.widgetMultimethods['Extraction'].get_method() |
|---|
| 535 | kargs = self.widgetMultimethods['Extraction'].get_dict() |
|---|
| 536 | |
|---|
| 537 | self.waveforms = m.compute(self.anaSigFilteredList, self.spikePosistion,self.spikeSign, left_sweep = self.left_sweep , right_sweep = self.right_sweep) |
|---|
| 538 | |
|---|
| 539 | def computeProjection(self) : |
|---|
| 540 | m = self.widgetMultimethods['Projection'].get_method() |
|---|
| 541 | kargs = self.widgetMultimethods['Projection'].get_dict() |
|---|
| 542 | |
|---|
| 543 | self.projected = m.compute( self.waveforms, self.anaSigFilteredList[0].sampling_rate, **kargs) |
|---|
| 544 | |
|---|
| 545 | |
|---|
| 546 | def computeClustering(self) : |
|---|
| 547 | m = self.widgetMultimethods['Clustering'].get_method() |
|---|
| 548 | kargs = self.widgetMultimethods['Clustering'].get_dict() |
|---|
| 549 | |
|---|
| 550 | self.sorted = m.compute( self.projected , self.spikePosistion , **kargs ) |
|---|
| 551 | |
|---|
| 552 | |
|---|
| 553 | |
|---|
| 554 | def recomputeAllSteps(self) : |
|---|
| 555 | self.computeFiltering() |
|---|
| 556 | self.computeDetection() |
|---|
| 557 | self.computeExtraction() |
|---|
| 558 | self.computeProjection() |
|---|
| 559 | self.computeClustering() |
|---|
| 560 | |
|---|
| 561 | |
|---|
| 562 | #~ def save_to_db(self) : |
|---|
| 563 | #~ n_cluster = unique(self.cluster).size |
|---|
| 564 | #~ waveform_size = self.param_database.get_one_param('waveform_size') |
|---|
| 565 | #~ oversampling = self.param_database.get_one_param('oversampling') |
|---|
| 566 | |
|---|
| 567 | #~ if self.id_electrode is not None : |
|---|
| 568 | #~ # mode one electrode |
|---|
| 569 | |
|---|
| 570 | #~ # delete old spiketrain and spike in database |
|---|
| 571 | #~ id_spiketrains, = sql('SELECT id_spiketrain FROM spiketrain WHERE id_electrode = %s' , self.id_electrode) |
|---|
| 572 | #~ for id_spiketrain in id_spiketrains : |
|---|
| 573 | #~ sptr = SpikeTrain() |
|---|
| 574 | #~ sptr.id_spiketrain = id_spiketrain |
|---|
| 575 | #~ sptr.id_principal = id_spiketrain |
|---|
| 576 | #~ sptr.delete_from_db_and_child(dict_hierarchic_class ) |
|---|
| 577 | |
|---|
| 578 | #~ #create new ones |
|---|
| 579 | #~ for n,cl in enumerate(unique(self.cluster)) : |
|---|
| 580 | |
|---|
| 581 | #~ sptr = SpikeTrain() |
|---|
| 582 | #~ sptr.id_trial = self.elec.id_trial |
|---|
| 583 | #~ sptr.id_electrode = self.elec.id_electrode |
|---|
| 584 | #~ sptr.id_cell = None |
|---|
| 585 | #~ sptr.fs = self.elec.fs |
|---|
| 586 | #~ sptr.shift_t0 = self.elec.shift_t0 |
|---|
| 587 | #~ sptr.oversampling = oversampling |
|---|
| 588 | #~ sptr.f_low = None |
|---|
| 589 | #~ sptr.f_hight = None |
|---|
| 590 | #~ sptr.label = u'' |
|---|
| 591 | #~ sptr.coment = u'' |
|---|
| 592 | #~ id_spiketrain = sptr.save_to_db() |
|---|
| 593 | |
|---|
| 594 | #~ pos = self.pos_spike[self.cluster== cl] |
|---|
| 595 | #~ isi = r_[diff(pos)/float(self.fs) , Inf] |
|---|
| 596 | #~ if self.param_database.get_one_param('save_filtered_waveform') : |
|---|
| 597 | #~ fil = self.multiMethod_filtering.get_method() |
|---|
| 598 | #~ karg = self.multiMethod_filtering.get_dict() |
|---|
| 599 | #~ sig_f = fil.compute(self.sig , self.fs , **karg) |
|---|
| 600 | #~ else : |
|---|
| 601 | #~ sig_f = self.sig |
|---|
| 602 | #~ waveform = waveform_extraction(pos,sig_f, self.fs , waveform_size,oversampling) |
|---|
| 603 | #~ for s in range(len(pos)) : |
|---|
| 604 | #~ sp = Spike() |
|---|
| 605 | #~ sp.id_spiketrain = id_spiketrain |
|---|
| 606 | #~ sp.id_electrode = self.elec.id_electrode |
|---|
| 607 | #~ sp.pos = pos[s] |
|---|
| 608 | #~ sp.val_max = sig_f[pos[s]] |
|---|
| 609 | #~ sp.waveform = squeeze(waveform[s,:]) |
|---|
| 610 | #~ sp.isi = isi[s] |
|---|
| 611 | #~ sp.save_to_db() |
|---|
| 612 | |
|---|
| 613 | #~ else: |
|---|
| 614 | #~ # mode all electrode on same serie |
|---|
| 615 | |
|---|
| 616 | #~ # delete old spiketrain and spike in database |
|---|
| 617 | #~ query = """ |
|---|
| 618 | #~ SELECT spiketrain.id_spiketrain |
|---|
| 619 | #~ FROM spiketrain , electrode , trial |
|---|
| 620 | #~ WHERE |
|---|
| 621 | #~ trial.id_trial = electrode.id_trial |
|---|
| 622 | #~ AND electrode.id_electrode = spiketrain.id_electrode |
|---|
| 623 | #~ AND trial.id_serie = %s |
|---|
| 624 | #~ AND electrode.num_channel = %s |
|---|
| 625 | #~ """ |
|---|
| 626 | #~ id_spiketrains, = sql(query , (self.id_serie , self.num_channel )) |
|---|
| 627 | #~ for id_spiketrain in id_spiketrains : |
|---|
| 628 | #~ sptr = SpikeTrain() |
|---|
| 629 | #~ sptr.id_spiketrain = id_spiketrain |
|---|
| 630 | #~ sptr.id_principal = id_spiketrain |
|---|
| 631 | #~ sptr.delete_from_db_and_child(dict_hierarchic_class ) |
|---|
| 632 | |
|---|
| 633 | #~ #create new cells, spiketrain et spike |
|---|
| 634 | #~ for n,cl in enumerate(unique(self.cluster)) : |
|---|
| 635 | #~ cell = Cell() |
|---|
| 636 | #~ cell.id_serie = self.id_serie |
|---|
| 637 | #~ cell.info = u'' |
|---|
| 638 | #~ cell.name = u'Cell %s NumChannel %s' %( n+1 , self.num_channel) |
|---|
| 639 | #~ id_cell = cell.save_to_db() |
|---|
| 640 | |
|---|
| 641 | #~ start = 0 |
|---|
| 642 | #~ for e,elec in enumerate(self.list_elec): |
|---|
| 643 | #~ sptr = SpikeTrain() |
|---|
| 644 | #~ sptr.id_trial = elec.id_trial |
|---|
| 645 | #~ sptr.id_electrode = elec.id_electrode |
|---|
| 646 | #~ sptr.id_cell = id_cell |
|---|
| 647 | #~ sptr.fs = elec.fs |
|---|
| 648 | #~ sptr.shift_t0 = elec.shift_t0 |
|---|
| 649 | #~ sptr.oversampling = oversampling |
|---|
| 650 | #~ sptr.f_low = None |
|---|
| 651 | #~ sptr.f_hight = None |
|---|
| 652 | #~ sptr.label = u'' |
|---|
| 653 | #~ sptr.coment = u'' |
|---|
| 654 | #~ id_spiketrain = sptr.save_to_db() |
|---|
| 655 | |
|---|
| 656 | #~ pos = self.pos_spike[self.cluster== cl] |
|---|
| 657 | #~ pos = pos[ (pos>= start) & (pos<start + elec.signal.size)] |
|---|
| 658 | #~ pos = pos - start |
|---|
| 659 | #~ isi = r_[diff(pos)/float(elec.fs) , Inf] |
|---|
| 660 | #~ if self.param_database.get_one_param('save_filtered_waveform') : |
|---|
| 661 | #~ fil = self.multiMethod_filtering.get_method() |
|---|
| 662 | #~ karg = self.multiMethod_filtering.get_dict() |
|---|
| 663 | #~ sig_f = fil.compute(elec.signal , elec.fs , **karg) |
|---|
| 664 | #~ else : |
|---|
| 665 | #~ sig_f = elec.signal |
|---|
| 666 | #~ waveform = waveform_extraction(pos,sig_f, self.fs , waveform_size,oversampling) |
|---|
| 667 | #~ for s in range(len(pos)) : |
|---|
| 668 | #~ sp = Spike() |
|---|
| 669 | #~ sp.id_spiketrain = id_spiketrain |
|---|
| 670 | #~ sp.id_electrode = elec.id_electrode |
|---|
| 671 | #~ sp.pos = pos[s] |
|---|
| 672 | #~ sp.val_max = sig_f[pos[s]] |
|---|
| 673 | #~ sp.waveform = squeeze(waveform[s,:]) |
|---|
| 674 | #~ sp.isi = isi[s] |
|---|
| 675 | #~ sp.save_to_db() |
|---|
| 676 | |
|---|
| 677 | #~ start += elec.signal.size |
|---|
| 678 | |
|---|
| 679 | |
|---|
| 680 | #~ def reload_from_db(self) : |
|---|
| 681 | #~ if self.id_electrode is not None : |
|---|
| 682 | #~ # mode one electrode |
|---|
| 683 | #~ id_spiketrains, = sql('SELECT id_spiketrain FROM spiketrain WHERE id_electrode = %s' , self.id_electrode) |
|---|
| 684 | #~ self.pos_spike = array([ ],dtype='i') |
|---|
| 685 | #~ self.cluster = array([ ],dtype='i') |
|---|
| 686 | #~ for i,id_spiketrain in enumerate(id_spiketrains) : |
|---|
| 687 | #~ sptr = SpikeTrain() |
|---|
| 688 | #~ sptr.load_from_db(id_spiketrain) |
|---|
| 689 | #~ pos = sptr.pos_spike() |
|---|
| 690 | #~ self.pos_spike = concatenate((self.pos_spike , pos)) |
|---|
| 691 | #~ self.cluster = concatenate((self.cluster , i*ones((len(pos)) , dtype = 'i') )) |
|---|
| 692 | #~ else: |
|---|
| 693 | #~ # mode all electrode on same serie |
|---|
| 694 | #~ query = """ |
|---|
| 695 | #~ SELECT spiketrain.id_spiketrain , cell.id_cell , electrode.id_electrode |
|---|
| 696 | #~ FROM spiketrain , electrode , trial , cell |
|---|
| 697 | #~ WHERE |
|---|
| 698 | #~ trial.id_trial = electrode.id_trial |
|---|
| 699 | #~ AND electrode.id_electrode = spiketrain.id_electrode |
|---|
| 700 | #~ AND cell.id_cell = spiketrain.id_cell |
|---|
| 701 | #~ AND trial.id_serie = %s |
|---|
| 702 | #~ AND electrode.num_channel = %s |
|---|
| 703 | #~ ORDER BY cell.id_cell |
|---|
| 704 | #~ """ |
|---|
| 705 | #~ id_spiketrains,id_cells,id_electrodes = sql(query , (self.id_serie , self.num_channel )) |
|---|
| 706 | #~ self.pos_spike = array([ ],dtype='i') |
|---|
| 707 | #~ self.cluster = array([ ],dtype='i') |
|---|
| 708 | #~ n_cluster = unique(id_cells).size |
|---|
| 709 | |
|---|
| 710 | #~ for i,id_spiketrain in enumerate(id_spiketrains) : |
|---|
| 711 | #~ id_cell,id_electrode = id_cells[i],id_electrodes[i] |
|---|
| 712 | #~ start = 0 |
|---|
| 713 | #~ for e,elec in enumerate(self.list_elec): |
|---|
| 714 | #~ if elec.id_electrode == id_electrode : break |
|---|
| 715 | #~ start += elec.signal.size |
|---|
| 716 | |
|---|
| 717 | #~ sptr = SpikeTrain() |
|---|
| 718 | #~ sptr.load_from_db(id_spiketrain) |
|---|
| 719 | #~ pos = sptr.pos_spike()+start |
|---|
| 720 | #~ self.pos_spike = concatenate((self.pos_spike , pos)) |
|---|
| 721 | |
|---|
| 722 | #~ cluster = where(id_cell == unique(id_cells))[0] |
|---|
| 723 | #~ self.cluster = concatenate((self.cluster , cluster*ones((len(pos)) , dtype = 'i') )) |
|---|
| 724 | |
|---|
| 725 | |
|---|
| 726 | |
|---|
| 727 | |
|---|
| 728 | |
|---|
| 729 | |
|---|
| 730 | |
|---|
| 731 | |
|---|
| 732 | |
|---|
| 733 | |
|---|
| 734 | class SpikeSorting(QDialog) : |
|---|
| 735 | """ |
|---|
| 736 | Scroll area resazible for stacking matplotlib canvas |
|---|
| 737 | |
|---|
| 738 | several modes : |
|---|
| 739 | - spikedetection/spikesorting on recording point (and its group) |
|---|
| 740 | - spikesorting on a list spiketrain |
|---|
| 741 | - |
|---|
| 742 | |
|---|
| 743 | |
|---|
| 744 | |
|---|
| 745 | """ |
|---|
| 746 | def __init__(self , parent = None , |
|---|
| 747 | metadata =None, |
|---|
| 748 | session = None, |
|---|
| 749 | globalApplicationDict = None, |
|---|
| 750 | |
|---|
| 751 | anaSigList = None, |
|---|
| 752 | |
|---|
| 753 | ): |
|---|
| 754 | QDialog.__init__(self, parent) |
|---|
| 755 | self.metadata = metadata |
|---|
| 756 | self.session = session |
|---|
| 757 | self.globalApplicationDict = globalApplicationDict |
|---|
| 758 | |
|---|
| 759 | mainLayout = QVBoxLayout() |
|---|
| 760 | self.setLayout(mainLayout) |
|---|
| 761 | |
|---|
| 762 | self.tab = TabSpikeSorting(metadata = self.metadata, |
|---|
| 763 | session = self.session, |
|---|
| 764 | globalApplicationDict= self.globalApplicationDict, |
|---|
| 765 | |
|---|
| 766 | anaSigList = anaSigList, |
|---|
| 767 | |
|---|
| 768 | ) |
|---|
| 769 | |
|---|
| 770 | mainLayout.addWidget(self.tab) |
|---|
| 771 | |
|---|
| 772 | self.plotWidget = { } |
|---|
| 773 | for name, module, plotWidget in steps: |
|---|
| 774 | |
|---|
| 775 | v = self.tab.vboxes[name] |
|---|
| 776 | but = QPushButton('Compute %s'%name) |
|---|
| 777 | v.addWidget(but) |
|---|
| 778 | #~ self.connect(but , SIGNAL('clicked()') , getattr(self , 'compute%s'%name) ) |
|---|
| 779 | self.connect(but , SIGNAL('clicked()') , self.computeAStep) |
|---|
| 780 | |
|---|
| 781 | h = self.tab.hboxes[name] |
|---|
| 782 | self.plotWidget[name] = plotWidget(parent = self) |
|---|
| 783 | h.addWidget(self.plotWidget[name], 3) |
|---|
| 784 | |
|---|
| 785 | |
|---|
| 786 | |
|---|
| 787 | #~ self.hboxes = { } |
|---|
| 788 | #~ self. = { } |
|---|
| 789 | #~ self.widgetMultimethods = { } |
|---|
| 790 | |
|---|
| 791 | |
|---|
| 792 | |
|---|
| 793 | def computeAStep(self, ): |
|---|
| 794 | name = self.sender().text() |
|---|
| 795 | name = str(name.replace('Compute ', '')) |
|---|
| 796 | print 'compute', name |
|---|
| 797 | |
|---|
| 798 | # launch computation |
|---|
| 799 | getattr(self.tab , 'compute%s'%name)( ) |
|---|
| 800 | |
|---|
| 801 | # refresh plot |
|---|
| 802 | self.plotWidget[name].refresh() |
|---|
| 803 | |
|---|
| 804 | |
|---|
| 805 | #~ def computeFiltering(self) : |
|---|
| 806 | #~ print 'computeFiltering' |
|---|
| 807 | #~ self.tab.computeFiltering() |
|---|
| 808 | |
|---|
| 809 | #~ self.plotWidget[name] |
|---|
| 810 | |
|---|
| 811 | |
|---|
| 812 | #~ def computeDetection(self) : |
|---|
| 813 | #~ print 'computeDetection' |
|---|
| 814 | #~ self.tab.computeDetection() |
|---|
| 815 | |
|---|
| 816 | |
|---|
| 817 | #~ def computeExtraction(self) : |
|---|
| 818 | #~ self.tab.computeExtraction() |
|---|
| 819 | |
|---|
| 820 | |
|---|
| 821 | #~ def computeProjection(self) : |
|---|
| 822 | #~ self.tab.computeProjection() |
|---|
| 823 | |
|---|
| 824 | |
|---|
| 825 | #~ def computeClustering(self) : |
|---|
| 826 | #~ self.tab.computeClustering() |
|---|
| 827 | |
|---|
| 828 | |
|---|
| 829 | |
|---|
| 830 | |
|---|
| 831 | |
|---|
| 832 | |
|---|
| 833 | |
|---|