mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[feat] Sync shard model (#6289)
* [feat] support hybrid parallel model sync * update consumer and producer * update files * update producer * remove print * update --------- Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
@@ -29,6 +29,7 @@ class BaseProducer:
|
||||
tokenizer_config: Optional[Dict[str, Any]] = None,
|
||||
microbatch_size: int = 1,
|
||||
backend: str = "transformers",
|
||||
consumer_plugin_config: Dict[str, Any] = None,
|
||||
):
|
||||
self.producer_idx = producer_idx
|
||||
self.num_producers = num_producers
|
||||
@@ -78,9 +79,15 @@ class BaseProducer:
|
||||
else:
|
||||
raise ValueError(f"Unexpected backend {backend}")
|
||||
|
||||
self.consumer_pp_size = consumer_plugin_config["pp_size"] # consumer pp size
|
||||
|
||||
def setup(self) -> None:
|
||||
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
|
||||
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
|
||||
if self.consumer_pp_size > 1:
|
||||
for i in range(self.consumer_pp_size):
|
||||
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}")
|
||||
else:
|
||||
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
|
||||
|
||||
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
@@ -125,15 +132,25 @@ class BaseProducer:
|
||||
):
|
||||
self.model.llm.sleep() # revict KV_cache to avoid OOM
|
||||
# don't sync model for last iteration
|
||||
print(
|
||||
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
state_dict = ray_broadcast_tensor_dict(
|
||||
None, self.num_producers, device=self.device, group_name="sync_model"
|
||||
)
|
||||
self.load_state_dict(state_dict)
|
||||
if self.consumer_pp_size > 1:
|
||||
for pp_idx in range(self.consumer_pp_size):
|
||||
print(
|
||||
f"[P{self.producer_idx}] Sync model PP stage {pp_idx} episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
||||
)
|
||||
state_dict = ray_broadcast_tensor_dict(
|
||||
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
|
||||
)
|
||||
self.load_state_dict(state_dict)
|
||||
else:
|
||||
print(
|
||||
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
||||
)
|
||||
state_dict = ray_broadcast_tensor_dict(
|
||||
None, self.num_producers, device=self.device, group_name="sync_model"
|
||||
)
|
||||
self.load_state_dict(state_dict)
|
||||
del state_dict
|
||||
torch.cuda.empty_cache()
|
||||
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
||||
@@ -170,6 +187,7 @@ class SimpleProducer(BaseProducer):
|
||||
microbatch_size=1,
|
||||
backend="transformers",
|
||||
num_generations: int = 8,
|
||||
consumer_plugin_config=None,
|
||||
):
|
||||
super().__init__(
|
||||
producer_idx,
|
||||
@@ -184,6 +202,7 @@ class SimpleProducer(BaseProducer):
|
||||
tokenizer_config,
|
||||
microbatch_size,
|
||||
backend,
|
||||
consumer_plugin_config,
|
||||
)
|
||||
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
||||
|
||||
|
Reference in New Issue
Block a user