New AGLU activation module (#14644)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
c6db604fe1
commit
b7c90526c8
4 changed files with 40 additions and 0 deletions
22
ultralytics/nn/modules/activation.py
Normal file
22
ultralytics/nn/modules/activation.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""Activation modules."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class AGLU(nn.Module):
|
||||
"""Unified activation function module from https://github.com/kostas1515/AGLU."""
|
||||
|
||||
def __init__(self, device=None, dtype=None) -> None:
|
||||
"""Initialize the Unified activation function."""
|
||||
super().__init__()
|
||||
self.act = nn.Softplus(beta=-1.0)
|
||||
self.lambd = nn.Parameter(nn.init.uniform_(torch.empty(1, device=device, dtype=dtype))) # lambda parameter
|
||||
self.kappa = nn.Parameter(nn.init.uniform_(torch.empty(1, device=device, dtype=dtype))) # kappa parameter
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute the forward pass of the Unified activation function."""
|
||||
lam = torch.clamp(self.lambd, min=0.0001)
|
||||
y = torch.exp((1 / lam) * self.act((self.kappa * x) - torch.log(lam)))
|
||||
return y # for AGLU simply return y * input
|
||||
Loading…
Add table
Add a link
Reference in a new issue