mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-06 12:07:00 +00:00
* [elixir] add elixir * [elixir] add unit tests * remove useless code * fix python 3.8 issue * fix typo * add test skip * add docstrings * add docstrings * add readme * fix typo
35 lines
784 B
Python
35 lines
784 B
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
from tests.test_elixir.utils.registry import TEST_MODELS
|
|
|
|
|
|
def mlp_data_fn():
|
|
return dict(x=torch.randn(4, 16))
|
|
|
|
|
|
class MlpModule(nn.Module):
|
|
|
|
def __init__(self, hidden_dim: int = 16) -> None:
|
|
super().__init__()
|
|
self.proj1 = nn.Linear(hidden_dim, 4 * hidden_dim)
|
|
self.act = nn.GELU()
|
|
self.proj2 = nn.Linear(4 * hidden_dim, hidden_dim)
|
|
|
|
def forward(self, x):
|
|
return x + (self.proj2(self.act(self.proj1(x))))
|
|
|
|
|
|
class MlpModel(nn.Module):
|
|
|
|
def __init__(self, hidden_dim: int = 16) -> None:
|
|
super().__init__()
|
|
self.mlp = MlpModule(hidden_dim)
|
|
|
|
def forward(self, x):
|
|
output = self.mlp(x)
|
|
return output.sum()
|
|
|
|
|
|
TEST_MODELS.register('mlp', MlpModel, mlp_data_fn)
|