root/branches/OpenElectrophy-0.2/OpenElectrophy/gui/spikesorting.py @ 233

Revision 233, 29.2 KB (checked in by sgarcia, 3 years ago)

Spike sorting GUI in progress

Line 
1# -*- coding: utf-8 -*-
2
3
4"""
5
6"""
7
8
9
10from PyQt4.QtCore import *
11from PyQt4.QtGui import *
12import numpy
13
14from guiutil.icons import icons
15#~ from guiutil.globalapplicationdict import *
16from guiutil.paramwidget import ParamWidget, LimitWidget
17from enhancedmatplotlib import *
18import numpy
19from numpy import inf, zeros, unique, mean, std, arange
20
21from ..computing.spikesorting import filtering, detection, extraction, projection, clustering
22
23#~ from ..classes import allclasses, Oscillation
24#~ from ..computing.timefrequency import LineDetector, PlotLineDetector
25from enhancedmatplotlib import SimpleCanvasAndTool
26#~ from queryresultbox import QueryResultBox
27
28#~ from sqlalchemy import and_, or_
29
30
31from mpl_toolkits.mplot3d import Axes3D
32
33
34
35
36
37colors = [ 'c' , 'g' , 'r' , 'b' , 'k' , 'm' , 'y']*100
38
39
40class WidgetMultiMethodsParam(QFrame) :
41    """
42    Widget for choosing a method and its parameters.
43    """
44    def __init__(self,  parent = None ,
45                        list_method = [ ],
46                        method_name = '',
47                        globalApplicationDict = None,
48                        ):
49        QFrame.__init__(self, parent)
50       
51        self.list_method = list_method
52        self.method_name = method_name
53        self.globalApplicationDict = globalApplicationDict
54       
55        self.setFrameStyle(QFrame.Raised | QFrame.StyledPanel)
56        self.v1 = QVBoxLayout()
57        v1 = self.v1
58        self.setLayout(v1)
59       
60       
61        v1.addWidget(QLabel(self.method_name))
62        self.comboBox_method = QComboBox()
63        v1.addWidget(self.comboBox_method)
64        self.comboBox_method.addItems([ method.name for  method in list_method ])
65       
66        self.connect(self.comboBox_method,SIGNAL('currentIndexChanged( int  )') , self.comboBoxChangeMethod )
67       
68        self.paramWidget = None
69       
70       
71        self.comboBoxChangeMethod()
72   
73    def comboBoxChangeMethod(self) :
74        pos = self.comboBox_method.currentIndex()
75        if self.paramWidget is not None :
76            self.paramWidget.setVisible(False)
77            self.v1.removeWidget(self.paramWidget)
78            del self.paramWidget
79        method = self.list_method[pos]
80        self.paramWidget = ParamWidget(method.params ,
81                                                        applicationdict = self.globalApplicationDict,
82                                                        keyformemory = 'spikesorting/%s/%s'%(self.method_name,method.name)  ,
83                                                        title = method.name,
84                                                        )
85        self.v1.addWidget(self.paramWidget,1)
86   
87    def get_method(self) :
88        pos = self.comboBox_method.currentIndex()
89        method = self.list_method[pos]()
90        return method
91   
92    def get_dict(self) :
93        return self.paramWidget.get_dict()
94
95
96
97
98class WidgetFiltering(QWidget):
99    """
100    widget to plot the filtering
101    """
102    def __init__(self , parent=None ,):
103        QWidget.__init__(self,parent )
104       
105        self.spikeSortingWin = self.parent()
106        mainLayout = QVBoxLayout()
107        self.setLayout(mainLayout)
108        self.canvas = SimpleCanvasAndTool(orientation = Qt.Horizontal )
109        mainLayout.addWidget(self.canvas)
110        self.fig = self.canvas.fig
111        self.ax1 = self.fig.add_subplot(2,1,1)
112        self.ax2 = self.fig.add_subplot(2,1,2, sharex = self.ax1)
113       
114        self.ax1.clear()
115        for anaSig in self.spikeSortingWin.tab.anaSigList:
116            self.ax1.plot(anaSig.t(), anaSig.signal)
117        self.canvas.draw()
118       
119    def refresh(self):
120        self.ax2.clear()
121        for anaSig in self.spikeSortingWin.tab.anaSigFilteredList:
122            self.ax2.plot(anaSig.t(), anaSig.signal)
123        self.canvas.draw()
124
125
126
127
128class WidgetDetection(QWidget):
129    """
130    Widget to plot the detection
131    """
132    def __init__(self , parent=None ,):
133        QWidget.__init__(self,parent )
134        self.spikeSortingWin = self.parent()
135        mainLayout = QVBoxLayout()
136        self.setLayout(mainLayout)
137        self.canvas = SimpleCanvasAndTool(orientation = Qt.Horizontal )
138        self.fig = self.canvas.fig
139        mainLayout.addWidget(self.canvas)
140       
141        self.axs = None
142        self.lines = [ ]
143       
144    def plotSigs(self):
145        n = len(self.spikeSortingWin.tab.anaSigFilteredList)
146        self.axs = [ ]
147        ax = None
148        for i , anaSig in enumerate(self.spikeSortingWin.tab.anaSigFilteredList):
149            ax = self.fig.add_subplot(n, 1,i+1 , sharex = ax, sharey = ax)
150            self.axs.append(ax)
151            ax.plot(anaSig.t(), anaSig.signal , color = 'b')
152        self.canvas.draw()
153       
154    def refresh(self):
155        if self.axs is None:
156            self.plotSigs()
157       
158        sorted = self.spikeSortingWin.tab.sorted
159       
160        #remove old detection
161        for i in range(len(self.lines)):
162            for l in self.lines[i]:
163                self.axs[i].lines.remove(l)
164        self.lines = [ ]
165       
166        for c in unique(sorted):
167           
168            sp = self.spikeSortingWin.tab.spikePosistion[ c==sorted ]
169            for i , anaSig in enumerate(self.spikeSortingWin.tab.anaSigFilteredList):
170                l = self.axs[i].plot( anaSig.t()[sp] , anaSig.signal[sp], linestyle = 'None', marker = 'o', color = colors[c])
171                self.lines.append( l )
172           
173        self.canvas.draw()
174       
175
176class WidgetExtraction(QWidget):
177    def __init__(self , parent=None ,):
178        QWidget.__init__(self,parent )
179        self.spikeSortingWin = self.parent()
180        mainLayout = QHBoxLayout()
181        self.setLayout(mainLayout)
182        self.canvas = SimpleCanvasAndTool(orientation = Qt.Horizontal )
183        self.fig = self.canvas.fig
184        mainLayout.addWidget(self.canvas)
185       
186        n = len(self.spikeSortingWin.tab.anaSigList)
187        self.ax_moy = [ ]
188        ax = None
189        for i in range(n):
190            ax = self.fig.add_subplot(2,n, i+1, sharex = ax, sharey = ax)
191            self.ax_moy.append(ax)
192           
193        self.ax_all = [ ]
194        ax = None
195        for i in range(n):
196            ax = self.fig.add_subplot(2,n, n+i+1, sharex = ax, sharey = ax)
197            self.ax_all.append(ax)
198       
199    def refresh(self):
200        sorted = self.spikeSortingWin.tab.sorted
201        n = len(self.spikeSortingWin.tab.anaSigList)
202        waveforms = self.spikeSortingWin.tab.waveforms
203        for i in range(n):
204            ax = self.ax_all[i]
205            ax.clear()
206            for c in unique(sorted):
207                ax.plot( waveforms[sorted ==c, i,  :].transpose(), color = colors[c])
208           
209            ax = self.ax_moy[i]
210            ax.clear()
211            for c in unique(sorted):
212                ind = c==sorted
213                m  = mean(waveforms[ind,i,:], axis = 0)
214                sd = std(waveforms[ind,i,:], axis = 0)
215                ax.plot( m, color = colors[ c ]  , linewidth=2)
216                ax.fill_between(arange(m.size), m-sd, m+sd , color = colors[ c ]  , alpha = .3)
217               
218        self.canvas.draw()
219
220
221
222class Widget3DViewer(QWidget):
223    def __init__(self , parent=None ,):
224        QWidget.__init__(self,parent )
225        self.spikeSortingWin = self.parent()
226        mainLayout = QVBoxLayout()
227        self.setLayout(mainLayout)
228       
229        h = QHBoxLayout()
230        mainLayout.addLayout(h)
231        h.addWidget(QLabel('Choose dim'))
232        self.combos = [ ]
233        for i in range(3):
234            cb = QComboBox()
235            self.combos.append(cb)
236            self.connect(cb, SIGNAL('activated(int)'),self.change_dim )
237            h.addWidget(cb)
238       
239        but = QPushButton(QIcon(':/view-refresh.png'), 'refresh')
240        h.addWidget(but)
241        self.connect(but, SIGNAL('clicked()'),  self.change_dim)
242           
243        self.canvas1 = SimpleCanvas()
244        #~ self.canvas1 = SimpleCanvasAndTool()
245        self.ax = Axes3D(self.canvas1.fig)
246        mainLayout.addWidget( self.canvas1 )
247       
248        self.projected = None
249        self.sorted = None       
250   
251    def change_dim(self, index = None):
252        if self.projected is None : return
253        self.ax.clear()
254        vects = [ ]
255        for i in range(3):
256            ind = self.combos[i].currentIndex()
257            vects.append( self.projected[:,ind] )
258       
259        for c in unique(self.sorted):
260            ind = self.sorted==c
261            self.ax.scatter(vects[0][ind], vects[1][ind], vects[2][ind], color = colors[c])
262        self.canvas1.draw()
263   
264    def refresh(self, projected, sorted):
265        ndim = projected.shape[1]
266        for i in range(3):
267            self.combos[i].clear()
268            self.combos[i].addItems( [ str(n) for n in range(ndim) ] )
269            if i<ndim:
270                self.combos[i].setCurrentIndex(i)
271       
272        self.projected = projected
273        self.sorted = sorted
274       
275        self.change_dim()
276       
277
278
279
280class WidgetProjection(QWidget):
281    def __init__(self , parent=None ,):
282        QWidget.__init__(self,parent )
283       
284        self.spikeSortingWin = self.parent()
285        mainLayout = QVBoxLayout()
286        self.setLayout(mainLayout)
287       
288        h = QHBoxLayout()
289        mainLayout.addLayout(h)
290        h.addWidget(QLabel('Choose a view for projection'))
291        self.comboView = QComboBox()
292        h.addWidget(self.comboView)
293        self.stacked = QStackedWidget()
294        mainLayout.addWidget(self.stacked)
295        self.connect(self.comboView, SIGNAL('activated(int)'),self.stacked, SLOT('setCurrentIndex(int)') )
296       
297        # flatened 1D view
298        self.comboView.addItem('flatened 1D view')
299        self.canvas1 = SimpleCanvasAndTool()
300        self.stacked.addWidget(self.canvas1)
301        self.ax1 = self.canvas1.fig.add_subplot(1,1,1)
302       
303        # combinated 2D
304        self.comboView.addItem('combinated 2D')
305        self.canvas2 = SimpleCanvasAndTool()
306        self.stacked.addWidget(self.canvas2)
307       
308        # 3D viewer
309        self.comboView.addItem('3D viewer')
310        self.widget3Dviewer = Widget3DViewer()
311        self.stacked.addWidget(self.widget3Dviewer)
312       
313       
314       
315    def refresh(self):
316        sorted = self.spikeSortingWin.tab.sorted
317        waveforms = self.spikeSortingWin.tab.waveforms
318        projected = self.spikeSortingWin.tab.projected
319        ndim = projected.shape[1]
320       
321        # flatened 1D view
322        self.ax1.clear()
323        for c in unique(sorted):
324            ind = c==sorted
325            self.ax1.plot( projected[ind,:].transpose() , color = colors[c], marker = '.') 
326        self.canvas1.draw()
327       
328        # combinated 2D
329        ndim2 = min(ndim, 16)
330        print 'yep'
331        self.canvas2.fig.clear()
332        if projected.shape[1]>1:
333            for c in unique(sorted):
334                ind = c==sorted
335               
336               
337                for i in range(ndim2):
338                    for j in range(i+1, ndim2):
339                        p = (j-1)*(ndim2-1)+i+1
340                        ax = self.canvas2.fig.add_subplot(ndim2-1, ndim2-1, p)
341                        ax.plot(projected[ind,i], projected[ind,j], color = colors[c], marker = '.', linestyle = 'None') 
342                        #ax.set_title('%d %d'%(i,j))
343                        if i==0:
344                            ax.set_ylabel( str(j) )
345                        if j==ndim-1:
346                            ax.set_xlabel( str(i) )
347                        ax.set_xticks([ ])
348                        ax.set_yticks([ ])
349        self.canvas2.draw()
350       
351       
352        # 3D viewer
353        self.widget3Dviewer.refresh( projected, sorted)
354       
355       
356       
357       
358       
359       
360
361class WidgetClustering(QWidget):
362    def __init__(self , parent=None ,):
363        QWidget.__init__(self,parent )
364    def refresh(self):
365        pass
366
367
368
369
370steps = [ 
371                        ['Filtering' , filtering, WidgetFiltering],
372                        ['Detection' , detection, WidgetDetection],
373                        ['Extraction' , extraction, WidgetExtraction],
374                        ['Projection' , projection, WidgetProjection],
375                        ['Clustering' , clustering, WidgetClustering],
376                   ]
377
378   
379class TabSpikeSorting(QTabWidget) :
380    """
381    Widget displaying all tabs and methods options.
382    Used in :
383            -
384            -
385   
386   
387    """
388       
389    def __init__(self , parent=None ,
390                            metadata =None,
391                            session = None,
392                            globalApplicationDict = None,
393                           
394                            # possibilitty 1
395                            anaSigList = None,
396                           
397                           
398                    ):
399        QTabWidget.__init__(self,parent )
400        self.setTabPosition(QTabWidget.West)
401       
402        #~ self.setAttribute(Qt.WA_DeleteOnClose)
403       
404        self.metadata = metadata
405        self.session = session
406        self.globalApplicationDict = globalApplicationDict
407       
408       
409        # construct all tabs
410
411       
412        self.hboxes = { }
413        self.vboxes = { }
414        self.widgetMultimethods = { }
415       
416        for name, module, plotWidget in steps:
417            w = QWidget()
418            self.addTab(w,name)
419            h= QHBoxLayout()
420            self.hboxes[name] = h
421            w.setLayout(h)
422           
423            v = QVBoxLayout( )
424            h.addLayout( v )
425            self.vboxes[name] = v
426            wMeth = WidgetMultiMethodsParam(  list_method = module.list_method,
427                                                                method_name = 'Choose methd for %s:'%name,
428                                                                globalApplicationDict = self.globalApplicationDict,
429                                                                )
430            self.widgetMultimethods[name] = wMeth
431            v.addWidget(wMeth)
432            v.addStretch(0)
433           
434       
435        # tab for database options
436        w = QWidget()
437        self.addTab(w,'Database option')
438        h= QHBoxLayout()
439        w.setLayout(h)
440        v = QVBoxLayout( )
441        h.addLayout( v )
442       
443        params = [
444                            ( 'save_filtered_waveform' , {'value' : True , 'label' : 'Save filterered waveform' }),
445                        ]
446        self.databaseOptions =  ParamWidget(params,
447                                    applicationdict = self.globalApplicationDict,
448                                    keyformemory = 'spikesorting/databaseoptions'  ,
449                                    title = 'database options',
450                                    )
451        v.addWidget( self.databaseOptions )
452        v.addStretch(0)
453       
454       
455        # variables
456        self.anaSigList = None
457        self.anaSigFilteredList = None
458        self.spikePosistion = None
459        self.spikeSign = None
460        self.left_sweep = None
461        self.right_sweep = None
462        self.waveforms = None
463        self.projected = None
464        self.sorted = None
465       
466       
467        # FIXME :
468        self.anaSigList = anaSigList
469       
470       
471    #~ def load_signal(self) :
472        #~ if self.id_electrode is not None :
473            #~ # mode one electrode
474            #~ self.elec = Electrode()
475            #~ self.elec.load_from_db(self.id_electrode)
476            #~ self.sig = self.elec.signal
477            #~ self.fs = self.elec.fs
478            #~ self.list_elec = None
479        #~ else :
480            #~ # mode all electrode on same serie
481            #~ query = """
482                    #~ SELECT id_electrode
483                    #~ FROM electrode , trial
484                    #~ WHERE
485                    #~ electrode.id_trial = trial.id_trial
486                    #~ AND trial.id_serie = %s
487                    #~ AND num_channel = %s
488                    #~ ORDER BY trial.thedatetime
489                    #~ """
490            #~ self.list_elec = [ ]
491            #~ id_electrodes, = sql(query , (self.id_serie , self.num_channel))
492            #~ self.sig = array([])
493            #~ for id_electrode in id_electrodes :
494                #~ elec = Electrode()
495                #~ elec.load_from_db(id_electrode)
496                #~ self.list_elec.append(elec)
497                #~ self.sig = concatenate((self.sig , elec.signal))
498                #~ self.fs = elec.fs
499       
500        #~ self.t = arange(self.sig.size)/self.fs
501        #~ self.pos_spike = [ ]
502        #~ self.sig_f = [ ]
503        #~ self.waveform = [ ]
504        #~ self.waveform_projected = [ ]
505        #~ self.cluster = [ ]
506        #~ self.waveform_size = None
507        #~ self.oversampling = None
508
509
510       
511    def computeFiltering(self) :
512        m = self.widgetMultimethods['Filtering'].get_method()
513        kargs = self.widgetMultimethods['Filtering'].get_dict()
514       
515        self.anaSigFilteredList = [ ]
516        for i in range(len( self.anaSigList )):
517            self.anaSigFilteredList.append( m.compute( self.anaSigList[i] , **kargs) )
518
519       
520       
521    def computeDetection(self) :
522        m = self.widgetMultimethods['Detection'].get_method()
523        kargs = self.widgetMultimethods['Detection'].get_dict()
524       
525        self.spikeSign = kargs['sign']
526        self.left_sweep = kargs['left_sweep']
527        self.right_sweep = kargs['right_sweep']
528        self.spikePosistion = m.compute(self.anaSigFilteredList, **kargs)
529       
530        self.sorted = zeros(self.spikePosistion.size, dtype = 'i')
531       
532       
533    def computeExtraction(self) :
534        m = self.widgetMultimethods['Extraction'].get_method()
535        kargs = self.widgetMultimethods['Extraction'].get_dict()
536       
537        self.waveforms = m.compute(self.anaSigFilteredList, self.spikePosistion,self.spikeSign, left_sweep = self.left_sweep , right_sweep = self.right_sweep)
538       
539    def computeProjection(self) :
540        m = self.widgetMultimethods['Projection'].get_method()
541        kargs = self.widgetMultimethods['Projection'].get_dict()
542       
543        self.projected = m.compute( self.waveforms, self.anaSigFilteredList[0].sampling_rate, **kargs)
544   
545       
546    def computeClustering(self) :
547        m = self.widgetMultimethods['Clustering'].get_method()
548        kargs = self.widgetMultimethods['Clustering'].get_dict()
549       
550        self.sorted = m.compute( self.projected , self.spikePosistion , **kargs )
551       
552   
553       
554    def recomputeAllSteps(self) :
555        self.computeFiltering()
556        self.computeDetection()
557        self.computeExtraction()
558        self.computeProjection()
559        self.computeClustering()
560       
561       
562    #~ def save_to_db(self) :
563        #~ n_cluster = unique(self.cluster).size
564        #~ waveform_size = self.param_database.get_one_param('waveform_size')
565        #~ oversampling = self.param_database.get_one_param('oversampling')
566       
567        #~ if self.id_electrode is not None :
568            #~ # mode one electrode
569           
570            #~ # delete old spiketrain and spike in database
571            #~ id_spiketrains, = sql('SELECT id_spiketrain FROM spiketrain WHERE id_electrode = %s'  , self.id_electrode)
572            #~ for id_spiketrain in id_spiketrains :
573                #~ sptr = SpikeTrain()
574                #~ sptr.id_spiketrain = id_spiketrain
575                #~ sptr.id_principal = id_spiketrain
576                #~ sptr.delete_from_db_and_child(dict_hierarchic_class )
577           
578            #~ #create new ones
579            #~ for n,cl in enumerate(unique(self.cluster)) :
580               
581                #~ sptr = SpikeTrain()
582                #~ sptr.id_trial = self.elec.id_trial
583                #~ sptr.id_electrode = self.elec.id_electrode
584                #~ sptr.id_cell = None
585                #~ sptr.fs = self.elec.fs
586                #~ sptr.shift_t0 =  self.elec.shift_t0
587                #~ sptr.oversampling = oversampling
588                #~ sptr.f_low = None
589                #~ sptr.f_hight = None
590                #~ sptr.label = u''
591                #~ sptr.coment = u''
592                #~ id_spiketrain = sptr.save_to_db()
593               
594                #~ pos = self.pos_spike[self.cluster== cl]
595                #~ isi = r_[diff(pos)/float(self.fs) , Inf]
596                #~ if self.param_database.get_one_param('save_filtered_waveform') :
597                    #~ fil = self.multiMethod_filtering.get_method()
598                    #~ karg = self.multiMethod_filtering.get_dict()
599                    #~ sig_f = fil.compute(self.sig , self.fs , **karg)
600                #~ else :
601                    #~ sig_f = self.sig
602                #~ waveform = waveform_extraction(pos,sig_f, self.fs , waveform_size,oversampling)
603                #~ for s in range(len(pos)) :
604                    #~ sp = Spike()
605                    #~ sp.id_spiketrain = id_spiketrain
606                    #~ sp.id_electrode = self.elec.id_electrode
607                    #~ sp.pos = pos[s]
608                    #~ sp.val_max = sig_f[pos[s]]
609                    #~ sp.waveform = squeeze(waveform[s,:])
610                    #~ sp.isi = isi[s]
611                    #~ sp.save_to_db()
612       
613        #~ else:
614            #~ # mode all electrode on same serie
615           
616            #~ # delete old spiketrain and spike in database
617            #~ query = """
618                    #~ SELECT spiketrain.id_spiketrain
619                    #~ FROM spiketrain , electrode , trial
620                    #~ WHERE
621                    #~ trial.id_trial = electrode.id_trial
622                    #~ AND electrode.id_electrode = spiketrain.id_electrode
623                    #~ AND trial.id_serie = %s
624                    #~ AND electrode.num_channel = %s
625                    #~ """
626            #~ id_spiketrains, = sql(query  , (self.id_serie , self.num_channel ))
627            #~ for id_spiketrain in id_spiketrains :
628                #~ sptr = SpikeTrain()
629                #~ sptr.id_spiketrain = id_spiketrain
630                #~ sptr.id_principal = id_spiketrain
631                #~ sptr.delete_from_db_and_child(dict_hierarchic_class )
632           
633            #~ #create new cells, spiketrain et spike
634            #~ for n,cl in enumerate(unique(self.cluster)) :
635                #~ cell = Cell()
636                #~ cell.id_serie = self.id_serie
637                #~ cell.info = u''
638                #~ cell.name = u'Cell %s NumChannel %s' %( n+1 , self.num_channel)
639                #~ id_cell = cell.save_to_db()
640               
641                #~ start = 0
642                #~ for e,elec in enumerate(self.list_elec):
643                    #~ sptr = SpikeTrain()
644                    #~ sptr.id_trial = elec.id_trial
645                    #~ sptr.id_electrode = elec.id_electrode
646                    #~ sptr.id_cell = id_cell
647                    #~ sptr.fs = elec.fs
648                    #~ sptr.shift_t0 =  elec.shift_t0
649                    #~ sptr.oversampling = oversampling
650                    #~ sptr.f_low = None
651                    #~ sptr.f_hight = None
652                    #~ sptr.label = u''
653                    #~ sptr.coment = u''
654                    #~ id_spiketrain = sptr.save_to_db()
655                   
656                    #~ pos = self.pos_spike[self.cluster== cl]
657                    #~ pos = pos[ (pos>= start) & (pos<start + elec.signal.size)]
658                    #~ pos = pos - start
659                    #~ isi = r_[diff(pos)/float(elec.fs) , Inf]
660                    #~ if self.param_database.get_one_param('save_filtered_waveform') :
661                        #~ fil = self.multiMethod_filtering.get_method()
662                        #~ karg = self.multiMethod_filtering.get_dict()
663                        #~ sig_f = fil.compute(elec.signal , elec.fs , **karg)
664                    #~ else :
665                        #~ sig_f = elec.signal
666                    #~ waveform = waveform_extraction(pos,sig_f, self.fs , waveform_size,oversampling)
667                    #~ for s in range(len(pos)) :
668                        #~ sp = Spike()
669                        #~ sp.id_spiketrain = id_spiketrain
670                        #~ sp.id_electrode = elec.id_electrode
671                        #~ sp.pos = pos[s]
672                        #~ sp.val_max = sig_f[pos[s]]
673                        #~ sp.waveform = squeeze(waveform[s,:])
674                        #~ sp.isi = isi[s]
675                        #~ sp.save_to_db()
676
677                    #~ start += elec.signal.size
678                   
679       
680    #~ def reload_from_db(self) :
681        #~ if self.id_electrode is not None :
682            #~ # mode one electrode
683            #~ id_spiketrains, = sql('SELECT id_spiketrain FROM spiketrain WHERE id_electrode = %s'  , self.id_electrode)
684            #~ self.pos_spike = array([ ],dtype='i')
685            #~ self.cluster = array([ ],dtype='i')
686            #~ for i,id_spiketrain in enumerate(id_spiketrains) :
687                #~ sptr = SpikeTrain()
688                #~ sptr.load_from_db(id_spiketrain)
689                #~ pos = sptr.pos_spike()
690                #~ self.pos_spike = concatenate((self.pos_spike , pos))
691                #~ self.cluster = concatenate((self.cluster , i*ones((len(pos)) , dtype = 'i') ))
692        #~ else:
693            #~ # mode all electrode on same serie
694            #~ query = """
695                    #~ SELECT spiketrain.id_spiketrain , cell.id_cell , electrode.id_electrode
696                    #~ FROM spiketrain , electrode , trial , cell
697                    #~ WHERE
698                    #~ trial.id_trial = electrode.id_trial
699                    #~ AND electrode.id_electrode = spiketrain.id_electrode
700                    #~ AND cell.id_cell = spiketrain.id_cell
701                    #~ AND trial.id_serie = %s
702                    #~ AND electrode.num_channel = %s
703                    #~ ORDER BY cell.id_cell
704                    #~ """
705            #~ id_spiketrains,id_cells,id_electrodes = sql(query  , (self.id_serie , self.num_channel ))
706            #~ self.pos_spike = array([ ],dtype='i')
707            #~ self.cluster = array([ ],dtype='i')
708            #~ n_cluster = unique(id_cells).size
709           
710            #~ for i,id_spiketrain in enumerate(id_spiketrains) :
711                #~ id_cell,id_electrode = id_cells[i],id_electrodes[i]
712                #~ start = 0
713                #~ for e,elec in enumerate(self.list_elec):
714                    #~ if elec.id_electrode == id_electrode : break
715                    #~ start += elec.signal.size
716
717                #~ sptr = SpikeTrain()
718                #~ sptr.load_from_db(id_spiketrain)
719                #~ pos = sptr.pos_spike()+start
720                #~ self.pos_spike = concatenate((self.pos_spike , pos))
721               
722                #~ cluster = where(id_cell == unique(id_cells))[0]
723                #~ self.cluster = concatenate((self.cluster , cluster*ones((len(pos)) , dtype = 'i') ))
724           
725
726
727
728
729
730
731
732
733
734class SpikeSorting(QDialog) :
735    """
736    Scroll area resazible for stacking matplotlib canvas
737   
738    several modes :
739                                - spikedetection/spikesorting on recording point (and its group)
740                                - spikesorting on a list spiketrain
741                                -
742   
743   
744   
745    """
746    def __init__(self  , parent = None ,
747                            metadata =None,
748                            session = None,
749                            globalApplicationDict = None,
750                           
751                            anaSigList = None,
752                           
753                            ):
754        QDialog.__init__(self, parent)
755        self.metadata = metadata
756        self.session = session
757        self.globalApplicationDict = globalApplicationDict
758
759        mainLayout = QVBoxLayout()
760        self.setLayout(mainLayout)
761
762        self.tab = TabSpikeSorting(metadata = self.metadata,
763                                                                    session = self.session,
764                                                                    globalApplicationDict= self.globalApplicationDict,
765                                                                   
766                                                                    anaSigList = anaSigList,
767                                                                   
768                                                                    )
769       
770        mainLayout.addWidget(self.tab)
771       
772        self.plotWidget = { }
773        for name, module, plotWidget in steps:
774           
775            v = self.tab.vboxes[name]
776            but = QPushButton('Compute %s'%name)
777            v.addWidget(but)
778            #~ self.connect(but , SIGNAL('clicked()') , getattr(self , 'compute%s'%name) )
779            self.connect(but , SIGNAL('clicked()') , self.computeAStep)
780
781            h = self.tab.hboxes[name]
782            self.plotWidget[name] = plotWidget(parent = self)
783            h.addWidget(self.plotWidget[name], 3)
784
785
786
787        #~ self.hboxes = { }
788        #~ self. = { }
789        #~ self.widgetMultimethods = { }
790           
791       
792
793    def computeAStep(self, ):
794        name = self.sender().text()
795        name = str(name.replace('Compute ', ''))
796        print 'compute', name
797       
798        # launch computation
799        getattr(self.tab , 'compute%s'%name)( )
800       
801       # refresh plot
802        self.plotWidget[name].refresh()
803       
804
805    #~ def computeFiltering(self) :
806        #~ print 'computeFiltering'
807        #~ self.tab.computeFiltering()
808       
809        #~ self.plotWidget[name]
810       
811
812    #~ def computeDetection(self) :
813        #~ print 'computeDetection'
814        #~ self.tab.computeDetection()
815       
816       
817    #~ def computeExtraction(self) :
818        #~ self.tab.computeExtraction()
819       
820       
821    #~ def computeProjection(self) :
822        #~ self.tab.computeProjection()
823   
824       
825    #~ def computeClustering(self) :
826        #~ self.tab.computeClustering()
827
828
829
830
831
832
833
Note: See TracBrowser for help on using the browser.