[shardformer] adapted T5 and LLaMa test to use kit (#4049)

* [shardformer] adapted T5 and LLaMa test to use kit

* polish code
This commit is contained in:
Frank Lee 2023-06-21 09:32:46 +08:00
parent 4021b9a8a2
commit 58df720570
24 changed files with 239 additions and 168 deletions

View File

@ -65,13 +65,14 @@ class Embedding1D(ParallelModule):
dtype: torch.dtype = None, dtype: torch.dtype = None,
device: torch.device = None, device: torch.device = None,
process_group: ProcessGroup = None, process_group: ProcessGroup = None,
gather_output: bool = True,
weight_initializer: Callable = init.normal_(), weight_initializer: Callable = init.normal_(),
*args, *args,
**kwargs): **kwargs):
super().__init__() super().__init__()
self.num_embeddings = num_embeddings self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim self.embedding_dim = embedding_dim
self.process_group = process_group self.process_group = process_group
self.num_partitions = dist.get_world_size(process_group) self.num_partitions = dist.get_world_size(process_group)
self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions) self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions)
@ -79,7 +80,7 @@ class Embedding1D(ParallelModule):
self.padding_idx = padding_idx self.padding_idx = padding_idx
self.embed_args = args self.embed_args = args
self.embed_kwargs = kwargs self.embed_kwargs = kwargs
# self.gather_output = gather_output self.gather_output = gather_output
if device is None: if device is None:
device = get_current_device() device = get_current_device()
@ -95,7 +96,9 @@ class Embedding1D(ParallelModule):
@staticmethod @staticmethod
def from_native_module(module: nn.Embedding, def from_native_module(module: nn.Embedding,
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Embedding1D": process_group: Union[ProcessGroup, List[ProcessGroup]] = None,
*args,
**kwargs) -> "Embedding1D":
r""" r"""
Build a 1D parallelized Embedding from a native nn.Embedding module. Build a 1D parallelized Embedding from a native nn.Embedding module.
""" """
@ -123,7 +126,9 @@ class Embedding1D(ParallelModule):
max_norm=max_norm, max_norm=max_norm,
norm_type=norm_type, norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq, scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse) sparse=sparse,
*args,
**kwargs)
# copy the weight # copy the weight
with torch.no_grad(): with torch.no_grad():
@ -133,7 +138,7 @@ class Embedding1D(ParallelModule):
return embedding return embedding
def reset_parameters(self, weight_initializer) -> None: def reset_parameters(self, weight_initializer) -> None:
fan_in, fan_out = self.num_embeddings, self.embed_dim fan_in, fan_out = self.num_embeddings, self.embedding_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero() self._fill_padding_idx_with_zero()
@ -144,6 +149,9 @@ class Embedding1D(ParallelModule):
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
if self.gather_output:
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
return output return output
else:
return output_parallel

View File

@ -4,7 +4,7 @@ import torch.nn as nn
from transformers import LlamaForCausalLM, LlamaForSequenceClassification from transformers import LlamaForCausalLM, LlamaForSequenceClassification
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

View File

@ -11,8 +11,7 @@ from transformers.models.t5.modeling_t5 import (
T5Stack, T5Stack,
) )
from colossalai.shardformer.layer.dropout import Dropout1D from colossalai.shardformer.layer import Dropout1D, Embedding1D, Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer.layers import Embedding1D, Linear1D_Col, Linear1D_Row
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

View File

@ -185,7 +185,14 @@ class ModelSharder(object):
if description.ignore_if_not_exist and native_sub_module is None: if description.ignore_if_not_exist and native_sub_module is None:
continue continue
try:
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'], replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'],
**kwargs) **kwargs)
except Exception as e:
raise RuntimeError(
f"Failed to replace {suffix} of type {native_sub_module.__class__.__qualname__}"
f" with {target_module.__qualname__} with the exception: {e}. "
"Please check your model configuration or sharding policy, you can set up an issue for us to help you as well."
)
setattr_(org_layer, suffix, replace_layer) setattr_(org_layer, suffix, replace_layer)

View File

