mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[moe] support optimizer checkpoint (#5015)
* Refactor MoE Manager setup method * unshard optim ckpt * optim io * update transformer version * update requirements * update ckpt * update ckpt * update ckpt * fix engine * fix engine
This commit is contained in:
@@ -79,13 +79,15 @@ class TPInferEngine:
|
||||
|
||||
self.multi_query_group_num = model.config.num_attention_heads
|
||||
# default to attention_heads
|
||||
self.multi_query_attention = model.config.multi_query_attention
|
||||
if hasattr(model.config, "multi_query_attention"):
|
||||
self.multi_query_attention = getattr(model.config, "multi_query_attention")
|
||||
|
||||
if hasattr(model.config, "multi_query_group_num"):
|
||||
self.multi_query_group_num = model.config.multi_query_group_num
|
||||
self.multi_query_group_num = getattr(model.config, "multi_query_group_num")
|
||||
|
||||
if hasattr(model.config, "num_key_value_heads"):
|
||||
self.multi_query_group_num = model.config.num_key_value_heads
|
||||
self.multi_query_group_num = getattr(model.config, "num_key_value_heads")
|
||||
|
||||
self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
|
||||
self.cache_manager = None
|
||||
|
||||
@@ -108,7 +110,7 @@ class TPInferEngine:
|
||||
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
|
||||
self.head_num //= self.tp_size # update sharded number of heads
|
||||
|
||||
if self.multi_query_attention:
|
||||
if hasattr(self, "multi_query_attention"):
|
||||
# NOTE the logic of MQA tensor parallelism should be specified.
|
||||
assert (
|
||||
self.multi_query_group_num % self.tp_size == 0
|
||||
|
Reference in New Issue
Block a user