Skip to content

Releases: jax-ml/jax

JAX v0.10.0

16 Apr 13:09

Choose a tag to compare

  • New features:

    • Added ResizeMethod.CUBIC_PYTORCH to jax.image.resize to match
      PyTorch's bicubic resize (#15768).
    • We now support differentiation of jax.lax.linalg.qr for wide
      matrices and when full_matrices is True.
    • LAPACK operations are now parallelized along the batch dimension on CPU.
    • Added perturb_singular argument to
      jax.lax.linalg.tridiagonal_solve to handle singular matrices by
      perturbing near-zero pivots in the LU decomposition. This is useful for
      solving numerically singular systems when computing eigenvectors by inverse
      iteration.
    • jax.scipy.linalg.eigh_tridiagonal now supports computing
      eigenvectors on CPU and GPU.
    • Added the jax.numpy.ndarray.byteswap method.
  • Breaking changes:

    • PartitionSpec objects no longer report themselves to be equal to tuples.
      Convert tuples to PartitionSpec objects before testing equality.
    • The .vma property has been removed from jax.core.ShapedArray. Use
      .manual_axis_type.varying instead.
    • JAX CPU devices now report their names as cpu:0, cpu:1, etc. instead of
      TFRT_CPU_0, TFRT_CPU_1.
    • The config state jax_pmap_shmap_merge has been removed. jax.pmap
      will now always use the new implementation that wraps
      jax.jit(jax.shard_map). Please see
      https://docs.jax.dev/en/latest/migrate_pmap.html for more information.
    • jax.device_put_sharded and jax.device_put_replicated have been removed
      from the public API and now raise an AttributeError when accessed.
      Please see
      https://docs.jax.dev/en/latest/migrate_pmap.html#drop-in-replacements for
      drop-in replacements.
    • The C++ pmap infrastructure has been removed. The following public APIs
      are no longer available:
      • jax.sharding.PmapSharding
      • From jaxlib.xla_extension: PmapFunction, pmap,
        NoSharding, Chunked, Unstacked, ShardedAxis, Replicated,
        ShardingSpec.
      • From jax.interpreters.pxla: MapTracer, PmapExecutable,
        parallel_callable, shard_args, xla_pmap_p, Chunked,
        NoSharding, Replicated, ShardedAxis, ShardingSpec,
        Unstacked, spec_to_indices.
    • The deprecated keyword arguments a, a_min, and a_max to
      jax.numpy.clip have been removed.
    • Functions jax.numpy.hstack, jax.numpy.vstack, jax.numpy.dstack,
      jax.numpy.column_stack, jax.numpy.atleast_1d, jax.numpy.atleast_2d,
      and jax.numpy.atleast_3d no longer accept non-ArrayLike inputs.
      Doing so previously issued a DeprecationWarning.
    • jax.scipy.stats.rankdata now returns floating point values in
      all cases, following a similar change in the SciPy 1.18 release.
  • Deprecations:

    • A number of internal APIs in jax.core have been newly deprecated and
      some have been moved to jax.extend.core. These include CallPrimitive,
      DebugInfo, DropVar, Effect, Effects, InconclusiveDimensionOperation,
      JaxprTypeError, check_jaxpr, concrete_or_error, find_top_trace,
      gensym, get_opaque_trace_state, jaxprs_in_params, new_jaxpr_eqn,
      no_effects, nonempty_axis_env_DO_NOT_USE, primal_dtype_to_tangent_dtype,
      unsafe_am_i_under_a_jit_DO_NOT_USE, unsafe_am_i_under_a_vmap_DO_NOT_USE,
      unsafe_get_axis_names_DO_NOT_USE, valid_jaxtype, JaxprPpContext,
      JaxprPpSettings, OutputType, abstract_token, aval_mapping_handlers,
      call, concretization_function_error, custom_typechecks, is_concrete,
      is_constant_dim, is_constant_shape, literalable_types, no_axis_name,
      pytype_aval_mappings, and trace_ctx.
  • Changes:

    • The minimum supported SciPy version is now 1.14.
    • vma parameter of jax.ShapeDtypeStruct has been replaced with
      manual_axis_type: jax.sharding.ManualAxisType. The .vma property has
      been replaced with .manual_axis_type.varying.
    • Removed experimental jax.experimental.custom_dce.custom_dce
    • jax.scipy.linalg.cho_solve, jax.scipy.linalg.lu_solve, and
      jax.scipy.linalg.solve_triangular now show a deprecation warning for
      batched 1D solves with b.ndim > 1. In the future these will be treated as
      batched 2D solves.
    • Added a new version 10 for the jax.export serialization format. This is
      an optimization for when there are multiple occurrences of the same
      abstract value, abstract mesh, or sharding.
  • Bug fixes:

    • Fixed a bug that led to differing output between CPU and GPU for
      non-symmetric multidimensional IRFFTs (#29325).
    • Fixed an error when tiny matrices were passed to
      jax.lax.linalg.tridiagonal_solve on GPU (#32487).
    • Fixed a bug in jax.scipy.fft.dctn and idctn where axes=None
      incorrectly defaulted to all axes when s was specified, instead of the
      last len(s) axes to match SciPy behavior (#29426).
    • Fixed a bug where calling jax.distributed.initialize() on a GCE TPU
      Managed Instance Group raised an IndexError (#36593). When
      jax.distributed.initialize() is called on a GCE VM, it uses the GCE
      metadata server to learn the addresses of all participating tasks. The format of this metadata
      on Managed Instance Groups was not a format JAX expected, leading to the
      exception. We now parse this format correctly.

JAX v0.9.2

18 Mar 23:40

Choose a tag to compare

JAX 0.9.2 (March 2, 2026)

  • Changes:
    • The semi-private type jax._src.literals.TypedNdArray is now a subclass of
      np.ndarray, rather than a duck type of it.
    • jax.numpy.arange with step specified no longer generates the array
      on host. The benefit is more efficient code, though this can lead to less
      precise outputs for narrow-width floats (e.g. bfloat16). To recover the
      previous behavior in this case, use jnp.array(np.arange(...)).

JAX v0.9.1

02 Mar 11:13

Choose a tag to compare

  • Changes:

    • JAX tracers that are not of Array type (e.g., of Ref type) will no
      longer report themselves to be instances of Array.
    • Using jax.shard_map in Explicit mode will raise an error
      if the PartitionSpec of input does not match the PartitionSpec specified in
      in_specs. In other words, it will act like an assert instead of an
      implicit reshard.
      in_specs is an optional argument so you can omit specifying it
      and shard_map will infer the PartitionSpec from the argument. If you
      want to reshard your inputs, you can use jax.reshard on the arguments and
      then pass those args to shard_map.
  • New features:

    • Added a debug config jax_compilation_cache_check_contents. If set, we miss
      when get() is called on a value that has not been put() by the current
      process, even if the value is actually in the disk cache. When a value is
      put(), we verify that its contents match.

JAX v0.9.0.1

05 Feb 18:51

Choose a tag to compare

JAX v0.9.0.1 is identical to v0.9.0 with the commits from the following four PRs patched in:

JAX v0.8.3

29 Jan 23:10

Choose a tag to compare

JAX v0.8.3 is identical to v0.8.2 with the following two bug fixes patched in:

JAX v0.9.0

20 Jan 23:23

Choose a tag to compare

  • New features:

    • Added jax.thread_guard, a context manager that detects when devices
      are used by multiple threads in multi-controller JAX.
  • Bug fixes:

    • Fixed a workspace size calculation error for pivoted QR (magma_zgeqp3_gpu)
      in MAGMA 2.9.0 when using use_magma=True and pivoting=True.
      (#34145).
  • Deprecations:

    • The flag jax_collectives_common_channel_id was removed.
    • The jax_pmap_no_rank_reduction config state has been removed. The
      no-rank-reduction behavior is now the only supported behavior: a
      jax.pmapped function f sees inputs of the same rank as the input to
      jax.pmap(f). For example, if jax.pmap(f) receives shape (8, 128) on
      8 devices, then f receives shape (1, 128).
    • Setting the jax_pmap_shmap_merge config state is deprecated in JAX v0.9.0
      and will be removed in JAX v0.10.0.
    • jax.numpy.fix is deprecated, anticipating the deprecation of
      numpy.fix in NumPy v2.5.0. jax.numpy.trunc is a drop-in
      replacement.
  • Changes:

    • jax.export now supports explicit sharding. This required a new
      export serialization format version that includes the NamedSharding,
      including the abstract mesh, and the partition spec. As part of this
      change we have added a restriction in the use of exported modules: when
      calling them the abstract mesh must match the one used at export time,
      including the axis names. Previously, only the number of the devices
      mattered.

JAX v0.8.2

18 Dec 18:50

Choose a tag to compare

  • Deprecations

    • jax.lax.pvary has been deprecated.
      Please use jax.lax.pcast(..., to='varying') as the replacement.
    • Complex arguments passed to jax.numpy.arange now result in a
      deprecation warning, because the output is poorly-defined.
    • From jax.core a number of symbols are newly deprecated including:
      call_impl, get_aval, mapped_aval, subjaxprs, set_current_trace,
      take_current_trace, traverse_jaxpr_params, unmapped_aval,
      AbstractToken, and TraceTag.
    • All symbols in jax.interpreters.pxla are deprecated. These are
      primarily JAX internal APIs, and users should not rely on them.
  • Changes:

    • jax's Tracer no longer inherits from jax.Array at runtime. However,
      jax.Array now uses a custom metaclass such isinstance(x, Array) is true
      if an object x represents a traced Array. Only some Tracers represent
      Arrays, so it is not correct for Tracer to inherit from Array.

      For the moment, during Python type checking, we continue to declare Tracer
      as a subclass of Array, however we expect to remove this in a future
      release.

    • jax.experimental.si_vjp has been deleted.
      jax.vjp subsumes it's functionality.

JAX v0.8.1

18 Nov 18:45

Choose a tag to compare

  • New features:

    • jax.jit now supports the decorator factory pattern; i.e instead of
      writing
      @functools.partial(jax.jit, static_argnames=['n'])
      def f(x, n):
        ...
      you may write
      @jax.jit(static_argnames=['n'])
      def f(x, n):
        ...
  • Changes:

    • jax.lax.linalg.eigh now accepts an implementation argument to
      select between QR (CPU/GPU), Jacobi (GPU/TPU), and QDWH (TPU)
      implementations. The EighImplementation enum is publicly exported from
      jax.lax.linalg.

    • jax.lax.linalg.svd now implements an algorithm that uses the polar
      decomposition on CUDA GPUs. This is also an alias for the existing algorithm
      on TPUs.

  • Bug fixes:

    • Fixed a bug introduced in JAX 0.7.2 where eigh failed for large matrices on
      GPU (#33062).
  • Deprecations:

    • jax.sharding.PmapSharding is now deprecated. Please use
      jax.NamedSharding instead.
    • jx.device_put_replicated is now deprecated. Please use jax.device_put
      with the appropriate sharding instead.
    • jax.device_put_sharded is now deprecated. Please use jax.device_put with
      the appropriate sharding instead.
    • Default axis_types of jax.make_mesh will change in JAX v0.9.0 to return
      jax.sharding.AxisType.Explicit. Leaving axis_types unspecified will raise a
      DeprecationWarning.
    • jax.cloud_tpu_init and its contents were deprecated. There is no reason for a user to import or use the contents of this module; JAX handles this for you automatically if needed.

JAX v0.8.0

15 Oct 23:38

Choose a tag to compare

  • Breaking changes:

    • JAX is changing the default jax.pmap implementation to one implemented in
      terms of jax.jit and jax.shard_map. jax.pmap is in maintenance mode
      and we encourage all new code to use jax.shard_map directly. See the
      migration guide for
      more information.
    • The auto= parameter of jax.experimental.shard_map.shard_map has been
      removed. This means that jax.experimental.shard_map.shard_map no longer
      supports nesting. If you want to nest shard_map calls, please use
      jax.shard_map.
    • JAX no longer allows passing objects that support __jax_array__ directly
      to, e.g. jit-ed functions. Call jax.numpy.asarray on them first.
    • jax.numpy.cov is now returns NaN for empty arrays ({jax-issue}#32305),
      and matches NumPy 2.2 behavior for single-row design matrices ({jax-issue}#32308).
    • JAX no longer accepts Array values where a dtype value is expected. Call
      .dtype on these values first.
    • The deprecated function jax.interpreters.mlir.custom_call was
      removed.
    • The jax.util, jax.extend.ffi, and jax.experimental.host_callback
      modules have been removed. All public APIs within these modules were
      deprecated and removed in v0.7.0 or earlier.
    • The deprecated symbol jax.custom_derivatives.custom_jvp_call_jaxpr_p
      was removed.
    • jax.experimental.multihost_utils.process_allgather raises an error when
      the input is a jax.Array and not fully-addressable and tiled=False. To fix
      this, pass tiled=True to your process_allgather invocation.
    • from jax.experimental.compilation_cache, the deprecated symbols
      is_initialized and initialize_cache were removed.
    • The deprecated function jax.interpreters.xla.canonicalize_dtype
      was removed.
    • jaxlib.hlo_helpers has been removed. Use jax.ffi instead.
    • The option jax_cpu_enable_gloo_collectives has been removed. Use
      jax_cpu_collectives_implementation instead.
    • The previously-deprecated interpolation argument to
      jax.numpy.percentile and jax.numpy.quantile has been
      removed; use method instead.
    • The JAX-internal for_loop primitive was removed. Its functionality,
      reading from and writing to refs in the loop body, is now directly
      supported by jax.lax.fori_loop. If you need help updating your
      code, please file a bug.
    • jax.numpy.trimzeros now errors for non-1D input.
    • The where argument to jax.numpy.sum and other reductions is now
      required to be boolean. Non-boolean values have resulted in a
      DeprecationWarning since JAX v0.5.0.
    • The deprecated functions in jax.dlpack, jax.errors,
      jax.lib.xla_bridge, jax.lib.xla_client, and
      jax.lib.xla_extension were removed.
    • jax.interpreters.mlir.dense_bool_array was removed. Use MLIR APIs to
      construct attributes instead.
  • Changes

    • jax.numpy.linalg.eig now returns a namedtuple (with attributes
      eigenvalues and eigenvectors) instead of a plain tuple.
    • jax.grad and jax.vjp will now round always primals to
      float32 if float64 mode is not enabled.
    • jax.dlpack.from_dlpack now accepts arrays with non-default layouts,
      for example, transposed.
    • The default nonsymmetric eigendecomposition on NVIDIA GPUs now uses
      cusolver. The magma and LAPACK implementations are still available via the
      new implementation argument to jax.lax.linalg.eig
      ({jax-issue}#27265). The use_magma argument is now deprecated in favor
      of implementation.
    • jax.numpy.trim_zeros now follows NumPy 2.2 in supporting
      multi-dimensional inputs.
  • Deprecations

    • jax.experimental.enable_x64 and jax.experimental.disable_x64
      are deprecated in favor of the new non-experimental context manager
      jax.enable_x64.
    • jax.experimental.shard_map.shard_map is deprecated; going forward use
      jax.shard_map.
    • jax.experimental.pjit.pjit is deprecated; going forward use
      jax.jit.

JAX v0.7.2

16 Sep 17:19

Choose a tag to compare

  • Breaking changes:

    • jax.dlpack.from_dlpack no longer accepts a DLPack capsule. This
      behavior was deprecated and is now removed. The function must be called
      with an array implementing __dlpack__ and __dlpack_device__.
  • Changes

    • The minimum supported NumPy version is now 2.0. Since SciPy 1.13 is required
      for NumPy 2.0 support, the minimum supported SciPy version is now 1.13.

    • JAX now represents constants in its internal jaxpr representation as a
      LiteralArray, which is a private JAX type that duck types as a
      numpy.ndarray. This type may be exposed to users via custom_jvp rules,
      for example, and may break code that uses isinstance(x, np.ndarray). If
      this breaks your code, you may convert these arrays to classic NumPy arrays
      using np.asarray(x).

  • Bug fixes

    • arr.view(dtype=None) now returns the array unchanged, matching NumPy's
      semantics. Previously it returned the array with a float dtype.
    • jax.random.randint now produces a less-biased distribution for 8-bit and
      16-bit integer types ({jax-issue}#27742). To restore the previous biased
      behavior, you may temporarily set the jax_safer_randint configuration to
      False, but note this is a temporary config that will be removed in a
      future release.
  • Deprecations:

    • The parameters enable_xla and native_serialization for jax2tf.convert
      are deprecated and will be removed in a future version of JAX. These were
      used for jax2tf with non-native serialization, which has been now removed.
    • Setting the config state jax_pmap_no_rank_reduction to False is
      deprecated. By default, jax_pmap_no_rank_reduction will be set to True
      and jax.pmap shards will not have their rank reduced, keeping the same
      rank as their enclosing array.