Coding in Jax is fairly straight forward, as advertised on their github repo.
However, thinking in Jax takes quite a bit of suffering to get used to. Here are some common antipatterns I have come across while coding in Jax.
Let say I have some function I want to loop over
def many_f(x):
for i in range(100):
x = f(x)
return x