grgrjax#

This library provides some useful additional tools to JAX, such as a multivariate Newton method, val_and_jacfwd or vectorized JVP and VJP functions.

The module documentation can be found on Readthedocs.

Installation with pip#

It’s as simple as typing

pip install grgrjax

in the terminal (Unix/Mac) or the Anaconda Prompt (Win).

Functions#

grgrjax.newton_jax(func, init, maxit=30, tol=1e-08, rtol=None, solver=None, verbose=True, verbose_jac=False)#

Newton method for root finding of func using automatic differenciation with jax. The argument func must be jittable with jax. newton_jax itself is not jittable, for this use newton_jax_jit.

Parameters:
  • func (callable) – Function f for which f(x)=0 should be found. Is assumed to return a pair (value, jacobian) or (value, jacobian, aux). If not, val_and_jacfwd will be applied to the function, in which case the function must be jittable with jax.

  • init (array) – Initial values of x

  • maxit (int, optional) – Maximum number of iterations

  • tol (float, optional) – Required tolerance. Defaults to 1e-8

  • solver (callable, optional) – Provide a custom solver solver(J,f) for J@x = f. defaults to jax.numpy.linalg.solve

  • verbose (bool, optional) – Whether to display messages

  • verbose_jac (bool, optional) – Whether to supply additional information on the determinant of the jacobian (computationally more costly).

Returns:

res – A dictionary of results similar to the output from scipy.optimize.root

Return type:

dict

grgrjax.newton_jax_jit(func, init, maxit=30, tol=1e-08, verbose=True)#

Newton method for root finding of func using automatic differentiation with jax and running in and as jitted jax.

Parameters:
  • func (callable) – Function returning (y, jac) where f(x)=y=0 should be found and jac is the jacobian. Must be jittable with jax. Could e.g. be the output of jacfwd_and_val. The function must be a jax.

  • init (array) – Initial values of x

  • maxit (int, optional) – Maximum number of iterations

  • tol (float, optional) – Required tolerance. Defaults to 1e-8

Returns:

  • xopt (array) – Solution value x for f

  • (fopt, jacopt) (tuple of arrays) – : value (fopt) and Jacobian (jacopt) of func at xopt

  • niter (int) – Number of iterations

  • success (bool) – Wether the convergence criterion was reached

grgrjax.amax(x, return_arg=False)#

Return the maximum absolute value.

grgrjax.callback_func(cnt, err, *args, fev=None, ltime=None, verbose=True)#

Print a formatted on-line update for a iterative process.

grgrjax.jax_print(w)#

Print in jax compiled functions. Wrapper around jax.experimental.host_callback.id_print.

grgrjax.jvp_vmap(fun: Callable, argnums=0)#

Vectorized (forward-mode) jacobian-vector product of fun. This is by large adopted from the implementation of jacfwd in jax._src.api.

Parameters:
  • fun – Function whose value and Jacobian is to be computed.

  • argnums – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

Returns:

A function with the same arguments as fun, that evaluates the value and Jacobian of fun using forward-mode automatic differentiation.

grgrjax.vjp_vmap(fun: Callable, argnums=0)#

Vectorized (reverse-mode) vector-jacobian product of fun. This is by large adopted from the implementation of jacrev in jax._src.api.

Parameters:
  • fun – Function whose value and Jacobian are to be computed.

  • argnums – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

Returns:

A function with the same arguments as fun, that evaluates the value and Jacobian of fun using reverse-mode automatic differentiation.

grgrjax.val_and_jacfwd(fun: Callable, argnums=0, has_aux: bool = False, holomorphic: bool = False) Callable#

Value and Jacobian of fun evaluated column-by-column using forward-mode AD. Apart from returning the function value, this is one-to-one adopted from jax._src.api.

Parameters:
  • fun – Function whose value and Jacobian are to be computed.

  • argnums – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

  • has_aux – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • holomorphic – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.

  • allow_int – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.

Returns:

A function with the same arguments as fun, that evaluates the value and Jacobian of fun using reverse-mode automatic differentiation. If has_aux is True then a pair of (jacobian, auxiliary_data) is returned.

grgrjax.val_and_jacrev(fun: Callable, argnums=0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False) Callable#

Value and Jacobian of fun evaluated row-by-row using reverse-mode AD. Apart from returning the function value, this is one-to-one adopted from jax._src.api.

Parameters:
  • fun – Function whose value and Jacobian are to be computed.

  • argnums – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

  • has_aux – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • holomorphic – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.

  • allow_int – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.

Returns:

A function with the same arguments as fun, that evaluates the value and Jacobian of fun using reverse-mode automatic differentiation. If has_aux is True then a pair of (jacobian, auxiliary_data) is returned.