[test] merge old components to test to model zoo (#4945)

* [test] add custom models in model zoo

* [test] update legacy test

* [test] update model zoo

* [test] update gemini test

* [test] remove components to test
This commit is contained in:
Hongxin Liu 2023-10-20 10:35:08 +08:00 committed by GitHub
parent 3a41e8304e
commit b8e770c832
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
49 changed files with 461 additions and 914 deletions

View File

@ -9,6 +9,7 @@ from .comparison import (
) )
from .pytest_wrapper import run_on_environment_flag from .pytest_wrapper import run_on_environment_flag
from .utils import ( from .utils import (
DummyDataloader,
clear_cache_before_run, clear_cache_before_run,
free_port, free_port,
parameterize, parameterize,
@ -34,4 +35,5 @@ __all__ = [
"run_on_environment_flag", "run_on_environment_flag",
"check_state_dict_equal", "check_state_dict_equal",
"assert_hf_output_close", "assert_hf_output_close",
"DummyDataloader",
] ]

View File

@ -273,3 +273,24 @@ def clear_cache_before_run():
return _clear_cache return _clear_cache
return _wrap_func return _wrap_func
class DummyDataloader:
def __init__(self, data_gen_fn: Callable, length: int = 10):
self.data_gen_fn = data_gen_fn
self.length = length
self.step = 0
def __iter__(self):
self.step = 0
return self
def __next__(self):
if self.step < self.length:
self.step += 1
return self.data_gen_fn()
else:
raise StopIteration
def __len__(self):
return self.length

View File

@ -1,29 +0,0 @@
from . import (
beit,
bert,
gpt2,
hanging_param_model,
inline_op_model,
nested_model,
repeated_computed_layers,
resnet,
simple_net,
)
from .utils import run_fwd, run_fwd_bwd
from . import albert # isort:skip
__all__ = [
"bert",
"gpt2",
"hanging_param_model",
"inline_op_model",
"nested_model",
"repeated_computed_layers",
"resnet",
"simple_net",
"run_fwd_bwd",
"albert",
"beit",
"run_fwd",
]

View File

@ -1,62 +0,0 @@
import torch
from transformers import AlbertConfig, AlbertForSequenceClassification
from .bert import get_bert_data_loader
from .registry import non_distributed_component_funcs
@non_distributed_component_funcs.register(name="albert")
def get_training_components():
hidden_dim = 8
num_head = 4
sequence_length = 12
num_layer = 2
vocab_size = 32
def bert_model_builder(checkpoint: bool = False):
config = AlbertConfig(
vocab_size=vocab_size,
gradient_checkpointing=checkpoint,
hidden_size=hidden_dim,
intermediate_size=hidden_dim * 4,
num_attention_heads=num_head,
max_position_embeddings=sequence_length,
num_hidden_layers=num_layer,
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
)
print("building AlbertForSequenceClassification model")
# adapting huggingface BertForSequenceClassification for single unittest calling interface
class ModelAdaptor(AlbertForSequenceClassification):
def forward(self, input_ids, labels):
"""
inputs: data, label
outputs: loss
"""
return super().forward(input_ids=input_ids, labels=labels)[0]
model = ModelAdaptor(config)
# if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"):
# model.gradient_checkpointing_enable()
return model
is_distributed = torch.distributed.is_initialized()
trainloader = get_bert_data_loader(
n_class=vocab_size,
batch_size=2,
total_samples=10000,
sequence_length=sequence_length,
is_distributed=is_distributed,
)
testloader = get_bert_data_loader(
n_class=vocab_size,
batch_size=2,
total_samples=10000,
sequence_length=sequence_length,
is_distributed=is_distributed,
)
criterion = None
return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion

View File

@ -1,44 +0,0 @@
import torch
from timm.models.beit import Beit
from colossalai.utils.cuda import get_current_device
from .registry import non_distributed_component_funcs
from .utils.dummy_data_generator import DummyDataGenerator
class DummyDataLoader(DummyDataGenerator):
img_size = 64
num_channel = 3
num_class = 10
batch_size = 4
def generate(self):
data = torch.randn(
(
DummyDataLoader.batch_size,
DummyDataLoader.num_channel,
DummyDataLoader.img_size,
DummyDataLoader.img_size,
),
device=get_current_device(),
)
label = torch.randint(
low=0, high=DummyDataLoader.num_class, size=(DummyDataLoader.batch_size,), device=get_current_device()
)
return data, label
@non_distributed_component_funcs.register(name="beit")
def get_training_components():
def model_builder(checkpoint=False):
model = Beit(
img_size=DummyDataLoader.img_size, num_classes=DummyDataLoader.num_class, embed_dim=32, depth=2, num_heads=4
)
return model
trainloader = DummyDataLoader()
testloader = DummyDataLoader()
criterion = torch.nn.CrossEntropyLoss()
return model_builder, trainloader, testloader, torch.optim.Adam, criterion

View File

@ -1,88 +0,0 @@
import torch
import transformers
from packaging import version
from torch.utils.data import SequentialSampler
from transformers import BertConfig, BertForSequenceClassification
from .registry import non_distributed_component_funcs
def get_bert_data_loader(
n_class,
batch_size,
total_samples,
sequence_length,
device=torch.device("cpu:0"),
is_distributed=False,
):
train_data = torch.randint(
low=0,
high=n_class,
size=(total_samples, sequence_length),
device=device,
dtype=torch.long,
)
train_label = torch.randint(low=0, high=2, size=(total_samples,), device=device, dtype=torch.long)
train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
if is_distributed:
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else:
sampler = SequentialSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
return train_loader
@non_distributed_component_funcs.register(name="bert")
def get_training_components():
hidden_dim = 8
num_head = 4
sequence_length = 12
num_layer = 2
vocab_size = 32
def bert_model_builder(checkpoint: bool = False):
config = BertConfig(
vocab_size=vocab_size,
gradient_checkpointing=checkpoint,
hidden_size=hidden_dim,
intermediate_size=hidden_dim * 4,
num_attention_heads=num_head,
max_position_embeddings=sequence_length,
num_hidden_layers=num_layer,
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
)
# adapting huggingface BertForSequenceClassification for single unittest calling interface
class ModelAdaptor(BertForSequenceClassification):
def forward(self, input_ids, labels):
"""
inputs: data, label
outputs: loss
"""
return super().forward(input_ids=input_ids, labels=labels)[0]
model = ModelAdaptor(config)
if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"):
model.gradient_checkpointing_enable()
return model
is_distributed = torch.distributed.is_initialized()
trainloader = get_bert_data_loader(
n_class=vocab_size,
batch_size=2,
total_samples=10000,
sequence_length=sequence_length,
is_distributed=is_distributed,
)
testloader = get_bert_data_loader(
n_class=vocab_size,
batch_size=2,
total_samples=10000,
sequence_length=sequence_length,
is_distributed=is_distributed,
)
criterion = None
return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion

View File

@ -1,92 +0,0 @@
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel
from colossalai.utils.cuda import get_current_device
from .registry import non_distributed_component_funcs
from .utils.dummy_data_generator import DummyDataGenerator
class DummyDataLoader(DummyDataGenerator):
vocab_size = 128
batch_size = 4
seq_len = 64
def generate(self):
input_ids = torch.randint(
0,
DummyDataLoader.vocab_size,
(DummyDataLoader.batch_size, DummyDataLoader.seq_len),
device=get_current_device(),
)
return input_ids, input_ids
class GPTLMModel(nn.Module):
def __init__(
self,
hidden_size=768,
num_layers=12,
num_attention_heads=12,
max_seq_len=1024,
vocab_size=50304,
checkpoint=False,
):
super().__init__()
self.checkpoint = checkpoint
self.model = GPT2LMHeadModel(
GPT2Config(
n_embd=hidden_size,
n_layer=num_layers,
n_head=num_attention_heads,
n_positions=max_seq_len,
n_ctx=max_seq_len,
vocab_size=vocab_size,
resid_pdrop=0.0,
embd_pdrop=0.0,
attn_pdrop=0.0,
)
)
if checkpoint:
self.model.gradient_checkpointing_enable()
def forward(self, input_ids):
# Only return lm_logits
attention_mask = torch.ones_like(input_ids)
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
def gpt2_micro(checkpoint=True):
return GPTLMModel(
checkpoint=checkpoint, hidden_size=32, num_layers=2, num_attention_heads=4, max_seq_len=64, vocab_size=128
)
def gpt2_s(checkpoint=True):
return GPTLMModel(checkpoint=checkpoint)
def gpt2_m(checkpoint=True):
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
@non_distributed_component_funcs.register(name="gpt2")
def get_training_components():
trainloader = DummyDataLoader()
testloader = DummyDataLoader()
criterion = GPTLMLoss()
return gpt2_micro, trainloader, testloader, torch.optim.Adam, criterion

View File

@ -1,48 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.legacy.nn import CheckpointModule
from .registry import non_distributed_component_funcs
from .utils.dummy_data_generator import DummyDataGenerator
class HangingParamModule(CheckpointModule):
"""
Hanging Parameter: a parameter dose not belong to a leaf Module.
It has subordinate nn.modules and a nn.Parameter.
"""
def __init__(self, checkpoint=False) -> None:
super().__init__(checkpoint=checkpoint)
self.proj1 = nn.Linear(4, 8)
self.weight = nn.Parameter(torch.randn(8, 8))
self.proj2 = nn.Linear(8, 4)
def forward(self, x):
x = self.proj1(x)
x = F.linear(x, self.weight)
x = self.proj2(x)
return x
class DummyDataLoader(DummyDataGenerator):
def generate(self):
data = torch.rand(16, 4)
label = torch.randint(low=0, high=2, size=(16,))
return data, label
@non_distributed_component_funcs.register(name="hanging_param_model")
def get_training_components():
def model_builder(checkpoint=False):
return HangingParamModule(checkpoint)
trainloader = DummyDataLoader()
testloader = DummyDataLoader()
criterion = torch.nn.CrossEntropyLoss()
from colossalai.nn.optimizer import HybridAdam
return model_builder, trainloader, testloader, HybridAdam, criterion

View File

@ -1,49 +0,0 @@
import torch
import torch.nn as nn
from colossalai.legacy.nn import CheckpointModule
from .registry import non_distributed_component_funcs
from .utils.dummy_data_generator import DummyDataGenerator
class InlineOpModule(CheckpointModule):
"""
a module with inline Ops
"""
def __init__(self, checkpoint=False) -> None:
super().__init__(checkpoint=checkpoint)
self.proj1 = nn.Linear(4, 8)
self.proj2 = nn.Linear(8, 8)
def forward(self, x):
x = self.proj1(x)
# inline add_
x.add_(10)
x = self.proj2(x)
# inline relu_
x = torch.relu_(x)
x = self.proj2(x)
return x
class DummyDataLoader(DummyDataGenerator):
def generate(self):
data = torch.rand(16, 4)
label = torch.randint(low=0, high=2, size=(16,))
return data, label
@non_distributed_component_funcs.register(name="inline_op_model")
def get_training_components():
def model_builder(checkpoint=False):
return InlineOpModule(checkpoint)
trainloader = DummyDataLoader()
testloader = DummyDataLoader()
criterion = torch.nn.CrossEntropyLoss()
from colossalai.nn.optimizer import HybridAdam
return model_builder, trainloader, testloader, HybridAdam, criterion

View File

@ -1,38 +0,0 @@
#!/usr/bin/env python
class Registry:
def __init__(self):
self._registry = dict()
def register(self, name):
assert name not in self._registry
def _register(callable_):
self._registry[name] = callable_
return _register
def get_callable(self, name: str):
return self._registry[name]
def __iter__(self):
self._idx = 0
self._len = len(self._registry)
self._names = list(self._registry.keys())
return self
def __next__(self):
if self._idx < self._len:
key = self._names[self._idx]
callable_ = self._registry[key]
self._idx += 1
return callable_
else:
raise StopIteration
non_distributed_component_funcs = Registry()
model_parallel_component_funcs = Registry()
__all__ = ["non_distributed_component_funcs", "model_parallel_component_funcs"]

View File

@ -1,47 +0,0 @@
#!/usr/bin/env python
import torch
import torch.nn as nn
from colossalai.legacy.nn import CheckpointModule
from .registry import non_distributed_component_funcs
from .utils.dummy_data_generator import DummyDataGenerator
class NetWithRepeatedlyComputedLayers(CheckpointModule):
"""
This model is to test with layers which go through forward pass multiple times.
In this model, the fc1 and fc2 call forward twice
"""
def __init__(self, checkpoint=False) -> None:
super().__init__(checkpoint=checkpoint)
self.fc1 = nn.Linear(5, 5)
self.fc2 = nn.Linear(5, 5)
self.fc3 = nn.Linear(5, 2)
self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3]
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class DummyDataLoader(DummyDataGenerator):
def generate(self):
data = torch.rand(16, 5)
label = torch.randint(low=0, high=2, size=(16,))
return data, label
@non_distributed_component_funcs.register(name="repeated_computed_layers")
def get_training_components():
def model_builder(checkpoint=False):
return NetWithRepeatedlyComputedLayers(checkpoint)
trainloader = DummyDataLoader()
testloader = DummyDataLoader()
criterion = torch.nn.CrossEntropyLoss()
return model_builder, trainloader, testloader, torch.optim.Adam, criterion

