[test] merge old components to test to model zoo (#4945)

* [test] add custom models in model zoo

* [test] update legacy test

* [test] update model zoo

* [test] update gemini test

* [test] remove components to test
This commit is contained in:
Hongxin Liu
2023-10-20 10:35:08 +08:00
committed by GitHub
parent 3a41e8304e
commit b8e770c832
49 changed files with 461 additions and 914 deletions

View File

@@ -1,4 +1,5 @@
from . import diffusers, timm, torchaudio, torchrec, torchvision, transformers
from . import custom, diffusers, timm, torchaudio, torchrec, torchvision, transformers
from .executor import run_fwd, run_fwd_bwd
from .registry import model_zoo
__all__ = ["model_zoo"]
__all__ = ["model_zoo", "run_fwd", "run_fwd_bwd"]

View File

@@ -0,0 +1,4 @@
from .hanging_param_model import *
from .nested_model import *
from .repeated_computed_layers import *
from .simple_net import *

View File

@@ -0,0 +1,26 @@
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
class CheckpointModule(nn.Module):
def __init__(self, checkpoint: bool = False):
super().__init__()
self.checkpoint = checkpoint
self._use_checkpoint = checkpoint
def _forward(self, *args, **kwargs):
raise NotImplementedError("CheckpointModule should implement _forward method instead of origin forward")
def forward(self, *args, **kwargs):
if self._use_checkpoint:
return checkpoint(self._forward, *args, **kwargs)
else:
return self._forward(*args, **kwargs)
def train(self, mode: bool = True):
self._use_checkpoint = self.checkpoint
return super().train(mode=mode)
def eval(self):
self._use_checkpoint = False
return super().eval()

View File

@@ -0,0 +1,48 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..registry import model_zoo
from .base import CheckpointModule
class HangingParamModule(CheckpointModule):
"""
Hanging Parameter: a parameter dose not belong to a leaf Module.
It has subordinate nn.modules and a nn.Parameter.
"""
def __init__(self, checkpoint=False) -> None:
super().__init__(checkpoint=checkpoint)
self.proj1 = nn.Linear(4, 8)
self.weight = nn.Parameter(torch.randn(8, 8))
self.proj2 = nn.Linear(8, 4)
def forward(self, x):
x = self.proj1(x)
x = F.linear(x, self.weight)
x = self.proj2(x)
return x
def data_gen():
return dict(x=torch.rand(16, 4))
def loss_fn(x):
outputs = x["x"]
label = torch.randint(low=0, high=2, size=(16,), device=outputs.device)
return F.cross_entropy(x["x"], label)
def output_transform(x: torch.Tensor):
return dict(x=x)
model_zoo.register(
name="custom_hanging_param_model",
model_fn=HangingParamModule,
data_gen_fn=data_gen,
output_transform_fn=output_transform,
loss_fn=loss_fn,
)

View File

@@ -0,0 +1,53 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..registry import model_zoo
from .base import CheckpointModule
class SubNet(nn.Module):
def __init__(self, out_features) -> None:
super().__init__()
self.bias = nn.Parameter(torch.zeros(out_features))
def forward(self, x, weight):
return F.linear(x, weight, self.bias)
class NestedNet(CheckpointModule):
def __init__(self, checkpoint=False) -> None:
super().__init__(checkpoint)
self.fc1 = nn.Linear(5, 5)
self.sub_fc = SubNet(5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
x = self.fc1(x)
x = self.sub_fc(x, self.fc1.weight)
x = self.fc1(x)
x = self.fc2(x)
return x
def data_gen():
return dict(x=torch.rand(16, 5))
def loss_fn(x):
outputs = x["x"]
label = torch.randint(low=0, high=2, size=(16,), device=outputs.device)
return F.cross_entropy(x["x"], label)
def output_transform(x: torch.Tensor):
return dict(x=x)
model_zoo.register(
name="custom_nested_model",
model_fn=NestedNet,
data_gen_fn=data_gen,
output_transform_fn=output_transform,
loss_fn=loss_fn,
)

View File

@@ -0,0 +1,48 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..registry import model_zoo
from .base import CheckpointModule
class NetWithRepeatedlyComputedLayers(CheckpointModule):
"""
This model is to test with layers which go through forward pass multiple times.
In this model, the fc1 and fc2 call forward twice
"""
def __init__(self, checkpoint=False) -> None:
super().__init__(checkpoint=checkpoint)
self.fc1 = nn.Linear(5, 5)
self.fc2 = nn.Linear(5, 5)
self.fc3 = nn.Linear(5, 2)
self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3]
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
def data_gen():
return dict(x=torch.rand(16, 5))
def loss_fn(x):
outputs = x["x"]
label = torch.randint(low=0, high=2, size=(16,), device=outputs.device)
return F.cross_entropy(x["x"], label)
def output_transform(x: torch.Tensor):
return dict(x=x)
model_zoo.register(
name="custom_repeated_computed_layers",
model_fn=NetWithRepeatedlyComputedLayers,
data_gen_fn=data_gen,
output_transform_fn=output_transform,
loss_fn=loss_fn,
)

View File

@@ -0,0 +1,53 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..registry import model_zoo
from .base import CheckpointModule
class SimpleNet(CheckpointModule):
"""
In this no-leaf module, it has subordinate nn.modules and a nn.Parameter.
"""
def __init__(self, checkpoint=False) -> None:
super().__init__(checkpoint=checkpoint)
self.embed = nn.Embedding(20, 4)
self.proj1 = nn.Linear(4, 8)
self.ln1 = nn.LayerNorm(8)
self.proj2 = nn.Linear(8, 4)
self.ln2 = nn.LayerNorm(4)
self.classifier = nn.Linear(4, 4)
def forward(self, x):
x = self.embed(x)
x = self.proj1(x)
x = self.ln1(x)
x = self.proj2(x)
x = self.ln2(x)
x = self.classifier(x)
return x
def data_gen():
return dict(x=torch.randint(low=0, high=20, size=(16,)))
def loss_fn(x):
outputs = x["x"]
label = torch.randint(low=0, high=2, size=(16,), device=outputs.device)
return F.cross_entropy(x["x"], label)
def output_transform(x: torch.Tensor):
return dict(x=x)
model_zoo.register(
name="custom_simple_net",
model_fn=SimpleNet,
data_gen_fn=data_gen,
output_transform_fn=output_transform,
loss_fn=loss_fn,
)

View File

@@ -0,0 +1,58 @@
from typing import Callable, Dict, Optional, Union
import torch
from torch.nn import Module
from torch.optim import Optimizer
from colossalai.interface import OptimizerWrapper
def run_fwd(
model: Module, data: Dict, output_transform_fn: Callable, criterion: Optional[Callable] = None
) -> torch.Tensor:
"""run_fwd
run fwd for the model
Args:
model (torch.nn.Module): a PyTorch model
data (torch.Tensor): input data
label (torch.Tensor): label
criterion (Optional[Callable]): a function of criterion
Returns:
torch.Tensor: loss of fwd
"""
outputs = model(**data)
outputs = output_transform_fn(outputs)
if criterion:
loss = criterion(outputs)
else:
loss = next(iter(outputs.values())).sum()
return loss
def run_fwd_bwd(
model: Module,
data: Dict,
output_transform_fn: Callable,
criterion: Optional[Callable] = None,
optimizer: Optional[Union[Optimizer, OptimizerWrapper]] = None,
) -> torch.Tensor:
"""run_fwd_bwd
run fwd and bwd for the model
Args:
model (torch.nn.Module): a PyTorch model
data (torch.Tensor): input data
label (torch.Tensor): label
criterion (Optional[Callable]): a function of criterion
Returns:
torch.Tensor: loss of fwd
"""
loss = run_fwd(model, data, output_transform_fn, criterion)
if optimizer:
optimizer.backward(loss)
else:
loss.backward()
return loss

View File

@@ -359,9 +359,9 @@ output_transform_fn = lambda x: x
# define loss funciton
loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])
)
loss_fn = lambda x: x.loss
loss_fn = lambda x: x["loss"]
config = transformers.BertConfig(
hidden_size=128,

View File

@@ -35,7 +35,7 @@ def data_gen():
output_transform_fn = lambda x: x
# define loss funciton
loss_fn_blip2_model = lambda x: x.loss
loss_fn_blip2_model = lambda x: x["loss"]
config = transformers.Blip2Config()
config.vision_config.patch_size = 14

View File

@@ -69,11 +69,11 @@ output_transform_fn = lambda x: x
# define loss function
loss_fn_for_bloom_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])
)
loss_fn_for_causal_lm = lambda x: x.loss
loss_fn_for_classification = lambda x: x.loss
loss_fn_for_question_answering = lambda x: x.loss
loss_fn_for_causal_lm = lambda x: x["loss"]
loss_fn_for_classification = lambda x: x["loss"]
loss_fn_for_question_answering = lambda x: x["loss"]
config = transformers.BloomConfig(
n_layer=2, n_head=4, vocab_size=250880, hidden_dropout=0, attention_dropout=0, hidden_size=64, pad_token_id=50256

View File

@@ -30,9 +30,9 @@ output_transform_fn = lambda x: x
# define loss function
loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])
)
loss_fn = lambda x: x.loss
loss_fn = lambda x: x["loss"]
config = ChatGLMConfig(
num_layers=2,

View File

@@ -87,13 +87,14 @@ output_transform_fn = lambda x: x
# define loss function
loss_fn_for_gpt2_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])
)
loss_fn = lambda x: x.loss
loss_fn = lambda x: x["loss"]
config = transformers.GPT2Config(
n_layer=2,
n_head=4,
n_embd=128,
vocab_size=50258,
attn_pdrop=0,
embd_pdrop=0,

View File

@@ -42,9 +42,9 @@ if HAS_LLAMA:
output_transform_fn = lambda x: x
# function to get the loss
loss_fn = lambda output: output.last_hidden_state.mean()
loss_fn_for_casual_lm = lambda output: output.loss
loss_fn_for_seq_classification = lambda output: output.logits.mean()
loss_fn = lambda output: output["last_hidden_state"].mean()
loss_fn_for_casual_lm = lambda output: output["loss"]
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
config = LlamaConfig(
num_hidden_layers=4,

View File

@@ -45,9 +45,9 @@ def data_gen_for_question_answering():
output_transform_fn = lambda x: x
loss_fn_for_opt_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])
)
loss_fn_for_lm = lambda x: x.loss
loss_fn_for_lm = lambda x: x["loss"]
config = transformers.OPTConfig(
hidden_size=128,
num_hidden_layers=2,

View File

@@ -40,7 +40,7 @@ def data_gen():
output_transform_fn = lambda x: x
# define loss funciton
loss_fn = lambda x: x.iou_scores.mean()
loss_fn = lambda x: x["iou_scores"].mean()
config = transformers.SamConfig()
config.vision_config.num_hidden_layers = 2

View File

@@ -44,9 +44,9 @@ def data_gen_for_t5_model():
output_transform_fn = lambda x: x
# define loss function
loss_fn_for_t5_model = lambda x: x.last_hidden_state.mean()
loss_fn_for_encoder_only = lambda x: x.last_hidden_state.mean()
loss_fn_for_conditional_generation = lambda x: x.loss
loss_fn_for_t5_model = lambda x: x["last_hidden_state"].mean()
loss_fn_for_encoder_only = lambda x: x["last_hidden_state"].mean()
loss_fn_for_conditional_generation = lambda x: x["loss"]
# define model config
config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decoder_start_token_id=0)

View File

@@ -34,9 +34,9 @@ def data_gen_for_masked_image_modeling():
output_transform_fn = lambda x: x
# function to get the loss
loss_fn_for_vit_model = lambda x: x.pooler_output.mean()
loss_fn_for_image_classification = lambda x: x.logits.mean()
loss_fn_for_masked_image_modeling = lambda x: x.loss
loss_fn_for_vit_model = lambda x: x["pooler_output"].mean()
loss_fn_for_image_classification = lambda x: x["logits"].mean()
loss_fn_for_masked_image_modeling = lambda x: x["loss"]
# register the following models
# transformers.ViTModel,

View File

@@ -53,8 +53,8 @@ def data_gen_for_audio_classification():
output_transform_fn = lambda x: x
# define loss funciton
loss_fn = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state))
loss_fn_attr = lambda x: x.loss
loss_fn = lambda x: torch.nn.functional.mse_loss(x["last_hidden_state"], torch.ones_like(x["last_hidden_state"]))
loss_fn_attr = lambda x: x["loss"]
config = transformers.WhisperConfig(
classifier_proj_size=256,