@ -98,6 +98,6 @@ def assert_hf_output_close(out1: Any,
raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}") raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}")
assert torch.allclose( assert torch.allclose(
out1, out2, atol=atol, rtol=rtol out1, out2, atol=atol, rtol=rtol
), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, mean error: {torch.abs(out1 - out2).mean()}" ), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, \nmean error: {torch.abs(out1 - out2).mean()}"
else: else:
assert out1 == out2, f"{track_name}: value mismatch.\nout1: {out1}\nout2: {out2}" assert out1 == out2, f"{track_name}: value mismatch.\nout1: {out1}\nout2: {out2}"

View File

@ -28,27 +28,35 @@ class ModelZooRegistry(dict):
model_fn: Callable, model_fn: Callable,
data_gen_fn: Callable, data_gen_fn: Callable,
output_transform_fn: Callable, output_transform_fn: Callable,
loss_fn: Callable = None,
model_attribute: ModelAttribute = None): model_attribute: ModelAttribute = None):
""" """
Register a model and data generation function. Register a model and data generation function.
Examples: Examples:
>>> # Register
>>> model_zoo = ModelZooRegistry() ```python
>>> model_zoo.register('resnet18', resnet18, resnet18_data_gen) # normal forward workflow
>>> # Run the model model = resnet18()
>>> data = resnet18_data_gen() # do not input any argument data = resnet18_data_gen()
>>> model = resnet18() # do not input any argument output = model(**data)
>>> out = model(**data) transformed_output = output_transform_fn(output)
loss = loss_fn(transformed_output)
# Register
model_zoo = ModelZooRegistry()
model_zoo.register('resnet18', resnet18, resnet18_data_gen, output_transform_fn, loss_fn)
```
Args: Args:
name (str): Name of the model. name (str): Name of the model.
model_fn (callable): A function that returns a model. **It must not contain any arguments.** model_fn (Callable): A function that returns a model. **It must not contain any arguments.**
output_transform_fn (callable): A function that transforms the output of the model into Dict. data_gen_fn (Callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.**
data_gen_fn (callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.** output_transform_fn (Callable): A function that transforms the output of the model into Dict.
loss_fn (Callable): a function to compute the loss from the given output. Defaults to None
model_attribute (ModelAttribute): Attributes of the model. Defaults to None. model_attribute (ModelAttribute): Attributes of the model. Defaults to None.
""" """
self[name] = (model_fn, data_gen_fn, output_transform_fn, model_attribute) self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute)
def get_sub_registry(self, keyword: str): def get_sub_registry(self, keyword: str):
""" """

View File

@ -1,5 +1,6 @@
from .albert import * from .albert import *
from .bert import * from .bert import *
from .gpt import * from .gpt import *
from .llama import *
from .opt import * from .opt import *
from .t5 import * from .t5 import *

View File

@ -0,0 +1,76 @@
import torch
import transformers
from ..registry import ModelAttribute, model_zoo
try:
from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
HAS_LLAMA = True
except ImportError:
HAS_LLAMA = False
if HAS_LLAMA:
# ===============================
# Register LLaMA
# ===============================
def data_gen():
# the input ids are corresponding to the sentence
# 'Hello, my dog is cute'
#
# the code is give below:
# -----------------------------------
# from transformers import LlamaTokenizerFast
# tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
# input = 'Hello, my dog is cute'
# tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
# -----------------------------------
input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long()
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long()
return dict(input_ids=input_ids, attention_mask=attention_mask)
# label is needed for casual lm
def data_gen_for_casual_lm():
data = data_gen()
labels = data['input_ids'].clone()
data['labels'] = labels
return data
# transform the output to a dict
output_transform_fn = lambda x: x
# function to get the loss
loss_fn = lambda output: output.last_hidden_state.mean()
loss_fn_for_casual_lm = lambda output: output.loss
loss_fn_for_seq_classification = lambda output: output.logits.mean()
config = LlamaConfig(num_hidden_layers=4,
hidden_size=128,
intermediate_size=256,
num_attention_heads=4,
max_position_embeddings=128,
num_labels=16)
# register the following models
# transformers.LlamaModel,
# transformers.LlamaForCausalLM,
# transformers.LlamaForSequenceClassification,
model_zoo.register(name='transformers_llama',
model_fn=lambda: transformers.LlamaModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_llama_for_casual_lm',
model_fn=lambda: transformers.LlamaForCausalLM(config),
data_gen_fn=data_gen_for_casual_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_casual_lm,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_llama_for_sequence_classification',
model_fn=lambda: transformers.LlamaForSequenceClassification(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_seq_classification,
model_attribute=ModelAttribute(has_control_flow=True))

View File

@ -6,24 +6,50 @@ from ..registry import ModelAttribute, model_zoo
# =============================== # ===============================
# Register single-sentence T5 # Register single-sentence T5
# =============================== # ===============================
BATCH_SIZE = 2
SEQ_LENGTH = 16
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
return dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
# define data gen function
def data_gen_for_encoder_only(): def data_gen_for_encoder_only():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) # Generated from following code snippet
#
# from transformers import T5Config, T5Tokenizer
# config = T5Config(decoder_start_token_id=0)
# tokenizer = T5Tokenizer.from_pretrained("t5-small")
# input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1]]).long()
return dict(input_ids=input_ids) return dict(input_ids=input_ids)
def data_gen_for_conditional_generation():
# labels is generated with the following code
#
# labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids
data = data_gen_for_encoder_only()
labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1]]).long()
data['labels'] = labels
return data
def data_gen_for_t5_model():
# decoder_inputs_ids is obtained with the following code
#
# decoder_input_ids = model._shift_right(input_ids)
data = data_gen_for_encoder_only()
decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5]]).long()
data['decoder_input_ids'] = decoder_input_ids
return data
# output transform function
output_transform_fn = lambda x: x output_transform_fn = lambda x: x
config = transformers.T5Config(d_model=128, num_layers=2) # define loss funciton
loss_fn_for_t5_model = lambda x: x.last_hidden_state.mean()
loss_fn_for_encoder_only = lambda x: x.last_hidden_state.mean()
loss_fn_for_conditional_generation = lambda x: x.loss
# define model config
config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decoder_start_token_id=0)
# register the following models # register the following models
# transformers.T5Model, # transformers.T5Model,
@ -31,16 +57,19 @@ config = transformers.T5Config(d_model=128, num_layers=2)
# transformers.T5EncoderModel, # transformers.T5EncoderModel,
model_zoo.register(name='transformers_t5', model_zoo.register(name='transformers_t5',
model_fn=lambda: transformers.T5Model(config), model_fn=lambda: transformers.T5Model(config),
data_gen_fn=data_gen, data_gen_fn=data_gen_for_t5_model,
output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_t5_model,
model_attribute=ModelAttribute(has_control_flow=True)) model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_t5_for_conditional_generation', model_zoo.register(name='transformers_t5_for_conditional_generation',
model_fn=lambda: transformers.T5ForConditionalGeneration(config), model_fn=lambda: transformers.T5ForConditionalGeneration(config),
data_gen_fn=data_gen, data_gen_fn=data_gen_for_conditional_generation,
output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_conditional_generation,
model_attribute=ModelAttribute(has_control_flow=True)) model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_t5_encoder_model', model_zoo.register(name='transformers_t5_encoder_model',
model_fn=lambda: transformers.T5EncoderModel(config), model_fn=lambda: transformers.T5EncoderModel(config),
data_gen_fn=data_gen_for_encoder_only, data_gen_fn=data_gen_for_encoder_only,
output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_encoder_only,
model_attribute=ModelAttribute(has_control_flow=True)) model_attribute=ModelAttribute(has_control_flow=True))

