Source code for prototorch.functions.activations

"""ProtoTorch activation functions."""

import torch

ACTIVATIONS = dict()


def register_activation(fn):
    """Add the activation function to the registry."""
    name = fn.__name__
    ACTIVATIONS[name] = fn
    return fn


[docs]@register_activation def identity(x, beta=0.0): """Identity activation function. Definition: :math:`f(x) = x` Keyword Arguments: beta (`float`): Ignored. """ return x
[docs]@register_activation def sigmoid_beta(x, beta=10.0): r"""Sigmoid activation function with scaling. Definition: :math:`f(x) = \frac{1}{1 + e^{-\beta x}}` Keyword Arguments: beta (`float`): Scaling parameter :math:`\beta` """ out = 1.0 / (1.0 + torch.exp(-1.0 * beta * x)) return out
[docs]@register_activation def swish_beta(x, beta=10.0): r"""Swish activation function with scaling. Definition: :math:`f(x) = \frac{x}{1 + e^{-\beta x}}` Keyword Arguments: beta (`float`): Scaling parameter :math:`\beta` """ out = x * sigmoid_beta(x, beta=beta) return out
def get_activation(funcname): """Deserialize the activation function.""" if callable(funcname): return funcname if funcname in ACTIVATIONS: return ACTIVATIONS.get(funcname) raise NameError(f"Activation {funcname} was not found.")