Backends and Acceleration¶
FunFact delegates numerical linear algebra operations, automatic differentiation, and hardware acceleration to external linear algebra packages that conforms to the NumPy API. Currently supported backends include JAX, PyTorch, and NumPy.
Note
The NumPy backend only supports forward evaluation.
The list of available backends can be query by:
The backend can be selected with the use
method.
The backend that is currently in use can be retrieved through the
active_backend
method:
The active backend can be imported and used as if it is the underlyinng NLA package:
from funfact import active_backend as ab
ab.eye(...)
ab.zeros(...)
# uses any method already defined by the underlying package (np, jnp, torch)
ab.*method*(...)
Besides this, a FunFact backend implements a few additional methods such as:
ab.tensor(...) # create native tensor from array-like data
ab.to_numpy(...) # convert native tensor to NumPy array
ab.normal(...) # normally distributed random data
ab.uniform(...) # uniformly distributed random data
Switching backends¶
Backends maybe be switched dynamically during the lifetime of a process:
ff.use('numpy')
a = ff.tensor(3, 2)
b = ff.tensor(2, 3)
tsrex = a @ b # tensor expression with NumPy backend
...
ff.use('jax')
c = ff.tensor(3, 4)
d = ff.tensor(5, 6)
tsrex = a & b # tensor expression with JAX backend
Warning
Dynamic switching of the backends will not automatically port the data in an existing tensor expression/model to the new backend.
Note
Some properties of the backend can only be set once when that backend is
loaded for the first time in a process. For example, with the JAX backend,
the enable_x64
flag:
use
a second time will not affect this
behavior.
Hardware acceleration¶
The JAX and PyTorch backends support hardware acceleration on the GPU
Warning
TODO: complete this section once this functionality is implemented in use
.