[shardformer] adapted T5 and LLaMa test to use kit (#4049)

* [shardformer] adapted T5 and LLaMa test to use kit

* polish code
This commit is contained in:
Frank Lee
2023-06-21 09:32:46 +08:00
parent 4021b9a8a2
commit 58df720570
24 changed files with 239 additions and 168 deletions

View File

@@ -65,13 +65,14 @@ class Embedding1D(ParallelModule):
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
gather_output: bool = True,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
self.embedding_dim = embedding_dim
self.process_group = process_group
self.num_partitions = dist.get_world_size(process_group)
self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions)
@@ -79,7 +80,7 @@ class Embedding1D(ParallelModule):
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
# self.gather_output = gather_output
self.gather_output = gather_output
if device is None:
device = get_current_device()
@@ -95,7 +96,9 @@ class Embedding1D(ParallelModule):
@staticmethod
def from_native_module(module: nn.Embedding,
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Embedding1D":
process_group: Union[ProcessGroup, List[ProcessGroup]] = None,
*args,
**kwargs) -> "Embedding1D":
r"""
Build a 1D parallelized Embedding from a native nn.Embedding module.
"""
@@ -123,7 +126,9 @@ class Embedding1D(ParallelModule):
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
sparse=sparse,
*args,
**kwargs)
# copy the weight
with torch.no_grad():
@@ -133,7 +138,7 @@ class Embedding1D(ParallelModule):
return embedding
def reset_parameters(self, weight_initializer) -> None:
fan_in, fan_out = self.num_embeddings, self.embed_dim
fan_in, fan_out = self.num_embeddings, self.embedding_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
@@ -144,6 +149,9 @@ class Embedding1D(ParallelModule):
def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
return output
if self.gather_output:
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
return output
else:
return output_parallel

View File

@@ -4,7 +4,7 @@ import torch.nn as nn
from transformers import LlamaForCausalLM, LlamaForSequenceClassification
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

View File

@@ -11,8 +11,7 @@ from transformers.models.t5.modeling_t5 import (
T5Stack,
)
from colossalai.shardformer.layer.dropout import Dropout1D
from colossalai.shardformer.layer.layers import Embedding1D, Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer import Dropout1D, Embedding1D, Linear1D_Col, Linear1D_Row
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

View File

@@ -185,7 +185,14 @@ class ModelSharder(object):
if description.ignore_if_not_exist and native_sub_module is None:
continue
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'],
**kwargs)
try:
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'],
**kwargs)
except Exception as e:
raise RuntimeError(
f"Failed to replace {suffix} of type {native_sub_module.__class__.__qualname__}"
f" with {target_module.__qualname__} with the exception: {e}. "
"Please check your model configuration or sharding policy, you can set up an issue for us to help you as well."
)
setattr_(org_layer, suffix, replace_layer)

View File

@@ -98,6 +98,6 @@ def assert_hf_output_close(out1: Any,
raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}")
assert torch.allclose(
out1, out2, atol=atol, rtol=rtol
), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, mean error: {torch.abs(out1 - out2).mean()}"
), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, \nmean error: {torch.abs(out1 - out2).mean()}"
else:
assert out1 == out2, f"{track_name}: value mismatch.\nout1: {out1}\nout2: {out2}"