Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Numba: JIT for Speed!

Numba is one of the most exciting things to happen to Python. It is a library than take a Python function, convert the bytecode to LLVM, compile it, and run it at full machine speed!

import numba
import numpy as np
import matplotlib.pyplot as plt

First example

def f1(a, b):
    return 2 * a**3 + 3 * b**0.5


@numba.vectorize
def f2(a, b):
    return 2 * a**3 + 3 * b**0.5
a, b = np.random.random_sample(size=(2, 100_000))
%%timeit
c = f1(a, b)
1.85 ms ± 192 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%%time
c = f2(a, b)
CPU times: user 425 ms, sys: 24.9 ms, total: 450 ms
Wall time: 810 ms

This probably took a bit longer. The very first time you JIT compile something, it takes time to do the compilation. Numba is pretty fast, but you probably still pay a cost. There are things you can do to control when this happens, but there is a small cost.

%%timeit
c = f2(a, b)
73.5 μs ± 1.02 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

It took the function we defined, pulled it apart, and turned into Low Level Virtual Machine (LLVM) code, and compiled it. No special strings or special syntax; it is just a (large) subset of Python and NumPy. And users and libraries can extend it too. It also supports:

  • Vectorized, general vectorized, or regular functions

  • Ahead of time compilation, JIT, or dynamic JIT

  • Parallelized targets

  • GPU targets via CUDA or ROCm

  • Nesting

  • Creating cfunction callbacks

It is almost always as fast or faster than any other compiled solution (minus the JIT time). A couple of years ago it became much easier to install (via PIP with LLVMLite’s lightweight and independent LLVM build).

JIT example

The example above using @numba.vectorize to make “ufunc” like functions. These can take any (broadcastable) size of array(s) and produces an output array. It’s similar to @numpy.vectorize which just loops in Python. Let’s try controlling the looping ourselves, using a ODE solver:

Problem setup

Let’s setup an ODE function to solve. We can write our ODE as a system of linear first order ODE equations:

The harmonic motion equation can be written in terms of f(t,y)=y˙\mathbf{f}(t, \mathbf{y}) = \dot{\mathbf{y}}, where this is in the standard form:

y=(x˙x)\mathbf{y} = \left( \begin{matrix} \dot{x} \\ x \end{matrix} \right)
f(t,y)=y˙=(x¨x˙)=(kmxx˙)=(kmy1y0)\mathbf{f}(t, \mathbf{y}) = \dot{\mathbf{y}} = \left( \begin{matrix} \ddot{x} \\ \dot{x} \end{matrix} \right) = \left( \begin{matrix} -\frac{k}{m} x \\ \dot{x} \end{matrix} \right) = \left( \begin{matrix} -\frac{k}{m} y_1 \\ y_0 \end{matrix} \right)
x_max = 1  # Size of x max
v_0 = 0
koverm = 1  # k / m


def f(t, y):
    "Y has two elements, x and v"
    return np.array([-koverm * y[1], y[0]])

Runge-Kutta introduction

Note that h=tn+1tnh = t_{n+1} - t_n .

y˙=f(t,y)\dot{y} = f(t,y)
    y=f(t,y)dt\implies y = \int f(t,y) \, dt
    yn+1=yn+tntn+1f(t,y)dt\implies y_{n+1} = y_{n} + \int_{t_n}^{t_{n+1}} f(t,y) \, dt

Now, expand ff in a Taylor series around the midpoint of the interval:

f(t,y)f(tn+12,yn+12)+(ttn+12)f˙(tn+12)+O(h2)f(t,y) \approx f(t_{n+\frac{1}{2}},y_{n+\frac{1}{2}}) + \left( t - t_{n+\frac{1}{2}}\right) \dot{f}(t_{n+\frac{1}{2}}) + \mathcal{O}(h^2)

The second term here is symmetric in the interval, so all we have left is the first term in the integral:

tntn+1f(t,y)dthf(tn+12,yn+12)+O(h3)\int_{t_n}^{t_{n+1}} f(t,y) \, dt \approx h\, f(t_{n+\frac{1}{2}},y_{n+\frac{1}{2}}) + \mathcal{O}(h^3)

Back into the original statement, we get:

yn+1yn+hf(tn+12,yn+12)+O(h3)(rk2)y_{n+1} \approx \color{blue}{ y_{n} + h\, f(t_{n+\frac{1}{2}},y_{n+\frac{1}{2}}) } + \mathcal{O}(h^3) \tag{rk2}

We’ve got one more problem! How do we calculate f(tn+12,yn+12)f(t_{n+\frac{1}{2}},y_{n+\frac{1}{2}})? We can use the Euler’s algorithm that we saw last time:

yn+12yn+12hy˙=yn+12hf(tn,yn)y_{n+\frac{1}{2}} \approx y_n + \frac{1}{2} h \dot{y} = \color{red}{ y_n + \frac{1}{2} h f(t_{n},y_{n}) }

Putting it together, this is our RK2 algorithm:

yn+1yn+k2(1.0)\mathbf{y}_{n+1} \approx \color{blue}{ \mathbf{y}_{n} + \mathbf{k}_2 } \tag{1.0}
k1=hf(tn,yn)(1.1)\mathbf{k}_1 = h \mathbf{f}(t_n,\, \mathbf{y}_n) \tag{1.1}
k2=hf(tn+h2,yn+k12)(1.2)\mathbf{k}_2 = h \mathbf{f}(t_n + \frac{h}{2},\, \color{red}{\mathbf{y}_n + \frac{\mathbf{k}_1}{2}}) \tag{1.2}

We’ve picked up bold face to indicate that we can have a vector of ODEs.

We can get the RK4 algorithm by keeping another non-zero term in the Taylor series:

yn+1yn+16(k1+2k2+2k3+k4)(2.0)\mathbf{y}_{n+1} \approx \mathbf{y}_{n} + \frac{1}{6} (\mathbf{k}_1 + 2 \mathbf{k}_2 + 2 \mathbf{k}_3 + \mathbf{k}_4 ) \tag{2.0}
k1=hf(tn,yn)(2.1)\mathbf{k}_1 = h \mathbf{f}(t_n,\, \mathbf{y}_n) \tag{2.1}
k2=hf(tn+h2,yn+k12)(2.2)\mathbf{k}_2 = h \mathbf{f}(t_n + \frac{h}{2},\, \mathbf{y}_n + \frac{\mathrm{k}_1}{2}) \tag{2.2}
k3=hf(tn+h2,yn+k22)(2.3)\mathbf{k}_3 = h \mathbf{f}(t_n + \frac{h}{2},\, \mathbf{y}_n + \frac{\mathrm{k}_2}{2}) \tag{2.3}
k4=hf(tn+h,yn+k3)(2.4)\mathbf{k}_4 = h \mathbf{f}(t_n + h,\, \mathbf{y}_n + \mathrm{k}_3) \tag{2.4}
def rk4_ivp(f, init_y, t):
    steps = len(t)
    order = len(init_y)

    y = np.empty((steps, order))
    y[0] = init_y

    for n in range(steps - 1):
        h = t[n + 1] - t[n]

        k1 = h * f(t[n], y[n])  # 2.1
        k2 = h * f(t[n] + h / 2, y[n] + k1 / 2)  # 2.2
        k3 = h * f(t[n] + h / 2, y[n] + k2 / 2)  # 2.3
        k4 = h * f(t[n] + h, y[n] + k3)  # 2.4

        y[n + 1] = y[n] + 1 / 6 * (k1 + 2 * k2 + 2 * k3 + k4)  # 2.0

    return y

Let’s plot this:

ts = np.linspace(0, 40, 100 + 1)
y = rk4_ivp(f, [x_max, v_0], ts)
plt.plot(ts, np.cos(ts))
plt.plot(ts, y[:, 0], "--");
<Figure size 640x480 with 1 Axes>
%%timeit
ts = np.linspace(0, 40, 1000 + 1)
y = rk4_ivp(f, [x_max, v_0], ts)
17 ms ± 108 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Adding Numba

Normally, you’d use a decorator here, but I’m lazy and don’t want to rewrite the function, so I’ll just manually apply the decorator, since we covered what the syntax actually does.

