Homework 1#
Instructions: Download this jupyer notebook (button at the top of the page) and complete the code where needed. Once you have completed the code, run the entire notebook, save a PDF and submit it on Canvas.
import abc
from typing import Callable
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import functools
import cvxpy as cp
Problem 1#
In this problem, you will familiarize yourself with the basic operations with dynamical systems and useful JAX operations.
Given the following dynamics class below.
class Dynamics(metaclass=abc.ABCMeta):
dynamics_func: Callable
state_dim: int
control_dim: int
def __init__(self, dynamics_func, state_dim, control_dim):
self.dynamics_func = dynamics_func
self.state_dim = state_dim
self.control_dim = control_dim
def __call__(self, state, control, time=0):
return self.dynamics_func(state, control, time)
(a) Setting up dynamics#
Using the Dynamics
class, construct the continuous time dynamics for the dynamically extended unicycle model.
def dynamic_unicycle_ode(state, control, time):
x, y, theta, v = state
omega, a = control
x_dot = v * jnp.cos(theta)
y_dot = v * jnp.sin(theta)
theta_dot = omega
v_dot = a
return jnp.array([x_dot, y_dot, theta_dot, v_dot])
# '''
# fill code here
# '''
# raise NotImplementedError # remove this
state_dim = 4
control_dim = 2
continuous_dynamics = Dynamics(dynamic_unicycle_ode, state_dim, control_dim)
(b) Obtaining discrete-time dynamics#
With the continuous time dynamics, we can obtain the discrete time dynamics by integrating over a time step \(\Delta t\).
Implement both Euler integation and Runge-Kutta integration to obtain the discrete-time dynamics.
def euler_integrate(dynamics, dt):
# zero-order hold
def integrator(x, u, t):
x_next = x + dt * dynamics(x, u, t)
return x_next
return integrator
# '''
# fill code here
# '''
# raise NotImplementedError # remove this
# return integrator
def runge_kutta_integrator(dynamics, dt=0.1):
# zero-order hold
def integrator(x, u, t):
dt2 = dt / 2.0
k1 = dynamics(x, u, t)
k2 = dynamics(x + dt2 * k1, u, t + dt2)
k3 = dynamics(x + dt2 * k2, u, t + dt2)
k4 = dynamics(x + dt * k3, u, t + dt)
return x + (dt / 6.0) * (k1 + 2 * k2 + 2 * k3 + k4)
return integrator
# '''
# fill code here
# '''
# raise NotImplementedError # remove this
# return integrator
dt = 0.1
discrete_dynamics_euler = Dynamics(
euler_integrate(continuous_dynamics, dt), state_dim, control_dim
)
discrete_dynamics_rk = Dynamics(
runge_kutta_integrator(continuous_dynamics, dt), state_dim, control_dim
)
(c) Simulating dynamics#
Simulate your dynamics over 5 seconds for different values of \(\Delta t\) and compare the trajectories.
Show on the same plot, the simulated trajectories for the following cases:
Discrete-time dynamics with Euler integration, \(\Delta t = 0.01\)
Discrete-time dynamics with Euler integration, \(\Delta t = 0.5\)
Discrete-time dynamics with RK integration, \(\Delta t = 0.01\)
Discrete-time dynamics with RK integration, \(\Delta t = 0.5\)
How does the choice of integration scheme and time step size influence the resulting trajectories? Feel free to try different time step values, but you don’t need to submit plots of them.
def simulate(dynamics, initial_state, controls, dt):
xs = [initial_state]
time = 0
for u in controls:
xs.append(dynamics(xs[-1], u, time))
time += dt
return jnp.stack(xs)
# '''
# fill code here
# '''
# raise NotImplementedError # remove this
initial_state = jnp.array([0.0, 0.0, 0.0, 0.0])
control = jnp.array([2.0, 1.0]) # constant control over the 5 second duration.
duration = 5.0
dts = [0.01, 0.5]
for dt in dts:
num_steps = int(duration / dt)
controls = [control] * num_steps
discrete_dynamics_euler = Dynamics(
euler_integrate(continuous_dynamics, dt), state_dim, control_dim
)
discrete_dynamics_rk = Dynamics(
runge_kutta_integrator(continuous_dynamics, dt), state_dim, control_dim
)
xs_euler = simulate(discrete_dynamics_euler, initial_state, controls, dt)
xs_rk = simulate(discrete_dynamics_rk, initial_state, controls, dt)
plt.plot(xs_euler[:, 0], xs_euler[:, 1], label=f"dt = {dt} Euler")
plt.plot(xs_rk[:, 0], xs_rk[:, 1], label=f"dt = {dt} RK")
plt.legend()
plt.grid(alpha=0.4)
plt.axis("equal")
plt.xlabel("x [m]")
plt.ylabel("y [m]")
# add your code here to simulate the system using the euler integrator
# and the runge-kutta integrator for the three different time steps.
# compare the results and plot the trajectories.
# please clearly label the plots and the results.
Text(0, 0.5, 'y [m]')

