[moe] support optimizer checkpoint (#5015)

* Refactor MoE Manager setup method

* unshard optim ckpt

* optim io

* update transformer version

* update requirements

* update ckpt

* update ckpt

* update ckpt

* fix engine

* fix engine
This commit is contained in:
Xuanlei Zhao
2023-11-08 23:07:03 +08:00
committed by GitHub
parent 67f5331754
commit f71e63b0f3
20 changed files with 738 additions and 150 deletions

View File

@@ -79,13 +79,15 @@ class TPInferEngine:
self.multi_query_group_num = model.config.num_attention_heads
# default to attention_heads
self.multi_query_attention = model.config.multi_query_attention
if hasattr(model.config, "multi_query_attention"):
self.multi_query_attention = getattr(model.config, "multi_query_attention")
if hasattr(model.config, "multi_query_group_num"):
self.multi_query_group_num = model.config.multi_query_group_num
self.multi_query_group_num = getattr(model.config, "multi_query_group_num")
if hasattr(model.config, "num_key_value_heads"):
self.multi_query_group_num = model.config.num_key_value_heads
self.multi_query_group_num = getattr(model.config, "num_key_value_heads")
self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
self.cache_manager = None
@@ -108,7 +110,7 @@ class TPInferEngine:
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
self.head_num //= self.tp_size # update sharded number of heads
if self.multi_query_attention:
if hasattr(self, "multi_query_attention"):
# NOTE the logic of MQA tensor parallelism should be specified.
assert (
self.multi_query_group_num % self.tp_size == 0