Skip to content

Factorization ΒΆ

A factorization model realizes a tensor expression to approximate a target tensor.

Note

Please use one of the from_* class methods to construct a factorization model. The __init__ method is NOT recommended for direct usage.

Parameters:

Name Type Description Default
tsrex TsrEx

A tensor expression.

required
extra_attributes kwargs

extra attributes to be stored in verbatim.

{}

Examples:

>>> import funfact as ff
>>> a = ff.tensor('a', 2, 3)
>>> b = ff.tensor('b', 3, 4)
>>> i, j, k = ff.indices(3)
>>> ff.Factorization.from_tsrex(a[i, j] * b[j, k])
<funfact.model._factorization.Factorization object at 0x7f5838105ee0>
Source code in funfact/model/_factorization.py
class Factorization:
    '''A factorization model realizes a tensor expression to approximate a
    target tensor.

    !!! note
        Please use one of the `from_*` class methods to construct a
        factorization model. The `__init__` method is **NOT** recommended for
        direct usage.

    Args:
        tsrex (TsrEx): A tensor expression.
        extra_attributes (kwargs): extra attributes to be stored in verbatim.

    Examples:
        >>> import funfact as ff
        >>> a = ff.tensor('a', 2, 3)
        >>> b = ff.tensor('b', 3, 4)
        >>> i, j, k = ff.indices(3)
        >>> ff.Factorization.from_tsrex(a[i, j] * b[j, k])
        <funfact.model._factorization.Factorization object at 0x7f5838105ee0>
    '''

    def __init__(self, tsrex, _secret=None, **extra_attributes):
        if _secret != '50A-2117':
            raise RuntimeError(
                'Please use one of the `from_*` methods to create a '
                'factorization from a tensor expression'
            )
        self.tsrex = tsrex
        self.__dict__.update(**extra_attributes)

    @classmethod
    def from_tsrex(
        cls, tsrex, dtype=None, vec_size=None, vec_axis=0, initialize=True
    ):
        '''Construct a factorization model from a tensor expresson.

        Args:
            tsrex (TsrEx): The tensor expression.
            dtype: numerical data type, defaults to float32.
            vec_size (int):
                Whether to vectorize the tensor expression with parallel
                instances.
            vec_axis (0 or -1): The position of the vectorization dimension.
            initialize (bool):
                Whether or not to fill abstract tensors with actual data.
        '''
        if vec_size:
            tsrex = vectorize(
                tsrex, vec_size, append=True if vec_axis == -1 else False
            )
        tsrex = tsrex | IndexnessAnalyzer() | TypeDeducer() | EinopCompiler()
        if initialize:
            tsrex = tsrex | LeafInitializer(dtype)
        return cls(tsrex, _secret='50A-2117')

    @classmethod
    def _from_jax_flatten(cls, tsrex, factors):
        '''
        '''
        tsrex = tsrex | IndexnessAnalyzer()
        fac = cls(tsrex, _secret='50A-2117')
        fac.factors = factors
        return fac

    @property
    def factors(self):
        '''A flattened list of optimizable factors in the model.

        Examples:
            >>> import funfact as ff
            >>> a = ff.tensor('a', 2, 3, optimizable=False)
            >>> b = ff.tensor('b', 3, 4)
            >>> i, j, k = ff.indices(3)
            >>> fac = ff.Factorization.from_tsrex(
            ...     a[i, j] * b[j, k],
            ...     initialize=True
            ... )
            >>> fac.factors
            <'data' field of tensor b>
            >>> fac.factors[0]
            DeviceArray([[ 0.5920733 ,  0.17746426, -1.8907379 , -0.10324025],
                         [ 0.05991533,  2.5538554 ,  0.05718338,  0.8887682 ],
                         [ 0.54816544,  2.3392196 ,  1.1973379 ,  0.04005199]],
                          dtype=float32)
        '''
        return self._NodeView(
            'data',
            list(unique(dfs_filter(
                lambda n: n.name in ['tensor', 'parametrized_tensor'] and
                n.decl.optimizable,
                self.tsrex.root
            )))
        )

    @factors.setter
    def factors(self, tensors):
        for i, n in enumerate(unique(
            dfs_filter(lambda n: n.name in ['tensor', 'parametrized_tensor']
                       and n.decl.optimizable, self.tsrex.root)
        )):
            n.data = tensors[i]

    @property
    def all_factors(self):
        '''A flattened list of all factors in the model.

        Examples:
            >>> import funfact as ff
            >>> a = ff.tensor('a', 2, 3, optimizable=False)
            >>> b = ff.tensor('b', 3, 4)
            >>> i, j, k = ff.indices(3)
            >>> fac = ff.Factorization.from_tsrex(
            ...     a[i, j] * b[j, k],
            ...     initialize=True
            ... )
            >>> fac.all_factors
            <'data' fields of tensors a, b>
            >>> fac.all_factors[0]
            DeviceArray([[[ 0.2509914 ],
                          [-0.5063717 ],
                          [-1.0069973 ]],
                         [[ 1.1088423 ],
                          [ 0.31595513],
                          [-0.11492359]]], dtype=float32)
        '''
        return self._NodeView(
            'data',
            list(unique(dfs_filter(
                lambda n: n.name == 'tensor', self.tsrex.root
            )))
        )

    @property
    def tsrex(self):
        '''The underlying tensor expression.'''
        return self._tsrex

    @tsrex.setter
    def tsrex(self, tsrex):
        '''Setting the underlying tensor expression.'''
        self._tsrex = tsrex

    @property
    def shape(self):
        '''The shape of the result tensor.'''
        return self.tsrex.shape

    @property
    def ndim(self):
        '''The dimensionality of the result tensor.'''
        return self.tsrex.ndim

    def penalty(self, sum_leafs: bool = True, sum_vec=False):
        '''The penalty of the result tensor.

        Args:
            sum_leafs (bool): sum the penalties over the leafs of the model.
            sum_vec (bool): sum the penalties over the vectorization dimension.
        '''

        factors = list(unique(dfs_filter(
                lambda n: n.name == 'tensor' and n.decl.optimizable,
                self.tsrex.root
        )))
        penalties = ab.stack(
            [f.decl.prefer(f.data, sum_vec) for f in factors],
            0 if sum_vec else -1
        )
        if sum_leafs:
            return ab.sum(penalties, 0 if sum_vec else -1)
        else:
            return penalties

    def __call__(self):
        '''Shorthand for :py:meth:`forward`.'''
        return self.forward()

    def forward(self):
        '''Evaluate the tensor expression the result.'''
        return self.tsrex | Evaluator()

    @staticmethod
    def _as_slice(i, axis):
        if isinstance(axis, slice):
            return axis
        elif isinstance(axis, Integral):
            if axis != -1:
                return slice(axis, axis + 1)
            else:
                return slice(axis, None)
        elif hasattr(axis, '__iter__'):
            return tuple(axis)
        elif axis is Ellipsis:
            return None
        else:
            raise RuntimeError(
                f'Invalid index for axis {i}: {axis}'
            )

    def _get_elements(self, key):
        '''Get elements at index of tensor expression.'''
        # Generate full index list
        indices = tuple([self._as_slice(i, axis) for i, axis in
                        enumerate(key)])
        try:
            i = key.index(Ellipsis)
        except ValueError:
            pass
        else:
            indices = tuple([
                *indices[:i],
                *[slice(None)] * (self.ndim - len(indices) + 1),
                *indices[i + 1:]
            ])

        # Validate full index list
        if len(indices) != self.ndim:
            raise IndexError(
                f'Wrong number of indices: expected {self.ndim}, '
                f'got {len(indices)}.'
            )

        # Evaluate model
        return self.tsrex | SlicingPropagator(indices) \
                          | ElementwiseEvaluator()

    def __getitem__(self, idx):
        '''Implements attribute-based access of factor tensors or output
        elements.'''
        if isinstance(idx, str):
            for n in unique(dfs_filter(
                lambda n: n.name == 'tensor' and str(n.decl.symbol) == idx,
                self.tsrex.root
            )):
                return n.data
            raise AttributeError(f'No factor tensor named {idx}.')
        else:
            return self._get_elements(idx)

    def __setitem__(self, name, data):
        '''Implements attribute-based access of factor tensors.'''
        for n in unique(dfs_filter(
            lambda n: n.name == 'tensor' and str(n.decl.symbol) == name,
            self.tsrex.root
        )):
            return setattr(n, 'data', data)
        raise AttributeError(f'No factor tensor named {name}.')

    class _NodeView:
        def __init__(self, attribute: str, nodes):
            self.attribute = attribute
            self.nodes = nodes

        def __repr__(self):
            return '<{attr} field{pl} of tensor{pl} {tensors}>'.format(
                attr=repr(self.attribute),
                tensors=', '.join([str(n.decl) for n in self.nodes]),
                pl='s' if len(self.nodes) > 1 else ''
            )

        def __getitem__(self, i):
            return getattr(self.nodes[i], self.attribute)

        def __setitem__(self, i, value):
            setattr(self.nodes[i], self.attribute, value)

        def __iter__(self):
            for n in self.nodes:
                yield getattr(n, self.attribute)

        def __len__(self):
            return len(self.nodes)
Back to top