17. Numba: JIT for Speed!#

Numba is one of the most exciting things to happen to Python. It is a library than take a Python function, convert the bytecode to LLVM, compile it, and run it at full machine speed!

import numba
import numpy as np
import matplotlib.pyplot as plt

17.1. First example#

def f1(a, b):
    return 2 * a**3 + 3 * b**0.5


@numba.vectorize
def f2(a, b):
    return 2 * a**3 + 3 * b**0.5
a, b = np.random.random_sample(size=(2, 100_000))
%%timeit
c = f1(a, b)
1.61 ms ± 5.26 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
%%time
c = f2(a, b)
CPU times: user 218 ms, sys: 11.9 ms, total: 230 ms
Wall time: 229 ms

This probably took a bit longer. The very first time you JIT compile something, it takes time to do the compilation. Numba is pretty fast, but you probably still pay a cost. There are things you can do to control when this happens, but there is a small cost.

%%timeit
c = f2(a, b)
73.2 µs ± 65.4 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

It took the function we defined, pulled it apart, and turned into Low Level Virtual Machine (LLVM) code, and compiled it. No special strings or special syntax; it is just a (large) subset of Python and NumPy. And users and libraries can extend it too. It also supports:

  • Vectorized, general vectorized, or regular functions

  • Ahead of time compilation, JIT, or dynamic JIT

  • Parallelized targets

  • GPU targets via CUDA or ROCm

  • Nesting

  • Creating cfunction callbacks

It is almost always as fast or faster than any other compiled solution (minus the JIT time). A couple of years ago it became much easier to install (via PIP with LLVMLite’s lightweight and independent LLVM build).

17.2. JIT example#

The example above using @numba.vectorize to make “ufunc” like functions. These can take any (broadcastable) size of array(s) and produces an output array. It’s similar to @numpy.vectorize which just loops in Python. Let’s try controlling the looping ourselves, using a ODE solver:

17.2.1. Problem setup#

Let’s setup an ODE function to solve. We can write our ODE as a system of linear first order ODE equations:

The harmonic motion equation can be written in terms of \(\mathbf{f}(t, \mathbf{y}) = \dot{\mathbf{y}}\), where this is in the standard form:

\[\begin{split} \mathbf{y} = \left( \begin{matrix} \dot{x} \\ x \end{matrix} \right) \end{split}\]
\[\begin{split} \mathbf{f}(t, \mathbf{y}) = \dot{\mathbf{y}} = \left( \begin{matrix} \ddot{x} \\ \dot{x} \end{matrix} \right) = \left( \begin{matrix} -\frac{k}{m} x \\ \dot{x} \end{matrix} \right) = \left( \begin{matrix} -\frac{k}{m} y_1 \\ y_0 \end{matrix} \right) \end{split}\]
x_max = 1  # Size of x max
v_0 = 0
koverm = 1  # k / m


def f(t, y):
    "Y has two elements, x and v"
    return np.array([-koverm * y[1], y[0]])

17.2.2. Runge-Kutta introduction#

Note that \(h = t_{n+1} - t_n \).

\[ \dot{y} = f(t,y) \]
\[ \implies y = \int f(t,y) \, dt \]
\[ \implies y_{n+1} = y_{n} + \int_{t_n}^{t_{n+1}} f(t,y) \, dt \]

Now, expand \(f\) in a Taylor series around the midpoint of the interval:

\[ f(t,y) \approx f(t_{n+\frac{1}{2}},y_{n+\frac{1}{2}}) + \left( t - t_{n+\frac{1}{2}}\right) \dot{f}(t_{n+\frac{1}{2}}) + \mathcal{O}(h^2) \]

The second term here is symmetric in the interval, so all we have left is the first term in the integral:

\[ \int_{t_n}^{t_{n+1}} f(t,y) \, dt \approx h\, f(t_{n+\frac{1}{2}},y_{n+\frac{1}{2}}) + \mathcal{O}(h^3) \]

Back into the original statement, we get:

\[ y_{n+1} \approx \color{blue}{ y_{n} + h\, f(t_{n+\frac{1}{2}},y_{n+\frac{1}{2}}) } + \mathcal{O}(h^3) \tag{rk2} \]

We’ve got one more problem! How do we calculate \(f(t_{n+\frac{1}{2}},y_{n+\frac{1}{2}})\)? We can use the Euler’s algorithm that we saw last time:

\[ y_{n+\frac{1}{2}} \approx y_n + \frac{1}{2} h \dot{y} = \color{red}{ y_n + \frac{1}{2} h f(t_{n},y_{n}) } \]

Putting it together, this is our RK2 algorithm:

\[ \mathbf{y}_{n+1} \approx \color{blue}{ \mathbf{y}_{n} + \mathbf{k}_2 } \tag{1.0} \]
\[ \mathbf{k}_1 = h \mathbf{f}(t_n,\, \mathbf{y}_n) \tag{1.1} \]
\[ \mathbf{k}_2 = h \mathbf{f}(t_n + \frac{h}{2},\, \color{red}{\mathbf{y}_n + \frac{\mathbf{k}_1}{2}}) \tag{1.2} \]

We’ve picked up bold face to indicate that we can have a vector of ODEs.

We can get the RK4 algorithm by keeping another non-zero term in the Taylor series:

\[ \mathbf{y}_{n+1} \approx \mathbf{y}_{n} + \frac{1}{6} (\mathbf{k}_1 + 2 \mathbf{k}_2 + 2 \mathbf{k}_3 + \mathbf{k}_4 ) \tag{2.0} \]
\[ \mathbf{k}_1 = h \mathbf{f}(t_n,\, \mathbf{y}_n) \tag{2.1} \]
\[ \mathbf{k}_2 = h \mathbf{f}(t_n + \frac{h}{2},\, \mathbf{y}_n + \frac{\mathrm{k}_1}{2}) \tag{2.2} \]
\[ \mathbf{k}_3 = h \mathbf{f}(t_n + \frac{h}{2},\, \mathbf{y}_n + \frac{\mathrm{k}_2}{2}) \tag{2.3} \]
\[ \mathbf{k}_4 = h \mathbf{f}(t_n + h,\, \mathbf{y}_n + \mathrm{k}_3) \tag{2.4} \]
def rk4_ivp(f, init_y, t):
    steps = len(t)
    order = len(init_y)

    y = np.empty((steps, order))
    y[0] = init_y

    for n in range(steps - 1):
        h = t[n + 1] - t[n]

        k1 = h * f(t[n], y[n])  # 2.1
        k2 = h * f(t[n] + h / 2, y[n] + k1 / 2)  # 2.2
        k3 = h * f(t[n] + h / 2, y[n] + k2 / 2)  # 2.3
        k4 = h * f(t[n] + h, y[n] + k3)  # 2.4

        y[n + 1] = y[n] + 1 / 6 * (k1 + 2 * k2 + 2 * k3 + k4)  # 2.0

    return y

Let’s plot this:

ts = np.linspace(0, 40, 100 + 1)
y = rk4_ivp(f, [x_max, v_0], ts)
plt.plot(ts, np.cos(ts))
plt.plot(ts, y[:, 0], "--");
../_images/53bb4eb9a4604ba7841ec7e12745cad3a8521a55bac7fced3a1d51996663e168.png
%%timeit
ts = np.linspace(0, 40, 1000 + 1)
y = rk4_ivp(f, [x_max, v_0], ts)
24.8 ms ± 2.43 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

17.2.3. Adding Numba#

Normally, you’d use a decorator here, but I’m lazy and don’t want to rewrite the function, so I’ll just manually apply the decorator, since we covered what the syntax actually does.

f_jit = numba.njit(f)
rk4_ivp_jit = numba.njit(rk4_ivp)
%%timeit
ts = np.linspace(0, 40, 1000 + 1)
y = rk4_ivp_jit(f_jit, np.array([x_max, v_0]), ts)
349 µs ± 47.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