View File

@ -1,37 +0,0 @@
import os
from pathlib import Path
import torch
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
from torchvision.transforms import transforms
from colossalai.legacy.utils import get_dataloader
from .registry import non_distributed_component_funcs
def get_cifar10_dataloader(train):
# build dataloaders
dataset = CIFAR10(
root=Path(os.environ["DATA"]),
download=True,
train=train,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]
),
)
dataloader = get_dataloader(dataset=dataset, shuffle=True, batch_size=16, drop_last=True)
return dataloader
@non_distributed_component_funcs.register(name="resnet18")
def get_resnet_training_components():
def model_builder(checkpoint=False):
return resnet18(num_classes=10)
trainloader = get_cifar10_dataloader(train=True)
testloader = get_cifar10_dataloader(train=False)
criterion = torch.nn.CrossEntropyLoss()
return model_builder, trainloader, testloader, torch.optim.Adam, criterion

View File

@ -1,53 +0,0 @@
import torch
import torch.nn as nn
from colossalai.legacy.nn import CheckpointModule
from colossalai.utils.cuda import get_current_device
from .registry import non_distributed_component_funcs
from .utils.dummy_data_generator import DummyDataGenerator
class SimpleNet(CheckpointModule):
"""
In this no-leaf module, it has subordinate nn.modules and a nn.Parameter.
"""
def __init__(self, checkpoint=False) -> None:
super().__init__(checkpoint=checkpoint)
self.embed = nn.Embedding(20, 4)
self.proj1 = nn.Linear(4, 8)
self.ln1 = nn.LayerNorm(8)
self.proj2 = nn.Linear(8, 4)
self.ln2 = nn.LayerNorm(4)
self.classifier = nn.Linear(4, 4)
def forward(self, x):
x = self.embed(x)
x = self.proj1(x)
x = self.ln1(x)
x = self.proj2(x)
x = self.ln2(x)
x = self.classifier(x)
return x
class DummyDataLoader(DummyDataGenerator):
def generate(self):
data = torch.randint(low=0, high=20, size=(16,), device=get_current_device())
label = torch.randint(low=0, high=2, size=(16,), device=get_current_device())
return data, label
@non_distributed_component_funcs.register(name="simple_net")
def get_training_components():
def model_builder(checkpoint=False):
return SimpleNet(checkpoint)
trainloader = DummyDataLoader()
testloader = DummyDataLoader()
criterion = torch.nn.CrossEntropyLoss()
from colossalai.nn.optimizer import HybridAdam
return model_builder, trainloader, testloader, HybridAdam, criterion

View File

@ -1,2 +0,0 @@
from .dummy_data_generator import DummyDataGenerator
from .executor import run_fwd, run_fwd_bwd

View File

@ -1,24 +0,0 @@
from abc import ABC, abstractmethod
class DummyDataGenerator(ABC):
def __init__(self, length=10):
self.length = length
@abstractmethod
def generate(self):
pass
def __iter__(self):
self.step = 0
return self
def __next__(self):
if self.step < self.length:
self.step += 1
return self.generate()
else:
raise StopIteration
def __len__(self):
return self.length

View File

@ -1,4 +1,5 @@
from . import diffusers, timm, torchaudio, torchrec, torchvision, transformers from . import custom, diffusers, timm, torchaudio, torchrec, torchvision, transformers
from .executor import run_fwd, run_fwd_bwd
from .registry import model_zoo from .registry import model_zoo
__all__ = ["model_zoo"] __all__ = ["model_zoo", "run_fwd", "run_fwd_bwd"]

View File

