Search
Pybindnumba
%load_ext ipybind
import numpy
import numba
%%pybind11

#include <complex>
#include <vector>
#include <pybind11/numpy.h>

py::array_t<int> quick(int height, int width, int maxiterations) {
    
    py::array_t<int> fractal({height, width});
    
    auto fractal_uc = fractal.mutable_unchecked<2>();
    
    for (int h = 0;  h < height;  h++) {
        for (int w = 0;  w < width;  w++) {
            
            std::complex<double> ci{
                double(h-1)/height - 1,
                1.5 * (double(w-1)/width - 1)};
            
            std::complex<double> z = ci;
            fractal_uc(h,w) = maxiterations;
            for (int i = 0;  i < maxiterations;  i++) {
                z = z * z + ci;
                if (std::abs(z) > 2) {
                    fractal_uc(h, w) = i;
                    break;
                }
            }
        }
    }
    
    return fractal;
}

PYBIND11_MODULE(py11fractal, m) {
    m.def("quick", quick);
}
%%time
quick(8000, 12000)
CPU times: user 7.45 s, sys: 230 ms, total: 7.68 s
Wall time: 7.71 s
array([[ 0,  0,  0, ..., 20, 20, 20],
       [ 0,  0,  0, ..., 20, 20, 20],
       [ 0,  0,  0, ..., 20, 20, 20],
       ...,
       [ 0,  0,  0, ..., 20, 20, 20],
       [ 0,  0,  0, ..., 20, 20, 20],
       [ 0,  0,  0, ..., 20, 20, 20]], dtype=int32)
@numba.vectorize
def as_ufunc(c, maxiterations):
    z = c
    for i in range(maxiterations):
        z = z**2 + c
        if abs(z) > 2:
            return i
    return maxiterations

def run_numba_2(height, width, maxiterations=20):
    y, x = numpy.ogrid[-1:0:height*1j, -1.5:0:width*1j]
    c = x + y*1j
    return as_ufunc(c, maxiterations)
%%time
run_numba_2(8000, 12000)
CPU times: user 7.27 s, sys: 1.13 s, total: 8.4 s
Wall time: 8.5 s
array([[ 0,  0,  0, ..., 10, 10, 20],
       [ 0,  0,  0, ...,  9, 10, 10],
       [ 0,  0,  0, ...,  9,  9,  9],
       ...,
       [20, 20, 20, ..., 20, 20, 20],
       [20, 20, 20, ..., 20, 20, 20],
       [20, 20, 20, ..., 20, 20, 20]])
@numba.njit
def run_numba(height, width, maxiterations):
    fractal = numpy.empty((height, width), dtype=numpy.int32)
    for h in range(height):
        for w in range(width):
            c = ((h-1)/height - 1) + 1.5j*((w-1)/width - 1)
            z = c
            fractal[h, w] = maxiterations
            
            for i in range(maxiterations):
                z = z**2 + c
                if abs(z) > 2:
                    fractal[h, w] = i
                    break
    return fractal
%%time
run_numba(8000, 12000, 20)
CPU times: user 5.92 s, sys: 177 ms, total: 6.1 s
Wall time: 6.1 s
array([[ 0,  0,  0, ..., 20, 20, 20],
       [ 0,  0,  0, ..., 20, 20, 20],
       [ 0,  0,  0, ..., 20, 20, 20],
       ...,
       [ 0,  0,  0, ..., 20, 20, 20],
       [ 0,  0,  0, ..., 20, 20, 20],
       [ 0,  0,  0, ..., 20, 20, 20]], dtype=int32)
@numba.njit(parallel=True)
def run_numba_p(height, width, maxiterations):
    fractal = numpy.empty((height, width), dtype=numpy.int32)
    for h in numba.prange(height):
        for w in range(width):
            c = ((h-1)/height - 1) + 1.5j*((w-1)/width - 1)
            z = c
            fractal[h, w] = maxiterations
            
            for i in range(maxiterations):
                z = z**2 + c
                if abs(z) > 2:
                    fractal[h, w] = i
                    break
    return fractal
%%time
run_numba_p(8000, 12000, 20)
CPU times: user 9.57 s, sys: 299 ms, total: 9.87 s
Wall time: 3.07 s
array([[ 0,  0,  0, ..., 20, 20, 20],
       [ 0,  0,  0, ..., 20, 20, 20],
       [ 0,  0,  0, ..., 20, 20, 20],
       ...,
       [ 0,  0,  0, ..., 20, 20, 20],
       [ 0,  0,  0, ..., 20, 20, 20],
       [ 0,  0,  0, ..., 20, 20, 20]], dtype=int32)
%%time
run_numba_p(8000, 12000, 20)
CPU times: user 9.02 s, sys: 326 ms, total: 9.35 s
Wall time: 2.86 s
array([[ 0,  0,  0, ..., 20, 20, 20],
       [ 0,  0,  0, ..., 20, 20, 20],
       [ 0,  0,  0, ..., 20, 20, 20],
       ...,
       [ 0,  0,  0, ..., 20, 20, 20],
       [ 0,  0,  0, ..., 20, 20, 20],
       [ 0,  0,  0, ..., 20, 20, 20]], dtype=int32)