mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 23:18:36 +00:00
[shardformer] adapted T5 and LLaMa test to use kit (#4049)
* [shardformer] adapted T5 and LLaMa test to use kit * polish code
This commit is contained in:
@@ -65,13 +65,14 @@ class Embedding1D(ParallelModule):
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
gather_output: bool = True,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embed_dim = embedding_dim
|
||||
self.embedding_dim = embedding_dim
|
||||
self.process_group = process_group
|
||||
self.num_partitions = dist.get_world_size(process_group)
|
||||
self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions)
|
||||
@@ -79,7 +80,7 @@ class Embedding1D(ParallelModule):
|
||||
self.padding_idx = padding_idx
|
||||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
# self.gather_output = gather_output
|
||||
self.gather_output = gather_output
|
||||
|
||||
if device is None:
|
||||
device = get_current_device()
|
||||
@@ -95,7 +96,9 @@ class Embedding1D(ParallelModule):
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Embedding,
|
||||
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Embedding1D":
|
||||
process_group: Union[ProcessGroup, List[ProcessGroup]] = None,
|
||||
*args,
|
||||
**kwargs) -> "Embedding1D":
|
||||
r"""
|
||||
Build a 1D parallelized Embedding from a native nn.Embedding module.
|
||||
"""
|
||||
@@ -123,7 +126,9 @@ class Embedding1D(ParallelModule):
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse)
|
||||
sparse=sparse,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
# copy the weight
|
||||
with torch.no_grad():
|
||||
@@ -133,7 +138,7 @@ class Embedding1D(ParallelModule):
|
||||
return embedding
|
||||
|
||||
def reset_parameters(self, weight_initializer) -> None:
|
||||
fan_in, fan_out = self.num_embeddings, self.embed_dim
|
||||
fan_in, fan_out = self.num_embeddings, self.embedding_dim
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
self._fill_padding_idx_with_zero()
|
||||
|
||||
@@ -144,6 +149,9 @@ class Embedding1D(ParallelModule):
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||
|
||||
return output
|
||||
if self.gather_output:
|
||||
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||
return output
|
||||
else:
|
||||
return output_parallel
|
||||
|
@@ -4,7 +4,7 @@ import torch.nn as nn
|
||||
from transformers import LlamaForCausalLM, LlamaForSequenceClassification
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
||||
|
||||
from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
@@ -11,8 +11,7 @@ from transformers.models.t5.modeling_t5 import (
|
||||
T5Stack,
|
||||
)
|
||||
|
||||
from colossalai.shardformer.layer.dropout import Dropout1D
|
||||
from colossalai.shardformer.layer.layers import Embedding1D, Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.layer import Dropout1D, Embedding1D, Linear1D_Col, Linear1D_Row
|
||||
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
@@ -185,7 +185,14 @@ class ModelSharder(object):
|
||||
if description.ignore_if_not_exist and native_sub_module is None:
|
||||
continue
|
||||
|
||||
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'],
|
||||
**kwargs)
|
||||
try:
|
||||
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'],
|
||||
**kwargs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to replace {suffix} of type {native_sub_module.__class__.__qualname__}"
|
||||
f" with {target_module.__qualname__} with the exception: {e}. "
|
||||
"Please check your model configuration or sharding policy, you can set up an issue for us to help you as well."
|
||||
)
|
||||
|
||||
setattr_(org_layer, suffix, replace_layer)
|
||||
|
@@ -98,6 +98,6 @@ def assert_hf_output_close(out1: Any,
|
||||
raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}")
|
||||
assert torch.allclose(
|
||||
out1, out2, atol=atol, rtol=rtol
|
||||
), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, mean error: {torch.abs(out1 - out2).mean()}"
|
||||
), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, \nmean error: {torch.abs(out1 - out2).mean()}"
|
||||
else:
|
||||
assert out1 == out2, f"{track_name}: value mismatch.\nout1: {out1}\nout2: {out2}"
|
||||
|
Reference in New Issue
Block a user