@ -0,0 +1,4 @@
from .hanging_param_model import *
from .nested_model import *
from .repeated_computed_layers import *
from .simple_net import *

View File

@ -0,0 +1,26 @@
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
class CheckpointModule(nn.Module):
def __init__(self, checkpoint: bool = False):
super().__init__()
self.checkpoint = checkpoint
self._use_checkpoint = checkpoint
def _forward(self, *args, **kwargs):
raise NotImplementedError("CheckpointModule should implement _forward method instead of origin forward")
def forward(self, *args, **kwargs):
if self._use_checkpoint:
return checkpoint(self._forward, *args, **kwargs)
else:
return self._forward(*args, **kwargs)
def train(self, mode: bool = True):
self._use_checkpoint = self.checkpoint
return super().train(mode=mode)
def eval(self):
self._use_checkpoint = False
return super().eval()

View File

@ -0,0 +1,48 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..registry import model_zoo
from .base import CheckpointModule
class HangingParamModule(CheckpointModule):
"""
Hanging Parameter: a parameter dose not belong to a leaf Module.
It has subordinate nn.modules and a nn.Parameter.
"""
def __init__(self, checkpoint=False) -> None:
super().__init__(checkpoint=checkpoint)
self.proj1 = nn.Linear(4, 8)
self.weight = nn.Parameter(torch.randn(8, 8))
self.proj2 = nn.Linear(8, 4)
def forward(self, x):
x = self.proj1(x)
x = F.linear(x, self.weight)
x = self.proj2(x)
return x
def data_gen():
return dict(x=torch.rand(16, 4))
def loss_fn(x):
outputs = x["x"]
label = torch.randint(low=0, high=2, size=(16,), device=outputs.device)
return F.cross_entropy(x["x"], label)
def output_transform(x: torch.Tensor):
return dict(x=x)
model_zoo.register(
name="custom_hanging_param_model",
model_fn=HangingParamModule,
data_gen_fn=data_gen,
output_transform_fn=output_transform,
loss_fn=loss_fn,
)

View File

@ -2,10 +2,8 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.legacy.nn import CheckpointModule from ..registry import model_zoo
from .base import CheckpointModule
from .registry import non_distributed_component_funcs
from .utils import DummyDataGenerator
class SubNet(nn.Module): class SubNet(nn.Module):
@ -32,20 +30,24 @@ class NestedNet(CheckpointModule):
return x return x
class DummyDataLoader(DummyDataGenerator): def data_gen():
def generate(self): return dict(x=torch.rand(16, 5))
data = torch.rand(16, 5)
label = torch.randint(low=0, high=2, size=(16,))
return data, label
@non_distributed_component_funcs.register(name="nested_model") def loss_fn(x):
def get_training_components(): outputs = x["x"]
def model_builder(checkpoint=False): label = torch.randint(low=0, high=2, size=(16,), device=outputs.device)
return NestedNet(checkpoint) return F.cross_entropy(x["x"], label)
trainloader = DummyDataLoader()
testloader = DummyDataLoader()
criterion = torch.nn.CrossEntropyLoss() def output_transform(x: torch.Tensor):
return model_builder, trainloader, testloader, torch.optim.Adam, criterion return dict(x=x)
model_zoo.register(
name="custom_nested_model",
model_fn=NestedNet,
data_gen_fn=data_gen,
output_transform_fn=output_transform,
loss_fn=loss_fn,
)

View File

@ -0,0 +1,48 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..registry import model_zoo
from .base import CheckpointModule
class NetWithRepeatedlyComputedLayers(CheckpointModule):
"""
This model is to test with layers which go through forward pass multiple times.
In this model, the fc1 and fc2 call forward twice
"""
def __init__(self, checkpoint=False) -> None:
super().__init__(checkpoint=checkpoint)
self.fc1 = nn.Linear(5, 5)
self.fc2 = nn.Linear(5, 5)
self.fc3 = nn.Linear(5, 2)
self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3]
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
def data_gen():
return dict(x=torch.rand(16, 5))
def loss_fn(x):
outputs = x["x"]
label = torch.randint(low=0, high=2, size=(16,), device=outputs.device)
return F.cross_entropy(x["x"], label)
def output_transform(x: torch.Tensor):
return dict(x=x)
model_zoo.register(
name="custom_repeated_computed_layers",
model_fn=NetWithRepeatedlyComputedLayers,
data_gen_fn=data_gen,
output_transform_fn=output_transform,
loss_fn=loss_fn,
)

View File

@ -0,0 +1,53 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..registry import model_zoo
from .base import CheckpointModule
class SimpleNet(CheckpointModule):
"""
In this no-leaf module, it has subordinate nn.modules and a nn.Parameter.
"""
def __init__(self, checkpoint=False) -> None:
super().__init__(checkpoint=checkpoint)
self.embed = nn.Embedding(20, 4)
self.proj1 = nn.Linear(4, 8)
self.ln1 = nn.LayerNorm(8)
self.proj2 = nn.Linear(8, 4)
self.ln2 = nn.LayerNorm(4)
self.classifier = nn.Linear(4, 4)
def forward(self, x):
x = self.embed(x)
x = self.proj1(x)
x = self.ln1(x)
x = self.proj2(x)
x = self.ln2(x)
x = self.classifier(x)
return x
def data_gen():
return dict(x=torch.randint(low=0, high=20, size=(16,)))
def loss_fn(x):
outputs = x["x"]
label = torch.randint(low=0, high=2, size=(16,), device=outputs.device)
return F.cross_entropy(x["x"], label)
def output_transform(x: torch.Tensor):
return dict(x=x)
model_zoo.register(
name="custom_simple_net",
model_fn=SimpleNet,
data_gen_fn=data_gen,
output_transform_fn=output_transform,
loss_fn=loss_fn,
)

View File

@ -1,7 +1,15 @@
from typing import Callable, Dict, Optional, Union
import torch import torch
from torch.nn import Module
from torch.optim import Optimizer
from colossalai.interface import OptimizerWrapper
def run_fwd(model, data, label, criterion) -> torch.Tensor: def run_fwd(
model: Module, data: Dict, output_transform_fn: Callable, criterion: Optional[Callable] = None
) -> torch.Tensor:
"""run_fwd """run_fwd
run fwd for the model run fwd for the model
@ -14,18 +22,22 @@ def run_fwd(model, data, label, criterion) -> torch.Tensor:
Returns: Returns:
torch.Tensor: loss of fwd torch.Tensor: loss of fwd
""" """
outputs = model(**data)
outputs = output_transform_fn(outputs)
if criterion: if criterion:
y = model(data) loss = criterion(outputs)
y = y.float()
loss = criterion(y, label)
else: else:
loss = model(data, label) loss = next(iter(outputs.values())).sum()
loss = loss.float()
return loss return loss
def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor: def run_fwd_bwd(
model: Module,
data: Dict,
output_transform_fn: Callable,
criterion: Optional[Callable] = None,
optimizer: Optional[Union[Optimizer, OptimizerWrapper]] = None,
) -> torch.Tensor:
"""run_fwd_bwd """run_fwd_bwd
run fwd and bwd for the model run fwd and bwd for the model
@ -38,7 +50,7 @@ def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor:
Returns: Returns:
torch.Tensor: loss of fwd torch.Tensor: loss of fwd
""" """
loss = run_fwd(model, data, label, criterion) loss = run_fwd(model, data, output_transform_fn, criterion)
if optimizer: if optimizer:
optimizer.backward(loss) optimizer.backward(loss)
else: else:

View File