f_jit = numba.njit(f)
rk4_ivp_jit = numba.njit(rk4_ivp)
%%timeit
ts = np.linspace(0, 40, 1000 + 1)
y = rk4_ivp_jit(f_jit, np.array([x_max, v_0]), ts)
366 μs ± 93.4 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)

You can inspect the types if you’d like to add them after running once:

f_jit.inspect_types()
f (float64, Array(float64, 1, 'C', False, aligned=True))
--------------------------------------------------------------------------------
# File: /tmp/ipykernel_10153/396120679.py
# --- LINE 6 --- 
# label 0
#   t = arg(0, name=t)  :: float64
#   del t
#   y = arg(1, name=y)  :: array(float64, 1d, C)

def f(t, y):

    # --- LINE 7 --- 

    "Y has two elements, x and v"

    # --- LINE 8 --- 
    #   $4load_global.0 = global(np: <module 'numpy' from '/home/runner/work/level-up-your-python/level-up-your-python/.pixi/envs/default/lib/python3.14/site-packages/numpy/__init__.py'>)  :: Module(<module 'numpy' from '/home/runner/work/level-up-your-python/level-up-your-python/.pixi/envs/default/lib/python3.14/site-packages/numpy/__init__.py'>)
    #   $14load_attr.1 = getattr(value=$4load_global.0, attr=array)  :: Function(<built-in function array>)
    #   del $4load_global.0
    #   del $14load_attr.1
    #   $34load_global.3 = global(koverm: 1)  :: Literal[int](1)
    #   $44unary_negative.4 = unary(fn=<built-in function neg>, value=$34load_global.3)  :: int64
    #   del $34load_global.3
    #   $const48.6.1 = const(int, 1)  :: Literal[int](1)
    #   $50binary_op.7 = static_getitem(value=y, index=1, index_var=$const48.6.1, fn=<built-in function getitem>)  :: float64
    #   del $const48.6.1
    #   $binop_mul62.8 = $44unary_negative.4 * $50binary_op.7  :: float64
    #   del $50binary_op.7
    #   del $44unary_negative.4
    #   $const76.10.0 = const(int, 0)  :: Literal[int](0)
    #   $78binary_op.11 = static_getitem(value=y, index=0, index_var=$const76.10.0, fn=<built-in function getitem>)  :: float64
    #   del y
    #   del $const76.10.0
    #   size = const(int, 2)  :: int64
    #   size_tuple = build_tuple(items=[Var(size, 396120679.py:8)])  :: UniTuple(int64 x 1)
    #   del size_tuple
    #   empty_func = global(empty: <built-in function empty>)  :: Function(<built-in function empty>)
    #   $np_g_var = global(np: <module 'numpy' from '/home/runner/work/level-up-your-python/level-up-your-python/.pixi/envs/default/lib/python3.14/site-packages/numpy/__init__.py'>)  :: Module(<module 'numpy' from '/home/runner/work/level-up-your-python/level-up-your-python/.pixi/envs/default/lib/python3.14/site-packages/numpy/__init__.py'>)
    #   $np_typ_var = getattr(value=$np_g_var, attr=float64)  :: dtype(float64)
    #   del $np_g_var
    #   $92call.13 = call empty_func(size, $np_typ_var, func=empty_func, args=[Var(size, 396120679.py:8), Var($np_typ_var, 396120679.py:8)], kws={}, vararg=None, varkwarg=None, target=None)  :: (int64, dtype(float64)) -> array(float64, 1d, C)
    #   del size
    #   del empty_func
    #   del $np_typ_var
    #   index = const(int, 0)  :: int64
    #   $92call.13[index] = $binop_mul62.8  :: (Array(float64, 1, 'C', False, aligned=True), int64, float64) -> none
    #   del index
    #   del $binop_mul62.8
    #   index.1 = const(int, 1)  :: int64
    #   $92call.13[index.1] = $78binary_op.11  :: (Array(float64, 1, 'C', False, aligned=True), int64, float64) -> none
    #   del index.1
    #   del $78binary_op.11
    #   $100return_value.14 = cast(value=$92call.13)  :: array(float64, 1d, C)
    #   del $92call.13
    #   return $100return_value.14

    return np.array([-koverm * y[1], y[0]])


================================================================================