Skip to content

use(backend, **context) ΒΆ

Specify the numerical tensor algebra library to use as the computational backend.

Parameters:

Name Type Description Default
backend str

Currently implemented backends are:

Dynamic switching betwewen backends is allowed. However, tensors created by the previous backend will not be automatically ported to the new backend.

required
context kwargs

Backend-specific additional arguments. For details, refer to the individual backends.

{}

Examples:

>>> from funfact import use, active_backend as ab
>>> use('numpy')
>>> ab
<backend 'NumpyBackend'>
>>> use('jax')
>>> ab
<backend 'JAXBackend'>
>>> use('torch')
>>> ab
<backend 'PyTorchBackend'>
>>> use('torch', device='cuda:0')
>>> ab
<backend 'PyTorchBackend'>
Source code in funfact/backend/_proxy.py
def use(backend: str, **context):
    '''Specify the numerical tensor algebra library to use as the computational
    backend.

    Args:
        backend (str):
            Currently implemented backends are:

            - `'numpy'`: the [NumPy backend](../backend/_numpy) only supports
            forward calculations but no automatic differentiation.
            - `'jax'`: [JAX backend](../backend/_jax).
            - `'torch'`: [PyTorch backend](../backend/_torch).

            Dynamic switching betwewen backends is allowed. However, tensors
            created by the previous backend will not be automatically ported to
            the new backend.

        context (kwargs): Backend-specific additional arguments.
            For details, refer to the individual backends.

    Examples:
        >>> from funfact import use, active_backend as ab
        >>> use('numpy')
        >>> ab
        <backend 'NumpyBackend'>

        >>> use('jax')
        >>> ab
        <backend 'JAXBackend'>

        >>> use('torch')
        >>> ab
        <backend 'PyTorchBackend'>

        >>> use('torch', device='cuda:0')
        >>> ab
        <backend 'PyTorchBackend'>
    '''
    global _active_backend
    try:
        with set_context(**context):
            _active_backend = importlib.import_module(
                f'funfact.backend._{backend}'
            )
    except KeyError:
        raise RuntimeError(f'Backend {backend} cannot be imported.')
Back to top