@ -359,9 +359,9 @@ output_transform_fn = lambda x: x
# define loss funciton # define loss funciton
loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss( loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state) x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])
) )
loss_fn = lambda x: x.loss loss_fn = lambda x: x["loss"]
config = transformers.BertConfig( config = transformers.BertConfig(
hidden_size=128, hidden_size=128,

View File

@ -35,7 +35,7 @@ def data_gen():
output_transform_fn = lambda x: x output_transform_fn = lambda x: x
# define loss funciton # define loss funciton
loss_fn_blip2_model = lambda x: x.loss loss_fn_blip2_model = lambda x: x["loss"]
config = transformers.Blip2Config() config = transformers.Blip2Config()
config.vision_config.patch_size = 14 config.vision_config.patch_size = 14

View File

@ -69,11 +69,11 @@ output_transform_fn = lambda x: x
# define loss function # define loss function
loss_fn_for_bloom_model = lambda x: torch.nn.functional.mse_loss( loss_fn_for_bloom_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state) x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])
) )
loss_fn_for_causal_lm = lambda x: x.loss loss_fn_for_causal_lm = lambda x: x["loss"]
loss_fn_for_classification = lambda x: x.loss loss_fn_for_classification = lambda x: x["loss"]
loss_fn_for_question_answering = lambda x: x.loss loss_fn_for_question_answering = lambda x: x["loss"]
config = transformers.BloomConfig( config = transformers.BloomConfig(
n_layer=2, n_head=4, vocab_size=250880, hidden_dropout=0, attention_dropout=0, hidden_size=64, pad_token_id=50256 n_layer=2, n_head=4, vocab_size=250880, hidden_dropout=0, attention_dropout=0, hidden_size=64, pad_token_id=50256

View File

@ -30,9 +30,9 @@ output_transform_fn = lambda x: x
# define loss function # define loss function
loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss( loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state) x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])
) )
loss_fn = lambda x: x.loss loss_fn = lambda x: x["loss"]
config = ChatGLMConfig( config = ChatGLMConfig(
num_layers=2, num_layers=2,

View File

@ -87,13 +87,14 @@ output_transform_fn = lambda x: x
# define loss function # define loss function
loss_fn_for_gpt2_model = lambda x: torch.nn.functional.mse_loss( loss_fn_for_gpt2_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state) x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])
) )
loss_fn = lambda x: x.loss loss_fn = lambda x: x["loss"]
config = transformers.GPT2Config( config = transformers.GPT2Config(
n_layer=2, n_layer=2,
n_head=4, n_head=4,
n_embd=128,
vocab_size=50258, vocab_size=50258,
attn_pdrop=0, attn_pdrop=0,
embd_pdrop=0, embd_pdrop=0,

View File

@ -42,9 +42,9 @@ if HAS_LLAMA:
output_transform_fn = lambda x: x output_transform_fn = lambda x: x
# function to get the loss # function to get the loss
loss_fn = lambda output: output.last_hidden_state.mean() loss_fn = lambda output: output["last_hidden_state"].mean()
loss_fn_for_casual_lm = lambda output: output.loss loss_fn_for_casual_lm = lambda output: output["loss"]
loss_fn_for_seq_classification = lambda output: output.logits.mean() loss_fn_for_seq_classification = lambda output: output["logits"].mean()
config = LlamaConfig( config = LlamaConfig(
num_hidden_layers=4, num_hidden_layers=4,

View File

@ -45,9 +45,9 @@ def data_gen_for_question_answering():
output_transform_fn = lambda x: x output_transform_fn = lambda x: x
loss_fn_for_opt_model = lambda x: torch.nn.functional.mse_loss( loss_fn_for_opt_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state) x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])
) )
loss_fn_for_lm = lambda x: x.loss loss_fn_for_lm = lambda x: x["loss"]
config = transformers.OPTConfig( config = transformers.OPTConfig(
hidden_size=128, hidden_size=128,
num_hidden_layers=2, num_hidden_layers=2,

View File

@ -40,7 +40,7 @@ def data_gen():
output_transform_fn = lambda x: x output_transform_fn = lambda x: x
# define loss funciton # define loss funciton
loss_fn = lambda x: x.iou_scores.mean() loss_fn = lambda x: x["iou_scores"].mean()
config = transformers.SamConfig() config = transformers.SamConfig()
config.vision_config.num_hidden_layers = 2 config.vision_config.num_hidden_layers = 2

View File

@ -44,9 +44,9 @@ def data_gen_for_t5_model():
output_transform_fn = lambda x: x output_transform_fn = lambda x: x
# define loss function # define loss function
loss_fn_for_t5_model = lambda x: x.last_hidden_state.mean() loss_fn_for_t5_model = lambda x: x["last_hidden_state"].mean()
loss_fn_for_encoder_only = lambda x: x.last_hidden_state.mean() loss_fn_for_encoder_only = lambda x: x["last_hidden_state"].mean()
loss_fn_for_conditional_generation = lambda x: x.loss loss_fn_for_conditional_generation = lambda x: x["loss"]
# define model config # define model config
config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decoder_start_token_id=0) config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decoder_start_token_id=0)

View File

@ -34,9 +34,9 @@ def data_gen_for_masked_image_modeling():
output_transform_fn = lambda x: x output_transform_fn = lambda x: x
# function to get the loss # function to get the loss
loss_fn_for_vit_model = lambda x: x.pooler_output.mean() loss_fn_for_vit_model = lambda x: x["pooler_output"].mean()
loss_fn_for_image_classification = lambda x: x.logits.mean() loss_fn_for_image_classification = lambda x: x["logits"].mean()
loss_fn_for_masked_image_modeling = lambda x: x.loss loss_fn_for_masked_image_modeling = lambda x: x["loss"]
# register the following models # register the following models
# transformers.ViTModel, # transformers.ViTModel,

View File

@ -53,8 +53,8 @@ def data_gen_for_audio_classification():
output_transform_fn = lambda x: x output_transform_fn = lambda x: x
# define loss funciton # define loss funciton
loss_fn = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state)) loss_fn = lambda x: torch.nn.functional.mse_loss(x["last_hidden_state"], torch.ones_like(x["last_hidden_state"]))
loss_fn_attr = lambda x: x.loss loss_fn_attr = lambda x: x["loss"]
config = transformers.WhisperConfig( config = transformers.WhisperConfig(
classifier_proj_size=256, classifier_proj_size=256,

View File

@ -6,7 +6,7 @@ import torch
import colossalai import colossalai
from colossalai.legacy.amp import convert_to_apex_amp, convert_to_naive_amp from colossalai.legacy.amp import convert_to_apex_amp, convert_to_naive_amp
from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.components_to_test.registry import non_distributed_component_funcs from tests.kit.model_zoo import model_zoo
def check_equal(a, b): def check_equal(a, b):
@ -25,13 +25,12 @@ def run_naive_amp():
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
# create layer # create layer
test_models = ["repeated_computed_layers", "nested_model", "resnet18"] test_models = ["custom_repeated_computed_layers", "custom_nested_model", "torchvision_resnet18"]
for test_name in test_models: for test_name in test_models:
get_component_func = non_distributed_component_funcs.get_callable(test_name) model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry(test_name).values()))
model_builder, train_dataloader, _, optim_class, _ = get_component_func()
# create model # create model
naive_amp_model = model_builder(checkpoint=True).cuda() naive_amp_model = model_builder().cuda()
apex_amp_model = copy.deepcopy(naive_amp_model) apex_amp_model = copy.deepcopy(naive_amp_model)
# create optimizer # create optimizer
@ -48,13 +47,12 @@ def run_naive_amp():
apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config) apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config)
# create data # create data
data_iter = iter(train_dataloader) data = data_gen_fn()
data, label = next(data_iter) data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
data = data.cuda()
# forward pass # forward pass
naive_amp_output = naive_amp_model(data) naive_amp_output = naive_amp_model(**data)
apex_amp_output = apex_amp_model(data) apex_amp_output = apex_amp_model(**data)
assert_close_loose(naive_amp_output, apex_amp_output) assert_close_loose(naive_amp_output, apex_amp_output)
# backward # backward

View File

@ -6,7 +6,7 @@ import torch
import colossalai import colossalai
from colossalai.legacy.amp import convert_to_apex_amp, convert_to_torch_amp from colossalai.legacy.amp import convert_to_apex_amp, convert_to_torch_amp
from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.components_to_test.registry import non_distributed_component_funcs from tests.kit.model_zoo import model_zoo
def run_torch_amp(): def run_torch_amp():
@ -18,13 +18,12 @@ def run_torch_amp():
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
# create layer # create layer
test_models = ["resnet18", "simple_net"] test_models = ["torchvision_resnet18", "custom_simple_net"]
for test_name in test_models: for test_name in test_models:
get_component_func = non_distributed_component_funcs.get_callable(test_name) model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry(test_name).values()))
model_builder, train_dataloader, _, optim_class, _ = get_component_func()
# create model # create model
torch_amp_model = model_builder(checkpoint=True).cuda() torch_amp_model = model_builder().cuda()
apex_amp_model = copy.deepcopy(torch_amp_model) apex_amp_model = copy.deepcopy(torch_amp_model)
# create optimizer # create optimizer
@ -41,13 +40,12 @@ def run_torch_amp():
apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config) apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config)
# create data # create data
data_iter = iter(train_dataloader) data = data_gen_fn()
data, label = next(data_iter) data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
data = data.cuda()
# forward pass # forward pass
torch_amp_output = torch_amp_model(data) torch_amp_output = torch_amp_model(**data)
apex_amp_output = apex_amp_model(data) apex_amp_output = apex_amp_model(**data)
assert_close_loose(torch_amp_output, apex_amp_output) assert_close_loose(torch_amp_output, apex_amp_output)
for torch_amp_param, apex_amp_param in zip(torch_amp_model.parameters(), apex_amp_model.parameters()): for torch_amp_param, apex_amp_param in zip(torch_amp_model.parameters(), apex_amp_model.parameters()):

