17. Numba: JIT for Speed!#
Numba is one of the most exciting things to happen to Python. It is a library than take a Python function, convert the bytecode to LLVM, compile it, and run it at full machine speed!
import numba
import numpy as np
import matplotlib.pyplot as plt
17.1. First example#
def f1(a, b):
return 2 * a**3 + 3 * b**0.5
@numba.vectorize
def f2(a, b):
return 2 * a**3 + 3 * b**0.5
a, b = np.random.random_sample(size=(2, 100_000))
%%timeit
c = f1(a, b)
1.61 ms ± 5.26 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
%%time
c = f2(a, b)
CPU times: user 218 ms, sys: 11.9 ms, total: 230 ms
Wall time: 229 ms
This probably took a bit longer. The very first time you JIT compile something, it takes time to do the compilation. Numba is pretty fast, but you probably still pay a cost. There are things you can do to control when this happens, but there is a small cost.
%%timeit
c = f2(a, b)
73.2 µs ± 65.4 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
It took the function we defined, pulled it apart, and turned into Low Level Virtual Machine (LLVM) code, and compiled it. No special strings or special syntax; it is just a (large) subset of Python and NumPy. And users and libraries can extend it too. It also supports:
Vectorized, general vectorized, or regular functions
Ahead of time compilation, JIT, or dynamic JIT
Parallelized targets
GPU targets via CUDA or ROCm
Nesting
Creating cfunction callbacks
It is almost always as fast or faster than any other compiled solution (minus the JIT time). A couple of years ago it became much easier to install (via PIP with LLVMLite’s lightweight and independent LLVM build).
17.2. JIT example#
The example above using @numba.vectorize
to make “ufunc” like functions. These can take any (broadcastable) size of array(s) and produces an output array. It’s similar to @numpy.vectorize
which just loops in Python. Let’s try controlling the looping ourselves, using a ODE solver:
17.2.1. Problem setup#
Let’s setup an ODE function to solve. We can write our ODE as a system of linear first order ODE equations:
The harmonic motion equation can be written in terms of \(\mathbf{f}(t, \mathbf{y}) = \dot{\mathbf{y}}\), where this is in the standard form:
x_max = 1 # Size of x max
v_0 = 0
koverm = 1 # k / m
def f(t, y):
"Y has two elements, x and v"
return np.array([-koverm * y[1], y[0]])
17.2.2. Runge-Kutta introduction#
Note that \(h = t_{n+1} - t_n \).
Now, expand \(f\) in a Taylor series around the midpoint of the interval:
The second term here is symmetric in the interval, so all we have left is the first term in the integral:
Back into the original statement, we get:
We’ve got one more problem! How do we calculate \(f(t_{n+\frac{1}{2}},y_{n+\frac{1}{2}})\)? We can use the Euler’s algorithm that we saw last time:
Putting it together, this is our RK2 algorithm:
We’ve picked up bold face to indicate that we can have a vector of ODEs.
We can get the RK4 algorithm by keeping another non-zero term in the Taylor series:
def rk4_ivp(f, init_y, t):
steps = len(t)
order = len(init_y)
y = np.empty((steps, order))
y[0] = init_y
for n in range(steps - 1):
h = t[n + 1] - t[n]
k1 = h * f(t[n], y[n]) # 2.1
k2 = h * f(t[n] + h / 2, y[n] + k1 / 2) # 2.2
k3 = h * f(t[n] + h / 2, y[n] + k2 / 2) # 2.3
k4 = h * f(t[n] + h, y[n] + k3) # 2.4
y[n + 1] = y[n] + 1 / 6 * (k1 + 2 * k2 + 2 * k3 + k4) # 2.0
return y
Let’s plot this:
ts = np.linspace(0, 40, 100 + 1)
y = rk4_ivp(f, [x_max, v_0], ts)
plt.plot(ts, np.cos(ts))
plt.plot(ts, y[:, 0], "--");
%%timeit
ts = np.linspace(0, 40, 1000 + 1)
y = rk4_ivp(f, [x_max, v_0], ts)
24.8 ms ± 2.43 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
17.2.3. Adding Numba#
Normally, you’d use a decorator here, but I’m lazy and don’t want to rewrite the function, so I’ll just manually apply the decorator, since we covered what the syntax actually does.
f_jit = numba.njit(f)
rk4_ivp_jit = numba.njit(rk4_ivp)
%%timeit
ts = np.linspace(0, 40, 1000 + 1)
y = rk4_ivp_jit(f_jit, np.array([x_max, v_0]), ts)
349 µs ± 47.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
You can inspect the types if you’d like to add them after running once:
f_jit.inspect_types()
f (float64, Array(float64, 1, 'C', False, aligned=True))
--------------------------------------------------------------------------------
# File: /tmp/ipykernel_3040/396120679.py
# --- LINE 6 ---
# label 0
# t = arg(0, name=t) :: float64
# del t
# y = arg(1, name=y) :: array(float64, 1d, C)
def f(t, y):
# --- LINE 7 ---
"Y has two elements, x and v"
# --- LINE 8 ---
# $2load_global.0 = global(np: <module 'numpy' from '/usr/share/miniconda3/envs/level-up-your-python/lib/python3.10/site-packages/numpy/__init__.py'>) :: Module(<module 'numpy' from '/usr/share/miniconda3/envs/level-up-your-python/lib/python3.10/site-packages/numpy/__init__.py'>)
# $4load_method.1 = getattr(value=$2load_global.0, attr=array) :: Function(<built-in function array>)
# del $4load_method.1
# del $2load_global.0
# $6load_global.2 = global(koverm: 1) :: Literal[int](1)
# $8unary_negative.3 = unary(fn=<built-in function neg>, value=$6load_global.2) :: int64
# del $6load_global.2
# $const12.5 = const(int, 1) :: Literal[int](1)
# $14binary_subscr.6 = static_getitem(value=y, index=1, index_var=$const12.5, fn=<built-in function getitem>) :: float64
# del $const12.5
# $16binary_multiply.7 = $8unary_negative.3 * $14binary_subscr.6 :: float64
# del $8unary_negative.3
# del $14binary_subscr.6
# $const20.9 = const(int, 0) :: Literal[int](0)
# $22binary_subscr.10 = static_getitem(value=y, index=0, index_var=$const20.9, fn=<built-in function getitem>) :: float64
# del y
# del $const20.9
# size = const(int, 2) :: int64
# size_tuple = build_tuple(items=[Var(size, 396120679.py:8)]) :: UniTuple(int64 x 1)
# del size_tuple
# empty_func = global(empty: <built-in function empty>) :: Function(<built-in function empty>)
# $np_g_var = global(np: <module 'numpy' from '/usr/share/miniconda3/envs/level-up-your-python/lib/python3.10/site-packages/numpy/__init__.py'>) :: Module(<module 'numpy' from '/usr/share/miniconda3/envs/level-up-your-python/lib/python3.10/site-packages/numpy/__init__.py'>)
# $np_typ_var = getattr(value=$np_g_var, attr=float64) :: dtype(float64)
# del $np_g_var
# $26call_method.12 = call empty_func(size, $np_typ_var, func=empty_func, args=[Var(size, 396120679.py:8), Var($np_typ_var, 396120679.py:8)], kws={}, vararg=None, varkwarg=None, target=None) :: (int64, dtype(float64)) -> array(float64, 1d, C)
# del size
# del empty_func
# del $np_typ_var
# index = const(int, 0) :: int64
# $26call_method.12[index] = $16binary_multiply.7 :: (Array(float64, 1, 'C', False, aligned=True), int64, float64) -> none
# del index
# del $16binary_multiply.7
# index.1 = const(int, 1) :: int64
# $26call_method.12[index.1] = $22binary_subscr.10 :: (Array(float64, 1, 'C', False, aligned=True), int64, float64) -> none
# del index.1
# del $22binary_subscr.10
# $28return_value.13 = cast(value=$26call_method.12) :: array(float64, 1d, C)
# del $26call_method.12
# return $28return_value.13
return np.array([-koverm * y[1], y[0]])
================================================================================