Dynamic slice in JAX
Published: 2024-11-30
Basic slicing in Jax
If you have a jax array you want update the value, instead of using the normal numpy way of
array[index] = new_value
You have to use the at
method in jax
array = array.at[index].set_value(new_value)
However, if you are trying to update a slice of the array with the normal at
method, say the following,
array = array.at[index: index+1].set_value(new_value)
you may run into the following error:
IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
Dynamic slicing in Jax
To update a slice of the array, you have to use the lax.dynamic_update_slice
method in jax. Here is an example of how you can update a slice of the array:
array = jnp.zeros(10)
@jax.jit
def update_slice(array, index, new_value):
return jax.lax.dynamic_update_slice(array, new_value, (index,))
array = update_slice(array, 3, jnp.array([1.]))
There is also a convenient method lax.dynamic_update_slice_in_dim
if you want to update a slice in a specific dimension. Here is an example of how you can update a slice in the first dimension of the array:
array = jnp.zeros((10,2))
@jax.jit
def update_slice(array, index, new_value):
return jax.lax.dynamic_update_slice_in_dim(array, new_value, index, 0)
array = update_slice(array, 3, jnp.array([[1.,2.]]))
This is useful when you modify some values in an array which you do not where it is in ahead of time.
Use cases and Caveats
I encountered this issue when I was trying to write an optimizer in jax, which requires me to swap some numbers in and out of an array. This may also be useful down the road for my research in building an adaptive mesh refinement solver for PDEs in Jax. While Jax doesn’t support dynamic allocation, as long as the slice of the array does not change its size, this should be a way to mutate the values of an array. There is also a sister operation in Jax called jax.lax.dynamic_index_in_dim
which is useful if you want to index instead of slice.
The caveat here is one may wants to do the following to get the slice of the array:
jax.lax.dynamic_slice(array, start_index, slice_size)
where slice_size
is also a run time value. I think Jax won’t like this very much, since now the shape of the object is runtime determined. In this case, the best bet you have is probably make the slice_size
a static value, which means you will recompile the function every time you want to change the size of the slice. This is not ideal, but it is what it is.