View File

@ -1,10 +1,11 @@
import pytest import pytest
import torch
import colossalai import colossalai
from colossalai.legacy.amp import AMP_TYPE from colossalai.legacy.amp import AMP_TYPE
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from tests.components_to_test.registry import non_distributed_component_funcs from tests.kit.model_zoo import model_zoo
CONFIG = dict( CONFIG = dict(
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), fp16=dict(mode=None), clip_grad_norm=1.0 parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), fp16=dict(mode=None), clip_grad_norm=1.0
@ -15,29 +16,29 @@ CONFIG = dict(
@parameterize("amp_mode", [AMP_TYPE.APEX, AMP_TYPE.TORCH, AMP_TYPE.NAIVE, None]) @parameterize("amp_mode", [AMP_TYPE.APEX, AMP_TYPE.TORCH, AMP_TYPE.NAIVE, None])
def run_train(model_name, amp_mode): def run_train(model_name, amp_mode):
# FIXME: test bert # FIXME: test bert
get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values()))
train_dataloader = DummyDataloader(data_gen_fn)
criterion = lambda x: x.sum()
gpc.config.fp16["mode"] = amp_mode gpc.config.fp16["mode"] = amp_mode
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
model = model_builder(checkpoint=False) model = model_builder()
engine, train_dataloader, *args = colossalai.legacy.initialize( engine, train_dataloader, *args = colossalai.legacy.initialize(
model=model, model=model,
optimizer=optimizer_class(model.parameters(), lr=1e-3), optimizer=torch.optim.Adam(model.parameters(), lr=1e-3),
criterion=criterion, criterion=criterion,
train_dataloader=train_dataloader, train_dataloader=train_dataloader,
) )
try: try:
engine.train() engine.train()
for data, label in train_dataloader: for data in train_dataloader:
engine.zero_grad() engine.zero_grad()
data = data.cuda() data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
label = label.cuda()
if criterion: if criterion:
output = engine(data) output = engine(**data)
loss = engine.criterion(output, label) loss = engine.criterion(output)
else: else:
loss = engine(data, label) loss = engine(**data)
engine.backward(loss) engine.backward(loss)
engine.step() engine.step()
break break

View File

@ -5,9 +5,9 @@ import colossalai
from colossalai.legacy.amp.amp_type import AMP_TYPE from colossalai.legacy.amp.amp_type import AMP_TYPE
from colossalai.legacy.trainer import Trainer from colossalai.legacy.trainer import Trainer
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import MultiTimer from colossalai.utils import MultiTimer
from tests.components_to_test.registry import non_distributed_component_funcs from tests.kit.model_zoo import model_zoo
BATCH_SIZE = 4 BATCH_SIZE = 4
IMG_SIZE = 32 IMG_SIZE = 32
@ -16,12 +16,14 @@ NUM_EPOCHS = 200
CONFIG = dict(fp16=dict(mode=AMP_TYPE.TORCH)) CONFIG = dict(fp16=dict(mode=AMP_TYPE.TORCH))
@parameterize("model_name", ["repeated_computed_layers", "resnet18", "nested_model"]) @parameterize("model_name", ["custom_repeated_computed_layers", "torchvision_resnet18", "custom_nested_model"])
def run_trainer(model_name): def run_trainer(model_name):
get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values()))
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model = model_builder() model = model_builder()
optimizer = optimizer_class(model.parameters(), lr=1e-3) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_dataloader = DummyDataloader(data_gen_fn)
test_dataloader = DummyDataloader(data_gen_fn)
criterion = lambda x: x.sum()
engine, train_dataloader, *_ = colossalai.legacy.initialize( engine, train_dataloader, *_ = colossalai.legacy.initialize(
model=model, optimizer=optimizer, criterion=criterion, train_dataloader=train_dataloader model=model, optimizer=optimizer, criterion=criterion, train_dataloader=train_dataloader
) )

View File

@ -2,7 +2,7 @@ import torch
from colossalai.nn.optimizer import CPUAdam, HybridAdam from colossalai.nn.optimizer import CPUAdam, HybridAdam
from colossalai.testing import clear_cache_before_run, parameterize from colossalai.testing import clear_cache_before_run, parameterize
from tests.components_to_test.registry import non_distributed_component_funcs from tests.kit.model_zoo import model_zoo
def move_some_params_to_cuda(model, torch_model): def move_some_params_to_cuda(model, torch_model):
@ -22,8 +22,7 @@ def check_params_equal(model, torch_model):
@parameterize("nvme_offload_dir", ["./offload", None]) @parameterize("nvme_offload_dir", ["./offload", None])
@parameterize("adam_cls", [CPUAdam, HybridAdam]) @parameterize("adam_cls", [CPUAdam, HybridAdam])
def test_nvme_adam(nvme_offload_fraction, nvme_offload_dir, adam_cls): def test_nvme_adam(nvme_offload_fraction, nvme_offload_dir, adam_cls):
get_components_func = non_distributed_component_funcs.get_callable("simple_net") model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry("custom_simple_net").values()))
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model = model_builder() model = model_builder()
torch_model = model_builder() torch_model = model_builder()
move_some_params_to_cuda(model, torch_model) move_some_params_to_cuda(model, torch_model)

View File