(d) jax.vmap
#
Suppose now that you want to simulate many trajectories. Rather than wrapping the simulate
function in a for loop, we can use jax.vmap
which is a vectorize map function, allowing us to apply a function, in this case simulate
over multiple inputs.
An example usage of how to use the jax.vmap
is shown below. Notice that we can specify which argument should be vectorized and along which dimension.
def foo(x, y, z):
return x + y + z
N = 1000
x = jnp.array(np.random.randn(N))
y = jnp.array(np.random.randn(N))
z = jnp.array(np.random.randn(N))
xs = jnp.array(np.random.randn(N, N))
ys = jnp.array(np.random.randn(N, N))
zs = jnp.array(np.random.randn(N, N))
foo(x, y, z) # non-vectorized version
# vectorized version for all inputs, 0 is the batch dimension for all inputs
jax.vmap(foo, in_axes=[0, 0, 0])(xs, ys, zs)
# x not batched, but ys and zs are with 0 as the batch dimension
jax.vmap(foo, in_axes=[None, 0, 0])(x, ys, zs)
# y not batched, but xs and zs are with 0 as the batch dimension
jax.vmap(foo, in_axes=[0, None, 0])(xs, y, zs)
# z not batched, but xs and ys are with 0 as the batch dimension
jax.vmap(foo, in_axes=[0, 0, None])(xs, ys, z)
# x and y not batched, but zs is with 0 as the batch dimension
jax.vmap(foo, in_axes=[None, None, 0])(x, y, zs)
# vectorized version for all inputs, batch dimension for xs is 1,
# while 0 is the batch dimension for yx and zs
jax.vmap(foo, in_axes=[1, 0, 0])(xs, ys, zs)
Array([[ 0.17315197, 0.03009735, 2.483957 , ..., 2.8294091 ,
-2.7555707 , -2.615646 ],
[-1.0296735 , 3.0838337 , -1.2769936 , ..., 2.6309218 ,
-2.6070423 , -2.6122117 ],
[ 0.87411404, 3.003016 , 0.3823322 , ..., 1.6262023 ,
-1.1618747 , 0.95355105],
...,
[ 0.11168297, 1.4064889 , -1.4427247 , ..., -0.26990697,
-0.7438707 , 3.2980409 ],
[ 1.5473831 , -1.2557771 , 0.15804249, ..., -1.0119408 ,
2.419374 , -2.6637568 ],
[-0.09277616, 1.7642998 , 1.6531329 , ..., -1.4141046 ,
-0.7107302 , -0.62759954]], dtype=float32)
Apply jax.vmap
for the simulate function for the following batch of initial states and control inputs
Choose \(\Delta t = 0.1\).
Use the following values and simulate multiple trajectories using the jax.vmap
function.
state_dim = continuous_dynamics.state_dim
control_dim = continuous_dynamics.control_dim
N = 1000
n_time_steps = 50
initial_states = jnp.array(np.random.randn(N, state_dim))
controls = jnp.array(np.random.randn(N, n_time_steps, control_dim))
trajs = jax.vmap(simulate, in_axes=[None, 0, 0, None])(
discrete_dynamics_rk, initial_states, controls, jnp.array(dt)
)
[plt.plot(tj[:, 0], tj[:, 1], alpha=0.3) for tj in trajs]
plt.grid(alpha=0.3)
plt.axis("equal")
(np.float64(-136.747652053833),
np.float64(91.89095497131348),
np.float64(-127.5980613708496),
np.float64(90.67414169311523))

