Skip to content

Backends and Acceleration

FunFact delegates numerical linear algebra operations, automatic differentiation, and hardware acceleration to external linear algebra packages that conforms to the NumPy API. Currently supported backends include JAX, PyTorch, and NumPy.

Note

The NumPy backend only supports forward evaluation.

The list of available backends can be query by:

import funfact as ff
ff.available_backends
['jax', 'torch', 'numpy']

The backend can be selected with the use method. The backend that is currently in use can be retrieved through the active_backend method:

from funfact import use, active_backend as ab
use('numpy')
ab
<backend 'NumpyBackend'>

from funfact import use, active_backend as ab
use('jax')
ab
<backend 'JAXBackend'>

from funfact import use, active_backend as ab
use('torch')
ab
<backend 'PyTorchBackend'>

The active backend can be imported and used as if it is the underlyinng NLA package:

from funfact import active_backend as ab
ab.eye(...)
ab.zeros(...)
# uses any method already defined by the underlying package (np, jnp, torch)
ab.*method*(...)

Besides this, a FunFact backend implements a few additional methods such as:

ab.tensor(...)          # create native tensor from array-like data
ab.to_numpy(...)        # convert native tensor to NumPy array
ab.normal(...)          # normally distributed random data
ab.uniform(...)         # uniformly distributed random data

Switching backends

Backends maybe be switched dynamically during the lifetime of a process:

ff.use('numpy')
a = ff.tensor(3, 2)
b = ff.tensor(2, 3)
tsrex = a @ b           # tensor expression with NumPy backend
...
ff.use('jax')
c = ff.tensor(3, 4)
d = ff.tensor(5, 6)
tsrex = a & b           # tensor expression with JAX backend

Warning

Dynamic switching of the backends will not automatically port the data in an existing tensor expression/model to the new backend.

Note

Some properties of the backend can only be set once when that backend is loaded for the first time in a process. For example, with the JAX backend, the enable_x64 flag:

ff.use('jax', enable_x64=True)
can only be set once. Running use a second time will not affect this behavior.

Hardware acceleration

The JAX and PyTorch backends support hardware acceleration on the GPU

Warning

TODO: complete this section once this functionality is implemented in use.

Back to top