@ -12,8 +12,7 @@ from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd_bwd from tests.kit.model_zoo import model_zoo, run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
PLACEMENT_CONFIGS = [ PLACEMENT_CONFIGS = [
{"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2
@ -38,7 +37,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("keep_gather", [False, True]) @parameterize("keep_gather", [False, True])
@parameterize("model_name", ["gpt2", "bert"]) @parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("use_grad_checkpoint", [False, True]) @parameterize("use_grad_checkpoint", [False, True])
@parameterize("master_weights", [False, True]) @parameterize("master_weights", [False, True])
def exam_gpt_fwd_bwd( def exam_gpt_fwd_bwd(
@ -49,17 +48,22 @@ def exam_gpt_fwd_bwd(
master_weights: bool = True, master_weights: bool = True,
): ):
init_device = get_current_device() init_device = get_current_device()
get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() iter(model_zoo.get_sub_registry(model_name).values())
)
set_seed(42) set_seed(42)
model = model_builder(use_grad_checkpoint) model = model_builder()
set_seed(42) set_seed(42)
torch_model = model_builder(use_grad_checkpoint).cuda() torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()): for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data) torch_p.data.copy_(p.data)
if use_grad_checkpoint:
model.gradient_checkpointing_enable()
torch_model.gradient_checkpointing_enable()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["chunk_size"] = 5000
@ -77,25 +81,22 @@ def exam_gpt_fwd_bwd(
torch_model = DDP(torch_model, device_ids=[rank]) torch_model = DDP(torch_model, device_ids=[rank])
set_seed(rank) set_seed(rank)
for i, (input_ids, label) in enumerate(train_dataloader):
# you can only test a single fwd + bwd.
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
if i > 0:
break
input_ids, label = input_ids.cuda(), label.cuda()
torch_optim.zero_grad() data = data_gen_fn()
zero_optim.zero_grad() data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
# set random seed is same as torch_model.eval() torch_optim.zero_grad()
set_seed(42) zero_optim.zero_grad()
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
set_seed(42)
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
assert torch.equal(torch_loss, loss) # set random seed is same as torch_model.eval()
set_seed(42)
torch_loss = run_fwd_bwd(torch_model, data, output_transform_fn, loss_fn, optimizer=torch_optim)
set_seed(42)
loss = run_fwd_bwd(model, data, output_transform_fn, loss_fn, optimizer=zero_optim)
check_grad(model, torch_model) assert_close(torch_loss.float(), loss.float())
check_grad(model, torch_model)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):

View File

@ -3,38 +3,34 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import colossalai import colossalai
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed from colossalai.utils import set_seed
from colossalai.zero import GeminiDDP from colossalai.zero import GeminiDDP
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer
from tests.components_to_test import run_fwd_bwd from tests.kit.model_zoo import model_zoo, run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
# run gemini use the runtime memory tracer # run gemini use the runtime memory tracer
@parameterize("placement_policy", ["auto"]) @parameterize("placement_policy", ["auto"])
@parameterize("keep_gather", [False]) @parameterize("keep_gather", [False])
@parameterize("model_name", ["repeated_computed_layers", "bert", "albert", "gpt2"]) @parameterize("model_name", ["transformers_bert_for_sequence_classification"])
@parameterize("use_grad_checkpoint", [False, True]) @parameterize("use_grad_checkpoint", [False, True])
def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_checkpoint: bool = False): def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_checkpoint: bool = False):
set_seed(42) set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen_fn, output_transform_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values()))
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model = model_builder(use_grad_checkpoint).cuda() model = model_builder().cuda()
if use_grad_checkpoint:
model.gradient_checkpointing_enable()
print(f"model_name {model_name}") print(f"model_name {model_name}")
runtime_mem_tracer = RuntimeMemTracer(model)
for i, (input_ids, label) in enumerate(train_dataloader):
if i > 0:
break
input_ids, label = input_ids.cuda(), label.cuda()
# mem tracing runtime_mem_tracer = RuntimeMemTracer(model)
if i == 0: data = data_gen_fn()
run_fwd_bwd(runtime_mem_tracer, input_ids, label, criterion, runtime_mem_tracer) data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
run_fwd_bwd(runtime_mem_tracer, data, output_transform_fn, optimizer=runtime_mem_tracer)
memstats = runtime_mem_tracer.memstats() memstats = runtime_mem_tracer.memstats()
runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list
print("runtime tracer non model data points: ", len(runtime_tracer_non_model_data)) print("runtime tracer non model data points: ", len(runtime_tracer_non_model_data))
@ -62,16 +58,17 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
) )
set_seed(dist.get_rank()) set_seed(dist.get_rank())
for i, (input_ids, label) in enumerate(train_dataloader): train_dataloader = DummyDataloader(data_gen_fn)
for i, data in enumerate(train_dataloader):
# you can only test a single fwd + bwd. # you can only test a single fwd + bwd.
# after bwd param is grad for Gemini, due to the chunk reuse optimization. # after bwd param is grad for Gemini, due to the chunk reuse optimization.
# print(f'iteration {i}') # print(f'iteration {i}')
if i > 4: if i > 4:
break break
input_ids, label = input_ids.cuda(), label.cuda() data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
set_seed(42) set_seed(42)
run_fwd_bwd(model, input_ids, label, criterion, model) run_fwd_bwd(model, data, output_transform_fn, optimizer=model)
gemini_non_model_data = model.gemini_manager._mem_stats_collector._memstats.non_model_data_list("cuda") gemini_non_model_data = model.gemini_manager._mem_stats_collector._memstats.non_model_data_list("cuda")

View File

@ -7,13 +7,12 @@ from torch.testing import assert_close
import colossalai import colossalai
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd from tests.kit.model_zoo import model_zoo, run_fwd
from tests.components_to_test.registry import non_distributed_component_funcs
PLACEMENT_CONFIGS = [ PLACEMENT_CONFIGS = [
{"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2
@ -38,7 +37,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
# Compare gradients. # Compare gradients.
for p0, p1 in zip(model.parameters(), torch_model.parameters()): for p0, p1 in zip(model.parameters(), torch_model.parameters()):
assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5) assert_close(p0, p1.grad, rtol=2e-3, atol=2e-2)
# Release gradient chunks and move them to gradient device. # Release gradient chunks and move them to gradient device.
for grad_chunk, device in zip(grad_chunk_list, device_list): for grad_chunk, device in zip(grad_chunk_list, device_list):
@ -48,21 +47,19 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("keep_gathered", [False, True]) @parameterize("keep_gathered", [False, True])
@parameterize("model_name", ["gpt2", "bert"]) @parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("use_grad_checkpoint", [False, True])
@parameterize("master_weights", [False, True]) @parameterize("master_weights", [False, True])
def exam_gemini_grad_acc( def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str, master_weights: bool):
placement_config, keep_gathered: bool, model_name: str, use_grad_checkpoint: bool, master_weights: bool
):
init_device = get_current_device() init_device = get_current_device()
get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
model_builder, train_dataloader, _, _, criterion = get_components_func() iter(model_zoo.get_sub_registry(model_name).values())
)
set_seed(42) set_seed(42)
gemini_model = model_builder(use_grad_checkpoint) gemini_model = model_builder()
set_seed(42) set_seed(42)
torch_model = model_builder(use_grad_checkpoint).cuda() torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), gemini_model.parameters()): for torch_p, p in zip(torch_model.parameters(), gemini_model.parameters()):
torch_p.data.copy_(p.data) torch_p.data.copy_(p.data)
@ -94,22 +91,23 @@ def exam_gemini_grad_acc(
set_seed(rank) set_seed(rank)
accum_iter = 4 accum_iter = 4
for i, (input_ids, label) in enumerate(train_dataloader): train_dataloader = DummyDataloader(data_gen_fn)
for i, data in enumerate(train_dataloader):
delay_unscale = False if (i + 1) % accum_iter == 0 else True delay_unscale = False if (i + 1) % accum_iter == 0 else True
input_ids, label = input_ids.cuda(), label.cuda() data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
set_seed(42 + rank) set_seed(42 + rank)
torch_loss = run_fwd(torch_model, input_ids, label, criterion) torch_loss = run_fwd(torch_model, data, output_transform_fn, loss_fn)
torch_loss = torch_loss / accum_iter torch_loss = torch_loss / accum_iter
with amp.scale_loss(torch_loss, torch_optim, delay_unscale=delay_unscale) as scaled_loss: with amp.scale_loss(torch_loss, torch_optim, delay_unscale=delay_unscale) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
set_seed(42 + rank) set_seed(42 + rank)
gemini_loss = run_fwd(gemini_model, input_ids, label, criterion) gemini_loss = run_fwd(gemini_model, data, output_transform_fn, loss_fn)
gemini_loss = gemini_loss / accum_iter gemini_loss = gemini_loss / accum_iter
gemini_optim.backward(gemini_loss) gemini_optim.backward(gemini_loss)
assert torch.allclose(torch_loss, gemini_loss, rtol=1e-3, atol=1e-5) assert torch.allclose(torch_loss.float(), gemini_loss.float(), rtol=1e-3, atol=1e-5)
check_grad(gemini_model, torch_model) check_grad(gemini_model, torch_model)

View File

@ -7,12 +7,11 @@ from torch.testing import assert_close
import colossalai import colossalai
from colossalai.legacy.amp import convert_to_apex_amp from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed from colossalai.utils import set_seed
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd_bwd from tests.kit.model_zoo import model_zoo, run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
PLACEMENT_CONFIGS = [ PLACEMENT_CONFIGS = [
{ {
@ -51,12 +50,13 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", ["gpt2"]) @parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("master_weights", [True, False]) @parameterize("master_weights", [True, False])
def exam_grad_clipping(placement_config, model_name: str, master_weights: bool): def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
set_seed(1912) set_seed(1912)
get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() iter(model_zoo.get_sub_registry(model_name).values())
)
torch_model = model_builder().cuda() torch_model = model_builder().cuda()
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=32) amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=32)
@ -94,21 +94,17 @@ def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
torch_model.train() torch_model.train()
set_seed(dist.get_rank() * 3 + 128) set_seed(dist.get_rank() * 3 + 128)
for i, (data, label) in enumerate(train_dataloader): train_dataloader = DummyDataloader(data_gen_fn)
for i, data in enumerate(train_dataloader):
if i > 2: if i > 2:
break break
data = data.cuda() data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
label = label.cuda()
zero_optim.zero_grad() zero_optim.zero_grad()
torch_optim.zero_grad() torch_optim.zero_grad()
torch_loss = run_fwd_bwd(torch_model, data, label, criterion, torch_optim) run_fwd_bwd(torch_model, data, output_transform_fn, loss_fn, optimizer=torch_optim)
loss = run_fwd_bwd(model, data, label, criterion, zero_optim) run_fwd_bwd(model, data, output_transform_fn, loss_fn, optimizer=zero_optim)
# as no master weights leads to error accumulation, we don't check the loss
if master_weights:
assert_close(torch_loss, loss)
import apex.amp as apex_amp import apex.amp as apex_amp

View File

@ -9,13 +9,12 @@ from torch.testing import assert_close
import colossalai import colossalai
from colossalai.legacy.amp import convert_to_apex_amp from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd_bwd from tests.kit.model_zoo import model_zoo, run_fwd, run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
PLACEMENT_CONFIGS = [ PLACEMENT_CONFIGS = [
{"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2
@ -53,12 +52,11 @@ def single_chunk_init(model: torch.nn.Module, placement_config: dict):
@parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", ["gpt2"]) @parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("model_init_func", [single_chunk_init, multi_chunk_init]) @parameterize("model_init_func", [single_chunk_init, multi_chunk_init])
def exam_inference(placement_config: dict, model_name: str, model_init_func: Callable): def exam_inference(placement_config: dict, model_name: str, model_init_func: Callable):
set_seed(19360226) set_seed(19360226)
get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen_fn, output_transform_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values()))
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
torch_model = model_builder().cuda() torch_model = model_builder().cuda()
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=128) amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=128)
@ -79,29 +77,27 @@ def exam_inference(placement_config: dict, model_name: str, model_init_func: Cal
torch_model.eval() torch_model.eval()
set_seed(dist.get_rank() * 3 + 128) set_seed(dist.get_rank() * 3 + 128)
train_dataloader = iter(train_dataloader) train_dataloader = iter(DummyDataloader(data_gen_fn))
def train_iter(): def train_iter():
input_ids, label = next(train_dataloader) data = next(train_dataloader)
input_ids, label = input_ids.cuda(), label.cuda() data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
zero_optim.zero_grad() zero_optim.zero_grad()
torch_optim.zero_grad() torch_optim.zero_grad()
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) torch_loss = run_fwd_bwd(torch_model, data, output_transform_fn, optimizer=torch_optim)
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) loss = run_fwd_bwd(model, data, output_transform_fn, optimizer=zero_optim)
assert_close(torch_loss, loss, rtol=1e-5, atol=1e-5) assert_close(torch_loss.float(), loss.float(), rtol=1e-5, atol=1e-5)
zero_optim.step() zero_optim.step()
torch_optim.step() torch_optim.step()
check_param(model, torch_model) check_param(model, torch_model)
def inference_iter(): def inference_iter():
input_ids, label = next(train_dataloader) data = next(train_dataloader)
input_ids, label = input_ids.cuda(), label.cuda() data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
with torch.no_grad(): with torch.no_grad():
torch_output = torch_model(input_ids) torch_loss = run_fwd(torch_model, data, output_transform_fn)
torch_loss = criterion(torch_output.float(), label) zero_loss = run_fwd(model, data, output_transform_fn)
zero_output = model(input_ids) assert_close(torch_loss.float(), zero_loss.float(), rtol=1e-5, atol=1e-5)
zero_loss = criterion(zero_output.float(), label)
assert_close(torch_loss, zero_loss)
train_iter() train_iter()
inference_iter() inference_iter()

