Skip to content

Commit

Permalink
store cls instead of an obj
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc committed Feb 19, 2025
1 parent f5929e0 commit 79d3096
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions src/diffusers/models/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
if is_torch_npu_available():
import torch_npu

ACTIVATION_FUNCTIONS = {
"swish": nn.SiLU(),
"silu": nn.SiLU(),
"mish": nn.Mish(),
"gelu": nn.GELU(),
"relu": nn.ReLU(),
ACT2CLS = {
"swish": nn.SiLU,
"silu": nn.SiLU,
"mish": nn.Mish,
"gelu": nn.GELU,
"relu": nn.ReLU,
}


Expand All @@ -44,11 +44,10 @@ def get_activation(act_fn: str) -> nn.Module:
"""

act_fn = act_fn.lower()
if act_fn in ACTIVATION_FUNCTIONS:
return ACTIVATION_FUNCTIONS[act_fn]
if act_fn in ACT2CLS:
return ACT2CLS[act_fn]()
else:
raise ValueError(f"Unsupported activation function: {act_fn}")

raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}")

class FP32SiLU(nn.Module):
r"""
Expand Down

0 comments on commit 79d3096

Please sign in to comment.