View File

@ -11,7 +11,7 @@ def run_torch_amp(rank, world_size, port):
# init dist env # init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
sub_model_zoo = model_zoo.get_sub_registry('timm') sub_model_zoo = model_zoo.get_sub_registry('timm')
for name, (model_fn, data_gen_fn, output_transform_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in sub_model_zoo.items():
# dlrm_interactionarch has not parameters, so skip # dlrm_interactionarch has not parameters, so skip
if name == 'dlrm_interactionarch': if name == 'dlrm_interactionarch':
continue continue

View File

@ -71,7 +71,7 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
passed_models = [] passed_models = []
failed_info = {} # (model_name, error) pair failed_info = {} # (model_name, error) pair
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
# These models lead to CUDA error # These models lead to CUDA error
if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp', if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp',
'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext'): 'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext'):

View File

@ -61,7 +61,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS
skipped_models = [] skipped_models = []
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
# FIXME(ver217): fix these models # FIXME(ver217): fix these models
if name in ignore_models: if name in ignore_models:
skipped_models.append(name) skipped_models.append(name)

View File

@ -40,7 +40,7 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):
def check_torch_ddp_plugin(): def check_torch_ddp_plugin():
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
if name == 'dlrm_interactionarch': if name == 'dlrm_interactionarch':
continue continue
run_fn(model_fn, data_gen_fn, output_transform_fn) run_fn(model_fn, data_gen_fn, output_transform_fn)

View File

