风流倜傥的脆皮肠 · jax.numpy.fft.fftn — ...· 1 周前 · |
爱听歌的鞭炮 · jax.scipy.special.gamm ...· 3 月前 · |
潇洒的小虾米 · JAX Errors — JAX ...· 9 月前 · |
叛逆的橡皮擦 · Change log — JAX ...· 11 月前 · |
会搭讪的蚂蚁 · 🔪 JAX - The Sharp ...· 1 年前 · |
痴情的啄木鸟 · 正本清源再出发,香港开启“一国两制”新征程· 1 月前 · |
文质彬彬的红金鱼 · Citra® Hop Pellets ...· 1 月前 · |
重感情的沙滩裤 · Unable to use the ...· 4 月前 · |
酒量大的芹菜 · 使用electron-packager打包e ...· 5 月前 · |
jax.debug.print
and
jax.debug.breakpoint
checkify
transformation
xmap
jax.checkpoint
(aka
jax.remat
)
jax.numpy
module
jax.scipy
module
jax.lax
module
jax.random
module
jax.sharding
module
jax.debug
module
jax.dlpack
module
jax.distributed
module
jax.dtypes
module
jax.flatten_util
module
jax.image
module
jax.nn
module
jax.nn.initializers
module
jax.ops
module
jax.profiler
module
jax.stages
module
jax.tree_util
module
jax.typing
module
jax.extend
module
jax.example_libraries
module
jax.example_libraries.optimizers
module
jax.example_libraries.stax
module
jax.experimental
module
jax.experimental.array_api
module
jax.experimental.checkify
module
jax.experimental.host_callback
module
jax.experimental.maps
module
jax.experimental.pjit
module
jax.experimental.sparse
module
jax.experimental.jet
module
jax.experimental.custom_partitioning
module
jax.experimental.multihost_utils
module
jax.experimental.compilation_cache
module
jax.lib
module
The minimum jaxlib version is now 0.4.19.
Released wheels are built now with clang instead of gcc.
Deprecations
The previously-deprecated
sym_pos
argument has been removed from
jax.scipy.linalg.solve()
. Use
assume_a='pos'
instead.
Passing
None
to
jax.array()
or
jax.asarray()
, either directly or
within a list or tuple, is deprecated and now raises a
FutureWarning
.
It currently is converted to NaN, and in the future will raise a
TypeError
.
Passing the
condition
,
x
, and
y
parameters to
jax.numpy.where
by
keyword arguments has been deprecated, to match
numpy.where
.
In preparation for adding distributed CPU support, JAX now treats CPU devices identically to GPU and TPU devices, that is:
jax.devices()
includes all devices present in a distributed job, even
those not local to the current process.
jax.local_devices()
still only
includes devices local to the current process, so if the change to
jax.devices()
breaks you, you most likely want to use
jax.local_devices()
instead.
CPU devices now receive a globally unique ID number within a distributed job; previously CPU devices would receive a process-local ID number.
The
process_index
of each CPU device will now match any GPU or TPU
devices within the same process; previously the
process_index
of a CPU
device was always 0.
On NVIDIA GPU, JAX now prefers a Jacobi SVD solver for matrices up to 1024x1024. The Jacobi solver appears faster than the non-Jacobi version.
Bug fixes
Fixed error/hang when an array with non-finite values is passed to a non-symmetric eigendecomposition (#18226). Arrays with non-finite values now produce arrays full of NaNs as outputs.
Added
jax.typing.DTypeLike
, which can be used to annotate objects that
are convertible to JAX dtypes.
Added
jax.numpy.fill_diagonal
.
Changes
JAX now requires SciPy 1.9 or newer.
Bug fixes
Only process 0 in a multicontroller distributed JAX program will write persistent compilation cache entries. This fixes write contention if the cache is placed on a network filesystem such as GCS.
The version check for cusolver and cufft no longer considers the patch versions when determining if the installed version of these libraries is at least as new as the versions against which JAX was built.
jaxlib will now always prefer pip-installed NVIDIA CUDA libraries
(nvidia-… packages) over any other CUDA installation if they are
installed, including installations named in
LD_LIBRARY_PATH
. If this
causes problems and the intent is to use a system-installed CUDA, the fix is
to remove the pip installed CUDA library packages.
CUDA jaxlibs now depend on the user to install a compatible NCCL version.
If using the recommended
cuda12_pip
installation, NCCL should be installed
automatically. Currently, NCCL 2.16 or newer is required.
We now provide Linux aarch64 wheels, both with and without NVIDIA GPU support.
Deprecations
A number of internal utilities and inadvertent exports in
jax.lax
have
been deprecated, and will be removed in a future release.
jax.lax.dtypes
: use
jax.dtypes
instead.
jax.lax.itertools
: use
itertools
instead.
naryop
,
naryop_dtype_rule
,
standard_abstract_eval
,
standard_naryop
,
standard_primitive
,
standard_unop
,
unop
, and
unop_dtype_rule
are
internal utilities, now deprecated without replacement.
Added new
jax.numpy.bitwise_count()
function, matching the API of the simlar
function recently added to NumPy.
Deprecations
Removed the deprecated module
jax.abstract_arrays
and all its contents.
Named key constructors in
jax.random
are deprecated. Pass the
impl
argument
to
jax.random.PRNGKey()
or
jax.random.key()
instead:
random.threefry2x32_key(seed)
becomes
random.PRNGKey(seed,
impl='threefry2x32')
random.rbg_key(seed)
becomes
random.PRNGKey(seed,
impl='rbg')
random.unsafe_rbg_key(seed)
becomes
random.PRNGKey(seed,
impl='unsafe_rbg')
CUDA: JAX now verifies that the CUDA libraries it finds are at least as new as the CUDA libraries that JAX was built against. If older libraries are found, JAX raises an exception since that is preferable to mysterious failures and crashes.
Removed the “No GPU/TPU” found warning. Instead warn if, on Linux, an
NVIDIA GPU or a Google TPU are found but not used and
--jax_platforms
was
not specified.
jax.scipy.stats.mode()
now returns a 0 count if the mode is taken
across a size-0 axis, matching the behavior of
scipy.stats.mode
in SciPy
1.11.
Most
jax.numpy
functions and attributes now have fully-defined type stubs.
Previously many of these were treated as
Any
by static type checkers like
mypy
and
pytype
.
Python 3.12 wheels were added in this release.
The CUDA 12 wheels now require CUDA 12.2 or newer and cuDNN 8.9.4 or newer.
Bug fixes:
Fixed log spam from ABSL when the JAX CPU backend was initialized.
Added
jax.numpy.ufunc
, as well as
jax.numpy.frompyfunc()
, which can convert
any scalar-valued function into a
numpy.ufunc()
-like object, with methods such as
outer()
,
reduce()
,
accumulate()
,
at()
, and
reduceat()
(
#17054
).
Added
jax.scipy.integrate.trapezoid()
.
When not running under IPython: when an exception is raised, JAX now filters out the
entirety of its internal frames from tracebacks. (Without the “unfiltered stack trace”
that previously appeared.) This should produce much friendlier-looking tracebacks. See
here
for an example.
This behavior can be changed by setting
JAX_TRACEBACK_FILTERING=remove_frames
(for two
separate unfiltered/filtered tracebacks, which was the old behavior) or
JAX_TRACEBACK_FILTERING=off
(for one unfiltered traceback).
jax2tf default serialization version is now 7, which introduces new shape safety assertions .
Devices passed to
jax.sharding.Mesh
should be hashable. This specifically
applies to mock devices or user created devices.
jax.devices()
are
already hashable.
Breaking changes:
jax2tf now uses native serialization by default. See the jax2tf documentation for details and for mechanisms to override the default.
The option
--jax_coordination_service
has been removed. It is now always
True
.
jax.jaxpr_util
has been removed from the public JAX namespace.
JAX_USE_PJRT_C_API_ON_TPU
no longer has an effect (i.e. it always defaults to true).
The backwards compatibility flag
--jax_host_callback_ad_transforms
introduced in December 2021, has been removed.
Deprecations:
Several
jax.numpy
APIs have been deprecated following
NumPy NEP-52
:
jax.numpy.NINF
has been deprecated. Use
-jax.numpy.inf
instead.
jax.numpy.PZERO
has been deprecated. Use
0.0
instead.
jax.numpy.NZERO
has been deprecated. Use
-0.0
instead.
jax.numpy.issubsctype(x,
t)
has been deprecated. Use
jax.numpy.issubdtype(x.dtype,
t)
.
jax.numpy.row_stack
has been deprecated. Use
jax.numpy.vstack
instead.
jax.numpy.in1d
has been deprecated. Use
jax.numpy.isin
instead.
jax.numpy.trapz
has been deprecated. Use
jax.scipy.integrate.trapezoid
instead.
jax.scipy.linalg.tril
and
jax.scipy.linalg.triu
have been deprecated,
following SciPy. Use
jax.numpy.tril
and
jax.numpy.triu
instead.
jax.lax.prod
has been removed after being deprecated in JAX v0.4.11.
Use the built-in
math.prod
instead.
A number of exports from
jax.interpreters.xla
related to defining
HLO lowering rules for custom JAX primitives have been deprecated. Custom
primitives should be defined using the StableHLO lowering utilities in
jax.interpreters.mlir
instead.
The following previously-deprecated functions have been removed after a three-month deprecation period:
jax.abstract_arrays.ShapedArray
: use
jax.core.ShapedArray
.
jax.abstract_arrays.raise_to_shaped
: use
jax.core.raise_to_shaped
.
jax.numpy.alltrue
: use
jax.numpy.all
.
jax.numpy.sometrue
: use
jax.numpy.any
.
jax.numpy.product
: use
jax.numpy.prod
.
jax.numpy.cumproduct
: use
jax.numpy.cumprod
.
Deprecations/removals:
The internal submodule
jax.prng
is now deprecated. Its contents are available at
jax.extend.random
.
The internal submodule path
jax.linear_util
has been deprecated. Use
jax.extend.linear_util
instead (Part of
jax.extend: a module for extensions
)
jax.random.PRNGKeyArray
and
jax.random.KeyArray
are deprecated. Use
jax.Array
for type annotations, and
jax.dtypes.issubdtype(arr.dtype,
jax.dtypes.prng_key)
for
runtime detection of typed prng keys.
The method
PRNGKeyArray.unsafe_raw_array
is deprecated. Use
jax.random.key_data()
instead.
jax.experimental.pjit.with_sharding_constraint
is deprecated. Use
jax.lax.with_sharding_constraint
instead.
The internal utilities
jax.core.is_opaque_dtype
and
jax.core.has_opaque_dtype
have been removed. Opaque dtypes have been renamed to Extended dtypes; use
jnp.issubdtype(dtype,
jax.dtypes.extended)
instead (available since jax v0.4.14).
The utility
jax.interpreters.xla.register_collective_primitive
has been
removed. This utility did nothing useful in recent JAX releases and calls
to it can be safely removed.
The internal submodule path
jax.linear_util
has been deprecated. Use
jax.extend.linear_util
instead (Part of
jax.extend: a module for extensions
)
Sparse CSR matrix multiplications via the experimental jax sparse APIs no longer uses a deterministic algorithm on NVIDIA GPUs. This change was made to improve compatibility with CUDA 12.2.1.
Bug fixes:
Fixed a crash on Windows due to a fatal LLVM error related to out-of-order sections and IMAGE_REL_AMD64_ADDR32NB relocations (https://github.com/openxla/xla/commit/cb732a921f0c4184995cbed82394931011d12bd4).
jax.jit
takes
donate_argnames
as an argument. It’s semantics are similar
to
static_argnames
.
If neither donate_argnums nor donate_argnames is provided, no
arguments are donated. If donate_argnums is not provided but
donate_argnames is, or vice versa, JAX uses
inspect.signature(fun)
to find any positional arguments that
correspond to donate_argnames (or vice versa). If both donate_argnums and donate_argnames are provided, inspect.signature is not used, and only actual
parameters listed in either donate_argnums or donate_argnames will
be donated.
jax.random.gamma()
has been re-factored to a more efficient algorithm
with more robust endpoint behavior (
#16779
). This means that the
sequence of values returned for a given
key
will change between JAX v0.4.13
and v0.4.14 for
gamma
and related samplers (including
jax.random.ball()
,
jax.random.beta()
,
jax.random.chisquare()
,
jax.random.dirichlet()
,
jax.random.generalized_normal()
,
jax.random.loggamma()
,
jax.random.t()
).
Deletions
in_axis_resources
and
out_axis_resources
have been deleted from pjit since
it has been more than 3 months since their deprecation. Please use
in_shardings
and
out_shardings
as the replacement.
This is a safe and trivial name replacement. It does not change any of the
current pjit semantics and doesn’t break any code.
You can still pass in
PartitionSpecs
to in_shardings and out_shardings.
Deprecations
Python 3.8 support has been dropped as per https://jax.readthedocs.io/en/latest/deprecation.html
JAX now requires NumPy 1.22 or newer as per https://jax.readthedocs.io/en/latest/deprecation.html
Passing optional arguments to
jax.numpy.ndarray.at()
by position is
no longer supported, after being deprecated in JAX version 0.4.7.
For example, instead of
x.at[i].get(True)
, use
x.at[i].get(indices_are_sorted=True)
The following
jax.Array
methods have been removed, after being deprecated
in JAX v0.4.5:
jax.Array.broadcast
: use
jax.lax.broadcast()
instead.
jax.Array.broadcast_in_dim
: use
jax.lax.broadcast_in_dim()
instead.
jax.Array.split
: use
jax.numpy.split()
instead.
The following APIs have been removed after previous deprecation:
jax.ad
: use
jax.interpreters.ad
.
jax.curry
: use
curry
=
lambda
f:
partial(partial,
f)
.
jax.partial_eval
: use
jax.interpreters.partial_eval
.
jax.pxla
: use
jax.interpreters.pxla
.
jax.xla
: use
jax.interpreters.xla
.
jax.ShapedArray
: use
jax.core.ShapedArray
.
jax.interpreters.pxla.device_put
: use
jax.device_put()
.
jax.interpreters.pxla.make_sharded_device_array
: use
jax.make_array_from_single_device_arrays()
.
jax.interpreters.pxla.ShardedDeviceArray
: use
jax.Array
.
jax.numpy.DeviceArray
: use
jax.Array
.
jax.stages.Compiled.compiler_ir
: use
jax.stages.Compiled.as_text()
.
JAX now requires ml_dtypes version 0.2.0 or newer.
To fix a corner case, calls to
jax.lax.cond()
with five
arguments will always resolve to the “common operands”
cond
behavior (as documented) if the second and third arguments are
callable, even if other operands are callable as well. See
#16413
.
The deprecated config options
jax_array
and
jax_jit_pjit_api_merge
,
which did nothing, have been removed. These options have been true by
default for many releases.
New features
JAX now supports a configuration flag –jax_serialization_version and a JAX_SERIALIZATION_VERSION environment variable to control the serialization version ( #16746 ).
jax2tf in presence of shape polymorphism now generates code that checks certain shape constraints, if the serialization version is at least 7. See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism.
jax.jit
now allows
None
to be passed to
in_shardings
and
out_shardings
. The semantics are as follows:
For in_shardings, JAX will mark is as replicated but this behavior can change in the future.
For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.
jax.experimental.pjit.pjit
also allows
None
to be passed to
in_shardings
and
out_shardings
. The semantics are as follows:
If the mesh context manager is not provided, JAX has the freedom to choose whatever sharding it wants.
For in_shardings, JAX will mark is as replicated but this behavior can change in the future.
For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.
If the mesh context manager is provided, None will imply that the value will be replicated on all devices of the mesh.
Executable.cost_analysis() works on Cloud TPU
Added a warning if a non-allowlisted
jaxlib
plugin is in use.
Added
jax.tree_util.tree_leaves_with_path
.
None
is not a valid input to
jax.experimental.multihost_utils.host_local_array_to_global_array
or
jax.experimental.multihost_utils.global_array_to_host_local_array
.
Please use
jax.sharding.PartitionSpec()
if you wanted to replicate your
input.
Bug fixes
Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel
is named
cudnn89
instead of
cudnn88
.
Deprecations
The
native_serialization_strict_checks
parameter to
jax.experimental.jax2tf.convert()
is deprecated in favor of the
new
native_serializaation_disabled_checks
(
#16347
).
__cuda_array_interface__
was broken in previous jaxlib versions and is now
fixed (
#16440
).
Concurrent CUDA kernel tracing is now enabled by default on NVIDIA GPUs.
jax.abstract_arrays
and its contents are now deprecated. See related
functionality in :mod:
jax.core
.
jax.numpy.alltrue
: use
jax.numpy.all
. This follows the deprecation
of
numpy.alltrue
in NumPy version 1.25.0.
jax.numpy.sometrue
: use
jax.numpy.any
. This follows the deprecation
of
numpy.sometrue
in NumPy version 1.25.0.
jax.numpy.product
: use
jax.numpy.prod
. This follows the deprecation
of
numpy.product
in NumPy version 1.25.0.
jax.numpy.cumproduct
: use
jax.numpy.cumprod
. This follows the deprecation
of
numpy.cumproduct
in NumPy version 1.25.0.
jax.sharding.OpShardingSharding
has been removed since it has been 3
months since it was deprecated.
Includes PTX/SASS for Hopper (SM version 9.0+) GPUs. Previous versions of jaxlib should work on Hopper but would have a long JIT-compilation delay the first time a JAX operation was executed.
Bug fixes
Fixes incorrect source line information in JAX-generated Python tracebacks under Python 3.11.
Fixes crash when printing local variables of frames in JAX-generated Python tracebacks (#16027).
The following APIs have been removed after a 3 month deprecation period, in accordance with the API compatibility policy:
jax.experimental.PartitionSpec
: use
jax.sharding.PartitionSpec
.
jax.experimental.maps.Mesh
: use
jax.sharding.Mesh
jax.experimental.pjit.NamedSharding
: use
jax.sharding.NamedSharding
.
jax.experimental.pjit.PartitionSpec
: use
jax.sharding.PartitionSpec
.
jax.experimental.pjit.FROM_GDA
. Instead pass sharded
jax.Array
objects
as input and remove the optional
in_shardings
argument to
pjit
.
jax.interpreters.pxla.PartitionSpec
: use
jax.sharding.PartitionSpec
.
jax.interpreters.pxla.Mesh
: use
jax.sharding.Mesh
jax.interpreters.xla.Buffer
: use
jax.Array
.
jax.interpreters.xla.Device
: use
jax.Device
.
jax.interpreters.xla.DeviceArray
: use
jax.Array
.
jax.interpreters.xla.device_put
: use
jax.device_put
.
jax.interpreters.xla.xla_call_p
: use
jax.experimental.pjit.pjit_p
.
axis_resources
argument of
with_sharding_constraint
is removed. Please
use
shardings
instead.
Added
memory_stats()
method to
Device
s. If supported, this returns a
dict of string stat names with int values, e.g.
"bytes_in_use"
, or None if
the platform doesn’t support memory statistics. The exact stats returned may
vary across platforms. Currently only implemented on Cloud TPU.
Readded support for the Python buffer protocol (
memoryview
) on CPU
devices.
The flags experimental_cpp_jit, experimental_cpp_pjit and experimental_cpp_pmap have been removed. They are now always on.
Accuracy of singular value decomposition (SVD) on TPU has been improved (requires jaxlib 0.4.9).
Deprecations
jax.experimental.gda_serialization
is deprecated and has been renamed to
jax.experimental.array_serialization
.
Please change your imports to use
jax.experimental.array_serialization
.
The
in_axis_resources
and
out_axis_resources
arguments of pjit have been
deprecated. Please use
in_shardings
and
out_shardings
respectively.
The function
jax.numpy.msort
has been removed. It has been deprecated since
JAX v0.4.1. Use
jnp.sort(a,
axis=0)
instead.
in_parts
and
out_parts
arguments have been removed from
jax.xla_computation
since they were only used with sharded_jit and sharded_jit is long gone.
instantiate_const_outputs
argument has been removed from
jax.xla_computation
since it has been unused for a very long time.
Breaking changes
A major component of the Cloud TPU runtime has been upgraded. This enables the following new features on Cloud TPU:
jax.debug.print()
,
jax.debug.callback()
, and
jax.debug.breakpoint()
now work on Cloud TPU
Automatic TPU memory defragmentation
jax.experimental.host_callback()
is no longer supported on Cloud TPU
with the new runtime component. Please file an issue on the
JAX issue
tracker
if the new
jax.debug
APIs
are insufficient for your use case.
The old runtime component will be available for at least the next three
months by setting the environment variable
JAX_USE_PJRT_C_API_ON_TPU=false
. If you find you need to disable the new
runtime for any reason, please let us know on the
JAX issue
tracker
.
Changes
The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7.
Deprecations
CUDA 11.4 support has been dropped. JAX GPU wheels only support CUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is built from source.
global_arg_shapes
argument of pmap only worked with sharded_jit and has
been removed from pmap. Please migrate to pjit and remove global_arg_shapes
from pmap.
Changes
As per https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration
jax.config.jax_array
cannot be disabled anymore.
jax.config.jax_jit_pjit_api_merge
cannot be disabled anymore.
jax.experimental.jax2tf.convert()
now supports the
native_serialization
parameter to use JAX’s native lowering to StableHLO to obtain a
StableHLO module for the entire JAX function instead of lowering each JAX
primitive to a TensorFlow op. This simplifies the internals and increases
the confidence that what you serialize matches the JAX native semantics.
See
documentation
.
As part of this change the config flag
--jax2tf_default_experimental_native_lowering
has been renamed to
--jax2tf_native_serialization
.
JAX now depends on
ml_dtypes
, which contains definitions of NumPy types
like bfloat16. These definitions were previously internal to JAX, but have
been split into a separate package to facilitate sharing them with other
projects.
JAX now requires NumPy 1.21 or newer and SciPy 1.7 or newer.
Deprecations
The type
jax.numpy.DeviceArray
is deprecated. Use
jax.Array
instead,
for which it is an alias.
The type
jax.interpreters.pxla.ShardedDeviceArray
is deprecated. Use
jax.Array
instead.
Passing additional arguments to
jax.numpy.ndarray.at()
by position is deprecated.
For example, instead of
x.at[i].get(True)
, use
x.at[i].get(indices_are_sorted=True)
jax.interpreters.xla.device_put
is deprecated. Please use
jax.device_put
.
jax.interpreters.pxla.device_put
is deprecated. Please use
jax.device_put
.
jax.experimental.pjit.FROM_GDA
is deprecated. Please pass in sharded
jax.Arrays as input and remove the
in_shardings
argument to pjit since
it is optional.
jaxlib now depends on
ml_dtypes
, which contains definitions of NumPy types
like bfloat16. These definitions were previously internal to JAX, but have
been split into a separate package to facilitate sharing them with other
projects.
Changes
jax.tree_util
now contain a set of APIs that allow user to define keys for their
custom pytree node. This includes:
tree_flatten_with_path
that flattens a tree and return not only each leaf but
also their key paths.
tree_map_with_paths
that can map a function that takes the key path as argument.
`register_pytree_with_keys`` to register how the key path and leaves should looks like in a custom pytree node.
keystr
that pretty-prints a key path.
jax2tf.call_tf()
has a new parameter
output_shape_dtype
(default
None
)
that can be used to declare the output shape and type of the result. This enables
jax2tf.call_tf()
to work in the presence of shape polymorphism. (
#14734
).
Deprecations
The old key-path APIs in
jax.tree_util
are deprecated and will be removed 3 months
from Mar 10 2023:
register_keypaths
: use
jax.tree_util.register_pytree_with_keys()
instead.
AttributeKeyPathEntry
: use
GetAttrKey
instead.
GetitemKeyPathEntry
: use
SequenceKey
or
DictKey
instead.
Deprecations
jax.sharding.OpShardingSharding
has been renamed to
jax.sharding.GSPMDSharding
.
jax.sharding.OpShardingSharding
will be removed in 3 months from Feb 17, 2023.
The following
jax.Array
methods are deprecated and will be removed 3 months from
Feb 23 2023:
jax.Array.broadcast
: use
jax.lax.broadcast()
instead.
jax.Array.broadcast_in_dim
: use
jax.lax.broadcast_in_dim()
instead.
jax.Array.split
: use
jax.numpy.split()
instead.
The implementation of
jit
and
pjit
has been merged. Merging jit and pjit
changes the internals of JAX without affecting the public API of JAX.
Before,
jit
was a final style primitive. Final style means that the creation
of jaxpr was delayed as much as possible and transformations were stacked
on top of each other. With the
jit
-
pjit
implementation merge,
jit
becomes an initial style primitive which means that we trace to jaxpr
as early as possible. For more information see
this section in autodidax
.
Moving to initial style should simplify JAX’s internals and make
development of features like dynamic shapes, etc easier.
You can disable it only via the environment variable i.e.
os.environ['JAX_JIT_PJIT_API_MERGE']
=
'0'
.
The merge must be disabled via an environment variable since it affects JAX
at import time so it needs to be disabled before jax is imported.
axis_resources
argument of
with_sharding_constraint
is deprecated.
Please use
shardings
instead. There is no change needed if you were using
axis_resources
as an arg. If you were using it as a kwarg, then please
use
shardings
instead.
axis_resources
will be removed after 3 months
from Feb 13, 2023.
added the
jax.typing
module, with tools for type annotations of JAX
functions.
The following names have been deprecated:
jax.xla.Device
and
jax.interpreters.xla.Device
: use
jax.Device
.
jax.experimental.maps.Mesh
. Use
jax.sharding.Mesh
instead.
jax.experimental.pjit.NamedSharding
: use
jax.sharding.NamedSharding
.
jax.experimental.pjit.PartitionSpec
: use
jax.sharding.PartitionSpec
.
jax.interpreters.pxla.Mesh
: use
jax.sharding.Mesh
.
jax.interpreters.pxla.PartitionSpec
: use
jax.sharding.PartitionSpec
.
the
initial
argument to reduction functions like :func:
jax.numpy.sum
is now required to be a scalar, consistent with the corresponding NumPy API.
The previous behavior of broadcating the output against non-scalar
initial
values was an unintentional implementation detail (
#14446
).
Support for NVIDIA Kepler series GPUs has been removed from the default
jaxlib
builds. If Kepler support is needed, it is still possible to
build
jaxlib
from source with Kepler support (via the
--cuda_compute_capabilities=sm_35
option to
build.py
), however note
that CUDA 12 has completely dropped support for Kepler GPUs.
Deleted
jax.scipy.linalg.polar_unitary()
, which was a deprecated JAX
extension to the scipy API. Use
jax.scipy.linalg.polar()
instead.
Changes
Added
jax.scipy.stats.rankdata()
.
jax.Array
now has the non-blocking
is_ready()
method, which returns
True
if the array is ready (see also
jax.block_until_ready()
).
Breaking changes
Deleted
jax.experimental.callback
Operations with dimensions in presence of jax2tf shape polymorphism have
been generalized to work in more scenarios, by converting the symbolic
dimension to JAX arrays. Operations involving symbolic dimensions and
np.ndarray
now can raise errors when the result is used as a shape value
(
#14106
).
jaxpr objects now raise an error on attribute setting in order to avoid problematic mutations ( #14102 )
Changes
jax2tf.call_tf()
has a new parameter
has_side_effects
(default
True
)
that can be used to declare whether an instance can be removed or replicated
by JAX optimizations such as dead-code elimination (
#13980
).
Added more support for floordiv and mod for jax2tf shape polymorphism. Previously, certain division operations resulted in errors in presence of symbolic dimensions ( #14108 ).
Support for Python 3.7 has been dropped, in accordance with JAX’s Python and NumPy version support policy .
We introduce
jax.Array
which is a unified array type that subsumes
DeviceArray
,
ShardedDeviceArray
, and
GlobalDeviceArray
types in JAX.
The
jax.Array
type helps make parallelism a core feature of JAX,
simplifies and unifies JAX internals, and allows us to unify
jit
and
pjit
.
jax.Array
has been enabled by default in JAX 0.4 and makes some
breaking change to the
pjit
API. The
jax.Array migration
guide
can
help you migrate your codebase to
jax.Array
. You can also look at the
Distributed arrays and automatic parallelization
tutorial to understand the new concepts.
PartitionSpec
and
Mesh
are now out of experimental. The new API endpoints
are
jax.sharding.PartitionSpec
and
jax.sharding.Mesh
.
jax.experimental.maps.Mesh
and
jax.experimental.PartitionSpec
are
deprecated and will be removed in 3 months.
with_sharding_constraint
s new public endpoint is
jax.lax.with_sharding_constraint
.
If using ABSL flags together with
jax.config
, the ABSL flag values are no
longer read or written after the JAX configuration options are initially
populated from the ABSL flags. This change improves performance of reading
jax.config
options, which are used pervasively in JAX.
The jax2tf.call_tf function now uses for TF lowering the first TF device of the same platform as used by the embedding JAX computation. Before, it was using the 0th device for the JAX-default backend.
A number of
jax.numpy
functions now have their arguments marked as
positional-only, matching NumPy.
jnp.msort
is now deprecated, following the deprecation of
np.msort
in numpy 1.24.
It will be removed in a future release, in accordance with the
API compatibility
policy. It can be replaced with
jnp.sort(a,
axis=0)
.
Support for Python 3.7 has been dropped, in accordance with JAX’s Python and NumPy version support policy .
The behavior of
XLA_PYTHON_CLIENT_MEM_FRACTION=.XX
has been changed to allocate XX% of
the total GPU memory instead of the previous behavior of using currently available GPU memory
to calculate preallocation. Please refer to
GPU memory allocation
for
more details.
The deprecated method
.block_host_until_ready()
has been removed. Use
.block_until_ready()
instead.
jax.numpy.linalg.pinv()
now supports the
hermitian
option.
jax.scipy.linalg.hessenberg()
is now supported on CPU only. Requires
jaxlib > 0.3.24.
New functions
jax.lax.linalg.hessenberg()
,
jax.lax.linalg.tridiagonal()
, and
jax.lax.linalg.householder_product()
were added. Householder reduction
is currently CPU-only and tridiagonal reductions are supported on CPU and
GPU only.
The gradients of
svd
and
jax.numpy.linalg.pinv
are now computed more
economically for non-square matrices.
Breaking Changes
Deleted the
jax_experimental_name_stack
config option.
Convert a string
axis_names
arguments to the
jax.experimental.maps.Mesh
constructor into a singleton tuple
instead of unpacking the string into a sequence of character axis names.
Added support for tridiagonal reductions on CPU and GPU.
Added support for upper Hessenberg reductions on CPU.
Fixed a bug that meant that frames in tracebacks captured by JAX were incorrectly mapped to source lines under Python 3.10+
JAX should be faster to import. We now import scipy lazily, which accounted for a significant fraction of JAX’s import time.
Setting the env var
JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N
can be
used to limit the number of cache entries written to the persistent cache.
By default, computations that take 1 second or more to compile will be
cached.
Added
jax.scipy.stats.mode()
.
The default device order used by
pmap
on TPU if no order is specified now
matches
jax.devices()
for single-process jobs. Previously the
two orderings differed, which could lead to unnecessary copies or
out-of-memory errors. Requiring the orderings to agree simplifies matters.
Breaking Changes
jax.numpy.gradient()
now behaves like most other functions in
jax.numpy
,
and forbids passing lists or tuples in place of arrays (
#12958
)
Functions in
jax.numpy.linalg
and
jax.numpy.fft
now uniformly
require inputs to be array-like: i.e. lists and tuples cannot be used in place
of arrays. Part of
#7737
.
Deprecations
jax.sharding.MeshPspecSharding
has been renamed to
jax.sharding.NamedSharding
.
jax.sharding.MeshPspecSharding
name will be removed in 3 months.
Add
JAX_PLATFORMS=tpu,cpu
as default setting in TPU initialization,
so JAX will raise an error if TPU cannot be initialized instead of falling
back to CPU. Set
JAX_PLATFORMS=''
to override this behavior and automatically
choose an available backend (the original default), or set
JAX_PLATFORMS=cpu
to always use CPU regardless of if the TPU is available.
Deprecations
Several test utilities deprecated in JAX v0.3.8 are now removed from
jax.test_util
.
The persistent compilation cache will now warn instead of raising an
exception on error (
#12582
), so program execution can continue
if something goes wrong with the cache. Set
JAX_RAISE_PERSISTENT_CACHE_ERRORS=true
to revert this behavior.
Bug fixes:
Adds missing
.pyi
files that were missing from the previous release (
#12536
).
Fixes an incompatibility between
jax
0.3.19 and the libtpu version it pinned (
#12550
). Requires jaxlib 0.3.20.
Fix incorrect
pip
url in
setup.py
comment (
#12528
).
Fixes support for limiting the visible CUDA devices via
jax_cuda_visible_devices
in distributed jobs. This functionality is needed for
the JAX/SLURM integration on GPU (
#12533
).
Ahead-of-time lowering and compilation functionality (tracked in
#7733
) is stable and public. See
the
overview
and the API docs
for
jax.stages
.
Introduced
jax.Array
, intended to be used for both
isinstance
checks
and type annotations for array types in JAX. Notice that this included some subtle
changes to how
isinstance
works for
jax.numpy.ndarray
for jax-internal
objects, as
jax.numpy.ndarray
is now a simple alias of
jax.Array
.
Breaking changes
jax._src
is no longer imported into the from the public
jax
namespace.
This may break users that were using JAX internals.
jax.soft_pmap
has been deleted. Please use
pjit
or
xmap
instead.
jax.soft_pmap
is undocumented. If it were documented, a deprecation period
would have been provided.
jax.checkpoint()
, also known as
jax.remat()
, no longer supports
the
concrete
option, following the previous version’s deprecation; see
JEP 11830
.
Changes
Added
jax.pure_callback()
that enables calling back to pure Python functions from compiled functions (e.g. functions decorated with
jax.jit
or
jax.pmap
).
Deprecations:
The deprecated
DeviceArray.tile()
method has been removed. Use
jax.numpy.tile()
(
#11944
).
DeviceArray.to_py()
has been deprecated. Use
np.asarray(x)
instead.
Support for NumPy 1.19 has been dropped, per the deprecation policy . Please upgrade to NumPy 1.20 or newer.
Changes
Added
jax.debug
that includes utilities for runtime value debugging such at
jax.debug.print()
and
jax.debug.breakpoint()
.
Added new documentation for runtime value debugging
Deprecations
jax.mask()
jax.shapecheck()
APIs have been removed.
See
#11557
.
jax.experimental.loops
has been removed. See
#10278
for an alternative API.
jax.tree_util.tree_multimap()
has been removed. It has been deprecated since
JAX release 0.3.5, and
jax.tree_util.tree_map()
is a direct replacement.
Removed
jax.experimental.stax
; it has long been a deprecated alias of
jax.example_libraries.stax
.
Removed
jax.experimental.optimizers
; it has long been a deprecated alias of
jax.example_libraries.optimizers
.
jax.checkpoint()
, also known as
jax.remat()
, has a new
implementation switched on by default, meaning the old implementation is
deprecated; see
JEP 11830
.
JaxTestCase
and
JaxTestLoader
have been removed from
jax.test_util
. These
classes have been deprecated since v0.3.1 (
#11248
).
Added
jax.scipy.gaussian_kde
(
#11237
).
Binary operations between JAX arrays and built-in collections (
dict
,
list
,
set
,
tuple
)
now raise a
TypeError
in all cases. Previously some cases (particularly equality and inequality)
would return boolean scalars inconsistent with similar operations in NumPy (
#11234
).
Several
jax.tree_util
routines accessed as top-level JAX package imports are now
deprecated, and will be removed in a future JAX release in accordance with the
API compatibility
policy:
jax.treedef_is_leaf()
is deprecated in favor of
jax.tree_util.treedef_is_leaf()
jax.tree_flatten()
is deprecated in favor of
jax.tree_util.tree_flatten()
jax.tree_leaves()
is deprecated in favor of
jax.tree_util.tree_leaves()
jax.tree_structure()
is deprecated in favor of
jax.tree_util.tree_structure()
jax.tree_transpose()
is deprecated in favor of
jax.tree_util.tree_transpose()
jax.tree_unflatten()
is deprecated in favor of
jax.tree_util.tree_unflatten()
The
sym_pos
argument of
jax.scipy.linalg.solve()
is deprecated in favor of
assume_a='pos'
,
following a similar deprecation in
scipy.linalg.solve()
.
jax.experimental.compilation_cache.initialize_cache()
does not support
max_cache_size_
bytes
anymore and will not get that as an input.
JAX_PLATFORMS
now raises an exception when platform initialization fails.
Changes
Fixed compatibility problems with NumPy 1.23.
jax.numpy.linalg.slogdet()
now accepts an optional
method
argument
that allows selection between an LU-decomposition based implementation and
an implementation based on QR decomposition.
jax.numpy.linalg.qr()
now supports
mode="raw"
.
pickle
,
copy.copy
, and
copy.deepcopy
now have more complete support when
used on jax arrays (
#10659
). In particular:
pickle
and
deepcopy
previously returned
np.ndarray
objects when used
on a
DeviceArray
; now
DeviceArray
objects are returned. For
deepcopy
,
the copied array is on the same device as the original. For
pickle
the
deserialized array will be on the default device.
Within function transformations (i.e. traced code),
deepcopy
and
copy
previously were no-ops. Now they use the same mechanism as
DeviceArray.copy()
.
Calling
pickle
on a traced array now results in an explicit
ConcretizationTypeError
.
The implementation of singular value decomposition (SVD) and symmetric/Hermitian eigendecomposition should be significantly faster on TPU, especially for matrices above 1000x1000 or so. Both now use a spectral divide-and-conquer algorithm for eigendecomposition (QDWH-eig).
jax.numpy.ldexp()
no longer silently promotes all inputs to float64,
instead it promotes to float32 for integer inputs of size int32 or smaller
(
#10921
).
Add a
create_perfetto_link
option to
jax.profiler.start_trace()
and
jax.profiler.start_trace()
. When used, the profiler will generate a
link to the Perfetto UI to view the trace.
Changed the semantics of
jax.profiler.start_server(...)()
to store the
keepalive globally, rather than requiring the user to keep a reference to
Added
jax.random.generalized_normal()
.
Added
jax.random.ball()
.
Added
jax.default_device()
.
Added a
python
-m
jax.collect_profile
script to manually capture program
traces as an alternative to the Tensorboard UI.
Added a
jax.named_scope
context manager that adds profiler metadata to
Python programs (similar to
jax.named_call
).
In scatter-update operations (i.e. :attr:
jax.numpy.ndarray.at
), unsafe implicit
dtype casts are deprecated, and now result in a
FutureWarning
.
In a future release, this will become an error. An example of an unsafe implicit
cast is
jnp.zeros(4,
dtype=int).at[0].set(1.5)
, in which
1.5
previously was
silently truncated to
1
.
jax.experimental.compilation_cache.initialize_cache()
now supports gcs
bucket path as input.
Added
jax.scipy.stats.gennorm()
.
jax.numpy.roots()
is now better behaved when
strip_zeros=False
when
coefficients have leading zeros (
#11215
).
x86-64 Mac wheels now require Mac OS 10.14 (Mojave) or newer. Mac OS 10.14 was released in 2018, so this should not be a very onerous requirement.
The bundled version of NCCL was updated to 2.12.12, fixing some deadlocks.
The Python flatbuffers package is no longer a dependency of jaxlib.
jax.lax.eigh()
now accepts an optional
sort_eigenvalues
argument
that allows users to opt out of eigenvalue sorting on TPU.
Deprecations
Non-array arguments to functions in
jax.lax.linalg
are now marked
keyword-only. As a backward-compatibility step passing keyword-only
arguments positionally yields a warning, but in a future JAX release passing
keyword-only arguments positionally will fail.
However, most users should prefer to use
jax.numpy.linalg
instead.
jax.scipy.linalg.polar_unitary()
, which was a JAX extension to the
scipy API, is deprecated. Use
jax.scipy.linalg.polar()
instead.
jax.numpy.linalg.svd()
on TPUs uses a qdwh-svd solver.
jax.numpy.linalg.cond()
on TPUs now accepts complex input.
jax.numpy.linalg.pinv()
on TPUs now accepts complex input.
jax.numpy.linalg.matrix_rank()
on TPUs now accepts complex input.
jax.scipy.cluster.vq.vq()
has been added.
jax.experimental.maps.mesh
has been deleted.
Please use
jax.experimental.maps.Mesh
. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh
for more information.
jax.scipy.linalg.qr()
now returns a length-1 tuple rather than the raw array when
mode='r'
, in order to match the behavior of
scipy.linalg.qr
(
#10452
)
jax.numpy.take_along_axis()
now takes an optional
mode
parameter
that specifies the behavior of out-of-bounds indexing. By default,
invalid values (e.g., NaN) will be returned for out-of-bounds indices. In
previous versions of JAX, invalid indices were clamped into range. The
previous behavior can be restored by passing
mode="clip"
.
jax.numpy.take()
now defaults to
mode="fill"
, which returns
invalid values (e.g., NaN) for out-of-bounds indices.
Scatter operations, such as
x.at[...].set(...)
, now have
"drop"
semantics.
This has no effect on the scatter operation itself, but it means that when
differentiated the gradient of a scatter will yield zero cotangents for
out-of-bounds indices. Previously out-of-bounds indices were clamped into
range for the gradient, which was not mathematically correct.
jax.numpy.take_along_axis()
now raises a
TypeError
if its indices
are not of an integer type, matching the behavior of
numpy.take_along_axis()
. Previously non-integer indices were silently
cast to integers.
jax.numpy.ravel_multi_index()
now raises a
TypeError
if its
dims
argument
is not of an integer type, matching the behavior of
numpy.ravel_multi_index()
. Previously non-integer
dims
was silently
cast to integers.
jax.numpy.split()
now raises a
TypeError
if its
axis
argument
is not of an integer type, matching the behavior of
numpy.split()
. Previously non-integer
axis
was silently
cast to integers.
jax.numpy.indices()
now raises a
TypeError
if its dimensions
are not of an integer type, matching the behavior of
numpy.indices()
. Previously non-integer dimensions were silently
cast to integers.
jax.numpy.diag()
now raises a
TypeError
if its
k
argument
is not of an integer type, matching the behavior of
numpy.diag()
. Previously non-integer
k
was silently
cast to integers.
Added
jax.random.orthogonal()
.
Deprecations
Many functions and objects available in
jax.test_util
are now deprecated and will raise a
warning on import. This includes
cases_from_list
,
check_close
,
check_eq
,
device_under_test
,
format_shape_dtype_string
,
rand_uniform
,
skip_on_devices
,
with_config
,
xla_bridge
, and
_default_tolerance
(
#10389
). These, along with previously-deprecated
JaxTestCase
,
JaxTestLoader
, and
BufferDonationTestCase
, will be removed in a future JAX release.
Most of these utilites can be replaced by calls to standard python & numpy testing utilities found
in e.g.
unittest
,
absl.testing
,
numpy.testing
, etc. JAX-specific functionality
such as device checking can be replaced through the use of public APIs such as
jax.devices()
.
Many of the deprecated utilities will still exist in
jax._src.test_util
, but these are not
public APIs and as such may be changed or removed without notice in future releases.
Fixed a performance problem if the indices passed to
jax.numpy.take_along_axis()
were broadcasted (
#10281
).
jax.scipy.special.expit()
and
jax.scipy.special.logit()
now
require their arguments to be scalars or JAX arrays. They also now promote
integer arguments to floating point.
The
DeviceArray.tile()
method is deprecated, because numpy arrays do not have a
tile()
method. As a replacement for this, use
jax.numpy.tile()
(
#10266
).
Upgraded libtpu wheel to a version that fixes a hang when initializing a TPU pod. Fixes #10218 .
Deprecations:
jax.experimental.loops
is being deprecated. See
#10278
for an alternative API.
added
jax.random.loggamma()
& improved behavior of
jax.random.beta()
and
jax.random.dirichlet()
for small parameter values (
#9906
).
the private
lax_numpy
submodule is no longer exposed in the
jax.numpy
namespace (
#10029
).
added array creation routines
jax.numpy.frombuffer()
,
jax.numpy.fromfunction()
,
and
jax.numpy.fromstring()
(
#10049
).
DeviceArray.copy()
now returns a
DeviceArray
rather than a
np.ndarray
(
#10069
)
jax.experimental.sharded_jit
has been deprecated and will be removed soon.
Deprecations:
jax.nn.normalize()
is being deprecated. Use
jax.nn.standardize()
instead (
#9899
).
jax.tree_util.tree_multimap()
is deprecated. Use
jax.tree_util.tree_map()
instead (
#5746
).
jax.experimental.sharded_jit
is deprecated. Use
pjit
instead.
Fixed a bug where double-precision complex-to-real IRFFTs would mutate their input buffers on GPU ( #9946 ).
Fixed incorrect constant-folding of complex scatters ( #10159 )
The functions
jax.ops.index_update
,
jax.ops.index_add
, which were
deprecated in 0.2.22, have been removed. Please use
the
.at
property on JAX arrays
instead, e.g.,
x.at[idx].set(y)
.
Moved
jax.experimental.ann.approx_*_k
into
jax.lax
. These functions are
optimized alternatives to
jax.lax.top_k
.
jax.numpy.broadcast_arrays()
and
jax.numpy.broadcast_to()
now require scalar
or array-like inputs, and will fail if they are passed lists (part of
#7737
).
The standard jax[tpu] install can now be used with Cloud TPU v4 VMs.
pjit
now works on CPU (in addition to previous TPU and GPU support).
XlaComputation.as_hlo_text()
now supports printing large constants by
passing boolean flag
print_large_constants=True
.
Deprecations:
The
.block_host_until_ready()
method on JAX arrays has been deprecated.
Use
.block_until_ready()
instead.
jax.test_util.JaxTestCase
and
jax.test_util.JaxTestLoader
are now deprecated.
The suggested replacement is to use
parametrized.TestCase
directly. For tests that
rely on custom asserts such as
JaxTestCase.assertAllClose()
, the suggested replacement
is to use standard numpy testing utilities such as
numpy.testing.assert_allclose()
,
which work directly with JAX arrays (
#9620
).
jax.test_util.JaxTestCase
now sets
jax_numpy_rank_promotion='raise'
by default
(
#9562
). To recover the previous behavior, use the new
jax.test_util.with_config
decorator:
@jtu.with_config(jax_numpy_rank_promotion='allow')
class MyTestCase(jtu.JaxTestCase):
Added jax.scipy.linalg.schur()
, jax.scipy.linalg.sqrtm()
,
jax.scipy.signal.csd()
, jax.scipy.signal.stft()
,
jax.scipy.signal.welch()
.
Bazel 5.0.0 is now required to build jaxlib.
jaxlib version has been bumped to 0.3.0. Please see the design doc
for the explanation.
jax.jit(f).lower(...).compiler_ir()
now defaults to the MHLO dialect if no
dialect=
is passed.
The jax.jit(f).lower(...).compiler_ir(dialect='mhlo')
now returns an MLIR
ir.Module
object instead of its string representation.
Includes precompiled SASS for NVidia compute capability 8.0 GPUS
(e.g. A100). Removes precompiled SASS for compute capability 6.1 so as not
to increase the number of compute capabilities: GPUs with compute capability
6.1 can use the 6.0 SASS.
With jaxlib 0.1.76, JAX uses the MHLO MLIR dialect as its primary target compiler IR
by default.
Breaking changes
Support for NumPy 1.18 has been dropped, per the
deprecation policy.
Please upgrade to a supported NumPy version.
Bug fixes
Fixed a bug where apparently identical pytreedef objects constructed by different routes
do not compare as equal (#9066).
The JAX jit cache requires two static arguments to have identical types for a cache hit (#9311).
Support for NumPy 1.18 has been dropped, per the
deprecation policy.
Please upgrade to a supported NumPy version.
The host_callback primitives have been simplified to drop the
special autodiff handling for hcb.id_tap and id_print.
From now on, only the primals are tapped. The old behavior can be
obtained (for a limited time) by setting the JAX_HOST_CALLBACK_AD_TRANSFORMS
environment variable, or the --flax_host_callback_ad_transforms
flag.
Additionally, added documentation for how to implement the old behavior
using JAX custom AD APIs (#8678).
Sorting now matches the behavior of NumPy for 0.0
and NaN
regardless of the
bit representation. In particular, 0.0
and -0.0
are now treated as equivalent,
where previously -0.0
was treated as less than 0.0
. Additionally all NaN
representations are now treated as equivalent and sorted to the end of the array.
Previously negative NaN
values were sorted to the front of the array, and NaN
values with different internal bit representations were not treated as equivalent, and
were sorted according to those bit patterns (#9178).
jax.numpy.unique()
now treats NaN
values in the same way as np.unique
in
NumPy versions 1.21 and newer: at most one NaN
value will appear in the uniquified
output (#9184).
Bug fixes:
host_callback now supports ad_checkpoint.checkpoint (#8907).
New features:
add jax.block_until_ready
({jax-issue}`#8941)
Added a new debugging flag/environment variable JAX_DUMP_IR_TO=/path
.
If set, JAX dumps the MHLO/HLO IR it generates for each computation to a
file under the given path.
Added jax.ensure_compile_time_eval
to the public api (#7987).
jax2tf now supports a flag jax2tf_associative_scan_reductions to change
the lowering for associative reductions, e.g., jnp.cumsum, to behave
like JAX on CPU and GPU (to use an associative scan). See the jax2tf README
for more details (#9189).
Out-of-bounds indices to jax.ops.segment_sum
will now be handled with
FILL_OR_DROP
semantics, as documented. This primarily afects the
reverse-mode derivative, where gradients corresponding to out-of-bounds
indices will now be returned as 0. (#8634).
jax2tf will force the converted code to use XLA for the code fragments
under jax.jit, e.g., most jax.numpy functions (#7839).
jaxlib 0.1.74 (Nov 17, 2021)#
Enabled peer-to-peer copies between GPUs. Previously, GPU copies were bounced via
the host, which is usually slower.
Added experimental MLIR Python bindings for use by JAX.
jax 0.2.25 (Nov 10, 2021)#
New features:
(Experimental) jax.distributed.initialize
exposes multi-host GPU backend.
jax.random.permutation
supports new independent
keyword argument
(#8430)
Breaking changes
Moved jax.experimental.stax
to jax.example_libraries.stax
Moved jax.experimental.optimizers
to jax.example_libraries.optimizers
New features:
Added jax.lax.linalg.qdwh
.
jax.random.choice
and jax.random.permutation
now support
multidimensional arrays and an optional axis
argument (#8158)
Breaking changes:
jax.numpy.take
and jax.numpy.take_along_axis
now require array-like inputs
(see #7737)
jaxlib 0.1.73 (Oct 18, 2021)#
Multiple cuDNN versions are now supported for jaxlib GPU cuda11
wheels.
cuDNN 8.2 or newer. We recommend using the cuDNN 8.2 wheel if your cuDNN
installation is new enough, since it supports additional functionality.
cuDNN 8.0.5 or newer.
Breaking changes:
The install commands for GPU jaxlib are as follows:
pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
# Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer.
pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
# Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer.
pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
Static arguments to jax.pmap
must now be hashable.
Unhashable static arguments have long been disallowed on jax.jit
, but they
were still permitted on jax.pmap
; jax.pmap
compared unhashable static
arguments using object identity.
This behavior is a footgun, since comparing arguments using
object identity leads to recompilation each time the object identity
changes. Instead, we now ban unhashable arguments: if a user of jax.pmap
wants to compare static arguments by object identity, they can define
__hash__
and __eq__
methods on their objects that do that, or wrap their
objects in an object that has those operations with object identity
semantics. Another option is to use functools.partial
to encapsulate the
unhashable static arguments into the function object.
jax.util.partial
was an accidental export that has now been removed. Use
functools.partial
from the Python standard library instead.
Deprecations
The functions jax.ops.index_update
, jax.ops.index_add
etc. are
deprecated and will be removed in a future JAX release. Please use
the .at
property on JAX arrays
instead, e.g., x.at[idx].set(y)
. For now, these functions produce a
DeprecationWarning
.
New features:
An optimized C++ code-path improving the dispatch time for pmap
is now the
default when using jaxlib 0.1.72 or newer. The feature can be disabled using
the --experimental_cpp_pmap
flag (or JAX_CPP_PMAP
environment variable).
jax.numpy.unique
now supports an optional fill_value
argument (#8121)
Breaking changes:
Support for CUDA 10.2 and CUDA 10.1 has been dropped. Jaxlib now supports
CUDA 11.1+.
Bug fixes:
Fixes https://github.com/google/jax/issues/7461, which caused wrong
outputs on all platforms due to incorrect buffer aliasing inside the XLA
compiler.
jax.api
has been removed. Functions that were available as jax.api.*
were aliases for functions in jax.*
; please use the functions in
jax.*
instead.
jax.partial
, and jax.lax.partial
were accidental exports that have now
been removed. Use functools.partial
from the Python standard library
instead.
Boolean scalar indices now raise a TypeError
; previously this silently
returned wrong results (#7925).
Many more jax.numpy
functions now require array-like inputs, and will error
if passed a list (#7747 #7802 #7907).
See #7737 for a discussion of the rationale behind this change.
When inside a transformation such as jax.jit
, jax.numpy.array
always
stages the array it produces into the traced computation. Previously
jax.numpy.array
would sometimes produce a on-device array, even under
a jax.jit
decorator. This change may break code that used JAX arrays to
perform shape or index computations that must be known statically; the
workaround is to perform such computations using classic NumPy arrays
instead.
jnp.ndarray
is now a true base-class for JAX arrays. In particular, this
means that for a standard numpy array x
, isinstance(x, jnp.ndarray)
will
now return False
(#7927).
New features:
Added jax.numpy.insert()
implementation (#7936).
jnp.poly*
functions now require array-like inputs (#7732)
jnp.unique
and other set-like operations now require array-like inputs
(#7662)
Breaking changes:
Support for CUDA 11.0 and CUDA 10.1 has been dropped. Jaxlib now supports
CUDA 10.2 and CUDA 11.1+.
Support for NumPy 1.17 has been dropped, per the
deprecation policy.
Please upgrade to a supported NumPy version.
The jit
decorator has been added around the implementation of a number of
operators on JAX arrays. This speeds up dispatch times for common
operators such as +
.
This change should largely be transparent to most users. However, there is
one known behavioral change, which is that large integer constants may now
produce an error when passed directly to a JAX operator
(e.g., x + 2**40
). The workaround is to cast the constant to an
explicit type (e.g., np.float64(2**40)
).
New features:
Improved the support for shape polymorphism in jax2tf for operations that
need to use a dimension size in array computation, e.g., jnp.mean
.
(#7317)
Bug fixes:
Some leaked trace errors from the previous release (#7613)
Support for Python 3.6 has been dropped, per the
deprecation policy.
Please upgrade to a supported Python version.
Support for NumPy 1.17 has been dropped, per the
deprecation policy.
Please upgrade to a supported NumPy version.
The host_callback mechanism now uses one thread per local device for
making the calls to the Python callbacks. Previously there was a single
thread for all devices. This means that the callbacks may now be called
interleaved. The callbacks corresponding to one device will still be
called in sequence.
Support for Python 3.6 has been dropped, per the
deprecation policy.
Please upgrade to a supported Python version.
The minimum jaxlib version is now 0.1.69.
The backend
argument to jax.dlpack.from_dlpack()
has been
removed.
New features:
Added a polar decomposition (jax.scipy.linalg.polar()
).
Bug fixes:
Tightened the checks for lax.argmin and lax.argmax to ensure they are
not used with an invalid axis
value, or with an empty reduction dimension.
(#7196)
Default to the older “stream_executor” CPU runtime for jaxlib <= 0.1.68
to work around #7229, which caused wrong outputs on CPU due to a concurrency
problem.
New features:
New SciPy function jax.scipy.special.sph_harm()
.
Reverse-mode autodiff functions (jax.grad()
,
jax.value_and_grad()
, jax.vjp()
, and
jax.linear_transpose()
) support a parameter that indicates which named
axes should be summed over in the backward pass if they were broadcasted
over in the forward pass. This enables use of these APIs in a
non-per-example way inside maps (initially only
jax.experimental.maps.xmap()
) (#6950).
#7042 Turned on TFRT CPU backend
with significant dispatch performance improvements on CPU.
The jax2tf.convert()
supports inequalities and min/max for booleans
(#6956).
New SciPy function jax.scipy.special.lpmn_values()
.
Breaking changes:
Support for NumPy 1.16 has been dropped, per the
deprecation policy.
Bug fixes:
Fixed bug that prevented round-tripping from JAX to TF and back:
jax2tf.call_tf(jax2tf.convert)
(#6947).
The jax2tf.convert()
now has support for pjit
and sharded_jit
.
A new configuration option JAX_TRACEBACK_FILTERING controls how JAX filters
tracebacks.
A new traceback filtering mode using __tracebackhide__
is now enabled by
default in sufficiently recent versions of IPython.
The jax2tf.convert()
supports shape polymorphism even when the
unknown dimensions are used in arithmetic operations, e.g., jnp.reshape(-1)
(#6827).
The jax2tf.convert()
generates custom attributes with location information
in TF ops. The code that XLA generates after jax2tf
has the same location information as JAX/XLA.
New SciPy function jax.scipy.special.lpmn()
.
Bug fixes:
The jax2tf.convert()
now ensures that it uses the same typing rules
for Python scalars and for choosing 32-bit vs. 64-bit computations
as JAX (#6883).
The jax2tf.convert()
now scopes the enable_xla
conversion parameter
properly to apply only during the just-in-time conversion
(#6720).
The jax2tf.convert()
now converts lax.dot_general
using the
XlaDot
TensorFlow op, for better fidelity w.r.t. JAX numerical precision
(#6717).
The jax2tf.convert()
now has support for inequality comparisons and
min/max for complex numbers (#6892).
CUDA 11.1 wheels are now supported on all CUDA 11 versions 11.1 or higher.
NVidia now promises compatibility between CUDA minor releases starting with
CUDA 11.1. This means that JAX can release a single CUDA 11.1 wheel that
is compatible with CUDA 11.2 and 11.3.
There is no longer a separate jaxlib release for CUDA 11.2 (or higher); use
the CUDA 11.1 wheel for those versions (cuda111).
Jaxlib now bundles libdevice.10.bc
in CUDA wheels. There should be no need
to point JAX to a CUDA installation to find this file.
Added automatic support for static keyword arguments to the jit()
implementation.
Added support for pretransformation exception traces.
Initial support for pruning unused arguments from jit()
-transformed
computations.
Pruning is still a work in progress.
Improved the string representation of PyTreeDef
objects.
Added support for XLA’s variadic ReduceWindow.
Bug fixes:
Fixed a bug in the remote cloud TPU support when large numbers of arguments
are passed to a computation.
Fix a bug that meant that JAX garbage collection was not triggered by
jit()
transformed functions.
When combined with jaxlib 0.1.66, jax.jit()
now supports static
keyword arguments. A new static_argnames
option has been added to specify
keyword arguments as static.
jax.nonzero()
has a new optional size
argument that allows it to
be used within jit
(#6501)
jax.numpy.unique()
now supports the axis
argument (#6532).
jax.experimental.host_callback.call()
now supports pjit.pjit
(#6569).
Added jax.scipy.linalg.eigh_tridiagonal()
that computes the
eigenvalues of a tridiagonal matrix. Only eigenvalues are supported at
present.
The order of the filtered and unfiltered stack traces in exceptions has been
changed. The traceback attached to an exception thrown from JAX-transformed
code is now filtered, with an UnfilteredStackTrace
exception
containing the original trace as the __cause__
of the filtered exception.
Filtered stack traces now also work with Python 3.6.
If an exception is thrown by code that has been transformed by reverse-mode
automatic differentiation, JAX now attempts to attach as a __cause__
of
the exception a JaxStackTraceBeforeTransformation
object that contains the
stack trace that created the original operation in the forward pass.
Requires jaxlib 0.1.66.
Breaking changes:
The following function names have changed. There are still aliases, so this
should not break existing code, but the aliases will eventually be removed
so please change your code.
host_id
–> process_index()
host_count
–> process_count()
host_ids
–> range(jax.process_count())
Similarly, the argument to local_devices()
has been renamed from
host_id
to process_index
.
Arguments to jax.jit()
other than the function are now marked as
keyword-only. This change is to prevent accidental breakage when arguments
are added to jit
.
Bug fixes:
The jax2tf.convert()
now works in presence of gradients for functions
with integer inputs (#6360).
Fixed assertion failure in jax2tf.call_tf()
when used with captured
tf.Variable
(#6572).
New profiling APIs: jax.profiler.start_trace()
,
jax.profiler.stop_trace()
, and jax.profiler.trace()
jax.lax.reduce()
is now differentiable.
Breaking changes:
The minimum jaxlib version is now 0.1.64.
Some profiler APIs names have been changed. There are still aliases, so this
should not break existing code, but the aliases will eventually be removed
so please change your code.
TraceContext
–> TraceAnnotation()
StepTraceContext
–> StepTraceAnnotation()
trace_function
–> annotate_function()
Omnistaging can no longer be disabled. See omnistaging
for more information.
Python integers larger than the maximum int64
value will now lead to an overflow
in all cases, rather than being silently converted to uint64
in some cases (#6047).
Outside X64 mode, Python integers outside the range representable by int32
will now lead to an
OverflowError
rather than having their value silently truncated.
Bug fixes:
host_callback
now supports empty arrays in arguments and results (#6262).
jax.random.randint()
clips rather than wraps of out-of-bounds limits, and can now generate
integers in the full range of the specified dtype (#5868)
#6112 added context managers:
jax.enable_checks
, jax.check_tracer_leaks
, jax.debug_nans
,
jax.debug_infs
, jax.log_compiles
.
#6085 added jnp.delete
Bug fixes:
#6136 generalized
jax.flatten_util.ravel_pytree
to handle integer dtypes.
#6129 fixed a bug with handling
some constants like enum.IntEnums
#6145 fixed batching issues with
incomplete beta functions
#6014 fixed H2D transfers during
tracing
#6165 avoids OverflowErrors when
converting some large Python integers to floats
Breaking changes:
The minimum jaxlib version is now 0.1.62.
New features:
jax.scipy.stats.chi2()
is now available as a distribution with logpdf and pdf methods.
jax.scipy.stats.betabinom()
is now available as a distribution with logpmf and pmf methods.
Added jax.experimental.jax2tf.call_tf()
to call TensorFlow functions
from JAX (#5627)
and README).
Extended the batching rule for lax.pad
to support batching of the padding values.
Bug fixes:
jax.numpy.take()
properly handles negative indices (#5768)
Breaking changes:
JAX’s promotion rules were adjusted to make promotion more consistent and
invariant to JIT. In particular, binary operations can now result in weakly-typed
values when appropriate. The main user-visible effect of the change is that
some operations result in outputs of different precision than before; for
example the expression jnp.bfloat16(1) + 0.1 * jnp.arange(10)
previously returned a float64
array, and now returns a bfloat16
array.
JAX’s type promotion behavior is described at Type promotion semantics.
jax.numpy.linspace()
now computes the floor of integer values, i.e.,
rounding towards -inf rather than 0. This change was made to match NumPy
1.20.0.
jax.numpy.i0()
no longer accepts complex numbers. Previously the
function computed the absolute value of complex arguments. This change was
made to match the semantics of NumPy 1.20.0.
Several jax.numpy
functions no longer accept tuples or lists in place
of array arguments: jax.numpy.pad()
, :funcjax.numpy.ravel
,
jax.numpy.repeat()
, jax.numpy.reshape()
.
In general, jax.numpy
functions should be used with scalars or array arguments.
jaxlib wheels are now built to require AVX instructions on x86-64 machines
by default. If you want to use JAX on a machine that doesn’t support AVX,
you can build a jaxlib from source using the --target_cpu_features
flag
to build.py
. --target_cpu_features
also replaces
--enable_march_native
.
Fixed a memory leak when converting CPU DeviceArrays to NumPy arrays. The
memory leak was present in jaxlib releases 0.1.58 and 0.1.59.
bool
, int8
, and uint8
are now considered safe to cast to
bfloat16
NumPy extension type.
Extend the jax.experimental.loops
module with support for pytrees. Improved
error checking and error messages.
Add jax.experimental.enable_x64()
and jax.experimental.disable_x64()
.
These are context managers which allow X64 mode to be temporarily enabled/disabled
within a session.
Breaking changes:
jax.ops.segment_sum()
now drops segment IDs that are out of range rather
than wrapping them into the segment ID space. This was done for performance
reasons.
Add jax.closure_convert()
for use with higher-order custom
derivative functions. (#5244)
Add jax.experimental.host_callback.call()
to call a custom Python
function on the host and return a result to the device computation.
(#5243)
Bug fixes:
jax.numpy.arccosh
now returns the same branch as numpy.arccosh
for
complex inputs (#5156)
host_callback.id_tap
now works for jax.pmap
also. There is an
optional parameter for id_tap
and id_print
to request that the
device from which the value is tapped be passed as a keyword argument
to the tap function (#5182).
Breaking changes:
jax.numpy.pad
now takes keyword arguments. Positional argument constant_values
has been removed. In addition, passing unsupported keyword arguments raises an error.
Changes for jax.experimental.host_callback.id_tap()
(#5243):
Removed support for kwargs
for jax.experimental.host_callback.id_tap()
.
(This support has been deprecated for a few months.)
Changed the printing of tuples for jax.experimental.host_callback.id_print()
to use ‘(’ instead of ‘[‘.
Changed the jax.experimental.host_callback.id_print()
in presence of JVP
to print a pair of primal and tangent. Previously, there were two separate
print operations for the primals and the tangent.
host_callback.outfeed_receiver
has been removed (it is not necessary,
and was deprecated a few months ago).
Add jax.device_put_replicated
Add multi-host support to jax.experimental.sharded_jit
Add support for differentiating eigenvalues computed by jax.numpy.linalg.eig
Add support for building on Windows platforms
Add support for general in_axes and out_axes in jax.pmap
Add complex support for jax.numpy.linalg.slogdet
Bug fixes:
Fix higher-than-second order derivatives of jax.numpy.sinc
at zero
Fix some hard-to-hit bugs around symbolic zeros in transpose rules
Breaking changes:
jax.experimental.optix
has been deleted, in favor of the standalone
optax
Python package.
indexing of JAX arrays with non-tuple sequences now raises a TypeError
. This type of indexing
has been deprecated in Numpy since v1.16, and in JAX since v0.2.4.
See #4564.
New Features:
Add support for shape-polymorphic tracing for the jax.experimental.jax2tf converter.
See README.md.
Breaking change cleanup
Raise an error on non-hashable static arguments for jax.jit and
xla_computation. See cb48f42.
Improve consistency of type promotion behavior (#4744):
Adding a complex Python scalar to a JAX floating point number respects the precision of
the JAX float. For example, jnp.float32(1) + 1j
now returns complex64
, where previously
it returned complex128
.
Results of type promotion with 3 or more terms involving uint64, a signed int, and a third type
are now independent of the order of arguments. For example:
jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)
and
jnp.result_type(jnp.float16, jnp.uint64, jnp.int64)
both return float16
, where previously
the first returned float64
and the second returned float16
.
The contents of the (undocumented) jax.lax_linalg
linear algebra module
are now exposed publicly as jax.lax.linalg
.
jax.random.PRNGKey
now produces the same results in and out of JIT compilation
(#4877).
This required changing the result for a given seed in a few particular cases:
With jax_enable_x64=False
, negative seeds passed as Python integers now return a different result
outside JIT mode. For example, jax.random.PRNGKey(-1)
previously returned
[4294967295, 4294967295]
, and now returns [0, 4294967295]
. This matches the behavior in JIT.
Seeds outside the range representable by int64
outside JIT now result in an OverflowError
rather than a TypeError
. This matches the behavior in JIT.
To recover the keys returned previously for negative integers with jax_enable_x64=False
outside JIT, you can use:
key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)
DeviceArray now raises RuntimeError
instead of ValueError
when trying
to access its value while it has been deleted.
jaxlib 0.1.58 (January 12ish 2021)#
Fixed a bug that meant JAX sometimes return platform-specific types (e.g.,
np.cint
) instead of standard types (e.g., np.int32
). (#4903)
Fixed a crash when constant-folding certain int16 operations. (#4971)
Added an is_leaf
predicate to pytree.flatten()
.
jaxlib 0.1.57 (November 12 2020)#
Fixed manylinux2010 compliance issues in GPU wheels.
Switched the CPU FFT implementation from Eigen to PocketFFT.
Fixed a bug where the hash of bfloat16 values was not correctly initialized
and could change (#4651).
Add support for retaining ownership when passing arrays to DLPack (#4636).
Fixed a bug for batched triangular solves with sizes greater than 128 but not
a multiple of 128.
Fixed a bug when performing concurrent FFTs on multiple GPUs (#3518).
Fixed a bug in profiler where tools are missing (#4427).
Dropped support for CUDA 10.0.
jax 0.2.5 (October 27 2020)#
Improvements:
Ensure that check_jaxpr
does not perform FLOPS. See #4650.
Expanded the set of JAX primitives converted by jax2tf.
See primitives_with_limited_support.md.
Deprecations
Indexing with non-tuple sequences is now deprecated, following a similar deprecation in Numpy.
In a future release, this will result in a TypeError. See #4564.
The reason for another release so soon is we need to temporarily roll back a
new jit fastpath while we look into a performance degradation
jax 0.2.2 (October 13 2020)#
jax 0.2.1 (October 6 2020)#
Improvements:
As a benefit of omnistaging, the host_callback functions are executed (in program
order) even if the result of the jax.experimental.host_callback.id_print()
/
jax.experimental.host_callback.id_tap()
is not used in the computation.
Prevent some accidental dtype warnings (#3874)
Fix a multi-threading bug in custom derivatives (#3845, #3869)
Improvements:
Faster searchsorted implementation (#3873)
Better test coverage for jax.numpy sorting algorithms (#3836)
jax.numpy.intersect1d (#3726)
jax.numpy.lexsort (#3812)
lax.scan
and the scan
primitive support an unroll
parameter for loop unrolling when lowering to XLA
(#3738).
Bug Fixes:
Fix reduction repeated axis error (#3618)
Fix shape rule for lax.pad for input dimensions of size 0. (#3608)
make psum transpose handle zero cotangents (#3653)
Fix shape error when taking JVP of reduce-prod over size 0 axis. (#3729)
Support differentiation through jax.lax.all_to_all (#3733)
address nan issue in jax.scipy.special.zeta (#3777)
Improvements:
Many improvements to jax2tf
Reimplement argmin/argmax using a single pass variadic reduction. (#3611)
Enable XLA SPMD partitioning by default. (#3151)
Add support for 0d transpose convolution (#3643)
Make LU gradient work for low-rank matrices (#3610)
support multiple_results and custom JVPs in jet (#3657)
Generalize reduce-window padding to support (lo, hi) pairs. (#3728)
Implement complex convolutions on CPU and GPU. (#3735)
Make jnp.take work for empty slices of empty arrays. (#3751)
Relax dimension ordering rules for dot_general. (#3778)
Enable buffer donation for GPU. (#3800)
Add support for base dilation and window dilation to reduce window op… (#3803)
Allow jax.experimental.ode.odeint
dynamics functions to close over
values with respect to which we’re differentiating
#3562.
Add support for CUDA 11.0.
Drop support for CUDA 9.2 (we only maintain support for the last four CUDA
versions.)
Update XLA.
jaxlib 0.1.49 (June 19, 2020)#
Bug fixes:
Fix build issue that could result in slow compiles
(tensorflow/tensorflow)
Adds support for fast traceback collection.
Adds preliminary support for on-device heap profiling.
Implements np.nextafter
for bfloat16
types.
Complex128 support for FFTs on CPU and GPU.
Bugfixes:
Improved float64 tanh
accuracy on GPU.
float64 scatters on GPU are much faster.
Complex matrix multiplication on CPU should be much faster.
Stable sorts on CPU should actually be stable now.
Concurrency bug fix in CPU backend.
lax.switch
introduces indexed conditionals with multiple
branches, together with a generalization of the cond
primitive
#3318.
Notable changes:
The format of the transforms
keyword for the jax.experimental.host_callback.id_tap()
primitive has changed #3132.
New features:
Support for reduction over subsets of a pmapped axis using axis_index_groups
#2382.
Experimental support for printing and calling host-side Python function from
compiled code. See id_print and id_tap
(#3006).
Notable changes:
The visibility of names exported from jax.numpy
has been
tightened. This may break code that was making use of names that were
previously exported accidentally.
jaxlib 0.1.46 (May 5, 2020)#
Fixes crash for linear algebra functions on Mac OS X (#432).
Fixes an illegal instruction crash caused by using AVX512 instructions when
an operating system or hypervisor disabled them (#2906).
jax 0.1.65 (April 30, 2020)#
New features:
Differentiation of determinants of singular matrices
#2809.
Bug fixes:
Fix odeint()
differentiation with respect to time of ODEs with
time-dependent dynamics #2817,
also add ODE CI testing.
Fix lax_linalg.qr()
differentiation
#2867.
Add syntactic sugar for functional indexed updates
#2684.
Add jax.numpy.unique()
#2760.
Add jax.numpy.rint()
#2724.
Add jax.numpy.rint()
#2724.
Add more primitive rules for jax.experimental.jet()
.
Bug fixes:
Fix logaddexp()
and logaddexp2()
differentiation at zero #2107.
Improve memory usage in reverse-mode autodiff without jit()
#2719.
Better errors:
Improves error message for reverse-mode differentiation of lax.while_loop()
#2129.
jaxlib 0.1.44 (April 16, 2020)#
Fixes a bug where if multiple GPUs of different models were present, JAX
would only compile programs suitable for the first GPU.
Bugfix for batch_group_count
convolutions.
Added precompiled SASS for more GPU versions to avoid startup PTX compilation
hang.
jax 0.1.63 (April 12, 2020)#
Added jax.custom_jvp
and jax.custom_vjp
from #2026, see the tutorial notebook. Deprecated jax.custom_transforms
and removed it from the docs (though it still works).
Add scipy.sparse.linalg.cg
#2566.
Changed how Tracers are printed to show more useful information for debugging #2591.
Made jax.numpy.isclose
handle nan
and inf
correctly #2501.
Added several new rules for jax.experimental.jet
#2537.
Fixed jax.experimental.stax.BatchNorm
when scale
/center
isn’t provided.
Fix some missing cases of broadcasting in jax.numpy.einsum
#2512.
Implement jax.numpy.cumsum
and jax.numpy.cumprod
in terms of a parallel prefix scan #2596 and make reduce_prod
differentiable to arbitray order #2597.
Add batch_group_count
to conv_general_dilated
#2635.
Add docstring for test_util.check_grads
#2656.
Add callback_transform
#2665.
Implement rollaxis
, convolve
/correlate
1d & 2d, copysign
,
trunc
, roots
, and quantile
/percentile
interpolation options.
jaxlib 0.1.43 (March 31, 2020)#
Fixed a performance regression for Resnet-50 on GPU.
jax 0.1.62 (March 21, 2020)#
JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
Removed the internal function lax._safe_mul
, which implemented the
convention 0. * nan == 0.
. This change means some programs when
differentiated will produce nans when they previously produced correct
values, though it ensures nans rather than silently incorrect results are
produced for other programs. See #2447 and #1052 for details.
Added an all_gather
parallel convenience function.
More type annotations in core code.
jaxlib 0.1.42 (March 19, 2020)#
jaxlib 0.1.41 broke cloud TPU support due to an API incompatibility. This
release fixes it again.
JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
jax 0.1.61 (March 17, 2020)#
Fixes Python 3.5 support. This will be the last JAX or jaxlib release that
supports Python 3.5.
jax 0.1.60 (March 17, 2020)#
New features:
jax.pmap()
has static_broadcast_argnums
argument which allows
the user to specify arguments that should be treated as compile-time
constants and should be broadcasted to all devices. It works analogously to
static_argnums
in jax.jit()
.
Improved error messages for when tracers are mistakenly saved in global state.
Added jax.nn.one_hot()
utility function.
Added jax.experimental.jet
for exponentially faster
higher-order automatic differentiation.
Added more correctness checking to arguments of jax.lax.broadcast_in_dim()
.
The minimum jaxlib version is now 0.1.41.
jaxlib 0.1.40 (March 4, 2020)#
Adds experimental support in Jaxlib for TensorFlow profiler, which allows
tracing of CPU and GPU computations from TensorBoard.
Includes prototype support for multihost GPU computations that communicate via
NCCL.
Improves performance of NCCL collectives on GPU.
Adds TopK, CustomCallWithoutLayout, CustomCallWithLayout, IGammaGradA and
RandomGamma implementations.
Supports device assignments known at XLA compilation time.
jax 0.1.59 (February 11, 2020)#
Breaking changes
The minimum jaxlib version is now 0.1.38.
Simplified Jaxpr
by removing the Jaxpr.freevars
and
Jaxpr.bound_subjaxprs
. The call primitives (xla_call
, xla_pmap
,
sharded_call
, and remat_call
) get a new parameter call_jaxpr
with a
fully-closed (no constvars
) jaxpr. Also, added a new field call_primitive
to primitives.
New features:
Reverse-mode automatic differentiation (e.g. grad
) of lax.cond
, making it
now differentiable in both modes (#2091)
JAX now supports DLPack, which allows sharing CPU and GPU arrays in a
zero-copy way with other libraries, such as PyTorch.
JAX GPU DeviceArrays now support __cuda_array_interface__
, which is another
zero-copy protocol for sharing GPU arrays with other libraries such as CuPy
and Numba.
JAX CPU device buffers now implement the Python buffer protocol, which allows
zero-copy buffer sharing between JAX and NumPy.
Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known as slow.
Breaking changes
JAX has dropped Python 2 support, because Python 2 reached its end of life on
January 1, 2020. Please update to Python 3.5 or newer.
New features
Forward-mode automatic differentiation (jvp
) of while loop
(#1980)
New NumPy and SciPy functions:
Batched Cholesky decomposition on GPU now uses a more efficient batched
kernel.
Notable bug fixes#
With the Python 3 upgrade, JAX no longer depends on fastcache
, which should
help with installation.
- jaxlib 0.1.62 (March 9 2021)
- jaxlib 0.1.61 (February 12 2021)
- jaxlib 0.1.60 (Febuary 3 2021)
- jax 0.2.9 (January 26 2021)
- jaxlib 0.1.59 (January 15 2021)
- jax 0.2.8 (January 12 2021)
- jax 0.2.7 (Dec 4 2020)
- jax 0.2.6 (Nov 18 2020)
- jaxlib 0.1.58 (January 12ish 2021)
- jaxlib 0.1.57 (November 12 2020)
- jax 0.2.5 (October 27 2020)
- jax 0.2.4 (October 19 2020)
- jaxlib 0.1.56 (October 14, 2020)
- jax 0.2.3 (October 14 2020)
- jax 0.2.2 (October 13 2020)
- jax 0.2.1 (October 6 2020)
- jax (0.2.0) (September 23 2020)
- jax (0.1.77) (September 15 2020)
- jaxlib 0.1.55 (September 8, 2020)
- jax 0.1.76 (September 8, 2020)
- jax 0.1.75 (July 30, 2020)
- jax 0.1.74 (July 29, 2020)
- jaxlib 0.1.52 (July 22, 2020)
- jax 0.1.73 (July 22, 2020)
- jaxlib 0.1.51 (July 2, 2020)
- jax 0.1.72 (June 28, 2020)
- jax 0.1.71 (June 25, 2020)
- jaxlib 0.1.50 (June 25, 2020)
- jaxlib 0.1.49 (June 19, 2020)
- jaxlib 0.1.48 (June 12, 2020)
- jax 0.1.70 (June 8, 2020)
- jax 0.1.69 (June 3, 2020)
- jax 0.1.68 (May 21, 2020)
- jax 0.1.67 (May 12, 2020)
- jaxlib 0.1.47 (May 8, 2020)
- jax 0.1.66 (May 5, 2020)
- jaxlib 0.1.46 (May 5, 2020)
- jax 0.1.65 (April 30, 2020)
- jaxlib 0.1.45 (April 21, 2020)
- jax 0.1.64 (April 21, 2020)
- jaxlib 0.1.44 (April 16, 2020)
- jax 0.1.63 (April 12, 2020)
- jaxlib 0.1.43 (March 31, 2020)
- jax 0.1.62 (March 21, 2020)
- jaxlib 0.1.42 (March 19, 2020)
- jax 0.1.61 (March 17, 2020)
- jax 0.1.60 (March 17, 2020)
- jaxlib 0.1.40 (March 4, 2020)
- jax 0.1.59 (February 11, 2020)
- jaxlib 0.1.39 (February 11, 2020)
- jaxlib 0.1.38 (January 29, 2020)
- jax 0.1.58 (January 28, 2020)
叛逆的橡皮擦 · Change log — JAX documentation 11 月前 |
痴情的啄木鸟 · 正本清源再出发,香港开启“一国两制”新征程 1 月前 |