View File

@ -1,20 +1,18 @@
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from packaging.version import Version
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.legacy.amp import convert_to_apex_amp from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd_bwd from tests.kit.model_zoo import model_zoo, run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
PLACEMENT_CONFIGS = [ PLACEMENT_CONFIGS = [
{"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0}, # zero2 {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0}, # zero2
@ -32,14 +30,17 @@ PLACEMENT_CONFIGS = [
] ]
# this model is large enough to slice to chunks # this model is large enough to slice to chunks
TEST_MODELS = ["gpt2"] TEST_MODELS = ["transformers_gpt_lm"]
# these models are too small, all parameters in these models are compacted into one chunk # these models are too small, all parameters in these models are compacted into one chunk
EXAMPLE_MODELS = ["albert", "beit", "bert", "hanging_param_model", "nested_model", "repeated_computed_layers"] EXAMPLE_MODELS = [
"transformers_bert_for_sequence_classification",
"custom_hanging_param_model",
"custom_nested_model",
"custom_repeated_computed_layers",
]
# bfloat16 cannot represent them exactly # bfloat16 cannot represent them exactly
BF16_IGNORED_KEYS = [ BF16_IGNORED_KEYS = [
"albert.embeddings.word_embeddings.weight",
"albert.embeddings.position_embeddings.weight",
"masked_bias", "masked_bias",
] ]
@ -55,7 +56,7 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty
temp_zero_value = zero_dict[key].to(device=value.device) temp_zero_value = zero_dict[key].to(device=value.device)
if dtype is torch.bfloat16 and any(k in key for k in BF16_IGNORED_KEYS): if dtype is torch.bfloat16 and any(k in key for k in BF16_IGNORED_KEYS):
continue continue
rtol, atol = 1e-3, 4e-3 rtol, atol = 2e-3, 6e-3
if dtype is torch.bfloat16: if dtype is torch.bfloat16:
rtol, atol = 4e-3, 8e-3 rtol, atol = 4e-3, 8e-3
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
@ -74,8 +75,9 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty
@parameterize("master_weights", [True, False]) @parameterize("master_weights", [True, False])
def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool): def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool):
set_seed(42) set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() iter(model_zoo.get_sub_registry(model_name).values())
)
torch_model = model_builder().cuda() torch_model = model_builder().cuda()
# apex no master weights leads to nan, so we don't use it # apex no master weights leads to nan, so we don't use it
@ -104,19 +106,20 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt
torch_model.eval() torch_model.eval()
set_seed(dist.get_rank() * 3 + 128) set_seed(dist.get_rank() * 3 + 128)
rtol, atol = 1e-4, 1e-5 rtol, atol = 4e-2, 4e-2
for i, (input_ids, label) in enumerate(train_dataloader): train_dataloader = iter(DummyDataloader(data_gen_fn))
for i, data in enumerate(train_dataloader):
if i > 2: if i > 2:
break break
input_ids, label = input_ids.cuda(), label.cuda() data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
zero_optim.zero_grad() zero_optim.zero_grad()
torch_optim.zero_grad() torch_optim.zero_grad()
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) torch_loss = run_fwd_bwd(torch_model, data, output_transform_fn, loss_fn, optimizer=torch_optim)
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) loss = run_fwd_bwd(model, data, output_transform_fn, loss_fn, optimizer=zero_optim)
# as no master weights leads to error accumulation, we don't check the loss # as no master weights leads to error accumulation, we don't check the loss
if master_weights: if master_weights:
assert_close(torch_loss, loss, rtol=rtol, atol=atol) assert_close(torch_loss.float(), loss.float(), rtol=rtol, atol=atol)
zero_optim.step() zero_optim.step()
torch_optim.step() torch_optim.step()
@ -125,13 +128,14 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt
check_param(model, torch_model, mixed_precision) check_param(model, torch_model, mixed_precision)
@parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("placement_config", [PLACEMENT_CONFIGS[3]])
@parameterize("model_name", EXAMPLE_MODELS) @parameterize("model_name", EXAMPLE_MODELS)
@parameterize("mixed_precision", [torch.half, torch.bfloat16]) @parameterize("mixed_precision", [torch.half])
def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch.dtype): def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch.dtype):
set_seed(2008) set_seed(2008)
get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() iter(model_zoo.get_sub_registry(model_name).values())
)
torch_model = model_builder().cuda() torch_model = model_builder().cuda()
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=2) amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=2)
@ -159,26 +163,19 @@ def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch.
torch_model.eval() torch_model.eval()
set_seed(dist.get_rank() * 3 + 128) set_seed(dist.get_rank() * 3 + 128)
rtol, atol = 1.5e-6, 2e-5
if mixed_precision is torch.bfloat16:
rtol, atol = 2e-3, 2e-3
elif Version(torch.__version__) >= Version("2.0.0"):
rtol, atol = 4e-5, 3e-5
for i, (input_ids, label) in enumerate(train_dataloader): train_dataloader = DummyDataloader(data_gen_fn)
for i, data in enumerate(train_dataloader):
if i > 2: if i > 2:
break break
input_ids = input_ids.cuda() data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
label = label.cuda()
zero_optim.zero_grad() zero_optim.zero_grad()
torch_optim.zero_grad() torch_optim.zero_grad()
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) run_fwd_bwd(torch_model, data, output_transform_fn, loss_fn, optimizer=torch_optim)
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) run_fwd_bwd(model, data, output_transform_fn, loss_fn, optimizer=zero_optim)
assert_close(torch_loss, loss, rtol=rtol, atol=atol) # atol should be 2e-5 for torch lower than 1.12
zero_optim.step() zero_optim.step()
torch_optim.step() torch_optim.step()

View File

