factorize(tsrex, target, optimizer='Adam', loss='MSE', lr=0.1, tol=1e-06, max_steps=10000, vec_size=1, vec_axis=0, stop_by='first', returns='best', checkpoint_freq=50, dtype=None, penalty_weight=1.0, plugins=[])
ΒΆ
Factorize a target tensor using the given tensor expression. The solution is found by minimizing the loss function between the original and approximate tensors using stochastic gradient descent.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tsrex |
TsrEx |
A tensor expression. |
required |
target |
tensor |
The original tensor to approximate. |
required |
optimizer |
str or callable |
|
'Adam' |
loss |
str or callable |
|
'MSE' |
lr |
float |
SGD learning rate. |
0.1 |
tol |
float |
convergence tolerance. |
1e-06 |
max_steps |
int |
maximum number of SGD steps to run. |
10000 |
vec_size |
int |
Number of parallel instances to compute. |
1 |
vec_axis |
0 or -1 |
The position of the vectorization dimension. |
0 |
stop_by |
'first', int >= 1, or None |
|
'first' |
returns |
'best', int >= 1, or 'all' |
|
'best' |
checkpoint_freq |
int >= 1 |
The frequency of convergence checking. |
50 |
dtype |
The datatype of the factorization model (None, ab.dtype):
|
None |
|
penalty_weight |
float) |
Weight of penalties relative to loss. |
1.0 |
plugins |
list |
Additional methods to be inserted into the gradient descent loop. |
[] |
Returns:
Type | Description |
---|---|
* |
|
Source code in funfact/algorithm.py
def factorize(
tsrex, target, optimizer='Adam', loss='MSE', lr=0.1, tol=1e-6,
max_steps=10000, vec_size=1, vec_axis=0, stop_by='first', returns='best',
checkpoint_freq=50, dtype=None, penalty_weight=1.0, plugins=[]
):
'''Factorize a target tensor using the given tensor expression. The
solution is found by minimizing the loss function between the original and
approximate tensors using stochastic gradient descent.
Args:
tsrex (TsrEx): A tensor expression.
target (tensor): The original tensor to approximate.
optimizer (str or callable):
- If `str`, must be one of the optimizers defined in
[funfact.optim]().
- If `callable`, can be any object that implements the interface of
[funfact.optim.Optimizer]().
loss (str or callable):
- If `str`, must be one of the loss functions defined in
[funfact.loss]().
- If `callable`, can be any object that implements the interface of
[funfact.loss.Loss]().
lr (float): SGD learning rate.
tol (float): convergence tolerance.
max_steps (int): maximum number of SGD steps to run.
vec_size (int): Number of parallel instances to compute.
vec_axis (0 or -1): The position of the vectorization dimension.
stop_by ('first', int >= 1, or None):
- If 'first', stop optimization as soon as one solution is
found whose loss is less than `tol` when running multiple parallel
instances.
- If int `n`, stop optimization after n instances
have found solutions with losses less than `tol`.
- If None, always optimize for `max_steps` steps.
returns ('best', int >= 1, or 'all'):
- If 'best', returns the solution with the smallest loss.
- If int `n` or 'all', returns a list of the top `n` or all of the
instances sorted in ascending order by loss.
checkpoint_freq (int >= 1): The frequency of convergence checking.
dtype: The datatype of the factorization model (None, ab.dtype):
- If None, the same data type as the target tensor is used.
- If concrete dtype (float32, float64, complex64, complex128),
that data type is used.
penalty_weight (float) : Weight of penalties relative to loss.
plugins (list):
Additional methods to be inserted into the gradient descent
loop.
Returns:
*:
- If `returns == 'best'`, return a factorization object of type
[funfact.Factorization]() representing the best solution found.
- If `returns == n`, return a list of factorization
objects representing the best `n` solutions found.
- If `returns == 'all'`, return a vectorized factorization object
that represents all the solutions.
'''
'''process arguments'''
assert vec_axis in [0, -1], "Vectorization axis must be either 0 or -1."
append = True if vec_axis == -1 else False
if dtype is None:
target = ab.tensor(target)
dtype = target.dtype
else:
target = ab.tensor(target, dtype=dtype)
fac = ab.add_autograd(Factorization).from_tsrex(
tsrex, dtype=dtype, vec_size=vec_size, vec_axis=vec_axis
)
if isinstance(optimizer, str):
try:
optimizer = getattr(funfact.optim, optimizer)
except AttributeError:
raise RuntimeError(
f'The optimizer \'{optimizer}\' does not exist in'
'funfact.optim.'
)
try:
opt = optimizer(fac.factors, lr=lr)
except Exception:
raise RuntimeError(
'Invalid optimization algorithm:\n{e}'
)
if isinstance(loss, str):
try:
loss = getattr(funfact.loss, loss)
except AttributeError:
raise RuntimeError(
f'The loss function \'{loss}\' does not exist in'
'funfact.loss.'
)
if isinstance(loss, type):
loss = loss()
try:
loss(target, target)
except Exception as e:
raise RuntimeError(
f'A loss function must accept two arguments:\n{e}'
)
def loss_and_penalty(model, target, sum_vec=True):
loss_val = loss(
model(), target, sum_vec=sum_vec, vectorized_along_last=append
)
if penalty_weight > 0:
return loss_val + penalty_weight * \
model.penalty(sum_leafs=True, sum_vec=sum_vec)
else:
return loss_val
loss_and_grad = ab.loss_and_grad(loss_and_penalty, fac, target)
if stop_by == 'first':
stop_by = 1
if not any((
stop_by is None, isinstance(stop_by, int) and stop_by > 0
)):
raise RuntimeError(f'Invalid argument value for stop_by: {stop_by}')
if not any((
returns in ['best', 'all'], isinstance(returns, int) and returns > 0
)):
raise RuntimeError(f'Invalid argument value for returns: {returns}')
'''define plugin for saving the best factorization model'''
best_factors = [np.zeros_like(ab.to_numpy(x)) for x in fac.factors]
best_loss = np.ones(vec_size) * np.inf
@gradient_descent_plugin(every=checkpoint_freq)
def save_best(state: GradientDescentState, best_loss=best_loss):
# TODO: use external validation set
validation_loss = ab.to_numpy(
loss_and_penalty(fac, target, sum_vec=False)
)
better = np.flatnonzero(validation_loss < best_loss)
best_loss = np.minimum(best_loss, validation_loss)
for b, o in zip(best_factors, fac.factors):
if append:
b[..., better] = ab.to_numpy(o[..., better])
else:
b[better, ...] = ab.to_numpy(o[better, ...])
'''define plugin for convergence test'''
converged = np.zeros(vec_size, dtype=np.bool_)
@gradient_descent_plugin(every=checkpoint_freq)
def convergence_check(state: GradientDescentState, converged=converged):
# TODO: use external validation set
validation_loss = ab.to_numpy(
loss_and_penalty(fac, target, sum_vec=False)
)
converged |= np.where(validation_loss < tol, True, False)
'''define early-exit conditions'''
def exit_condition(state: GradientDescentState):
if stop_by is not None:
return np.count_nonzero(converged) >= stop_by
else:
return False
'''run the gradient descent loop'''
gradient_descent(
lambda: loss_and_grad(fac, target), opt, max_steps, exit_condition,
plugins=plugins + [
save_best,
convergence_check
]
)
'''collect results'''
best_factors = [ab.tensor(x) for x in best_factors]
if returns == 'best':
return view(
best_factors,
Factorization.from_tsrex(tsrex, dtype=dtype),
np.argmin(best_loss), append
)
else:
if isinstance(returns, int):
instances = np.argsort(best_loss)[:returns]
elif returns == 'all':
instances = np.argsort(best_loss)
return [
view(
best_factors,
Factorization.from_tsrex(tsrex, dtype=dtype),
i, append
) for i in instances
]