Why are JAX arrays immutable but NumPy arrays are not?

JAX was designed to implement the concepts of functional programming, which makes it easier to understand the behavior of a program and avoids common mistakes caused by mutable state. The immutability of JAX arrays is also based on this idea. Because in functional programming, functions do not modify their inputs, but instead create new objects as outputs. Therefore, in JAX, all array transformations are implemented as pure functions that take an input array and return a new array. This ensures that the original array remains unchanged, improving the readability and maintainability of your code.

# JAX: immutable arrays
x = jnp.arange(10)
x[0] = 10

# ---------------------------------------------------------------------------
# TypeError                                 Traceback (most recent call last)
# <ipython-input-7-6b90817377fe> in <module>()
#       1 # JAX: immutable arrays
#       2 x = jnp.arange(10)
# ----> 3 x[0] = 10

# TypeError: '<class 'jax.interpreters.xla._DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?

JAX provides an index update syntax for updating a single element, returning an updated copy:

y = x.at[0].set(10)
print(x)
print(y)

# [0 1 2 3 4 5 6 7 8 9]
# [10  1  2  3  4  5  6  7  8  9]

In contrast, NumPy arrays are mutable by default, which means that you can directly modify the contents of an array without creating a new one. While this may be convenient, it can also make the behavior of the program more difficult to understand and lead to bugs. For example, using mutable NumPy arrays in a multi-threaded environment can lead to problems such as race conditions.

# NumPy: mutable arrays
x = np.arange(10)
x[0] = 10
print(x)

# [10  1  2  3  4  5  6  7  8  9]

However, sometimes arrays must be modified in-place to improve performance, especially when dealing with large datasets. To this end, JAX provides methods to modify arrays in-place if necessary, for example using jax.ops.index_updatefunctions or jax.laxmodules. These operations are uncommon in JAX code and are usually only used when performance optimization is required. They are still pure functions because they return a new array instead of modifying the original array.
In conclusion, the design of JAX follows the concept of functional programming and ensures that the original array remains unchanged through immutability to improve the readability and maintainability of the code. While JAX provides methods to modify arrays in-place, these operations are still pure functions and are usually only used when performance optimization is required.

おすすめ

転載: blog.csdn.net/bigbaojian/article/details/129734739