Source code for mmf.utils.mmf_plot

r"""Plotting tools.

This module includes some additional tools for making publication
quality figures for LaTeX documents.  The main goal is for good
control over sizing, fonts, etc.

To Do
-----
.. todo:: Fix this bug!

   >>> from mmf.utils.mmf_plot import Figure
   >>> fig = Figure()
   >>> plt.plot([0,1],[1,1.118])
   >>> fig.adjust()

   This has something to do with the yticks.  If they are set, then
   things work fine:

   >>> plt.clf()
   >>> plt.plot([0,1],[1,1.118])
   >>> plt.yticks(plt.yticks()[0])
   >>> fig.adjust()

"""
from __future__ import division

import logging

import numpy as np
import scipy.stats
import scipy as sp
import matplotlib.collections
import matplotlib.artist
import matplotlib.pyplot as plt

from mmf.objects import StateVars, process_vars
from mmf.objects import Computed, ClassVar, Excluded

_FINFO = np.finfo(float)
_EPS = _FINFO.eps

[docs]def plot_errorbars(x,y,dx=None,dy=None,colour='',linestyle='', pointstyle='.',barwidth=0.5,**kwargs): if pointstyle != '' or linestyle != '': plt.plot(x,y, pointstyle + linestyle + colour,**kwargs) elif dx is None and dy is None: # Plot points if both error bars are not drawn. plt.plot(x,y, '.' + colour,**kwargs) if dx is not None: xmax = x + dx xmin = x - dx if dy is not None: ymax = y + dy ymin = y - dy for n in xrange(len(x)): if dx is not None: plt.plot([xmin[n],xmax[n]],[y[n],y[n]], '-|' + colour,lw=barwidth) if dy is not None: plt.plot([x[n],x[n]],[ymin[n],ymax[n]], '-_' + colour,lw=barwidth)
[docs]def plot_err(x,y,yerr=None,xerr=None,**kwarg): """Plot x vs. y with errorbars. Right now we support the following cases: x = 1D, y = 1D """ if (1 == len(x.shape) and 1 == len(y.shape)): plt.errorbar(x,y,yerr=yerr,xerr=xerr,**kwarg) elif (1 == len(x.shape) and 1 < len(y.shape)): plot_axis = np.where(np.array(y.shape) == len(x))[0][0] y = y.swapaxes(0,plot_axis) Nx,Ny = y.shape for n in Ny: plt.errorbar(x,y[:,n],**kwarg) elif (max(x.shape) == np.prod(x.shape)): plot_axis = np.argmax(x.shape) x = x.ravel() y = y.swapaxes(0,plot_axis) if yerr is not None: yerr = yerr.swapaxes(0,plot_axis) Nx,Ny = y.shape for n in xrange(Ny): if yerr is None: plt.errorbar(x,y[:,n],xerr=xerr,**kwarg) else: plt.errorbar(x,y[:,n],xerr=xerr,yerr=yerr[:,n],**kwarg) else: plt.plot(x,y,**kwarg)
[docs]def error_line(x,y,dy,fgc='k',bgc='w',N=20,fill=True): """Plots a curve (x,y) with gaussian errors dy represented by shading out to 5 dy.""" yp0 = y ym0 = y pdf = sp.stats.norm().pdf to_rgb = plt.matplotlib.colors.ColorConverter().to_rgb bg_colour = np.array(to_rgb(bgc)) fg_colour = np.array(to_rgb(fgc)) for sigma in np.linspace(0,5,N)[1:]: yp = y+dy*sigma ym = y-dy*sigma c = pdf(sigma)/pdf(0.0) #colour = fg_colour*c + (1.0-c)*bg_colour colour = fg_colour if fill: X = np.hstack((x, np.flipud(x))) Y = np.hstack((yp0, np.flipud(yp))) plt.fill(X,Y,fc=colour,ec=colour,lw=0,alpha=c) X = np.hstack((x, np.flipud(x))) Y = np.hstack((ym0,np.flipud(ym))) plt.fill(X,Y,fc=fg_colour,ec=fg_colour,lw=0,alpha=c) else: plt.plot(x,yp,color=colour,alpha=c) plt.plot(x,ym,color=fg_colour*c+(1.0-c)*bg_colour) ym0 = ym yp0 = yp
[docs]class LaTeXPlotProperties(StateVars): r"""Instances of this class provide a description of properties of a plot based on numbers extracted from a LaTeX file. Insert the following code into the section where the plot is to appear in order to extract the appropriate parameters and then use the reported values to initialize this class:: \showthe\textwidth \showthe\columnwidth \showthe\baselinskip .. note:: We assume that the document is typeset using the Computer Modern fonts. """ _state_vars = [ ('textwidth_pt', 332.89723, "From LaTeX \showthe\textwidth"), ('columnwidth_pt', 332.89723, "From LaTeX \showthe\columnwidth"), ('baselineskip_pt', 12.0, "From LaTeX \showthe\baselineskip"), ('tick_fontsize', 'footnotesize', "Ticks etc. will be typeset in this font"), ('usetex', True,\ r"""If `True`, then LaTeX will be used to typeset labels etc. Otherwise, labels etc. will be left as plain text that can be replaced with the ``\psfrag{}{}`` command in the LaTeX file. As of matplotlib version 1.0.1, psfrag replacements do not work, so the default is now to use LaTeX."""), # The following are "constants" that you should typically not # have to adjust unless you use a different font package. ('font_info', {'euler': ('zeur', '\usepackage{eulervm}')}), ('font', {'family': 'serif', 'serif': ['computer modern roman'], 'sans-serif': ['computer modern sans serif'], 'monospace': ['computer modern typewriter']}, r"`dict` of args passed to `matplotlib.rc('font')"), ('font', {'family': 'serif', 'serif': ['palatino'], 'sans-serif': ['bera sans serif'], 'monospace': ['computer modern typewriter']}, r"`dict` of args passed to `matplotlib.rc('font')"), ('latex_preamble', [r"\usepackage{amsmath}"], r"""List of strings to add to LaTeX preamble. Add any ``\usepackage{}`` commands here. .. note:: Don't forget to use raw strings to prevent escaping of characters. Thus use something like the default value: `[r"\usepackage{amsmath}"]`"""), ('latex_preview', True, "If `True`, use LaTeX preview package"), ('golden_mean', (np.sqrt(5) - 1)/2), ('font_size_pt', 10), ('font_factors', {'small': 9/10, 'footnotesize': 8/10},\ """Font size reduction factors for latex fonts."""), # Some units. These can appear in expressions. ('inches_per_pt', 1.0/72.27), ('inches', 1.0), 'pt=inches_per_pt', ('textwidth', Computed), ('columnwidth', Computed), ('baselineskip', Computed), ('tick_font', Computed), ] process_vars() def __init__(self, *v, **kw): self.textwidth = self.textwidth_pt*self.inches_per_pt self.columnwidth = self.columnwidth_pt*self.inches_per_pt self.baselineskip = self.baselineskip_pt*self.inches_per_pt self.tick_font = self.font_size_pt*self.font_factors[self.tick_fontsize] def initialize_matplotlib(self): r""":class:`Figure` calls this.""" matplotlib.rc('text', usetex=self.usetex) matplotlib.rc('font', **self.font) matplotlib.rc('text.latex', preamble=self.latex_preamble, preview=self.latex_preview, ) matplotlib.rc('font', size=self.font_size_pt) # Use TT fonts matplotlib.rc('ps', fonttype=42) # Default global instance.
_PLOT_PROPERTIES = LaTeXPlotProperties()
[docs]def xticks(ticks): """Replace ticks with real text so psfrag works. There was an API change somewhere along the line that broke this...""" plt.xticks(ticks,ticks)
[docs]def yticks(ticks): plt.yticks(ticks,ticks)
[docs]class Figure(StateVars): r"""This class represents a single figure and allows customization of properties, as well as providing plotting facilities. Notes ----- Units are either pts (for fonts) or inches (for linear measurements). Examples -------- Here is an example of a figure suitable for a half of a page in a normal LaTeX book. First we run the following file through LaTeX:: \documentclass{book} \begin{document} \showthe\textwidth \showthe\columnwidth \showthe\baselineskip \end{document} This gives:: > 345.0pt. l.3 \showthe\textwidth ? > 345.0pt. l.4 \showthe\columnwidth ? > 12.0pt. l.5 \showthe\baselineskip .. plot:: :include-source: x = np.linspace(0,1.01,100) y = np.sin(x) plot_prop = LaTeXPlotProperties(textwidth_pt=345.0, columnwidth_pt=345.0, baselineskip_pt=12.0) fig = Figure(filename='tst_book.eps', width='0.5*textwidth', plot_properties=plot_prop) plt.plot(x, y, label="r'\sin(x)'") plt.axis([-0.02,1.02,-0.02,1.02]) plt.ylabel( r'$\int_{0}^{x}\left(\frac{\cos(\tilde{x})}{1}\right)d{\tilde{x}}$') #fig.savefig() Here is another example using a two-column article:: \documentclass[twocolumn]{article} \begin{document} \showthe\textwidth \showthe\columnwidth \showthe\baselineskip \end{document} This gives:: > 469.0pt. l.3 \showthe\textwidth ? > 229.5pt. l.4 \showthe\columnwidth ? > 12.0pt. l.5 \showthe\baselineskip .. plot:: :include-source: x = np.linspace(0,1.01,100) y = np.sin(x) plot_prop = LaTeXPlotProperties(textwidth_pt=489.0, columnwidth_pt=229.5, baselineskip_pt=12.0) fig = Figure(filename='tst_article.eps', plot_properties=plot_prop) plt.plot(x, y, label="r'\sin(x)'") plt.axis([-0.02,1.02,-0.02,1.02]) plt.ylabel( r'$\int_{0}^{x}\left(\frac{\cos(\tilde{x})}{1}\right)d{\tilde{x}}$') #fig.savefig() """ _state_vars = [ ('num', None, "Figure number"), ('filename', None, "Filename for figure."), ('width', 'columnwidth',\ r"Expression involving 'columnwidth' and/or 'textwidth'"), ('height', 1.0, r"Fraction of `golden_mean*width`"), ('plot_properties', None), ('axes_dict', dict(labelsize='medium')), ('tick_dict', dict(labelsize='small')), ('legend_dict', dict(fontsize='medium')), ('margin_factors', dict(top=0.5, left=2.8, bot=3, right=0.5), """These allocate extra space for labels etc."""), ('autoadjust', False, r"""Attempt to autoadjust for labels, otherwise you can do this manually by calling :meth:`adjust`."""), ('figure_manager', Computed), ('_inset_axes', Computed, "Set of axes to be excluded from adjustment"), ('figures', ClassVar({}), "Dictonary of computed figures."), ('on_draw_id', Excluded(None),\ "Id associated with 'on_draw' event"), ('footnotesize', Computed, "Size matching corresponding LaTeX font"), ('small', Computed, "Size matching corresponding LaTeX font"), ] process_vars() def __init__(self, *v, **kw): if self.plot_properties is None: self.plot_properties = _PLOT_PROPERTIES pp = self.plot_properties pp.initialize_matplotlib() self._inset_axes = set() for _size in pp.font_factors: setattr(self, _size, pp.font_size_pt*pp.font_factors[_size]) if 'num' in kw or 'filename' in kw: width = eval(self.width, pp.__dict__) fig_width = width fig_height = self.height*width*pp.golden_mean size = pp.font_size_pt*pp.inches_per_pt # top space = 1/2 font space_top = self.margin_factors['top']*size space_left = self.margin_factors['left']*size space_bottom = self.margin_factors['bot']*size space_right = self.margin_factors['right']*size # Compute axes size: axes_left = space_left/fig_width axes_bottom = space_bottom/fig_height axes_width = 1.0 - (space_left + space_right)/fig_width axes_height = 1.0 - (space_bottom + space_top)/fig_height axes_size = [axes_left,axes_bottom, axes_width,axes_height] plt.rc('font', size=pp.font_size_pt) plt.rc('axes', **self.axes_dict) plt.rc('xtick', **self.tick_dict) plt.rc('ytick', **self.tick_dict) plt.rc('legend', **self.legend_dict) plt.figure( num=self.num, figsize=(fig_width, fig_height)) self.figure_manager = plt.get_current_fig_manager() self.num = self.figure_manager.num self.figures[self.num] = self.figure_manager plt.clf() a = plt.axes(axes_size) if self.autoadjust and False: # This makes the axis full frame. Use adjust to shrink. a.set_position([0,0,1,1]) self.start_adjusting() elif False: self.stop_adjusting def activate(self): return plt.figure(self.num) def start_adjusting(self): if self.on_draw_id: self.figure_manager.canvas.mpl_disconnect(self.on_draw_id) self.on_draw_id = self.figure_manager.canvas.mpl_connect( 'draw_event', self.on_draw) def stop_adjusting(self): if self.on_draw_id: self.figure_manager.canvas.mpl_disconnect(self.on_draw_id) self.on_draw_id = 0 def new_inset_axes(self, rect): r"""Return a new axes set inside the main axis. Parameters ---------- rect : [left, bottom, width or right, height or top] This is the rectangle for the new axes (the labels etc. will be outside). Coordinates may be either floating point numbers which specify the location of the inset in terms of a fraction between 0 and 1 of the current axis. One may also specify the coordinates in the data units of the actual corners by specifying the data as an imaginary number. This will be transformed into relative axis coordinates using the current axis limits (the subplot will not subsequently move). (Not implemented yet.) """ ax = plt.axes(rect) self._inset_axes.add(ax) return ax def axis(self, *v, **kw): r"""Wrapper for :func:`pyplot.axis` function that applies the transformation to each axis (useful if :func:`pyplot.twinx` or :func:`pyplot.twiny` has been used).""" fig = self.figure_manager.canvas.figure for _a in fig.axes: _a.axis(*v,**kw) def adjust(self, full=True, padding=0.05): r"""Adjust the axes so that all text lies withing the figure. Optionally, add some padding.""" plt.ioff() plt.figure(self.num) if full: # Reset axis to full size. fig = self.figure_manager.canvas.figure for _a in fig.axes: _a.set_position([0,0,1,1]) on_draw_id = self.figure_manager.canvas.mpl_connect( 'draw_event', self.on_draw) try: plt.ion() plt.draw() except: raise finally: pass self.figure_manager.canvas.mpl_disconnect(on_draw_id) adjustable_axes = [_a for _a in fig.axes if _a not in self._inset_axes] if 0 < padding: for _a in adjustable_axes: bb_a = _a.get_position() dx = bb_a.width*padding/2 dy = bb_a.height*padding/2 bb_a.x0 += dx bb_a.x1 -= dx bb_a.y0 += dy bb_a.y1 -= dy bb_a = _a.set_position(bb_a) @staticmethod def _shrink_bb(bb, factor=_EPS): r"""Shrink the bounding box bb by factor in order to prevent unneeded work due to rounding.""" p = bb.get_points() p += factor*(np.diff(p)*np.array([1, -1])).T bb.set_points(p) return bb def _adjust(self, logger=logging.getLogger("mmf.utils.mmf_plot.Figure._adjust")): r"""Adjust the axes to make sure all text is inside the box.""" fig = self.figure_manager.canvas.figure bb_f = fig.get_window_extent().inverse_transformed(fig.transFigure) logger.debug("Fig bb %s" % (" ".join(str(bb_f).split()),)) texts = [] adjustable_axes = [_a for _a in fig.axes if _a not in self._inset_axes] for _a in adjustable_axes: texts.extend(_a.texts) texts.append(_a.title) texts.extend(_a.get_xticklabels()) texts.extend(_a.get_yticklabels()) texts.append(_a.xaxis.get_label()) texts.append(_a.yaxis.get_label()) bboxes = [] for t in texts: if not t.get_text(): # Ignore empty text! continue bbox = t.get_window_extent() # the figure transform goes from relative # coords->pixels and we want the inverse of that bboxi = bbox.inverse_transformed(fig.transFigure) bboxes.append(bboxi) # this is the bbox that bounds all the bboxes, again in # relative figure coords bbox = self._shrink_bb(matplotlib.transforms.Bbox.union(bboxes)) adjusted = False if not np.all([bb_f.contains(*c) for c in bbox.corners()]): # Adjust axes position for _a in adjustable_axes: bb_a = _a.get_position() logger.debug("Text bb %s" % (" ".join(str(bbox).split()),)) logger.debug("Axis bb %s" % (" ".join(str(bb_a).split()),)) bb_a.x0 += max(0, bb_f.xmin - bbox.xmin) bb_a.x1 += min(0, bb_f.xmax - bbox.xmax) bb_a.y0 += max(0, bb_f.ymin - bbox.ymin) bb_a.y1 += min(0, bb_f.ymax - bbox.ymax) logger.debug("New bb %s" % (" ".join(str(bb_a).split()),)) _a.set_position(bb_a) adjusted = True return adjusted def on_draw(self, event, _adjusting=[False]): """We register this to perform processing after the figure is drawn, like adjusting the margins so that the labels fit.""" fig = self.figure_manager.canvas.figure logger = logging.getLogger("mmf.utils.mmf_plot.Figure.on_draw") if _adjusting[0]: # Don't recurse! return if event is None: # If called interactively... import pdb;pdb.set_trace() _adjusting[0] = True try: _max_adjust = 10 adjusted = False for _n in xrange(_max_adjust): adjusted = self._adjust(logger=logger) if adjusted: fig.canvas.draw() else: break if adjusted: # Even after _max_adjust steps we still needed adjusting: logger.warn("Still need adjustment after %i steps" % (_max_adjust,)) finally: _adjusting[0] = False def adjust_axis(self,extents=None, xl=None,xh=None,yl=None,yh=None, extend_x=0.0,extend_y=0.0): if extents is not None: plt.axis(extents); xl_,xh_,yl_,yh_ = plt.axis() if xl is not None: xl_ = xl if xh is not None: xh_ = xh if yl is not None: yl_ = yl if yh is not None: yh_ = yh plt.axis([xl_,xh_,yl_,yh_]); dx = extend_x*(xh_ - xl_) dy = extend_y*(yh_ - yl_) return plt.axis([xl_ - dx,xh_ + dx, yl_ - dy,yh_ + dy]) def savefig(self, filename=None): if not filename: filename = self.filename print("Saving plot as %r..."%(filename,)) plt.figure(self.num) plt.ion() # Do this to ensure autoadjustments plt.draw() # are made! plt.savefig(filename) print("Saving plot as %r. Done."%(filename,)) def __del__(self): """Destructor: make sure we unregister the autoadjustor.""" self.autoadjust = False # Here we monkeypath mpl_toolkits.axes_grid.inset_locator to allow for # independent x and y zoom factors.
[docs]def monkey_patch_inset_locator(): from mpl_toolkits.axes_grid.inset_locator import AnchoredZoomLocator import matplotlib.transforms def get_extent(self, renderer): bb = matplotlib.transforms.TransformedBbox( self.axes.viewLim, self.parent_axes.transData) x, y, w, h = bb.bounds xd, yd = 0, 0 fontsize = renderer.points_to_pixels(self.prop.get_size_in_points()) pad = self.pad * fontsize wh = np.array([w, h]) return tuple((wh*self.zoom + 2*pad).tolist() + [xd + pad, yd + pad]) AnchoredZoomLocator.get_extent = get_extent
[docs]class ListCollection(matplotlib.collections.Collection): r"""Provide a simple :class:`matplotlib.collections.Collection` of a list of artists. Provided so that this collection of artists can be simultaneously rasterized. Used by my custom :func:`contourf` function."""
[docs] def __init__(self, collections, **kwargs): matplotlib.collections.Collection.__init__(self, **kwargs) self.set_collections(collections)
[docs] def set_collections(self, collections): self._collections = collections
[docs] def get_collections(self): return self._collections
@matplotlib.artist.allow_rasterization
[docs] def draw(self, renderer): for _c in self._collections: _c.draw(renderer)
[docs]def contourf(*v, **kw): r"""Replacement for :func:`matplotlib.pyplot.contourf` that supports the `rasterized` keyword.""" was_interactive = matplotlib.is_interactive() matplotlib.interactive(False) contour_set = plt.contourf(*v, **kw) for _c in contour_set.collections: _c.remove() collection = ListCollection( contour_set.collections, rasterized=kw.get('rasterized', None)) ax = plt.gca() ax.add_artist(collection) matplotlib.interactive(was_interactive) return contour_set
[docs]def imcontourf(x, y, z, contours=None, *v, **kw): r"""Like :func:`matplotlib.pyplot.contourf` but does not actually find contours. Just displays `z` using :func:`matplotlib.pyplot.imshow` which is much faster and uses exactly the information available. Parameters ---------- x, y, z : array-like Assumes that `z` is ordered as `z[x,y]`. If `x` and `y` have the same shape as `z`, then `x = x[:,0]` and `y = y[0,:]` are used. Otherwise, `z.shape == (len(x), len(y))`. `x` and `y` must be equally spaced. """ x, y, z = map(np.asarray, (x, y, z)) if x.shape == z.shape: x = x[:,0] if y.shape == z.shape: y = y[0,:] assert z.shape[:2] == (len(x), len(y)) assert np.allclose(np.diff(np.diff(x)), 0) assert np.allclose(np.diff(np.diff(y)), 0) kwargs = dict(**kw) kwargs.update(dict(aspect='auto')) return plt.imshow(np.rollaxis(z,0,2), origin='lower', extent=(x[0], x[-1], y[0], y[-1]), *v, **kwargs)
[docs]def plot3d(x, y, z, wireframe=False, **kw): r"""Wrapper to generate 3d surface plots.""" # Move these out once fixed. from mpl_toolkits.mplot3d import Axes3D from matplotlib import cm if 1 == len(x.shape): x = x[:, None] if 1 == len(y.shape): y = y[None, :] if x.shape != z.shape: x = x + 0*y if y.shape != z.shape: y = y + 0*x assert z.shape == x.shape assert z.shape == y.shape assert np.allclose(np.diff(np.diff(x)), 0) assert np.allclose(np.diff(np.diff(y)), 0) kwargs = dict(**kw) fig = plt.gcf() ax = fig.gca(projection='3d') kw.setdefault('cmap', cm.jet) if wireframe: kw.setdefault('rstride', 10) kw.setdefault('cstride', 10) surf = ax.plot_wireframe(x, y, z, **kw) else: kw.setdefault('rstride', 1) kw.setdefault('cstride', 1) kw.setdefault('antialiased', False) surf = ax.plot_surface(x, y, z, **kw) #fig.colorbar(surf, shrink=0.5, aspect=5) plt.draw_if_interactive() return surf
[docs]def plot3dmpl(X, Y, Z, zmin=-np.inf, zmax=np.inf, xlabel=None, ylabel=None, abs_parts=False, **kw): r"""Use MayaVi2 to plot the surface. Parameters ---------- abs_parts : bool If `True`, the plot `abs(real)` and `-abs(imag)`. """ from mayavi import mlab def draw(z, kw=dict(kw), **_kw): _kw.update(kw) return mlab.surf(X, Y, np.maximum(np.minimum(z, zmax), zmin), **_kw) if np.any(np.iscomplex(Z)): if abs_parts: s = (draw(abs(Z.real), colormap='Greens', opacity=1.0), draw(-abs(Z.imag), colormap='Reds', opacity=1.0)) else: s = (draw(Z.real, colormap='Greens', opacity=0.5), draw(Z.imag, colormap='Reds', opacity=0.5)) else: s = draw(Z) mlab.axes() if xlabel: mlab.xlabel(xlabel) if ylabel: mlab.xlabel(ylabel) return s