![]() |
旅途中的茄子 · 更新SQL子查询 - 腾讯云开发者社区 - 腾讯云· 1 月前 · |
![]() |
叛逆的镜子 · 行情接口整理(转自知乎)_新浪关于外汇接口- ...· 4 月前 · |
![]() |
曾经爱过的机器猫 · 6 Tasks — The Yocto ...· 6 月前 · |
![]() |
玩篮球的沙滩裤 · 刪除頑固的Eng英語美式鍵盤 – ...· 6 月前 · |
![]() |
直爽的饭盒 · 联想thinkplus真无线蓝牙耳机 ...· 6 月前 · |
jax.debug.print
and
jax.debug.breakpoint
checkify
transformation
BlockSpec
s
shard_map
xmap
jax.checkpoint
(aka
jax.remat
)
jax.numpy
module
jax.scipy
module
jax.lax
module
jax.random
module
jax.sharding
module
jax.debug
module
jax.dlpack
module
jax.distributed
module
jax.dtypes
module
jax.flatten_util
module
jax.image
module
jax.nn
module
jax.nn.initializers
module
jax.ops
module
jax.profiler
module
jax.stages
module
jax.tree
module
jax.tree_util
module
jax.typing
module
jax.export
module
jax.extend
module
jax.extend.ffi
module
jax.extend.linear_util
module
jax.extend.mlir
module
jax.extend.random
module
jax.example_libraries
module
jax.example_libraries.optimizers
module
jax.example_libraries.stax
module
jax.experimental
module
jax.experimental.array_api
module
jax.experimental.checkify
module
jax.experimental.host_callback
module
jax.experimental.maps
module
jax.experimental.pjit
module
jax.experimental.sparse
module
jax.experimental.jet
module
jax.experimental.custom_partitioning
module
jax.experimental.multihost_utils
module
jax.experimental.compilation_cache
module
jax.experimental.key_reuse
module
jax.experimental.mesh_utils
module
jax.experimental.serialize_executable
module
jax.experimental.shard_map
module
jax.lib
module
Define a vectorized function with broadcasting.
vectorize()
is a convenience wrapper for defining vectorized
functions with broadcasting, in the style of NumPy’s
generalized universal functions
.
It allows for defining functions that are automatically repeated across
any leading dimensions, without the implementation of the function needing to
be concerned about how to handle higher dimensional inputs.
jax.numpy.vectorize()
has the same interface as
numpy.vectorize
, but it is syntactic sugar for an auto-batching
transformation (
vmap()
) rather than a Python loop. This should be
considerably more efficient, but the implementation must be written in terms
of functions that act on JAX arrays.
pyfunc – function to vectorize.
excluded
– optional set of integers representing positional arguments for
which the function will not be vectorized. These will be passed directly
to
pyfunc
unmodified.
signature
– optional generalized universal function signature, e.g.,
(m,n),(n)->(m)
for vectorized matrix-vector multiplication. If
provided,
pyfunc
will be called with (and expected to return) arrays
with shapes given by the size of corresponding core dimensions. By
default, pyfunc is assumed to take scalars arrays as input and output.
Vectorized version of the given function.
Here are a few examples of how one could write vectorized linear algebra
routines using
vectorize()
:
>>> from functools import partial
>>> @partial(jnp.vectorize, signature='(k),(k)->(k)')
... def cross_product(a, b):
... assert a.shape == b.shape and a.ndim == b.ndim == 1
... return jnp.array([a[1] * b[2] - a[2] * b[1],
... a[2] * b[0] - a[0] * b[2],
... a[0] * b[1] - a[1] * b[0]])
>>> @partial(jnp.vectorize, signature='(n,m),(m)->(n)')
... def matrix_vector_product(matrix, vector):
... assert matrix.ndim == 2 and matrix.shape[1:] == vector.shape
... return matrix @ vector
These functions are only written to handle 1D or 2D arrays (the assert
statements will never be violated), but with vectorize they support
arbitrary dimensional inputs with NumPy style broadcasting, e.g.,
>>> cross_product(jnp.ones(3), jnp.ones(3)).shape
>>> cross_product(jnp.ones((2, 3)), jnp.ones(3)).shape
(2, 3)
>>> cross_product(jnp.ones((1, 2, 3)), jnp.ones((2, 1, 3))).shape
(2, 2, 3)
>>> matrix_vector_product(jnp.ones(3), jnp.ones(3))
Traceback (most recent call last):
ValueError: input with shape (3,) does not have enough dimensions for all
core dimensions ('n', 'k') on vectorized function with excluded=frozenset()
and signature='(n,k),(k)->(k)'
>>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones(3)).shape
>>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones((4, 3))).shape
(4, 2)
Note that this has different semantics than jnp.matmul:
>>> jnp.matmul(jnp.ones((2, 3)), jnp.ones((4, 3)))
Traceback (most recent call last):
TypeError: dot_general requires contracting dimensions to have the same shape, got [3] and [4].
![]() |
旅途中的茄子 · 更新SQL子查询 - 腾讯云开发者社区 - 腾讯云 1 月前 |