[shardformer] support inplace sharding (#4251)

* [shardformer] embedding support inplace sharding

* [shardformer] linear support inplace sharding

* [shardformer] layernorm support inplace sharding

* [shardformer] qkv support inplace sharding

* [test] update shardformer layer test

* [shardformer] fix shared param sharding

* [shardformer] fix bert policy

* [shardformer] fix bloom policy

* [shardformer] fix llama policy

* [shardformer] fix opt policy

* [shardformer] fix t5 policy

* [shardformer] fix fused qkv linear

* [shardformer] fix bugs

* force sync

* [test] fix bugs

* [test] fix transformer version
This commit is contained in:
Hongxin Liu
2023-07-20 10:39:06 +08:00
parent 2a2eacfaf1
commit d921ce8391
26 changed files with 371 additions and 340 deletions

View File

@@ -8,7 +8,6 @@ from colossalai.shardformer.layer import (
)
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
from .._utils import getattr_, setattr_
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
@@ -53,7 +52,7 @@ class T5BasePolicy(Policy):
),
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=Embedding1D,
target_module=VocabParallelEmbedding1D,
)
])
policy[T5LayerSelfAttention] = ModulePolicyDescription(sub_module_replacement=[
@@ -165,12 +164,6 @@ class T5BasePolicy(Policy):
return policy
def postprocess(self):
if self.shard_config.enable_tensor_parallelism:
binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]]
for k, v in binding_map:
mod = getattr_(self.model, k)
setattr_(self.model, v, mod)
return self.model
@@ -211,18 +204,6 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
target_key=T5ForConditionalGeneration)
return policy
def postprocess(self):
super().postprocess()
if self.shard_config.enable_tensor_parallelism:
binding_map = {"shared": "lm_head"}
for k, v in binding_map.items():
src_mod = getattr_(self.model, k)
dst_mod = getattr_(self.model, v)
dst_mod.weight = src_mod.weight
return self.model
class T5EncoderPolicy(T5BasePolicy):
@@ -239,14 +220,3 @@ class T5EncoderPolicy(T5BasePolicy):
policy=base_policy,
target_key=T5EncoderModel)
return base_policy
def postprocess(self):
if self.shard_config.enable_tensor_parallelism:
binding_map = [
["shared", "encoder.embed_tokens"],
]
for k, v in binding_map:
mod = getattr_(self.model, k)
setattr_(self.model, v, mod)
return self.model