grgrjax#
This library provides some useful additional tools to JAX, such as a multivariate Newton method, val_and_jacfwd
or vectorized JVP and VJP functions.
The module documentation can be found on Readthedocs.
Installation with pip
#
It’s as simple as typing
pip install grgrjax
in the terminal (Unix/Mac) or the Anaconda Prompt (Win).
Functions#
- grgrjax.newton_jax(func, init, maxit=30, tol=1e-08, relaxation=1, rtol=None, solver=None, verbose=True, verbose_jac=False)#
Newton method for root finding of func using automatic differenciation with jax. The argument func must be jittable with jax. newton_jax itself is not jittable, for this use newton_jax_jit.
- Parameters:
func (callable) – Function f for which f(x)=0 should be found. Is assumed to return a pair (value, jacobian) or (value, jacobian, aux). If not, val_and_jacfwd will be applied to the function, in which case the function must be jittable with jax.
init (array) – Initial values of x
maxit (int, optional) – Maximum number of iterations
tol (float, optional) – Required tolerance, defaults to 1e-8
relaxation (float, optional) – relaxation factor applied to each newton iteration, defaults to 1
solver (callable, optional) – Provide a custom solver solver(J,f) for J@x = f, defaults to jax.numpy.linalg.solve
verbose (bool, optional) – Whether to display messages
verbose_jac (bool, optional) – Whether to supply additional information on the determinant of the jacobian (computationally more costly).
- Returns:
res – A dictionary of results similar to the output from scipy.optimize.root
- Return type:
dict
- grgrjax.newton_jax_jit(func, init, maxit=30, tol=1e-08, relaxation=1, verbose=True)#
Newton method for root finding of func using automatic differentiation with jax and running in and as jitted jax.
- Parameters:
func (callable) – Function returning (y, jac) where f(x)=y=0 should be found and jac is the jacobian. Must be jittable with jax. Could e.g. be the output of val_and_jacfwd. The function must be pure jac (i.e. use jax.tree_util.Partial instead of lambda where applicable).
init (array) – Initial values of x
maxit (int, optional) – Maximum number of iterations
tol (float, optional) – Required tolerance. Defaults to 1e-8
relaxation (float, optional) – relaxation factor applied to each newton iteration, defaults to 1
verbose (bool, optional) – Whether to display messages
- Returns:
xopt (array) – Solution value x for f
(fopt, jacopt) (tuple of arrays) – : value (fopt) and Jacobian (jacopt) of func at xopt
niter (int) – Number of iterations
success (bool) – Wether the convergence criterion was reached
- grgrjax.amax(x, return_arg=False)#
Return the maximum absolute value.
- grgrjax.callback_func(cnt, err, *args, fev=None, ltime=None, verbose=True)#
Print a formatted on-line update for a iterative process.
- grgrjax.jax_print(w)#
Print in jax compiled functions. Wrapper around jax.experimental.host_callback.id_print.
- grgrjax.jvp_vmap(fun: Callable, argnums=0)#
Vectorized (forward-mode) jacobian-vector product of
fun
. This is by large adopted from the implementation of jacfwd in jax._src.api.- Parameters:
fun – Function whose value and Jacobian is to be computed.
argnums – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default
0
).
- Returns:
A function with the same arguments as
fun
, that evaluates the value and Jacobian offun
using forward-mode automatic differentiation.
- grgrjax.vjp_vmap(fun: Callable, argnums=0)#
Vectorized (reverse-mode) vector-jacobian product of
fun
. This is by large adopted from the implementation of jacrev in jax._src.api.- Parameters:
fun – Function whose value and Jacobian are to be computed.
argnums – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default
0
).
- Returns:
A function with the same arguments as
fun
, that evaluates the value and Jacobian offun
using reverse-mode automatic differentiation.
- grgrjax.val_and_jacfwd(fun: Callable, argnums=0, has_aux: bool = False, holomorphic: bool = False) Callable #
Value and Jacobian of
fun
evaluated column-by-column using forward-mode AD. Apart from returning the function value, this is one-to-one adopted from jax._src.api.- Parameters:
fun – Function whose value and Jacobian are to be computed.
argnums – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default
0
).has_aux – Optional, bool. Indicates whether
fun
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic – Optional, bool. Indicates whether
fun
is promised to be holomorphic. Default False.allow_int – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.
- Returns:
A function with the same arguments as
fun
, that evaluates the value and Jacobian offun
using reverse-mode automatic differentiation. Ifhas_aux
is True then a pair of (jacobian, auxiliary_data) is returned.
- grgrjax.val_and_jacrev(fun: Callable, argnums=0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False) Callable #
Value and Jacobian of
fun
evaluated row-by-row using reverse-mode AD. Apart from returning the function value, this is one-to-one adopted from jax._src.api.- Parameters:
fun – Function whose value and Jacobian are to be computed.
argnums – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default
0
).has_aux – Optional, bool. Indicates whether
fun
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic – Optional, bool. Indicates whether
fun
is promised to be holomorphic. Default False.allow_int – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.
- Returns:
A function with the same arguments as
fun
, that evaluates the value and Jacobian offun
using reverse-mode automatic differentiation. Ifhas_aux
is True then a pair of (jacobian, auxiliary_data) is returned.