(e) jax.jit
(optional reading)#
Bleh! You notice that it takes some time to run it. And if you increased the duration or number of trajectories to simulate, the computation would increase.
If only we could compile the code to help reduce computation time. With JAX, you can! We can use the jax.jit
function that performs just-in-time compilation. JAX will figure out the expected sizes of the input arrays and allocate memory based on that.
There are number of ways to just jax.jit
, and it can get a bit tricky as your code becomes more complex. Best to read up the JAX documentation for more information.
But for relatively simple functions, you can usually just apply jax.jit
without any fuss, and get significant speedup in your code.
# without jitting
%timeit jax.vmap(simulate, in_axes=[None, 0, 0, None])(discrete_dynamics_rk, initial_states, controls, jnp.array(dt)).block_until_ready()
462 ms ± 27 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# method 1: directly apply jax.jit over the jax.vmap function
# need to provide the static_argnums argument to the first argument since that is a function input and not an array input
sim_jit = jax.jit(jax.vmap(simulate, in_axes=[None, 0, 0, None]), static_argnums=0)
# time the run
%timeit sim_jit(discrete_dynamics_rk, initial_states, controls, jnp.array(dt)).block_until_ready()
13.4 ms ± 279 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# method 2: apply jax.jit over the simulate function and then apply jax.vmap
sim_jit = jax.jit(simulate, static_argnums=0)
sim_jit_vmap = jax.vmap(sim_jit, in_axes=[None, 0, 0, None])
%timeit sim_jit_vmap(discrete_dynamics_rk, initial_states, controls, jnp.array(dt)).block_until_ready()
13.5 ms ± 341 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# Method 3: apply jax.jit over the simulate function during function construction and then apply jax.vmap
@functools.partial(jax.jit, static_argnames=("dynamics"))
def simulate(dynamics, initial_state, controls, dt):
xs = [initial_state]
time = 0
for u in controls:
xs.append(dynamics(xs[-1], u, time))
time += dt
return jnp.stack(xs)
sim_jit_vmap = jax.vmap(simulate, in_axes=[None, 0, 0, None])
%timeit sim_jit_vmap(discrete_dynamics_rk, initial_states, controls, jnp.array(dt)).block_until_ready()
15.5 ms ± 294 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
Problem 2#
We continue to consider the dynamically-extended unicycle model and investigate a way to linearize the dynamics around any state. First, we will perform the linearization analytically, and then leverage modern computation tools which will be incredibly helpful especially if the dynamics are complicated! We can efficiently compute gradients via automatic differentiation. JAX is an automatic differentiation library.
(a) Linearize dynamics analytically#
Linearize the dynamics given in Problem 1 part (a) about a point \((\mathbf{x}_0, \mathbf{u}_0)\). That is, for linearized dynamics of the form \(\dot{\mathbf{x}} \approx A\mathbf{x}+ B\mathbf{u} + C\), give expressions for \(A\), \(B\), and \(C\).
\(A=\ldots\)
\(B=\ldots\)
\(C=\ldots\)
Also code up your analytic expression in linearize_unicycle_continuous_time_analytic
.
def linearize_unicycle_continuous_time_analytic(state, control, time):
'''
Linearizes the continuous time dynamics of the dynamic unicyle using analytic expression
Inputs:
state : A jax.numpy array of size (n,)
control : A jax.numpy array of size (m,)
time : A real scalar
Outputs:
A : A jax.numpy array of size (n,n)
B : A jax.numpy array of size (n,m)
C : A jax.numpy array of size (n,1)
'''
x, y, theta, v = state
w, a = control
A = jnp.array([[0.0, 0.0, -v*jnp.sin(theta), jnp.cos(theta)],
[0.0, 0.0, v*jnp.cos(theta), jnp.sin(theta)],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0]])
B = jnp.array([[0.0, 0.0],
[0.0, 0.0],
[1.0, 0.0],
[0.0, 1.0]])
C = dynamic_unicycle_ode(state, control, time) - A @ state - B @ control
return A, B, C
(b) Evaluate linearized dynamics (analytic)#
Using your answer from 2(a), evaluate \(A\), \(B\), and \(C\) for \(\mathbf{x}_0 = [0, 0, \frac{\pi}{4}, 2.]^T\) and \(\mathbf{u}_0 = [0.1, 1.]^T\). Give your answer to 2 decimal places.
state = jnp.array([0.0, 0.0, jnp.pi/4, 2.])
control = jnp.array([0.1, 1.])
time = 0.
A, B, C = linearize_unicycle_continuous_time_analytic(state, control, time)
print("A:\n", A)
print("B:\n", B)
print("C:\n", C)
A:
[[ 0. 0. -1.4142137 0.70710677]
[ 0. 0. 1.4142135 0.7071068 ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]]
B:
[[0. 0.]
[0. 0.]
[1. 0.]
[0. 1.]]
C:
[ 1.1107209 -1.1107206 0. 0. ]
(c) Linearize dynamics using JAX autodiff#
Time to test out Jax’s autodifferentiation capabilities! JAX has an Autodiff Cookbook that provides more details about the various autodiff functions, forward vs backward autodiff, jacboians, hessians, and so forth. You are strongly encouraged read through it.
Using Jax and its built-in jax.jacobian
function, fill in the linearize_autodiff
function that takes in a dynamics function, and a state and control to linearize about, and returns the \(A\), \(B\), and \(C\) matrices describing the linearized dynamics. Test your function using the continuous-time dynamics with \(\mathbf{x}_0 = [0.0, 0.0, \frac{\pi}{4}, 2.0]^T\) and \(\mathbf{u}_0 = [0.1, 1.0]^T\) and use the provided test code to verify that the outputs you get from your function are the same as the values you get from linearize_unicycle_continuous_time_analytic
.
Side note for the curious: the jnp.allclose
function tests if all corresponding elements of two arrays are within a small tolerance of each other. When working with finite-precision machine arithmetic, you can almost never test two numbers for exact equality directly, because different rounding errors in different computations very often result in very slightly different values even when the two calculations should theoretically result in the same number. For this reason, real numbers in software (which on almost all modern hardware are represented in IEEE 754 floating-point format) are usually considered to be equal if they are close enough that their difference could be reasonably explained by rounding errors.
def linearize_autodiff(function_name, state, control, time):
'''
Linearizes the any dynamics using jax autodiff.
Inputs:
function_name: name of function to be linearized. Takes state, control, and time as inputs.
state : A jax.numpy array of size (n,); the state to linearize about
control : A jax.numpy array of size (m,); the control to linearize about
time : A real scalar; the time to linearize about
Outputs:
A : A jax.numpy array of size (n,n)
B : A jax.numpy array of size (n,m)
C : A jax.numpy array of size (n,1)
'''
A, B = jax.jacobian(function_name, [0, 1])(state, control)
C = function_name(state, control) - A @ state - B @ control
return A, B, C
# test code:
state = jnp.array([0.0, 0.0, jnp.pi/4, 2.])
control = jnp.array([0.1, 1.])
time = 0.
A_autodiff, B_autodiff, C_autodiff = linearize_autodiff(continuous_dynamics, state, control, time)
A_analytic, B_analytic, C_analytic = linearize_unicycle_continuous_time_analytic(state, control, time)
print('A matrices match:', jnp.allclose(A_autodiff, A_analytic))
print('B matrices match:', jnp.allclose(B_autodiff, B_analytic))
print('C matrices match:', jnp.allclose(C_autodiff, C_analytic))
A matrices match: True
B matrices match: True
C matrices match: True
(d) Linearize discrete-time dynamics#
Assuming your answer from 2(c) matched 2(b) and that you are convinced of the power of automatic differentiation, use your linearize_autodiff
function on discrete_dynamics_euler
and discrete_dynamics_rk
with \(\mathbf{x}_0 = [0.0, 0.0, \frac{\pi}{4}, 2.0]^T\) and \(\mathbf{u}_0 = [0.1, 1.0]^T\). (Imagine trying to differentiate the expressions analytically! It would be tedious!)
dt = 0.1
discrete_dynamics_euler = Dynamics(
euler_integrate(continuous_dynamics, dt), state_dim, control_dim
)
discrete_dynamics_rk = Dynamics(
runge_kutta_integrator(continuous_dynamics, dt), state_dim, control_dim
)
A_euler_autodiff, B_euler_autodiff, C_euler_autodiff = linearize_autodiff(discrete_dynamics_euler, state, control, time)
A_rk_autodiff, B_rk_autodiff, C_rk_autodiff = linearize_autodiff(discrete_dynamics_rk, state, control, time)
print("Euler integration")
print('A, Euler:')
print(A_euler_autodiff)
print('B, Euler:')
print(B_euler_autodiff)
print('C, Euler:')
print("\n\n Runge-Kutta integration")
print('A, RK:')
print(A_rk_autodiff)
print('B, RK:')
print(B_rk_autodiff)
print('C, RK:')
print(C_rk_autodiff)
Euler integration
A, Euler:
[[ 1. 0. -0.14142136 0.07071068]
[ 0. 1. 0.14142136 0.07071068]
[ 0. 0. 1. 0. ]
[ 0. 0. 0. 1. ]]
B, Euler:
[[0. 0. ]
[0. 0. ]
[0.1 0. ]
[0. 0.1]]
C, Euler:
Runge-Kutta integration
A, RK:
[[ 1. 0. -0.1456851 0.07035595]
[ 0. 1. 0.14422378 0.07106306]
[ 0. 0. 1. 0. ]
[ 0. 0. 0. 1. ]]
B, RK:
[[-0.00735549 0.00351188]
[ 0.00725768 0.00355902]
[ 0.10000001 0. ]
[ 0. 0.10000001]]
C, RK:
[ 1.1515637e-01 -1.1399886e-01 -1.0244548e-08 -1.0430813e-07]
(e) Applying vmap
to linearize over multiple points#
Now, try to linearize your dynamics over multiple state-control values using vmap
!
key = jax.random.PRNGKey(42) # Set a fixed seed
n_samples = 1000
state_dim = 4 # 4-dimensional state
ctrl_dim = 2 # 2-dimensional control
time = 0.0
random_states = jax.random.normal(key, shape=(n_samples, state_dim))
random_controls = jax.random.normal(key, shape=(n_samples, ctrl_dim))
As, Bs, Cs = jax.vmap(linearize_autodiff, in_axes=[None, 0, 0, None])(discrete_dynamics_rk, random_states, random_controls, time)
As.shape, Bs.shape, Cs.shape
((1000, 4, 4), (1000, 4, 2), (1000, 4))
Problem 3 Unconstrained optimization#
The goal of this problem is to introduce some basic optimization theory and understanding how automatic differentition can be used for solving unconstrained optimization problems (and essentially powering deep learning libraries under the hood).
In this problem, we will implement a simple version of the gradient descent algorithm to solve an unconstrained optimization problem. Specifically, you will learn about the log-barrier method which is used to cast a constrained optimization problem into an unconstrained one which can be easily solved using gradient descent.
Some background on optimization#
In mathematics, optimization is the process of finding which value of an argument to a function makes the output of that function as high or as low as possible. In other words, given a function \(f(x)\) (called the cost function or objective function), which value of \(x\) makes \(f(x)\) as high (or as low) as possible?
In controls, we often consider the objective as a cost function (e.g., total fuel or control effort, distance traveled) and thus aim to make \(f(x)\) as low as possible; in other words, we often only consider the problem of minimizing \(f(x)\). While in reinforcement learning, typically the objective is viewed as a reward, and therefore the aim is to maximize \(f(x)\). But note that this is just the same thing as minimizing \(-f(x)\), so maximization and minimization actually turn out to be basically the same mathematical problem, but the convention depends on the community.
Optimization is very useful in engineering; for example, if \(x\) is a variable that somehow represents the design of an airplane (maybe it’s a vector of design parameters such as wing aspect ratio, engine thrust, cargo space, etc), and if we can find a function \(f(x)\) that computes the fuel consumption of the final airplane as a function of the design parameters, than finding \(x\) which minimizes \(f(x)\) amounts to designing the most fuel-efficient airplane possible under the design assumptions. We often denote the value of \(x\) which optimizes \(f(x)\) as \(x^*\), and the optimal value of \(f(x)\) is then \(f(x^*)\).
You may remember from undergrad calculus that we can minimize \(f(x)\) by taking its derivative \(f'(x)\), setting \(f'(x^*) = 0\), and solving for \(^*x\). If \(f(x)\) is scalar-valued but has a vector-valued input, we would generalize this to taking the gradient \(\nabla f(x)\), setting to zero, and solving for \(x^*\). However, in many practical applications \(\nabla f(x)\) is far too complicated to solve \(\nabla f(x^*) = 0\) analytically; just think how complicated our fuel consumption function in the example above would be. Thus, in practice we often use iterative numerical optimization algorithms such as gradient descent.
Gradient descent (for a scalar function of a single scalar argument) works as follows. First, assume some initial guess \(x_0\). Then, iteratively improve successive guesses \(x_k\) like so:
where \(\alpha\) is some small positive fixed “step size” parameter which we choose. This has the effect of moving our successive guesses “down the hill” of the function \(f(x)\). This process is repeated until successive iterates converge to some specified convergence tolerance \(\epsilon\):
The final \(x\) iterate we take to be our optimal \(x^*\).
You may notice that your solution, or whether you may converge, depends on your choice of step size. There are techniques to choose and adapt the step size. We won’t discuss this further in the course, but ADAM is a popular approach.
(a) Gradient descent on unconstrained optimization problem#
Consider the objective function \(f(x) = (x + 2)^2 + 5\tanh (x)\), defined in code and plotted below. The value of \(x\) which minimizes \(f(x)\) is \(x^* = -2.13578\), and \(f(x^*) = -4.84389\). In the next code cell, fill out the minimize_with_gradient_descent
function; you can use Jax’s autodiff capabilities to compute \(f'(x)\). The starter code in that cell will use your gradient descent function with an initial guess of \(x = 5\), a step size of \(0.1\), and a convergence tolerance of \(10^{-8}\) to minimize \(f(x)\); verify that your algorithm does indeed find the correct minimizing \(x^*\) and minimal value \(f(x^*)\).
Fun fact: with \(f(x) = (x + 2)^2 + 5\tanh (x)\), then \(f'(x)\) is a transcendental function, and \(f'(x) = 0\) can’t be solved algebraically. If your gradient descent code worked, congratulations! You’ve written code that numerically solved a math problem in milliseconds that can’t be solved analytically at all.
def f(x):
return (x + 2)**2 + 5*jnp.tanh(x)
args = np.arange(-6,4,0.01)
plt.figure(figsize=(8,6))
plt.plot(args, f(args))
plt.xlabel('x')
plt.ylabel('f(x)')
plt.title('Objective function')
plt.grid(alpha=0.3)
plt.show()

def minimize_with_gradient_descent(func, initial_guess, step_size, convergence_tol=1e-8):
'''
Minimizes a scalar function of a single variable.
Inputs:
func : name of function to be optimized. Takes initial_guess as input.
initial_guess : a real number
convergence_tol : convergence tolerace; when current and next guesses of of optimal x are closer
together than this, algorithm terminates and returns current estimate of optimal x
Outputs:
cur_x : current best estimate of x which minimizes f(x)
'''
grad_f = jax.jacobian(func)
cur_x = initial_guess
abs_delta_val = None # no value initially, we haven't started iterating yet
started = False
while (not started) or (abs_delta_val > convergence_tol):
started = True
next_x = cur_x - step_size * grad_f(cur_x)
abs_delta_val = np.abs(next_x - cur_x)
cur_x = next_x
return cur_x
x_opt = minimize_with_gradient_descent(f, 5.0, 0.1)
# output and plot:
print('optimal x:', x_opt)
print('optimal value of f(x):', f(x_opt))
args = np.arange(-6,4,0.01)
plt.figure(figsize=(8,6))
plt.plot(args, f(args), label='f(x)')
plt.scatter(x_opt, f(x_opt), zorder=2, color='red', label='optimal point')
plt.title('x_opt = {:.4f}, f(x_opt) = {:.4f}'.format(x_opt, f(x_opt)))
plt.grid(alpha=0.3)
plt.xlabel('x')
plt.ylabel('f(x)')
plt.legend()
plt.show()
optimal x: -2.135782
optimal value of f(x): -4.8438854

(b) Applying log-barrier for solving constrained optimization problems#
Now suppose we want to add a constraint to our optimization problem; that is, a restriction on the set of \(x\) values we’ll consider acceptable. In particular, in this case we’ll say that we only want values of \(x\) such that, for some specified constraint function \(g(x)\), we have \(g(x) < 0\); this is called an inequality constraint.
The gradient descent algorithm described above has no way to enforce a constraint like this; \(x\) is allowed to wander wherever the gradient takes it. Thus we must modify the algorithm to allow it to find constrained optima under inequality constraints. There are several ways to do this, but in this problem we will be using the log-barrier method.
In this method, we construct from the objective function \(f(x)\) a different objective function \(\phi (x)\) that has the following properties:
When \(x\) is such that \(g(x)\) is far away from \(0\), then \(\phi(x)\) must approximate \(f(x)\), so that minimizing \(\phi(x)\) approximately minimizes \(f(x)\).
\(\phi(x)\) must grow to infinity as \(g(x)\) approaces \(0\). This will prevent \(x\) from crossing over the boundary of \(g(x) < 0\).
To accomplish these goals, we construct \(\phi(x)\) like so: $\( \phi(x) = f(x) - t\ln(-g(x)), \)\( where \)t$ is a weighting parameter that we choose.
Suppose we want to minimize the same objective function as before, \(f(x) = (x + 2)^2 + 5\tanh (x)\), but now we want a constraint \(x > 1\). In the code block below, fill out the functions g
and phi
. The starter code will plot \(f(x)\), as well as \(\phi(x)\) for \(t = 0.5,\ 2,\ 5\). Comment on how \(\phi(x)\) changes with changing \(t\). Note that since \(\ln(y)\) is not defined (in the real numbers) for \(y \le 0\), the domain of \(\phi(x)\) is restricted to \(x > 1\).
Hint: for your phi
function to work well with the next parts of this problem, be sure to use jnp.log
instead of np.log
.
# fill out g(x) so that the statement g(x) < 0 is equivalent to the statement x > 1
def g(x):
return 1 - x
def phi(f, x, g, t):
'''
Computes phi(x).
Inputs:
f : name of f(x) function; takes x as input
x : variable to be optimized
g : constraint function; we want to ensure g(x) <= 0
t : log-barrier weighting parameter
Outputs:
phi(x)
'''
return f(x) - t * jnp.log(-g(x))
x_upper = 4
dx = 0.01
f_x_domain = np.arange(-6, x_upper, dx)
phi_x_domain = np.arange(1.00001, x_upper, dx)
plt.figure(figsize=(8,6))
plt.plot(f_x_domain, f(f_x_domain), label='f(x)')
plt.plot(phi_x_domain, phi(f, phi_x_domain, g, 5), label='phi(x), t = 5')
plt.plot(phi_x_domain, phi(f, phi_x_domain, g, 2), label='phi(x), t = 2')
plt.plot(phi_x_domain, phi(f, phi_x_domain, g, 0.5), label='phi(x), t = 0.5')
plt.vlines(1, -10, 40, linestyles='dashed', label='x = 1', color='black')
plt.xlabel('x')
plt.grid(alpha=0.3)
plt.ylabel('f(x), phi(x)')
plt.title('f(x) and phi(x) vs x')
plt.legend(loc='upper left')
# plt.ylim(-10, 40)
plt.show()

(c)#
Now, we can use our minimize_with_gradient_descent
function to minimize \(\phi(x)\) instead of \(f(x)\). Do so below, for a few different values of \(t\), and comment on the results. The true constrained optimal point is \(x^* = 1\), \(f(x^*) = 12.808\).
Does the log-barrier method find the true constrained optimal point?
How does the point found by the log-barrier method change as you change \(t\)?
You may find that as you decrease \(t\), you also need to decrease the gradient descent step size to get good behavior (in particular, to avoid “overshooting” the constraint and finding an \(x^*\) estimate which does not respect the constraint). Also, for the log-barrier method to work, your initial guess must itself respect the constraint (in this case, your initial guess must be greater than one).
Hint: the minimize_with_gradient_descent
function requires an objective function that takes only \(x\) as an argument: \(f(x)\). But our phi
Python function defined above takes f
, x
, g
, and t
; you may find Python’s “lambda function” capabilities useful or the functools.partial
.
# ========================== hint: lambdas ==============================
def hint_func(arg1, arg2, arg3):
return arg1 + 2 * arg2 + arg3
def hint_func_caller(func, arg):
return func(arg)
foo = 42
bar = 100
lambda x: hint_func(foo, x, bar) # this essentially defines a function of x only, that calls hint_func(x, 42, 100)
# hint_func_caller expects a function that takes only one argument. But we can use a lambda expression to
# "prepopulate" all but one argument of hint_func, "turning it into" an argument of one function:
hint_func_caller(lambda x: hint_func(foo, x, bar), 5) # this will work and give 152
# ^^^^^^^^^^^^^^^^^^^^^^^^^^ hint: lambdas ^^^^^^^^^^^^^^^^^^^^^^^^^^
# ========================== hint: functools.partial ==============================
new_func = functools.partial(hint_func, foo, arg3=bar) # this is equivalent to lambda x: hint_func(foo, x, bar)
new_func(5) # this will give 152
# OR
new_func = functools.partial(hint_func, arg1=foo, arg3=bar) # this is equivalent to lambda x: hint_func(foo, x, bar)
new_func(arg2=5) # this will give 152
# ^^^^^^^^^^^^^^^^^^^^^^^^^^ hint: functools.partial ^^^^^^^^^^^^^^^^^^^^^^^^^^
t = 0.5
x_opt = minimize_with_gradient_descent(lambda x:phi(f, x, g, t), 5.0, 0.01)
x_opt = minimize_with_gradient_descent(functools.partial(phi, f, g=g, t=t), 5.0, 0.01)
# output and plot:
print('optimal x:', x_opt)
print('optimal value of f(x):', f(x_opt))
args = np.arange(-6,4,0.01)
plt.figure(figsize=(8,6))
plt.plot(args, f(args), label='f(x)')
plt.scatter(x_opt, f(x_opt), zorder=2, color='red', label='optimal point')
plt.plot(phi_x_domain, phi(f, phi_x_domain, g, t), label='phi(x), t = {:.1f}'.format(t))
plt.vlines(1, -10, 40, linestyles='dashed', label='x = 1', color='black')
plt.title('x_opt = {:.4f}, f(x_opt) = {:.4f}'.format(x_opt, f(x_opt)))
plt.xlabel('x')
plt.ylabel('f(x)')
plt.legend()
plt.grid(alpha=0.3)
plt.show()
optimal x: 1.0622556
optimal value of f(x): 13.310038

Problem 4 (cvxpy
)#
In this problem we will explore the basics of cvxpy
, a Python package for solving convex optimization problems. cvxpy
has a good tutorial here, so read that page before proceeding with this problem (the section on “parameters”, while useful, is not important for this homework, so consider that section optional for now).
Consider a vector variable \(x = [x_1, x_2, x_3]^T\). Use cvxpy
to compute the minimizer of the following objective function
$\(
x_1^2 + 2x_2^2 + 3.5x_3^2,
\)$
subject to the constraint
Print the optimal \(x\) and the optimal value of the objective function.
Notice how much easier it is to use cvxpy
than to write our own optimization algorithm from scratch!
x = cp.Variable(3)
objective = cp.Minimize(x[0]**2 + 2*x[1]**2 + 3.5*x[2]**2)
M = np.array([[-0.707, -0.707, 0],
[-1, 0, 0],
[0, -1, 0],
[0, 0, -1]])
b = [-2, -1, -1, -3]
constraints = [M @ x <= b]
problem = cp.Problem(objective, constraints)
problem.solve(solver='CLARABEL')
print('x =',x.value)
print('objective funciton value =', problem.value)
x = [1.82885425 1.00000006 3. ]
objective funciton value = 36.844708132970986
Problem 5 (Control Barrier Function and CBF-QP)#
In this problem, using the concepts and code you have developed in the previous problems, you will implement a CBF-QP safety filter. We will restrict ourselves to a simple 1D system so that you can verify your results analytically if needed, but note that the theory extends to higher dimensional problems, and try to keep your code general so that it could still work on a different system and choice of CBF.
Consider the following 1D single integator dynamics \( \dot{x} = u \) and let \(b(x) = x^2 - 1\).
(a) Computing the Lie derivative#
What is the expression for \(\nabla b(x)^T f(x,u)\)?
\(\nabla b(x) = \ldots\)
\(f(x,u) = \ldots\)
\(\nabla b(x)^Tf(x,u) = \ldots\)
(b) Solving the CBF-QP#
Suppose that your desired control is \(u_\mathrm{des}=0.5\), which is to move in the positive \(x\)-direction at constant velocity. But the safety filter enforces that \(x^2 \geq 1\), i.e., keep a 1 unit away from the origin.
The CBF safety filter essentially chooses a control input that is as close to \(u_\mathrm{des}\) as possible while satisfying the CBF inequality constraint.
This is referred to as the CBF-QP since it is a quadratic program (quadratic objective and linear constraints). For simplicity, we assume there are no other constraints on controls.
Let \(\alpha(z) = az\), \(a=0.5\)
Use cvxpy
to solve the CBF-QP for \(x=-3\), \(x=-2\) and \(x=-1.1\).
Report the corresponding safe control values.
Hint: You may find the Parameter
variable in cvxpy
helpful. It allows you to update certain parameters without needing to reconstruct the optimization each time.
# your code here...
u = cp.Variable(1)
x = cp.Parameter(1)
a = cp.Parameter(1)
u_des = cp.Parameter(1)
objective = cp.Minimize((u - u_des)**2)
constraints = [2 * x @ u + a * (x @ x - 1) >= 0]
problem = cp.Problem(objective, constraints)
a.value = np.array([0.5])
u_des.value = np.array([0.5])
xs = np.array([-3., -2., -1.1])
for x_val in xs:
x.value = np.array([x_val])
problem.solve()
print("x = {:.2f}, u_safe = {:.2f}".format(x_val, u.value[0]))
x = -3.00, u_safe = 0.50
x = -2.00, u_safe = 0.38
x = -1.10, u_safe = 0.05
/Users/karen/venvs/aa548/lib/python3.13/site-packages/cvxpy/reductions/solvers/solving_chain.py:245: UserWarning: You are solving a parameterized problem that is not DPP. Because the problem is not DPP, subsequent solves will not be faster than the first one. For more information, see the documentation on Disciplined Parametrized Programming, at https://www.cvxpy.org/tutorial/dpp/index.html
warnings.warn(DPP_ERROR_MSG)
(c) Applying the CBF safety filter#
Now, simulate the system starting from \(x=-5\) and with \(u_\mathrm{des}=0.5\), but the CBF safety filter is applied. Plot the state and control sequence for \(a=2\), \(a=1\), \(a=0.5\) and \(a=0.1\) where \(\alpha(z) = az\). Comment on how the trajectory changes as \(a\) changes. What is your interpretation of \(a\)?
Use \(\Delta t = 0.05\) and simulate for 500 steps.
Note that the CBF theory is for continuous-time dynamics, but when we simulate, we use discrete-time dynamics. There are some practical issues that we need to be careful about (see this paper for more details), but right now, let’s stick to using a reasonably small time step.
# define continuous-time dynamics
# construct discrete-time dynamics
# simulate system, during which the CBF-QP safety filter is applied.
# loop over various values of a.
def single_integrator_ode(state, control, time):
return control
state_dim = 1
control_dim = 1
continuous_dynamics = Dynamics(single_integrator_ode, state_dim, control_dim)
dt = 0.05
discrete_dynamics_rk = Dynamics(
runge_kutta_integrator(continuous_dynamics, dt), state_dim, control_dim
)
n_steps = 500
x0 = -5.0
xs = []
us = []
as_val = [2., 1., 0.5, 0.1]
for (i,a_val) in enumerate(as_val):
x_val = x0
a.value = np.array([a_val])
xs.append([x_val])
us.append([])
for k in range(n_steps):
x.value = np.array([x_val])
problem.solve()
u_val = u.value[0]
x_val = discrete_dynamics_rk(x_val, u_val, 0)
xs[i].append(x_val)
us[i].append(u_val)
xs = jnp.array(xs)
us = jnp.array(us)
plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.title('Trajectories')
plt.xlabel('Time step')
plt.ylabel('x')
plt.plot(xs.T)
plt.grid(alpha=0.3)
plt.legend(['a = {}'.format(a_val) for a_val in as_val])
plt.subplot(1,2,2)
plt.title('Control inputs')
plt.xlabel('Time step')
plt.ylabel('u')
plt.plot(us.T)
plt.grid(alpha=0.3)
plt.legend(['a = {}'.format(a_val) for a_val in as_val])
plt.tight_layout()
