[shardformer] support inplace sharding (#4251)

* [shardformer] embedding support inplace sharding

* [shardformer] linear support inplace sharding

* [shardformer] layernorm support inplace sharding

* [shardformer] qkv support inplace sharding

* [test] update shardformer layer test

* [shardformer] fix shared param sharding

* [shardformer] fix bert policy

* [shardformer] fix bloom policy

* [shardformer] fix llama policy

* [shardformer] fix opt policy

* [shardformer] fix t5 policy

* [shardformer] fix fused qkv linear

* [shardformer] fix bugs

* force sync

* [test] fix bugs

* [test] fix transformer version
This commit is contained in:
Hongxin Liu
2023-07-20 10:39:06 +08:00
parent 2a2eacfaf1
commit d921ce8391
26 changed files with 371 additions and 340 deletions

View File

@@ -1,6 +1,5 @@
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .._utils import getattr_, setattr_
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
@@ -116,19 +115,6 @@ class OPTForCausalLMPolicy(OPTPolicy):
target_key=OPTForCausalLM)
return policy
def postprocess(self):
if self.shard_config.enable_tensor_parallelism:
binding_map = {
'model.decoder.embed_tokens': 'lm_head',
}
for k, v in binding_map.items():
src_mod = getattr_(self.model, k)
dst_mod = getattr_(self.model, v)
dst_mod.weight = src_mod.weight
return self.model
class OPTForSequenceClassificationPolicy(OPTPolicy):