You can inspect the types if you’d like to add them after running once:

f_jit.inspect_types()
f (float64, Array(float64, 1, 'C', False, aligned=True))
--------------------------------------------------------------------------------
# File: /tmp/ipykernel_3040/396120679.py
# --- LINE 6 --- 
# label 0
#   t = arg(0, name=t)  :: float64
#   del t
#   y = arg(1, name=y)  :: array(float64, 1d, C)

def f(t, y):

    # --- LINE 7 --- 

    "Y has two elements, x and v"

    # --- LINE 8 --- 
    #   $2load_global.0 = global(np: <module 'numpy' from '/usr/share/miniconda3/envs/level-up-your-python/lib/python3.10/site-packages/numpy/__init__.py'>)  :: Module(<module 'numpy' from '/usr/share/miniconda3/envs/level-up-your-python/lib/python3.10/site-packages/numpy/__init__.py'>)
    #   $4load_method.1 = getattr(value=$2load_global.0, attr=array)  :: Function(<built-in function array>)
    #   del $4load_method.1
    #   del $2load_global.0
    #   $6load_global.2 = global(koverm: 1)  :: Literal[int](1)
    #   $8unary_negative.3 = unary(fn=<built-in function neg>, value=$6load_global.2)  :: int64
    #   del $6load_global.2
    #   $const12.5 = const(int, 1)  :: Literal[int](1)
    #   $14binary_subscr.6 = static_getitem(value=y, index=1, index_var=$const12.5, fn=<built-in function getitem>)  :: float64
    #   del $const12.5
    #   $16binary_multiply.7 = $8unary_negative.3 * $14binary_subscr.6  :: float64
    #   del $8unary_negative.3
    #   del $14binary_subscr.6
    #   $const20.9 = const(int, 0)  :: Literal[int](0)
    #   $22binary_subscr.10 = static_getitem(value=y, index=0, index_var=$const20.9, fn=<built-in function getitem>)  :: float64
    #   del y
    #   del $const20.9
    #   size = const(int, 2)  :: int64
    #   size_tuple = build_tuple(items=[Var(size, 396120679.py:8)])  :: UniTuple(int64 x 1)
    #   del size_tuple
    #   empty_func = global(empty: <built-in function empty>)  :: Function(<built-in function empty>)
    #   $np_g_var = global(np: <module 'numpy' from '/usr/share/miniconda3/envs/level-up-your-python/lib/python3.10/site-packages/numpy/__init__.py'>)  :: Module(<module 'numpy' from '/usr/share/miniconda3/envs/level-up-your-python/lib/python3.10/site-packages/numpy/__init__.py'>)
    #   $np_typ_var = getattr(value=$np_g_var, attr=float64)  :: dtype(float64)
    #   del $np_g_var
    #   $26call_method.12 = call empty_func(size, $np_typ_var, func=empty_func, args=[Var(size, 396120679.py:8), Var($np_typ_var, 396120679.py:8)], kws={}, vararg=None, varkwarg=None, target=None)  :: (int64, dtype(float64)) -> array(float64, 1d, C)
    #   del size
    #   del empty_func
    #   del $np_typ_var
    #   index = const(int, 0)  :: int64
    #   $26call_method.12[index] = $16binary_multiply.7  :: (Array(float64, 1, 'C', False, aligned=True), int64, float64) -> none
    #   del index
    #   del $16binary_multiply.7
    #   index.1 = const(int, 1)  :: int64
    #   $26call_method.12[index.1] = $22binary_subscr.10  :: (Array(float64, 1, 'C', False, aligned=True), int64, float64) -> none
    #   del index.1
    #   del $22binary_subscr.10
    #   $28return_value.13 = cast(value=$26call_method.12)  :: array(float64, 1d, C)
    #   del $26call_method.12
    #   return $28return_value.13

    return np.array([-koverm * y[1], y[0]])


================================================================================

17.3. See also:#