r"""Plotting tools.
This module includes some additional tools for making publication
quality figures for LaTeX documents. The main goal is for good
control over sizing, fonts, etc.
To Do
-----
.. todo:: Fix this bug!
>>> from mmf.utils.mmf_plot import Figure
>>> fig = Figure()
>>> plt.plot([0,1],[1,1.118])
>>> fig.adjust()
This has something to do with the yticks. If they are set, then
things work fine:
>>> plt.clf()
>>> plt.plot([0,1],[1,1.118])
>>> plt.yticks(plt.yticks()[0])
>>> fig.adjust()
"""
from __future__ import division
import logging
import numpy as np
import scipy.stats
import scipy as sp
import matplotlib.collections
import matplotlib.artist
import matplotlib.pyplot as plt
from mmf.objects import StateVars, process_vars
from mmf.objects import Computed, ClassVar, Excluded
_FINFO = np.finfo(float)
_EPS = _FINFO.eps
[docs]def plot_errorbars(x,y,dx=None,dy=None,colour='',linestyle='',
pointstyle='.',barwidth=0.5,**kwargs):
if pointstyle != '' or linestyle != '':
plt.plot(x,y, pointstyle + linestyle + colour,**kwargs)
elif dx is None and dy is None:
# Plot points if both error bars are not drawn.
plt.plot(x,y, '.' + colour,**kwargs)
if dx is not None:
xmax = x + dx
xmin = x - dx
if dy is not None:
ymax = y + dy
ymin = y - dy
for n in xrange(len(x)):
if dx is not None:
plt.plot([xmin[n],xmax[n]],[y[n],y[n]],
'-|' + colour,lw=barwidth)
if dy is not None:
plt.plot([x[n],x[n]],[ymin[n],ymax[n]],
'-_' + colour,lw=barwidth)
[docs]def plot_err(x,y,yerr=None,xerr=None,**kwarg):
"""Plot x vs. y with errorbars.
Right now we support the following cases:
x = 1D, y = 1D
"""
if (1 == len(x.shape) and
1 == len(y.shape)):
plt.errorbar(x,y,yerr=yerr,xerr=xerr,**kwarg)
elif (1 == len(x.shape) and
1 < len(y.shape)):
plot_axis = np.where(np.array(y.shape) == len(x))[0][0]
y = y.swapaxes(0,plot_axis)
Nx,Ny = y.shape
for n in Ny:
plt.errorbar(x,y[:,n],**kwarg)
elif (max(x.shape) == np.prod(x.shape)):
plot_axis = np.argmax(x.shape)
x = x.ravel()
y = y.swapaxes(0,plot_axis)
if yerr is not None:
yerr = yerr.swapaxes(0,plot_axis)
Nx,Ny = y.shape
for n in xrange(Ny):
if yerr is None:
plt.errorbar(x,y[:,n],xerr=xerr,**kwarg)
else:
plt.errorbar(x,y[:,n],xerr=xerr,yerr=yerr[:,n],**kwarg)
else:
plt.plot(x,y,**kwarg)
[docs]def error_line(x,y,dy,fgc='k',bgc='w',N=20,fill=True):
"""Plots a curve (x,y) with gaussian errors dy represented by
shading out to 5 dy."""
yp0 = y
ym0 = y
pdf = sp.stats.norm().pdf
to_rgb = plt.matplotlib.colors.ColorConverter().to_rgb
bg_colour = np.array(to_rgb(bgc))
fg_colour = np.array(to_rgb(fgc))
for sigma in np.linspace(0,5,N)[1:]:
yp = y+dy*sigma
ym = y-dy*sigma
c = pdf(sigma)/pdf(0.0)
#colour = fg_colour*c + (1.0-c)*bg_colour
colour = fg_colour
if fill:
X = np.hstack((x, np.flipud(x)))
Y = np.hstack((yp0, np.flipud(yp)))
plt.fill(X,Y,fc=colour,ec=colour,lw=0,alpha=c)
X = np.hstack((x, np.flipud(x)))
Y = np.hstack((ym0,np.flipud(ym)))
plt.fill(X,Y,fc=fg_colour,ec=fg_colour,lw=0,alpha=c)
else:
plt.plot(x,yp,color=colour,alpha=c)
plt.plot(x,ym,color=fg_colour*c+(1.0-c)*bg_colour)
ym0 = ym
yp0 = yp
[docs]class LaTeXPlotProperties(StateVars):
r"""Instances of this class provide a description of properties of
a plot based on numbers extracted from a LaTeX file. Insert the
following code into the section where the plot is to appear in
order to extract the appropriate parameters and then use the
reported values to initialize this class::
\showthe\textwidth
\showthe\columnwidth
\showthe\baselinskip
.. note:: We assume that the document is typeset using the
Computer Modern fonts.
"""
_state_vars = [
('textwidth_pt', 332.89723, "From LaTeX \showthe\textwidth"),
('columnwidth_pt', 332.89723, "From LaTeX \showthe\columnwidth"),
('baselineskip_pt', 12.0, "From LaTeX \showthe\baselineskip"),
('tick_fontsize', 'footnotesize',
"Ticks etc. will be typeset in this font"),
('usetex', True,\
r"""If `True`, then LaTeX will be used to typeset labels
etc. Otherwise, labels etc. will be left as plain text
that can be replaced with the ``\psfrag{}{}`` command in
the LaTeX file.
As of matplotlib version 1.0.1, psfrag replacements do not work, so
the default is now to use LaTeX."""),
# The following are "constants" that you should typically not
# have to adjust unless you use a different font package.
('font_info', {'euler': ('zeur', '\usepackage{eulervm}')}),
('font', {'family': 'serif',
'serif': ['computer modern roman'],
'sans-serif': ['computer modern sans serif'],
'monospace': ['computer modern typewriter']},
r"`dict` of args passed to `matplotlib.rc('font')"),
('font', {'family': 'serif',
'serif': ['palatino'],
'sans-serif': ['bera sans serif'],
'monospace': ['computer modern typewriter']},
r"`dict` of args passed to `matplotlib.rc('font')"),
('latex_preamble', [r"\usepackage{amsmath}"],
r"""List of strings to add to LaTeX preamble. Add any
``\usepackage{}`` commands here.
.. note:: Don't forget to use raw strings to prevent
escaping of characters. Thus use something like the
default value: `[r"\usepackage{amsmath}"]`"""),
('latex_preview', True, "If `True`, use LaTeX preview package"),
('golden_mean', (np.sqrt(5) - 1)/2),
('font_size_pt', 10),
('font_factors', {'small': 9/10,
'footnotesize': 8/10},\
"""Font size reduction factors for latex fonts."""),
# Some units. These can appear in expressions.
('inches_per_pt', 1.0/72.27),
('inches', 1.0),
'pt=inches_per_pt',
('textwidth', Computed),
('columnwidth', Computed),
('baselineskip', Computed),
('tick_font', Computed),
]
process_vars()
def __init__(self, *v, **kw):
self.textwidth = self.textwidth_pt*self.inches_per_pt
self.columnwidth = self.columnwidth_pt*self.inches_per_pt
self.baselineskip = self.baselineskip_pt*self.inches_per_pt
self.tick_font = self.font_size_pt*self.font_factors[self.tick_fontsize]
def initialize_matplotlib(self):
r""":class:`Figure` calls this."""
matplotlib.rc('text', usetex=self.usetex)
matplotlib.rc('font', **self.font)
matplotlib.rc('text.latex',
preamble=self.latex_preamble,
preview=self.latex_preview,
)
matplotlib.rc('font', size=self.font_size_pt)
# Use TT fonts
matplotlib.rc('ps', fonttype=42)
# Default global instance.
_PLOT_PROPERTIES = LaTeXPlotProperties()
[docs]def xticks(ticks):
"""Replace ticks with real text so psfrag works. There
was an API change somewhere along the line that broke this..."""
plt.xticks(ticks,ticks)
[docs]class Figure(StateVars):
r"""This class represents a single figure and allows customization
of properties, as well as providing plotting facilities.
Notes
-----
Units are either pts (for fonts) or inches (for linear
measurements).
Examples
--------
Here is an example of a figure suitable for a half of a page in a
normal LaTeX book. First we run the following file through
LaTeX::
\documentclass{book}
\begin{document}
\showthe\textwidth
\showthe\columnwidth
\showthe\baselineskip
\end{document}
This gives::
> 345.0pt.
l.3 \showthe\textwidth
?
> 345.0pt.
l.4 \showthe\columnwidth
?
> 12.0pt.
l.5 \showthe\baselineskip
.. plot::
:include-source:
x = np.linspace(0,1.01,100)
y = np.sin(x)
plot_prop = LaTeXPlotProperties(textwidth_pt=345.0,
columnwidth_pt=345.0,
baselineskip_pt=12.0)
fig = Figure(filename='tst_book.eps',
width='0.5*textwidth',
plot_properties=plot_prop)
plt.plot(x, y, label="r'\sin(x)'")
plt.axis([-0.02,1.02,-0.02,1.02])
plt.ylabel(
r'$\int_{0}^{x}\left(\frac{\cos(\tilde{x})}{1}\right)d{\tilde{x}}$')
#fig.savefig()
Here is another example using a two-column article::
\documentclass[twocolumn]{article}
\begin{document}
\showthe\textwidth
\showthe\columnwidth
\showthe\baselineskip
\end{document}
This gives::
> 469.0pt.
l.3 \showthe\textwidth
?
> 229.5pt.
l.4 \showthe\columnwidth
?
> 12.0pt.
l.5 \showthe\baselineskip
.. plot::
:include-source:
x = np.linspace(0,1.01,100)
y = np.sin(x)
plot_prop = LaTeXPlotProperties(textwidth_pt=489.0,
columnwidth_pt=229.5,
baselineskip_pt=12.0)
fig = Figure(filename='tst_article.eps',
plot_properties=plot_prop)
plt.plot(x, y, label="r'\sin(x)'")
plt.axis([-0.02,1.02,-0.02,1.02])
plt.ylabel(
r'$\int_{0}^{x}\left(\frac{\cos(\tilde{x})}{1}\right)d{\tilde{x}}$')
#fig.savefig()
"""
_state_vars = [
('num', None, "Figure number"),
('filename', None, "Filename for figure."),
('width', 'columnwidth',\
r"Expression involving 'columnwidth' and/or 'textwidth'"),
('height', 1.0, r"Fraction of `golden_mean*width`"),
('plot_properties', None),
('axes_dict', dict(labelsize='medium')),
('tick_dict', dict(labelsize='small')),
('legend_dict', dict(fontsize='medium')),
('margin_factors', dict(top=0.5,
left=2.8,
bot=3,
right=0.5),
"""These allocate extra space for labels etc."""),
('autoadjust', False, r"""Attempt to autoadjust for labels,
otherwise you can do this manually by calling
:meth:`adjust`."""),
('figure_manager', Computed),
('_inset_axes', Computed,
"Set of axes to be excluded from adjustment"),
('figures', ClassVar({}), "Dictonary of computed figures."),
('on_draw_id', Excluded(None),\
"Id associated with 'on_draw' event"),
('footnotesize', Computed, "Size matching corresponding LaTeX font"),
('small', Computed, "Size matching corresponding LaTeX font"),
]
process_vars()
def __init__(self, *v, **kw):
if self.plot_properties is None:
self.plot_properties = _PLOT_PROPERTIES
pp = self.plot_properties
pp.initialize_matplotlib()
self._inset_axes = set()
for _size in pp.font_factors:
setattr(self, _size, pp.font_size_pt*pp.font_factors[_size])
if 'num' in kw or 'filename' in kw:
width = eval(self.width, pp.__dict__)
fig_width = width
fig_height = self.height*width*pp.golden_mean
size = pp.font_size_pt*pp.inches_per_pt
# top space = 1/2 font
space_top = self.margin_factors['top']*size
space_left = self.margin_factors['left']*size
space_bottom = self.margin_factors['bot']*size
space_right = self.margin_factors['right']*size
# Compute axes size:
axes_left = space_left/fig_width
axes_bottom = space_bottom/fig_height
axes_width = 1.0 - (space_left + space_right)/fig_width
axes_height = 1.0 - (space_bottom + space_top)/fig_height
axes_size = [axes_left,axes_bottom,
axes_width,axes_height]
plt.rc('font', size=pp.font_size_pt)
plt.rc('axes', **self.axes_dict)
plt.rc('xtick', **self.tick_dict)
plt.rc('ytick', **self.tick_dict)
plt.rc('legend', **self.legend_dict)
plt.figure(
num=self.num,
figsize=(fig_width, fig_height))
self.figure_manager = plt.get_current_fig_manager()
self.num = self.figure_manager.num
self.figures[self.num] = self.figure_manager
plt.clf()
a = plt.axes(axes_size)
if self.autoadjust and False:
# This makes the axis full frame. Use adjust to shrink.
a.set_position([0,0,1,1])
self.start_adjusting()
elif False:
self.stop_adjusting
def activate(self):
return plt.figure(self.num)
def start_adjusting(self):
if self.on_draw_id:
self.figure_manager.canvas.mpl_disconnect(self.on_draw_id)
self.on_draw_id = self.figure_manager.canvas.mpl_connect(
'draw_event', self.on_draw)
def stop_adjusting(self):
if self.on_draw_id:
self.figure_manager.canvas.mpl_disconnect(self.on_draw_id)
self.on_draw_id = 0
def new_inset_axes(self, rect):
r"""Return a new axes set inside the main axis.
Parameters
----------
rect : [left, bottom, width or right, height or top]
This is the rectangle for the new axes (the labels etc. will be
outside). Coordinates may be either floating point numbers which
specify the location of the inset in terms of a fraction between 0
and 1 of the current axis.
One may also specify the coordinates in the data units of the actual
corners by specifying the data as an imaginary number. This will be
transformed into relative axis coordinates using the current axis
limits (the subplot will not subsequently move). (Not implemented
yet.)
"""
ax = plt.axes(rect)
self._inset_axes.add(ax)
return ax
def axis(self, *v, **kw):
r"""Wrapper for :func:`pyplot.axis` function that applies the
transformation to each axis (useful if :func:`pyplot.twinx` or
:func:`pyplot.twiny` has been used)."""
fig = self.figure_manager.canvas.figure
for _a in fig.axes:
_a.axis(*v,**kw)
def adjust(self, full=True,
padding=0.05):
r"""Adjust the axes so that all text lies withing the figure.
Optionally, add some padding."""
plt.ioff()
plt.figure(self.num)
if full:
# Reset axis to full size.
fig = self.figure_manager.canvas.figure
for _a in fig.axes:
_a.set_position([0,0,1,1])
on_draw_id = self.figure_manager.canvas.mpl_connect(
'draw_event', self.on_draw)
try:
plt.ion()
plt.draw()
except:
raise
finally:
pass
self.figure_manager.canvas.mpl_disconnect(on_draw_id)
adjustable_axes = [_a for _a in fig.axes
if _a not in self._inset_axes]
if 0 < padding:
for _a in adjustable_axes:
bb_a = _a.get_position()
dx = bb_a.width*padding/2
dy = bb_a.height*padding/2
bb_a.x0 += dx
bb_a.x1 -= dx
bb_a.y0 += dy
bb_a.y1 -= dy
bb_a = _a.set_position(bb_a)
@staticmethod
def _shrink_bb(bb, factor=_EPS):
r"""Shrink the bounding box bb by factor in order to prevent unneeded
work due to rounding."""
p = bb.get_points()
p += factor*(np.diff(p)*np.array([1, -1])).T
bb.set_points(p)
return bb
def _adjust(self,
logger=logging.getLogger("mmf.utils.mmf_plot.Figure._adjust")):
r"""Adjust the axes to make sure all text is inside the box."""
fig = self.figure_manager.canvas.figure
bb_f = fig.get_window_extent().inverse_transformed(fig.transFigure)
logger.debug("Fig bb %s" % (" ".join(str(bb_f).split()),))
texts = []
adjustable_axes = [_a for _a in fig.axes
if _a not in self._inset_axes]
for _a in adjustable_axes:
texts.extend(_a.texts)
texts.append(_a.title)
texts.extend(_a.get_xticklabels())
texts.extend(_a.get_yticklabels())
texts.append(_a.xaxis.get_label())
texts.append(_a.yaxis.get_label())
bboxes = []
for t in texts:
if not t.get_text():
# Ignore empty text!
continue
bbox = t.get_window_extent()
# the figure transform goes from relative
# coords->pixels and we want the inverse of that
bboxi = bbox.inverse_transformed(fig.transFigure)
bboxes.append(bboxi)
# this is the bbox that bounds all the bboxes, again in
# relative figure coords
bbox = self._shrink_bb(matplotlib.transforms.Bbox.union(bboxes))
adjusted = False
if not np.all([bb_f.contains(*c) for c in bbox.corners()]):
# Adjust axes position
for _a in adjustable_axes:
bb_a = _a.get_position()
logger.debug("Text bb %s"
% (" ".join(str(bbox).split()),))
logger.debug("Axis bb %s"
% (" ".join(str(bb_a).split()),))
bb_a.x0 += max(0, bb_f.xmin - bbox.xmin)
bb_a.x1 += min(0, bb_f.xmax - bbox.xmax)
bb_a.y0 += max(0, bb_f.ymin - bbox.ymin)
bb_a.y1 += min(0, bb_f.ymax - bbox.ymax)
logger.debug("New bb %s"
% (" ".join(str(bb_a).split()),))
_a.set_position(bb_a)
adjusted = True
return adjusted
def on_draw(self, event, _adjusting=[False]):
"""We register this to perform processing after the figure is
drawn, like adjusting the margins so that the labels fit."""
fig = self.figure_manager.canvas.figure
logger = logging.getLogger("mmf.utils.mmf_plot.Figure.on_draw")
if _adjusting[0]:
# Don't recurse!
return
if event is None:
# If called interactively...
import pdb;pdb.set_trace()
_adjusting[0] = True
try:
_max_adjust = 10
adjusted = False
for _n in xrange(_max_adjust):
adjusted = self._adjust(logger=logger)
if adjusted:
fig.canvas.draw()
else:
break
if adjusted:
# Even after _max_adjust steps we still needed adjusting:
logger.warn("Still need adjustment after %i steps"
% (_max_adjust,))
finally:
_adjusting[0] = False
def adjust_axis(self,extents=None,
xl=None,xh=None,yl=None,yh=None,
extend_x=0.0,extend_y=0.0):
if extents is not None:
plt.axis(extents);
xl_,xh_,yl_,yh_ = plt.axis()
if xl is not None:
xl_ = xl
if xh is not None:
xh_ = xh
if yl is not None:
yl_ = yl
if yh is not None:
yh_ = yh
plt.axis([xl_,xh_,yl_,yh_]);
dx = extend_x*(xh_ - xl_)
dy = extend_y*(yh_ - yl_)
return plt.axis([xl_ - dx,xh_ + dx,
yl_ - dy,yh_ + dy])
def savefig(self, filename=None):
if not filename:
filename = self.filename
print("Saving plot as %r..."%(filename,))
plt.figure(self.num)
plt.ion() # Do this to ensure autoadjustments
plt.draw() # are made!
plt.savefig(filename)
print("Saving plot as %r. Done."%(filename,))
def __del__(self):
"""Destructor: make sure we unregister the autoadjustor."""
self.autoadjust = False
# Here we monkeypath mpl_toolkits.axes_grid.inset_locator to allow for
# independent x and y zoom factors.
[docs]def monkey_patch_inset_locator():
from mpl_toolkits.axes_grid.inset_locator import AnchoredZoomLocator
import matplotlib.transforms
def get_extent(self, renderer):
bb = matplotlib.transforms.TransformedBbox(
self.axes.viewLim, self.parent_axes.transData)
x, y, w, h = bb.bounds
xd, yd = 0, 0
fontsize = renderer.points_to_pixels(self.prop.get_size_in_points())
pad = self.pad * fontsize
wh = np.array([w, h])
return tuple((wh*self.zoom + 2*pad).tolist() + [xd + pad, yd + pad])
AnchoredZoomLocator.get_extent = get_extent
[docs]class ListCollection(matplotlib.collections.Collection):
r"""Provide a simple :class:`matplotlib.collections.Collection` of a list of
artists. Provided so that this collection of artists can be simultaneously
rasterized. Used by my custom :func:`contourf` function."""
[docs] def __init__(self, collections, **kwargs):
matplotlib.collections.Collection.__init__(self, **kwargs)
self.set_collections(collections)
@matplotlib.artist.allow_rasterization
[docs]def contourf(*v, **kw):
r"""Replacement for :func:`matplotlib.pyplot.contourf` that supports the
`rasterized` keyword."""
was_interactive = matplotlib.is_interactive()
matplotlib.interactive(False)
contour_set = plt.contourf(*v, **kw)
for _c in contour_set.collections:
_c.remove()
collection = ListCollection(
contour_set.collections,
rasterized=kw.get('rasterized', None))
ax = plt.gca()
ax.add_artist(collection)
matplotlib.interactive(was_interactive)
return contour_set
[docs]def imcontourf(x, y, z, contours=None, *v, **kw):
r"""Like :func:`matplotlib.pyplot.contourf` but does not actually find
contours. Just displays `z` using :func:`matplotlib.pyplot.imshow` which is
much faster and uses exactly the information available.
Parameters
----------
x, y, z : array-like
Assumes that `z` is ordered as `z[x,y]`. If `x` and `y` have the same
shape as `z`, then `x = x[:,0]` and `y = y[0,:]` are used. Otherwise,
`z.shape == (len(x), len(y))`. `x` and `y` must be equally spaced.
"""
x, y, z = map(np.asarray, (x, y, z))
if x.shape == z.shape:
x = x[:,0]
if y.shape == z.shape:
y = y[0,:]
assert z.shape[:2] == (len(x), len(y))
assert np.allclose(np.diff(np.diff(x)), 0)
assert np.allclose(np.diff(np.diff(y)), 0)
kwargs = dict(**kw)
kwargs.update(dict(aspect='auto'))
return plt.imshow(np.rollaxis(z,0,2), origin='lower',
extent=(x[0], x[-1], y[0], y[-1]), *v, **kwargs)
[docs]def plot3d(x, y, z, wireframe=False,
**kw):
r"""Wrapper to generate 3d surface plots."""
# Move these out once fixed.
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
if 1 == len(x.shape):
x = x[:, None]
if 1 == len(y.shape):
y = y[None, :]
if x.shape != z.shape:
x = x + 0*y
if y.shape != z.shape:
y = y + 0*x
assert z.shape == x.shape
assert z.shape == y.shape
assert np.allclose(np.diff(np.diff(x)), 0)
assert np.allclose(np.diff(np.diff(y)), 0)
kwargs = dict(**kw)
fig = plt.gcf()
ax = fig.gca(projection='3d')
kw.setdefault('cmap', cm.jet)
if wireframe:
kw.setdefault('rstride', 10)
kw.setdefault('cstride', 10)
surf = ax.plot_wireframe(x, y, z, **kw)
else:
kw.setdefault('rstride', 1)
kw.setdefault('cstride', 1)
kw.setdefault('antialiased', False)
surf = ax.plot_surface(x, y, z, **kw)
#fig.colorbar(surf, shrink=0.5, aspect=5)
plt.draw_if_interactive()
return surf
[docs]def plot3dmpl(X, Y, Z, zmin=-np.inf, zmax=np.inf,
xlabel=None, ylabel=None,
abs_parts=False, **kw):
r"""Use MayaVi2 to plot the surface.
Parameters
----------
abs_parts : bool
If `True`, the plot `abs(real)` and `-abs(imag)`.
"""
from mayavi import mlab
def draw(z, kw=dict(kw), **_kw):
_kw.update(kw)
return mlab.surf(X, Y, np.maximum(np.minimum(z, zmax), zmin), **_kw)
if np.any(np.iscomplex(Z)):
if abs_parts:
s = (draw(abs(Z.real), colormap='Greens', opacity=1.0),
draw(-abs(Z.imag), colormap='Reds', opacity=1.0))
else:
s = (draw(Z.real, colormap='Greens', opacity=0.5),
draw(Z.imag, colormap='Reds', opacity=0.5))
else:
s = draw(Z)
mlab.axes()
if xlabel: mlab.xlabel(xlabel)
if ylabel: mlab.xlabel(ylabel)
return s