diff --git a/sf3d/models/network.py b/sf3d/models/network.py index 3886778..92359a2 100644 --- a/sf3d/models/network.py +++ b/sf3d/models/network.py @@ -80,46 +80,38 @@ def backward(ctx, g): # pylint: disable=arguments-differ trunc_exp = _TruncExp.apply +activation_functions = { + "none": lambda x: x, + "linear": lambda x: x, + "identity": lambda x: x, + "lin2srgb": lambda x: torch.where( + x > 0.0031308, + torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055, + 12.92 * x, + ).clamp(0.0, 1.0), + "exp": torch.exp, + "shifted_exp": lambda x: torch.exp(x - 1.0), + "trunc_exp": trunc_exp, + "shifted_trunc_exp": lambda x: trunc_exp(x - 1.0), + "sigmoid": torch.sigmoid, + "tanh": torch.tanh, + "shifted_softplus": lambda x: F.softplus(x - 1.0), + "scale_-11_01": lambda x: x * 0.5 + 0.5, + "negative": lambda x: -x, + "normalize_channel_last": normalize, + "normalize_channel_first": lambda x: normalize(x, dim=1) +} + def get_activation(name) -> Callable: if name is None: return lambda x: x name = name.lower() - if name == "none" or name == "linear" or name == "identity": - return lambda x: x - elif name == "lin2srgb": - return lambda x: torch.where( - x > 0.0031308, - torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055, - 12.92 * x, - ).clamp(0.0, 1.0) - elif name == "exp": - return lambda x: torch.exp(x) - elif name == "shifted_exp": - return lambda x: torch.exp(x - 1.0) - elif name == "trunc_exp": - return trunc_exp - elif name == "shifted_trunc_exp": - return lambda x: trunc_exp(x - 1.0) - elif name == "sigmoid": - return lambda x: torch.sigmoid(x) - elif name == "tanh": - return lambda x: torch.tanh(x) - elif name == "shifted_softplus": - return lambda x: F.softplus(x - 1.0) - elif name == "scale_-11_01": - return lambda x: x * 0.5 + 0.5 - elif name == "negative": - return lambda x: -x - elif name == "normalize_channel_last": - return lambda x: normalize(x) - elif name == "normalize_channel_first": - return lambda x: normalize(x, dim=1) - else: - try: - return getattr(F, name) - except AttributeError: - raise ValueError(f"Unknown activation function: {name}") - + if name in activation_functions: + return activation_functions[name] + try: + return getattr(F, name) + except AttributeError: + raise ValueError(f"Unknown activation function: {name}") @dataclass class HeadSpec: @@ -193,3 +185,4 @@ def forward( } return out +