Source code for mmf.utils.mmf_sympy
r"""Tools for using sympy to generate code."""
import numpy as np
import math
import sympy
_ENVS = dict(math=dict(math.__dict__, math=math),
numpy=dict(np.__dict__, numpy=np, np=np))
[docs]def lambdify_cse(var, expr, modules='numpy', env={}):
r"""Generates a function that evaluates 'expr' but uses cse to optimize
common subexpressions."""
siderels, expr = sympy.cse(expr)
if 1 == len(expr):
expr = expr[0]
code = "\n".join(
[r"def f({var}):".format(var=str(var))] +
[r" {v} = {e}".format(v=str(_v), e=str(_e))
for _v, _e in siderels] +
[r" return {0}".format(str(expr))])
if modules in _ENVS:
env = dict(_ENVS[modules])
else:
env = {}
for _m in modules:
env.update(_m.__dict__)
env.update(env)
exec code in env
return env['f']