root/trunk/src/plotting.py

Revision 186, 5.4 kB (checked in by apdavison, 4 weeks ago)

In signals module:

  • renamed SpikeTrain.rescale() to SpikeTrain.relative_times(). This seems clearer to me since 'rescale' implies multiplication whereas we only do a subtraction.
  • added SpikeTrain.merge(), which adds the spike times from another spike train to this one.
  • fixed a bug in iterating over a SpikeList object.
  • added SpikeList.f1f0_ratios(). This might be a bit too vision-specific, but since it is not uncommon to have oscillatory patterns in spike trains I think it should be here.

In plotting module: minor improvement to spacing between subplots, which is now proportional to the size of the subplots and may be different for the vertical and horizontal layout.

  • Property svn:executable set to *
Line 
1
2 import numpy #, pylab
3 import sys
4 from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
5 from matplotlib.figure import Figure
6 from matplotlib.lines import Line2D
7 """
8 Plotting helpers
9 """
10
11 def pylab_params(fig_width_pt = 246.0,
12                 ratio = (numpy.sqrt(5)-1.0)/2.0,# Aesthetic golden mean ratio by default
13                 text_fontsize = 10 , tick_labelsize = 8):
14     """
15     This functions calls some parameters to properly print figures for your papers.
16
17     fig_width_pt   # Get this from LaTeX using \showthe\columnwidth
18
19     ratio : ratio between the height and the width of the figure
20
21     """
22
23     # TODO: include a conversion to cm cm=1/3 # inches per cm
24
25     inches_per_pt = 1.0/72.27               # Convert pt to inch
26     fig_width = fig_width_pt*inches_per_pt  # width in inches
27     fig_height = fig_width*ratio      # height in inches
28     fig_size =  [fig_width,fig_height]
29
30     params = {
31             #'axes.labelsize': text_fontsize,
32             #'text.fontsize': text_fontsize,
33             #'xtick.labelsize': tick_labelsize,
34             #'ytick.labelsize': tick_labelsize,
35             'text.usetex':False, ##True, ## problem with svg output resolved in latest matplotlib
36             'figure.figsize': fig_size}
37     return params
38
39
40 def raster_plot(spike_list,output=None):# limits of the plot
41     import pylab
42     DATA=spike_list.as_list_id_list_time()
43     pylab.plot(DATA[1],DATA[0],'.')
44     pylab.ylabel('neuron ID')
45     pylab.xlabel('time (s)')
46     pylab.axis([spike_list.t_start, spike_list.t_stop, 0, spike_list.N])
47     if not(output==None):
48         pylab.savefig(output)
49     #else:
50         #pylab.show()
51
52
53
54
55 def set_frame(ax,boollist,linewidth=2):
56     assert len(boollist) == 4
57     bottom = Line2D([0, 1], [0, 0], transform=ax.transAxes, linewidth=linewidth, color='k')
58     left   = Line2D([0, 0], [0, 1], transform=ax.transAxes, linewidth=linewidth, color='k')
59     top    = Line2D([0, 1], [1, 1], transform=ax.transAxes, linewidth=linewidth, color='k')
60     right  = Line2D([1, 0], [1, 1], transform=ax.transAxes, linewidth=linewidth, color='k')
61     # anti-aliased?
62     if boollist != [True,True,True,True]:
63         ax.set_frame_on(False)
64         for side,draw in zip([left,bottom,right,top],boollist):
65             if draw:
66                 ax.add_line(side)
67
68 class SimpleMultiplot(object):
69     """
70     A figure consisting of multiple panels, all with the same datatype and
71     the same x-range.
72     """
73    
74     def __init__(self, nrows, ncolumns, title="", xlabel=None, ylabel=None,
75                  scaling=('linear','linear')):
76         self.fig = Figure()
77         self.canvas = FigureCanvas(self.fig)
78         self.axes = []
79         self.all_panels = self.axes
80         self.nrows = nrows
81         self.ncolumns = ncolumns
82         self.n = nrows*ncolumns
83         self._curr_panel = 0
84         self.title = title
85         topmargin = 0.06
86         rightmargin = 0.02
87         bottommargin = 0.1
88         leftmargin=0.1
89         v_panelsep = 0.1*(1 - topmargin - bottommargin)/nrows #0.05
90         h_panelsep = 0.1*(1 - leftmargin - rightmargin)/ncolumns
91         panelheight = (1 - topmargin - bottommargin - (nrows-1)*v_panelsep)/nrows
92         panelwidth = (1 - leftmargin - rightmargin - (ncolumns-1)*h_panelsep)/ncolumns
93         assert panelheight > 0
94        
95         bottomlist = [bottommargin + i*v_panelsep + i*panelheight for i in range(nrows)]
96         leftlist = [leftmargin + j*h_panelsep + j*panelwidth for j in range(ncolumns)]
97         bottomlist.reverse()
98         for j in range(ncolumns):
99             for i in range(nrows):
100                 ax = self.fig.add_axes([leftlist[j],bottomlist[i],panelwidth,panelheight])
101                 set_frame(ax,[True,True,False,False])
102                 ax.xaxis.tick_bottom()
103                 ax.yaxis.tick_left()
104                 self.axes.append(ax)
105         if xlabel:
106             self.axes[self.nrows-1].set_xlabel(xlabel)
107         if ylabel:
108             self.fig.text(0.5*leftmargin,0.5,ylabel,
109                           rotation='vertical',
110                           horizontalalignment='center',
111                           verticalalignment='center')
112         if scaling == ("linear","linear"):
113             self.plot_function = "plot"
114         elif scaling == ("log", "log"):
115             self.plot_function = "loglog"
116         elif scaling == ("log", "linear"):
117             self.plot_function = "semilogx"
118         elif scaling == ("linear", "log"):
119             self.plot_function = "semilogy"
120         else:
121             raise Exception("Invalid value for scaling parameter")
122    
123     def finalise(self):
124         """Adjustments to be made after all panels have been plotted."""
125         # Turn off tick labels for all x-axes except the bottom one
126         self.fig.text(0.5, 0.99, self.title, horizontalalignment='center',
127                       verticalalignment='top')
128         for ax in self.axes[0:self.nrows-1]+self.axes[self.nrows:]:
129             ax.xaxis.set_ticklabels([])
130
131     def save(self,filename):
132         """Save/print the figure to file."""
133         self.finalise()
134         self.canvas.print_figure(filename)
135    
136     def next_panel(self):
137         ax = self.axes[self._curr_panel]
138         self._curr_panel += 1
139         if self._curr_panel >= self.n:
140             self._curr_panel = 0
141         ax.plot1 = getattr(ax, self.plot_function)
142         return ax
143        
144     def panel(self,i):
145         """Return panel i."""
146         ax = self.axes[i]
147         ax.plot1 = getattr(ax, self.plot_function)
148         return ax
149    
Note: See TracBrowser for help on using the browser.