mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
[test] merge old components to test to model zoo (#4945)
* [test] add custom models in model zoo * [test] update legacy test * [test] update model zoo * [test] update gemini test * [test] remove components to test
This commit is contained in:
4
tests/kit/model_zoo/custom/__init__.py
Normal file
4
tests/kit/model_zoo/custom/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .hanging_param_model import *
|
||||
from .nested_model import *
|
||||
from .repeated_computed_layers import *
|
||||
from .simple_net import *
|
26
tests/kit/model_zoo/custom/base.py
Normal file
26
tests/kit/model_zoo/custom/base.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
|
||||
class CheckpointModule(nn.Module):
|
||||
def __init__(self, checkpoint: bool = False):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self._use_checkpoint = checkpoint
|
||||
|
||||
def _forward(self, *args, **kwargs):
|
||||
raise NotImplementedError("CheckpointModule should implement _forward method instead of origin forward")
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self._use_checkpoint:
|
||||
return checkpoint(self._forward, *args, **kwargs)
|
||||
else:
|
||||
return self._forward(*args, **kwargs)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
self._use_checkpoint = self.checkpoint
|
||||
return super().train(mode=mode)
|
||||
|
||||
def eval(self):
|
||||
self._use_checkpoint = False
|
||||
return super().eval()
|
48
tests/kit/model_zoo/custom/hanging_param_model.py
Normal file
48
tests/kit/model_zoo/custom/hanging_param_model.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..registry import model_zoo
|
||||
from .base import CheckpointModule
|
||||
|
||||
|
||||
class HangingParamModule(CheckpointModule):
|
||||
"""
|
||||
Hanging Parameter: a parameter dose not belong to a leaf Module.
|
||||
It has subordinate nn.modules and a nn.Parameter.
|
||||
"""
|
||||
|
||||
def __init__(self, checkpoint=False) -> None:
|
||||
super().__init__(checkpoint=checkpoint)
|
||||
self.proj1 = nn.Linear(4, 8)
|
||||
self.weight = nn.Parameter(torch.randn(8, 8))
|
||||
self.proj2 = nn.Linear(8, 4)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj1(x)
|
||||
x = F.linear(x, self.weight)
|
||||
x = self.proj2(x)
|
||||
return x
|
||||
|
||||
|
||||
def data_gen():
|
||||
return dict(x=torch.rand(16, 4))
|
||||
|
||||
|
||||
def loss_fn(x):
|
||||
outputs = x["x"]
|
||||
label = torch.randint(low=0, high=2, size=(16,), device=outputs.device)
|
||||
return F.cross_entropy(x["x"], label)
|
||||
|
||||
|
||||
def output_transform(x: torch.Tensor):
|
||||
return dict(x=x)
|
||||
|
||||
|
||||
model_zoo.register(
|
||||
name="custom_hanging_param_model",
|
||||
model_fn=HangingParamModule,
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform,
|
||||
loss_fn=loss_fn,
|
||||
)
|
53
tests/kit/model_zoo/custom/nested_model.py
Normal file
53
tests/kit/model_zoo/custom/nested_model.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..registry import model_zoo
|
||||
from .base import CheckpointModule
|
||||
|
||||
|
||||
class SubNet(nn.Module):
|
||||
def __init__(self, out_features) -> None:
|
||||
super().__init__()
|
||||
self.bias = nn.Parameter(torch.zeros(out_features))
|
||||
|
||||
def forward(self, x, weight):
|
||||
return F.linear(x, weight, self.bias)
|
||||
|
||||
|
||||
class NestedNet(CheckpointModule):
|
||||
def __init__(self, checkpoint=False) -> None:
|
||||
super().__init__(checkpoint)
|
||||
self.fc1 = nn.Linear(5, 5)
|
||||
self.sub_fc = SubNet(5)
|
||||
self.fc2 = nn.Linear(5, 2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.sub_fc(x, self.fc1.weight)
|
||||
x = self.fc1(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
def data_gen():
|
||||
return dict(x=torch.rand(16, 5))
|
||||
|
||||
|
||||
def loss_fn(x):
|
||||
outputs = x["x"]
|
||||
label = torch.randint(low=0, high=2, size=(16,), device=outputs.device)
|
||||
return F.cross_entropy(x["x"], label)
|
||||
|
||||
|
||||
def output_transform(x: torch.Tensor):
|
||||
return dict(x=x)
|
||||
|
||||
|
||||
model_zoo.register(
|
||||
name="custom_nested_model",
|
||||
model_fn=NestedNet,
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform,
|
||||
loss_fn=loss_fn,
|
||||
)
|
48
tests/kit/model_zoo/custom/repeated_computed_layers.py
Normal file
48
tests/kit/model_zoo/custom/repeated_computed_layers.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..registry import model_zoo
|
||||
from .base import CheckpointModule
|
||||
|
||||
|
||||
class NetWithRepeatedlyComputedLayers(CheckpointModule):
|
||||
"""
|
||||
This model is to test with layers which go through forward pass multiple times.
|
||||
In this model, the fc1 and fc2 call forward twice
|
||||
"""
|
||||
|
||||
def __init__(self, checkpoint=False) -> None:
|
||||
super().__init__(checkpoint=checkpoint)
|
||||
self.fc1 = nn.Linear(5, 5)
|
||||
self.fc2 = nn.Linear(5, 5)
|
||||
self.fc3 = nn.Linear(5, 2)
|
||||
self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3]
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def data_gen():
|
||||
return dict(x=torch.rand(16, 5))
|
||||
|
||||
|
||||
def loss_fn(x):
|
||||
outputs = x["x"]
|
||||
label = torch.randint(low=0, high=2, size=(16,), device=outputs.device)
|
||||
return F.cross_entropy(x["x"], label)
|
||||
|
||||
|
||||
def output_transform(x: torch.Tensor):
|
||||
return dict(x=x)
|
||||
|
||||
|
||||
model_zoo.register(
|
||||
name="custom_repeated_computed_layers",
|
||||
model_fn=NetWithRepeatedlyComputedLayers,
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform,
|
||||
loss_fn=loss_fn,
|
||||
)
|
53
tests/kit/model_zoo/custom/simple_net.py
Normal file
53
tests/kit/model_zoo/custom/simple_net.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..registry import model_zoo
|
||||
from .base import CheckpointModule
|
||||
|
||||
|
||||
class SimpleNet(CheckpointModule):
|
||||
"""
|
||||
In this no-leaf module, it has subordinate nn.modules and a nn.Parameter.
|
||||
"""
|
||||
|
||||
def __init__(self, checkpoint=False) -> None:
|
||||
super().__init__(checkpoint=checkpoint)
|
||||
self.embed = nn.Embedding(20, 4)
|
||||
self.proj1 = nn.Linear(4, 8)
|
||||
self.ln1 = nn.LayerNorm(8)
|
||||
self.proj2 = nn.Linear(8, 4)
|
||||
self.ln2 = nn.LayerNorm(4)
|
||||
self.classifier = nn.Linear(4, 4)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.embed(x)
|
||||
x = self.proj1(x)
|
||||
x = self.ln1(x)
|
||||
x = self.proj2(x)
|
||||
x = self.ln2(x)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
def data_gen():
|
||||
return dict(x=torch.randint(low=0, high=20, size=(16,)))
|
||||
|
||||
|
||||
def loss_fn(x):
|
||||
outputs = x["x"]
|
||||
label = torch.randint(low=0, high=2, size=(16,), device=outputs.device)
|
||||
return F.cross_entropy(x["x"], label)
|
||||
|
||||
|
||||
def output_transform(x: torch.Tensor):
|
||||
return dict(x=x)
|
||||
|
||||
|
||||
model_zoo.register(
|
||||
name="custom_simple_net",
|
||||
model_fn=SimpleNet,
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform,
|
||||
loss_fn=loss_fn,
|
||||
)
|
Reference in New Issue
Block a user