添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接
  • Just In Time Compilation with JAX
  • Automatic Vectorization in JAX
  • Advanced Automatic Differentiation in JAX
  • Pseudo Random Numbers in JAX
  • Working with Pytrees
  • Parallel Evaluation in JAX
  • Stateful Computations in JAX
  • Further Resources

  • User Guides
    • Profiling JAX programs
    • Device Memory Profiling
    • Runtime value debugging in JAX
      • jax.debug.print and jax.debug.breakpoint
      • The checkify transformation
      • JAX debugging flags
      • GPU peformance tips
      • Understanding Jaxprs
      • External Callbacks in JAX
      • Type promotion semantics
      • Pytrees
      • Ahead-of-time lowering and compilation
      • JAX Errors
      • Transfer guard
      • Pallas: a JAX kernel language
        • Pallas Design
        • Pallas Quickstart
        • Writing TPU kernels with Pallas
        • Advanced Tutorials
          • Training a Simple Neural Network, with tensorflow/datasets Data Loading
          • Training a Simple Neural Network, with PyTorch Data Loading
          • Autobatching for Bayesian Inference
          • Using JAX in multi-host and multi-process environments
          • Distributed arrays and automatic parallelization
          • Intro
          • Named axes and easy-to-revise parallelism with xmap
          • The Autodiff Cookbook
          • Custom derivative rules for JAX-transformable Python functions
          • Control autodiff’s saved values with jax.checkpoint (aka jax.remat )
          • How JAX primitives work
          • Writing custom Jaxpr interpreters in JAX
          • Custom operations for GPUs with C++ and CUDA
          • Generalized Convolutions in JAX
          • Developer Documentation
            • Contributing to JAX
            • Building from source
            • Internal APIs
            • Autodidax: JAX core from scratch
            • JAX Enhancement Proposals (JEPs)
              • 263: JAX PRNG Design
              • 2026: Custom JVP/VJP rules for JAX-transformable functions
              • 4008: Custom VJP and `nondiff_argnums` update
              • 4410: Omnistaging
              • 9263: Typed keys & pluggable RNGs
              • 9407: Design of Type Promotion Semantics for JAX
              • 9419: Jax and Jaxlib versioning
              • 10657: Sequencing side-effects in JAX
              • 11830: `jax.remat` / `jax.checkpoint` new implementation
              • 12049: Type Annotation Roadmap for JAX
              • 14273: `shard_map` (`shmap`) for simple per-device code
              • 15856: `jax.extend`, an extensions module
              • 17111: Efficient transposition of `shard_map` (and other maps)
              • 18137: Scope of JAX NumPy & SciPy Wrappers
              • Investigating a regression
              • Building on JAX
              • Notes
                • API compatibility
                • Python and NumPy version support policy
                • jax.Array migration
                • Asynchronous dispatch
                • Concurrency
                • GPU memory allocation
                • Rank promotion warning
                • Public API: jax package
                  • jax.numpy module
                  • jax.scipy module
                  • jax.lax module
                  • jax.random module
                  • jax.sharding module
                  • jax.debug module
                  • jax.dlpack module
                  • jax.distributed module
                  • jax.dtypes module
                  • jax.flatten_util module
                  • jax.image module
                  • jax.nn module
                    • jax.nn.initializers module
                    • jax.ops module
                    • jax.profiler module
                    • jax.stages module
                    • jax.tree_util module
                    • jax.typing module
                    • jax.extend module
                      • jax.extend.linear_util module
                      • jax.extend.mlir module
                      • jax.extend.random module
                      • jax.example_libraries module
                        • jax.example_libraries.optimizers module
                        • jax.example_libraries.stax module
                        • jax.experimental module
                          • jax.experimental.array_api module
                          • jax.experimental.checkify module
                          • jax.experimental.host_callback module
                          • jax.experimental.maps module
                          • jax.experimental.pjit module
                          • jax.experimental.sparse module
                          • jax.experimental.jet module
                          • jax.experimental.custom_partitioning module
                          • jax.experimental.multihost_utils module
                          • jax.experimental.compilation_cache module
                          • jax.experimental.key_reuse module
                          • jax.lib module
                          • JAX Errors #

                            This page lists a few of the errors you might encounter when using JAX, along with representative examples of how one might fix them.

                            class jax.errors. ConcretizationTypeError ( tracer , context = '' ) #

                            This error occurs when a JAX Tracer object is used in a context where a concrete value is required (see Different kinds of JAX values for more on what a Tracer is). In some situations, it can be easily fixed by marking problematic values as static; in others, it may indicate that your program is doing operations that are not directly supported by JAX’s JIT compilation model.

                            Examples:

                            Traced value where static value is expected

                            One common cause of this error is using a traced value where a static value is required. For example:

                            >>> from functools import partial
                            >>> from jax import jit
                            >>> import jax.numpy as jnp
                            ... def func(x, axis):
                            ...   return x.min(axis)
                            
                            >>> func(jnp.arange(4), 0)  
                            Traceback (most recent call last):
                            ConcretizationTypeError: Abstract tracer value encountered where concrete
                            value is expected: axis argument to jnp.min().
                            

                            This can often be fixed by marking the problematic argument as static:

                            >>> @partial(jit, static_argnums=1)
                            ... def func(x, axis):
                            ...   return x.min(axis)
                            >>> func(jnp.arange(4), 0)
                            Array(0, dtype=int32)
                            
                            Shape depends on Traced Value

                            Such an error may also arise when a shape in your JIT-compiled computation depends on the values within a traced quantity. For example:

                            ... def func(x): ... return jnp.where(x < 0) >>> func(jnp.arange(4)) Traceback (most recent call last): ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: The error arose in jnp.nonzero.

                            This is an example of an operation that is incompatible with JAX’s JIT compilation model, which requires array sizes to be known at compile-time. Here the size of the returned array depends on the contents of x, and such code cannot be JIT compiled.

                            In many cases it is possible to work around this by modifying the logic used in the function; for example here is code with a similar issue:

                            ... def func(x): ... indices = jnp.where(x > 1) ... return x[indices].sum() >>> func(jnp.arange(4)) Traceback (most recent call last): ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: The error arose in jnp.nonzero.

                            And here is how you might express the same operation in a way that avoids creation of a dynamically-sized index array:

                            ... def func(x): ... return jnp.where(x > 1, x, 0).sum() >>> func(jnp.arange(4)) Array(5, dtype=int32)

                            To understand more subtleties having to do with tracers vs. regular values, and concrete vs. abstract values, you may want to read Different kinds of JAX values.

                            Parameters:
                          • tracer (Tracer) –

                          • context (str) –

                          • class jax.errors.NonConcreteBooleanIndexError(tracer)#

                            This error occurs when a program attempts to use non-concrete boolean indices in a traced indexing operation. Under JIT compilation, JAX arrays must have static shapes (i.e. shapes that are known at compile-time) and so boolean masks must be used carefully. Some logic implemented via boolean masking is simply not possible in a jax.jit() function; in other cases, the logic can be re-expressed in a JIT-compatible way, often using the three-argument version of where().

                            Following are a few examples of when this error might arise.

                            Constructing arrays via boolean masking

                            This most commonly arises when attempting to create an array via a boolean mask within a JIT context. For example:

                            >>> import jax
                            >>> import jax.numpy as jnp
                            >>> @jax.jit
                            ... def positive_values(x):
                            ...   return x[x > 0]
                            >>> positive_values(jnp.arange(-5, 5))  
                            Traceback (most recent call last):
                            NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])
                            

                            This function is attempting to return only the positive values in the input array; the size of this returned array cannot be determined at compile-time unless x is marked as static, and so operations like this cannot be performed under JIT compilation.

                            Reexpressible Boolean Logic

                            Although creating dynamically sized arrays is not supported directly, in many cases it is possible to re-express the logic of the computation in terms of a JIT-compatible operation. For example, here is another function that fails under JIT for the same reason:

                            >>> @jax.jit
                            ... def sum_of_positive(x):
                            ...   return x[x > 0].sum()
                            >>> sum_of_positive(jnp.arange(-5, 5))  
                            Traceback (most recent call last):
                            NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])
                            

                            In this case, however, the problematic array is only an intermediate value, and we can instead express the same logic in terms of the JIT-compatible three-argument version of jax.numpy.where():

                            >>> @jax.jit
                            ... def sum_of_positive(x):
                            ...   return jnp.where(x > 0, x, 0).sum()
                            >>> sum_of_positive(jnp.arange(-5, 5))
                            Array(10, dtype=int32)
                            

                            This pattern of replacing boolean masking with three-argument where() is a common solution to this sort of problem.

                            Boolean indexing into JAX arrays

                            The other situation where this error often arises is when using boolean indices, such as with .at[...].set(...). Here is a simple example:

                            >>> @jax.jit
                            ... def manual_clip(x):
                            ...   return x.at[x < 0].set(0)
                            >>> manual_clip(jnp.arange(-2, 2))  
                            Traceback (most recent call last):
                            NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[4])
                            

                            This function is attempting to set values smaller than zero to a scalar fill value. As above, this can be addressed by re-expressing the logic in terms of where():

                            >>> @jax.jit
                            ... def manual_clip(x):
                            ...   return jnp.where(x < 0, 0, x)
                            >>> manual_clip(jnp.arange(-2, 2))
                            Array([0, 0, 0, 1], dtype=int32)
                            class jax.errors.TracerArrayConversionError(tracer)#
                            

                            This error occurs when a program attempts to convert a JAX Tracer object into a standard NumPy array (see Different kinds of JAX values for more on what a Tracer is). It typically occurs in one of a few situations.

                            Using non-JAX functions in JAX transformations

                            This error can occur if you attempt to use a non-JAX library like numpy or scipy inside a JAX transformation (jit(), grad(), jax.vmap(), etc.). For example:

                            >>> from jax import jit
                            >>> import numpy as np
                            ... def func(x):
                            ...   return np.sin(x)
                            >>> func(np.arange(4))  
                            Traceback (most recent call last):
                            TracerArrayConversionError: The numpy.ndarray conversion method
                            __array__() was called on traced array with shape int32[4]
                            

                            In this case, you can fix the issue by using jax.numpy.sin() in place of numpy.sin():

                            >>> import jax.numpy as jnp
                            ... def func(x):
                            ...   return jnp.sin(x)
                            >>> func(jnp.arange(4))
                            Array([0.        , 0.84147096, 0.9092974 , 0.14112   ], dtype=float32)
                            

                            See also External Callbacks for options for calling back to host-side computations from transformed JAX code.

                            Indexing a numpy array with a tracer

                            If this error arises on a line that involves array indexing, it may be that the array being indexed x is a standard numpy.ndarray while the indices idx are traced JAX arrays. For example:

                            >>> x = np.arange(10)
                            ... def func(i):
                            ...   return x[i]
                            >>> func(0)  
                            Traceback (most recent call last):
                            TracerArrayConversionError: The numpy.ndarray conversion method
                            __array__() was called on traced array with shape int32[0]
                            

                            Depending on the context, you may fix this by converting the numpy array into a JAX array:

                            ... def func(i): ... return jnp.asarray(x)[i] >>> func(0) Array(0, dtype=int32)

                            or by declaring the index as a static argument:

                            >>> from functools import partial
                            >>> @partial(jit, static_argnums=(0,))
                            ... def func(i):
                            ...   return x[i]
                            >>> func(0)
                            Array(0, dtype=int32)
                            

                            To understand more subtleties having to do with tracers vs. regular values, and concrete vs. abstract values, you may want to read Different kinds of JAX values.

                            Parameters:

                            tracer (Tracer) –

                            class jax.errors.TracerBoolConversionError(tracer)#

                            This error occurs when a traced value in JAX is used in a context where a boolean value is expected (see Different kinds of JAX values for more on what a Tracer is).

                            The boolean cast may be an explicit (e.g. bool(x)) or implicit, through use of control flow (e.g. if x > 0 or while x), use of Python boolean operators (e.g. z = x and y, z = x or y, z = not x) or functions that use them (e.g. z = max(x, y), z = min(x, y) etc.).

                            In some situations, this problem can be easily fixed by marking traced values as static; in others, it may indicate that your program is doing operations that are not directly supported by JAX’s JIT compilation model.

                            Examples:

                            Traced value used in control flow

                            One case where this often arises is when a traced value is used in Python control flow. For example:

                            >>> from jax import jit
                            >>> import jax.numpy as jnp
                            ... def func(x, y):
                            ...   return x if x.sum() < y.sum() else y
                            >>> func(jnp.ones(4), jnp.zeros(4))  
                            Traceback (most recent call last):
                            TracerBoolConversionError: Attempted boolean conversion of JAX Tracer [...]
                            

                            We could mark both inputs x and y as static, but that would defeat the purpose of using jax.jit() here. Another option is to re-express the if statement in terms of the three-term jax.numpy.where():

                            ... def func(x, y): ... return jnp.where(x.sum() < y.sum(), x, y) >>> func(jnp.ones(4), jnp.zeros(4)) Array([0., 0., 0., 0.], dtype=float32)

                            For more complicated control flow including loops, see Control flow operators.

                            Control flow on traced values

                            Another common cause of this error is if you inadvertently trace over a boolean flag. For example:

                            ... def func(x, normalize=True): ... if normalize: ... return x / x.sum() ... return x >>> func(jnp.arange(5), True) Traceback (most recent call last): TracerBoolConversionError: Attempted boolean conversion of JAX Tracer ...

                            Here because the flag normalize is traced, it cannot be used in Python control flow. In this situation, the best solution is probably to mark this value as static:

                            >>> from functools import partial
                            >>> @partial(jit, static_argnames=['normalize'])
                            ... def func(x, normalize=True):
                            ...   if normalize:
                            ...     return x / x.sum()
                            ...   return x
                            >>> func(jnp.arange(5), True)
                            Array([0. , 0.1, 0.2, 0.3, 0.4], dtype=float32)
                            

                            For more on static_argnums, see the documentation of jax.jit().

                            Using non-JAX aware functions

                            Another common cause of this error is using non-JAX aware functions within JAX code. For example:

                            ... def func(x): ... return min(x, 0)
                            >>> func(2)  
                            Traceback (most recent call last):
                            TracerBoolConversionError: Attempted boolean conversion of JAX Tracer ...
                            

                            In this case, the error occurs because Python’s built-in min function is not compatible with JAX transforms. This can be fixed by replacing it with jnp.minumum:

                            ... def func(x): ... return jnp.minimum(x, 0)
                            >>> print(func(2))
                            

                            To understand more subtleties having to do with tracers vs. regular values, and concrete vs. abstract values, you may want to read Different kinds of JAX values.

                            Parameters:

                            tracer (Tracer) –

                            class jax.errors.TracerIntegerConversionError(tracer)#

                            This error can occur when a JAX Tracer object is used in a context where a Python integer is expected (see Different kinds of JAX values for more on what a Tracer is). It typically occurs in a few situations.

                            Passing a tracer in place of an integer

                            This error can occur if you attempt to pass a traced value to a function that requires a static integer argument; for example:

                            >>> from jax import jit
                            >>> import numpy as np
                            ... def func(x, axis):
                            ...   return np.split(x, 2, axis)
                            >>> func(np.arange(4), 0)  
                            Traceback (most recent call last):
                            TracerIntegerConversionError: The __index__() method was called on
                            traced array with shape int32[0]
                            

                            When this happens, the solution is often to mark the problematic argument as static:

                            >>> from functools import partial
                            >>> @partial(jit, static_argnums=1)
                            ... def func(x, axis):
                            ...   return np.split(x, 2, axis)
                            >>> func(np.arange(10), 0)
                            [Array([0, 1, 2, 3, 4], dtype=int32),
                             Array([5, 6, 7, 8, 9], dtype=int32)]
                            

                            An alternative is to apply the transformation to a closure that encapsulates the arguments to be protected, either manually as below or by using functools.partial():

                            >>> jit(lambda arr: np.split(arr, 2, 0))(np.arange(4))
                            [Array([0, 1], dtype=int32), Array([2, 3], dtype=int32)]
                            

                            Note a new closure is created at every invocation, which defeats the compilation caching mechanism, which is why static_argnums is preferred.

                            Indexing a list with a Tracer

                            This error can occur if you attempt to index a Python list with a traced quantity. For example:

                            >>> import jax.numpy as jnp
                            >>> from jax import jit
                            >>> L = [1, 2, 3]
                            ... def func(i):
                            ...   return L[i]
                            >>> func(0)  
                            Traceback (most recent call last):
                            TracerIntegerConversionError: The __index__() method was called on
                            traced array with shape int32[0]
                            

                            Depending on the context, you can generally fix this either by converting the list to a JAX array:

                            ... def func(i): ... return jnp.array(L)[i] >>> func(0) Array(1, dtype=int32)

                            or by declaring the index as a static argument:

                            >>> from functools import partial
                            >>> @partial(jit, static_argnums=0)
                            ... def func(i):
                            ...   return L[i]
                            >>> func(0)
                            Array(1, dtype=int32, weak_type=True)
                            

                            To understand more subtleties having to do with tracers vs. regular values, and concrete vs. abstract values, you may want to read Different kinds of JAX values.

                            Parameters:

                            tracer (Tracer) –

                            class jax.errors.UnexpectedTracerError(msg)#

                            This error occurs when you use a JAX value that has leaked out of a function. What does it mean to leak a value? If you use a JAX transformation on a function f that stores, in some scope outside of f, a reference to an intermediate value, that value is considered to have been leaked. Leaking values is a side effect. (Read more about avoiding side effects in Pure Functions)

                            JAX detects leaks when you then use the leaked value in another operation later on, at which point it raises an UnexpectedTracerError. To fix this, avoid side effects: if a function computes a value needed in an outer scope, return that value from the transformed function explicitly.

                            Specifically, a Tracer is JAX’s internal representation of a function’s intermediate values during transformations, e.g. within jit(), pmap(), vmap(), etc. Encountering a Tracer outside of a transformation implies a leak.

                            Life-cycle of a leaked value

                            Consider the following example of a transformed function which leaks a value to an outer scope:

                            >>> from jax import jit
                            >>> import jax.numpy as jnp
                            >>> outs = []
                            >>> @jit                   # 1
                            ... def side_effecting(x):
                            ...   y = x + 1            # 3
                            ...   outs.append(y)       # 4
                            >>> x = 1
                            >>> side_effecting(x)      # 2
                            >>> outs[0] + 1            # 5  
                            Traceback (most recent call last):
                            UnexpectedTracerError: Encountered an unexpected tracer.
                            

                            In this example we leak a Traced value from an inner transformed scope to an outer scope. We get an UnexpectedTracerError when the leaked value is used, not when the value is leaked.

                            This example also demonstrates the life-cycle of a leaked value:

                          • A function is transformed (in this case, by jit())

                          • The transformed function is called (initiating an abstract trace of the function and turning x into a Tracer)

                          • The intermediate value y, which will later be leaked, is created (an intermediate value of a traced function is also a Tracer)

                          • The value is leaked (appended to a list in an outer scope, escaping the function through a side-channel)

                          • The leaked value is used, and an UnexpectedTracerError is raised.

                          • The UnexpectedTracerError message tries to point to these locations in your code by including information about each stage. Respectively:

                          • The name of the transformed function (side_effecting) and which transform kicked of the trace jit()).

                          • A reconstructed stack trace of where the leaked Tracer was created, which includes where the transformed function was called. (When the Tracer was created, the final 5 stack frames were...).

                          • From the reconstructed stack trace, the line of code that created the leaked Tracer.

                          • The leak location is not included in the error message because it is difficult to pin down! JAX can only tell you what the leaked value looks like (what shape is has and where it was created) and what boundary it was leaked over (the name of the transformation and the name of the transformed function).

                          • The current error’s stack trace points to where the value is used.

                          • The error can be fixed by the returning the value out of the transformed function:

                            >>> from jax import jit
                            >>> import jax.numpy as jnp
                            >>> outs = []
                            ... def not_side_effecting(x):
                            ...   y = x+1
                            ...   return y
                            >>> x = 1
                            >>> y = not_side_effecting(x)
                            >>> outs.append(y)
                            >>> outs[0] + 1  # all good! no longer a leaked value.
                            Array(3, dtype=int32, weak_type=True)
                            
                            Leak checker

                            As discussed in point 2 and 3 above, JAX shows a reconstructed stack trace which points to where the leaked value was created. This is because JAX only raises an error when the leaked value is used, not when the value is leaked. This is not the most useful place to raise this error, because you need to know the location where the Tracer was leaked to fix the error.

                            To make this location easier to track down, you can use the leak checker. When the leak checker is enabled, an error is raised as soon as a Tracer is leaked. (To be more exact, it will raise an error when the transformed function from which the Tracer is leaked returns)

                            To enable the leak checker you can use the JAX_CHECK_TRACER_LEAKS environment variable or the with jax.checking_leaks() context manager.

                            Note that this tool is experimental and may report false positives. It works by disabling some JAX caches, so it will have a negative effect on performance and should only be used when debugging.

                            Example usage:

                            >>> from jax import jit
                            >>> import jax.numpy as jnp
                            >>> outs = []
                            ... def side_effecting(x):
                            ...   y = x+1
                            ...   outs.append(y)
                            >>> x = 1
                            >>> with jax.checking_leaks():
                            ...   y = side_effecting(x)  
                            Traceback (most recent call last):
                            Exception: Leaked Trace
                            
  •