Jax code#
A simple jax
based implementation looks like this.
import jax.numpy as jnp
The core implementation looks familiar, but actually differs quite significantly, as jax
does not allow inplace modifications.
So instead of array modified mask
based approach, we use the triple argument version of .where
, which by the way would also have been possible in the numpy
version of the code.
Here new arrays will be generated for each iteration discarding the old ones. Just a small overhead though.
def mandel(c, max_iterations):
c0 = c.copy()
iterations = jnp.zeros_like(c, dtype=jnp.uint32)
for iteration in range(max_iterations):
mask = jnp.abs(c) < 2.0
c = jnp.where(mask, c**2 + c0, c)
iterations = jnp.where(mask, iteration, iterations)
return iterations