[shardformer] support lazy init (#4202)

* [shardformer] support lazy init

* [shardformer] linear support lazy init

* [shardformer] embedding support lazy init

* [shardformer] norm support lazy init

* [shardformer] fused linear support lazy init

* [test] update shardformer test layer

* [test] shardformer with lazy init fit ddp

* [lazy] hotfix deepcopy of param

* [shardformer] fix bert policy and update test

* [shardformer] fix bloom policy and update test

* [shardformer] fix opt policy and update test

* [shardformer] fix t5 policy and update test

* [shardformer] fix gpt2 policy and update test

* [shardformer] fix llama policy and update test
This commit is contained in:
Hongxin Liu
2023-07-10 10:48:53 +08:00
parent f3bcc292c8
commit 890774b2fb
25 changed files with 263 additions and 157 deletions

View File

@@ -24,11 +24,12 @@ class T5BasePolicy(Policy):
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self):
@@ -164,11 +165,12 @@ class T5BasePolicy(Policy):
return policy
def postprocess(self):
binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]]
if self.shard_config.enable_tensor_parallelism:
binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]]
for k, v in binding_map:
mod = getattr_(self.model, k)
setattr_(self.model, v, mod)
for k, v in binding_map:
mod = getattr_(self.model, k)
setattr_(self.model, v, mod)
return self.model
@@ -211,13 +213,13 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
def postprocess(self):
super().postprocess()
if self.shard_config.enable_tensor_parallelism:
binding_map = {"shared": "lm_head"}
binding_map = {"shared": "lm_head"}
for k, v in binding_map.items():
src_mod = getattr_(self.model, k)
dst_mod = getattr_(self.model, v)
dst_mod.weight = src_mod.weight
for k, v in binding_map.items():
src_mod = getattr_(self.model, k)
dst_mod = getattr_(self.model, v)
dst_mod.weight = src_mod.weight
return self.model
@@ -239,11 +241,12 @@ class T5EncoderPolicy(T5BasePolicy):
return base_policy
def postprocess(self):
binding_map = [
["shared", "encoder.embed_tokens"],
]
if self.shard_config.enable_tensor_parallelism:
binding_map = [
["shared", "encoder.embed_tokens"],
]
for k, v in binding_map:
mod = getattr_(self.model, k)
setattr_(self.model, v, mod)
for k, v in binding_map:
mod = getattr_(self.model, k)
setattr_(self.model, v, mod)
return self.model