@ -42,7 +42,7 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):
def check_torch_fsdp_plugin(): def check_torch_fsdp_plugin():
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
if any(element in name for element in [ if any(element in name for element in [
'diffusers', 'deepfm_sparsearch', 'dlrm_interactionarch', 'torchvision_googlenet', 'diffusers', 'deepfm_sparsearch', 'dlrm_interactionarch', 'torchvision_googlenet',
'torchvision_inception_v3' 'torchvision_inception_v3'

View File

@ -47,7 +47,7 @@ def test_diffusers():
sub_model_zoo = model_zoo.get_sub_registry('diffusers') sub_model_zoo = model_zoo.get_sub_registry('diffusers')
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
data = data_gen_fn() data = data_gen_fn()
trace_and_compare(model_fn, data, output_transform_fn) trace_and_compare(model_fn, data, output_transform_fn)
torch.cuda.synchronize() torch.cuda.synchronize()
@ -60,7 +60,7 @@ def test_torch_diffusers():
sub_model_zoo = model_zoo.get_sub_registry('diffusers') sub_model_zoo = model_zoo.get_sub_registry('diffusers')
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
data = data_gen_fn() data = data_gen_fn()
model = model_fn() model = model_fn()
output = model(**data) output = model(**data)

View File

@ -56,7 +56,7 @@ def test_timm_models():
sub_model_zoo = model_zoo.get_sub_registry('timm') sub_model_zoo = model_zoo.get_sub_registry('timm')
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
data = data_gen_fn() data = data_gen_fn()
if attribute is not None and attribute.has_control_flow: if attribute is not None and attribute.has_control_flow:
meta_args = {k: v.to('meta') for k, v in data.items()} meta_args = {k: v.to('meta') for k, v in data.items()}

View File

@ -16,7 +16,7 @@ def test_torchaudio_models():
sub_model_zoo = model_zoo.get_sub_registry('torchaudio') sub_model_zoo = model_zoo.get_sub_registry('torchaudio')
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
model = model_fn() model = model_fn()
trace_and_compare(model, trace_and_compare(model,
data_gen_fn, data_gen_fn,

View File

@ -60,7 +60,7 @@ def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn:
def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, check_forward: bool = False) -> None: def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, check_forward: bool = False) -> None:
model_fn, data_gen_fn, output_transform_fn, model_attr = entry model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry
_MyTensor._pre_op_fn = lambda *args: set_seed(seed) _MyTensor._pre_op_fn = lambda *args: set_seed(seed)
LazyTensor._pre_op_fn = lambda *args: set_seed(seed) LazyTensor._pre_op_fn = lambda *args: set_seed(seed)
ctx = LazyInitContext(tensor_cls=_MyTensor) ctx = LazyInitContext(tensor_cls=_MyTensor)

View File

@ -78,7 +78,7 @@ def run_dist_lazy_init(subset, seed: int = 42):
if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'): if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'):
continue continue
print_rank_0(name) print_rank_0(name)
model_fn, data_gen_fn, output_transform_fn, model_attr = entry model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry
ctx = LazyInitContext(tensor_cls=_MyTensor) ctx = LazyInitContext(tensor_cls=_MyTensor)
with ctx: with ctx:
model = model_fn() model = model_fn()

View File

View File

@ -0,0 +1,38 @@
import copy
from colossalai.shardformer import ShardConfig, ShardFormer
def build_model(world_size, model_fn):
# create new model
org_model = model_fn().cuda()
# shard model
shard_config = ShardConfig(tensor_parallel_size=world_size)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(model_copy)
return org_model, sharded_model
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# prepare input
data = data_gen_fn()
data = {k: v.cuda() for k, v in data.items()}
# switch to train mode
original_model.train()
sharded_model.train()
# run forward
org_output = original_model(**data)
org_output = output_transform_fn(org_output)
org_loss = loss_fn(org_output)
shard_output = sharded_model(**data)
shard_output = output_transform_fn(shard_output)
shard_loss = loss_fn(shard_output)
return org_output, org_loss, shard_output, shard_loss

View File

@ -1,64 +1,22 @@
import copy
import os import os
import random
import pytest import pytest
import torch import torch
from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaTokenizerFast
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
def build_model(world_size, model_fn): def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# create new model org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
config = LlamaConfig(num_hidden_layers=4, output_transform_fn, loss_fn)
hidden_size=128,
intermediate_size=256,
num_attention_heads=4,
max_position_embeddings=128)
org_model = model_fn(config).cuda()
# shard model
shard_config = ShardConfig(tensor_parallel_size=world_size)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(model_copy)
return org_model, sharded_model
def check_forward_backward(org_model, sharded_model):
# prepare input
input = 'Hello, my dog is cute'
tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
del tokenized_input["token_type_ids"]
del tokenized_input["attention_mask"]
# switch to train mode
org_model.train()
sharded_model.train()
if isinstance(org_model, (LlamaModel, LlamaForSequenceClassification)):
org_output = org_model(**tokenized_input)
org_loss = org_output.last_hidden_state.mean()
shard_output = sharded_model(**tokenized_input)
shard_loss = shard_output.last_hidden_state.mean()
elif isinstance(org_model, LlamaForCausalLM):
labels = tokenized_input['input_ids'].clone()
labels[labels == tokenizer.pad_token_id] = -100
tokenized_input['labels'] = labels
org_output = org_model(**tokenized_input)
org_loss = org_output.loss
shard_output = sharded_model(**tokenized_input)
shard_loss = shard_output.loss
# forward check
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-4) assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-4)
# run backward # run backward
@ -66,12 +24,12 @@ def check_forward_backward(org_model, sharded_model):
shard_loss.backward() shard_loss.backward()
# check grad # check grad
if isinstance(org_model, LlamaModel): if hasattr(org_model, 'model'):
llama_model = org_model
shard_llama_model = sharded_model
else:
llama_model = org_model.model llama_model = org_model.model
shard_llama_model = sharded_model.model shard_llama_model = sharded_model.model
else:
llama_model = org_model
shard_llama_model = sharded_model
org_grad = llama_model.layers[0].self_attn.q_proj.weight.grad org_grad = llama_model.layers[0].self_attn.q_proj.weight.grad
shard_grad = shard_llama_model.layers[0].self_attn.q_proj.weight.grad shard_grad = shard_llama_model.layers[0].self_attn.q_proj.weight.grad
@ -89,17 +47,11 @@ def check_llama(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model_list = [ sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
LlamaModel,
# LlamaForCausalLM,
# TODO: do not work yet for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
# LlamaForSequenceClassification
]
for model_fn in model_list:
org_model, sharded_model = build_model(world_size, model_fn) org_model, sharded_model = build_model(world_size, model_fn)
check_forward_backward(org_model, sharded_model) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -1,64 +1,20 @@
import copy
import os import os
import pytest import pytest
import torch import torch
from transformers import T5Config, T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Tokenizer, T5TokenizerFast
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.shard import ShardConfig, ShardFormer
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' from tests.test_shardformer.test_model._utils import build_model, run_forward
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
tokenizer = T5Tokenizer.from_pretrained("t5-small")
def build_model(world_size, model_fn): def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
config = T5Config(decoder_start_token_id=0) # check forward
config.dropout_rate = 0 # the value "past_key_values" is sharded, so we ignore
org_model = model_fn(config=config).to('cuda') org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
shard_config = ShardConfig(tensor_parallel_size=world_size) output_transform_fn, loss_fn)
# shard model
shard_config = ShardConfig(tensor_parallel_size=world_size)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(model_copy)
return org_model, sharded_model
def check_forward_backward(org_model, sharded_model):
# prepare input
input_ids = tokenizer("translate English to German: The house is wonderful.",
return_tensors="pt").input_ids.to('cuda')
labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids.to('cuda')
# switch to train mode
org_model.train()
sharded_model.train()
if isinstance(org_model, T5ForConditionalGeneration):
org_output = org_model(input_ids=input_ids, labels=labels)
org_loss = org_output.loss
shard_output = sharded_model(input_ids=input_ids, labels=labels)
shard_loss = shard_output.loss
elif isinstance(org_model, T5Model):
decoder_input_ids = org_model._shift_right(input_ids)
org_output = org_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
org_loss = org_output.last_hidden_state.mean()
shard_output = sharded_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
shard_loss = shard_output.last_hidden_state.mean()
elif isinstance(org_model, T5EncoderModel):
org_output = org_model(input_ids=input_ids)
org_loss = org_output.last_hidden_state.mean()
shard_output = sharded_model(input_ids=input_ids)
shard_loss = shard_output.last_hidden_state.mean()
# key is sharded, so we ignore
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
# do backward # do backward
@ -81,17 +37,14 @@ def check_forward_backward(org_model, sharded_model):
def check_t5(rank, world_size, port): def check_t5(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model_fn_list = [ sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
T5Model,
T5ForConditionalGeneration,
T5EncoderModel,
]
for model_fn in model_fn_list: for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn) org_model, sharded_model = build_model(world_size, model_fn)
check_forward_backward(org_model, sharded_model) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache() torch.cuda.empty_cache()