diff --git a/docs/en/reference/nn/modules/activation.md b/docs/en/reference/nn/modules/activation.md new file mode 100644 index 00000000..09dd92ed --- /dev/null +++ b/docs/en/reference/nn/modules/activation.md @@ -0,0 +1,16 @@ +--- +description: Explore activation functions in Ultralytics, including the Unified activation function and other custom implementations for neural networks. +keywords: ultralytics, activation functions, neural networks, Unified activation, AGLU, SiLU, ReLU, PyTorch, deep learning, custom activations +--- + +# Reference for `ultralytics/nn/modules/activation.py` + +!!! Note + + This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/nn/modules/activation.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/nn/modules/activation.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/nn/modules/activation.py) 🛠️. Thank you 🙏! + +
+ +## ::: ultralytics.nn.modules.activation.AGLU + +

diff --git a/docs/mkdocs_github_authors.yaml b/docs/mkdocs_github_authors.yaml index 91d29605..bfaddb3a 100644 --- a/docs/mkdocs_github_authors.yaml +++ b/docs/mkdocs_github_authors.yaml @@ -12,6 +12,7 @@ 48149018+zhixuwei@users.noreply.github.com: zhixuwei 49699333+dependabot[bot]@users.noreply.github.com: dependabot 52826299+Chayanonjackal@users.noreply.github.com: Chayanonjackal +53246858+hasanghaffari93@users.noreply.github.com: hasanghaffari93 61612323+Laughing-q@users.noreply.github.com: Laughing-q 62214284+Burhan-Q@users.noreply.github.com: Burhan-Q 68285002+Kayzwer@users.noreply.github.com: Kayzwer diff --git a/mkdocs.yml b/mkdocs.yml index e3f38ce1..ad11736f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -536,6 +536,7 @@ nav: - nn: - autobackend: reference/nn/autobackend.md - modules: + - activation: reference/nn/modules/activation.md - block: reference/nn/modules/block.md - conv: reference/nn/modules/conv.md - head: reference/nn/modules/head.md diff --git a/ultralytics/nn/modules/activation.py b/ultralytics/nn/modules/activation.py new file mode 100644 index 00000000..25cca2a5 --- /dev/null +++ b/ultralytics/nn/modules/activation.py @@ -0,0 +1,22 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +"""Activation modules.""" + +import torch +import torch.nn as nn + + +class AGLU(nn.Module): + """Unified activation function module from https://github.com/kostas1515/AGLU.""" + + def __init__(self, device=None, dtype=None) -> None: + """Initialize the Unified activation function.""" + super().__init__() + self.act = nn.Softplus(beta=-1.0) + self.lambd = nn.Parameter(nn.init.uniform_(torch.empty(1, device=device, dtype=dtype))) # lambda parameter + self.kappa = nn.Parameter(nn.init.uniform_(torch.empty(1, device=device, dtype=dtype))) # kappa parameter + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Compute the forward pass of the Unified activation function.""" + lam = torch.clamp(self.lambd, min=0.0001) + y = torch.exp((1 / lam) * self.act((self.kappa * x) - torch.log(lam))) + return y # for AGLU simply return y * input