Anti patterns in JAX
Common antipatterns in JAX
2023-12-03

Coding in Jax is fairly straight forward, as advertised on their github repo.

However, thinking in Jax takes quite a bit of suffering to get used to. Here are some common antipatterns I have come across while coding in Jax.

For-loops

Let say I have some function I want to loop over

def many_f(x):
    for i in range(100):
        x = f(x)
    return x

In-place replacement