@ -4,10 +4,9 @@ import numpy as np
import pytest import pytest
import torch import torch
from colossalai.testing import clear_cache_before_run from colossalai.testing import DummyDataloader, clear_cache_before_run
from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer
from tests.components_to_test import run_fwd_bwd from tests.kit.model_zoo import model_zoo, run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
@pytest.mark.skip("this is not used") @pytest.mark.skip("this is not used")
@ -16,21 +15,22 @@ def test_runtime_mem_tracer():
test_models = ["gpt2", "bert", "simple_net", "repeated_computed_layers", "nested_model", "albert"] test_models = ["gpt2", "bert", "simple_net", "repeated_computed_layers", "nested_model", "albert"]
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen_fn, output_transform_fn, *_ = next(
model_builder, train_dataloader, _, _, criterion = get_components_func() iter(model_zoo.get_sub_registry(model_name).values())
)
model = model_builder(checkpoint=False).cuda() model = model_builder().cuda()
model_bk = deepcopy(model) model_bk = deepcopy(model)
runtime_mem_tracer = RuntimeMemTracer(model) runtime_mem_tracer = RuntimeMemTracer(model)
for i, (data, label) in enumerate(train_dataloader): train_dataloader = DummyDataloader(data_gen_fn)
for i, data in enumerate(train_dataloader):
if i > 1: if i > 1:
break break
data = data.cuda() data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
label = label.cuda()
run_fwd_bwd(runtime_mem_tracer, data, label, criterion, optimizer=runtime_mem_tracer) run_fwd_bwd(runtime_mem_tracer, data, output_transform_fn, optimizer=runtime_mem_tracer)
for p1, p2 in zip(model_bk.parameters(), model.parameters()): for p1, p2 in zip(model_bk.parameters(), model.parameters()):
torch.allclose(p1.to(torch.half), p2) torch.allclose(p1.to(torch.half), p2)

View File

@ -5,40 +5,37 @@ import colossalai
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration
from tests.components_to_test.registry import non_distributed_component_funcs from tests.kit.model_zoo import model_zoo
def exam_search_chunk_size(): def exam_search_chunk_size():
world_size = torch.distributed.get_world_size() model_builder, data_gen_fn, output_transform_fn, *_ = next(
iter(model_zoo.get_sub_registry("transformers_gpt_lm").values())
get_components_func = non_distributed_component_funcs.get_callable("gpt2") )
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
# make sure torch_model and model has the same parameter values # make sure torch_model and model has the same parameter values
model = model_builder() model = model_builder()
config_dict, *_ = search_chunk_configuration( config_dict, *_ = search_chunk_configuration(
model, search_range_m=1, search_interval=16, min_chunk_size_m=0, filter_exlarge_params=True model, search_range_m=1, search_interval=128, min_chunk_size_m=0, filter_exlarge_params=True
) )
for key in config_dict: for key in config_dict:
chunk_size = config_dict[key]["chunk_size"] chunk_size = config_dict[key]["chunk_size"]
if world_size == 1 or True: assert chunk_size == 527872
assert chunk_size == 31616
else:
assert chunk_size == 1024
def exam_chunk_manager(): def exam_chunk_manager():
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
get_components_func = non_distributed_component_funcs.get_callable("gpt2") model_builder, data_gen_fn, output_transform_fn, *_ = next(
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() iter(model_zoo.get_sub_registry("transformers_gpt_lm").values())
)
sharded_ddp_model = model_builder() sharded_ddp_model = model_builder()
chunk_manager = init_chunk_manager( chunk_manager = init_chunk_manager(
sharded_ddp_model, sharded_ddp_model,
get_current_device(), get_current_device(),
hidden_dim=16, hidden_dim=128,
search_range_m=1, search_range_m=1,
min_chunk_size_m=0, min_chunk_size_m=0,
filter_exlarge_params=True, filter_exlarge_params=True,
@ -46,7 +43,7 @@ def exam_chunk_manager():
) )
config_dict = chunk_manager.dp_degree_chunk_size_dict config_dict = chunk_manager.dp_degree_chunk_size_dict
assert len(config_dict) == 1 assert len(config_dict) == 1
assert config_dict[world_size] == 31616 assert config_dict[world_size] == 527872
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):

View File

@ -7,7 +7,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed from colossalai.utils import set_seed
from colossalai.zero import GeminiDDP from colossalai.zero import GeminiDDP
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test.registry import non_distributed_component_funcs from tests.kit.model_zoo import model_zoo
PLACEMENT_CONFIGS = [ PLACEMENT_CONFIGS = [
{"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2
@ -26,15 +26,16 @@ def ignore_the_first_parameter(model: torch.nn.Module):
@parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("keep_gathered", [True, False]) @parameterize("keep_gathered", [True, False])
@parameterize("model_name", ["gpt2", "bert"]) @parameterize("model_name", ["transformers_gpt_lm", "transformers_bert_for_sequence_classification"])
@parameterize("master_weights", [False, True]) @parameterize("master_weights", [False, True])
def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool): def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool):
set_seed(431) set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen_fn, output_transform_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values()))
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model = model_builder() model = model_builder()
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
torch_model = model_builder() torch_model = model_builder()
for torch_p, p in zip(torch_model.parameters(), model.parameters()): for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data) torch_p.data.copy_(p.data)
@ -54,29 +55,7 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str, master_wei
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5) assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5)
# check load state dict
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("keep_gathered", [True, False])
@parameterize("model_name", ["gpt2", "bert"])
@parameterize("master_weights", [False, True])
def exam_load_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool):
set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model = model_builder()
set_seed(451)
torch_model = model_builder() # get a different model
world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]["chunk_size"] = 5000
config_dict[world_size]["keep_gathered"] = keep_gathered
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, master_weights=master_weights)
torch_dict = torch_model.state_dict()
model.load_state_dict(torch_dict, strict=False) model.load_state_dict(torch_dict, strict=False)
zero_dict = model.state_dict(only_rank_0=False) zero_dict = model.state_dict(only_rank_0=False)
@ -85,23 +64,7 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str, maste
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5) assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5)
# check state dict shard
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", ["gpt2", "bert"])
@parameterize("master_weights", [False, True])
def exam_state_dict_shard(placement_config, model_name: str, master_weights: bool):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model = model_builder()
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
model = GeminiDDP(model, config_dict, **placement_config, master_weights=master_weights)
model.train()
zero_dict = model.state_dict(only_rank_0=False)
accumulated_keys = set() accumulated_keys = set()
# ensure number of shards > 1 # ensure number of shards > 1
for shard, _ in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False): for shard, _ in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False):
@ -116,8 +79,6 @@ def run_dist(rank, world_size, port):
config = {} config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_state_dict() exam_state_dict()
exam_load_state_dict()
exam_state_dict_shard()
@pytest.mark.dist @pytest.mark.dist

View File

@ -8,7 +8,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed from colossalai.utils import set_seed
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test.registry import non_distributed_component_funcs from tests.kit.model_zoo import model_zoo
PLACEMENT_CONFIGS = [ PLACEMENT_CONFIGS = [
{"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0}, # zero2 {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0}, # zero2
@ -22,8 +22,9 @@ PLACEMENT_CONFIGS = [
@parameterize("keep_gathered", [True, False]) @parameterize("keep_gathered", [True, False])
def exam_zero_optim_state_dict(placement_config, keep_gathered): def exam_zero_optim_state_dict(placement_config, keep_gathered):
set_seed(431) set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable("gpt2") model_builder, data_gen_fn, output_transform_fn, *_ = next(
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() iter(model_zoo.get_sub_registry("transformers_gpt_lm").values())
)
model = model_builder() model = model_builder()
@ -41,15 +42,15 @@ def exam_zero_optim_state_dict(placement_config, keep_gathered):
set_seed(dist.get_rank() * 3 + 128) set_seed(dist.get_rank() * 3 + 128)
model.train() model.train()
for i, (input_ids, label) in enumerate(train_dataloader): data = data_gen_fn()
if i > 0: data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
break
optim.zero_grad() optim.zero_grad()
logits = model(input_ids) outputs = model(**data)
logits = logits.float() outputs = output_transform_fn(outputs)
loss = criterion(logits, input_ids) loss = next(iter(outputs.values())).sum()
optim.backward(loss) optim.backward(loss)
optim.step() optim.step()
optim_state_dict = optim.state_dict() optim_state_dict = optim.state_dict()
optim.load_state_dict(optim_state_dict) optim.load_state_dict(optim_state_dict)