From fcb60de3f859252af3a5ca19b6154a2af2f24960 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Tue, 29 Apr 2025 16:20:24 +0800 Subject: [PATCH] Revert "fix pp state dict incomplete issue" This reverts commit 6c1b3b694f4898c04575c891d5ccc205a82a3c40. --- .../coati/distributed/consumer.py | 23 +++++------ .../coati/distributed/grpo_consumer.py | 17 -------- .../ColossalChat/coati/distributed/launch.py | 18 ++------- .../coati/distributed/producer.py | 40 ++++--------------- 4 files changed, 20 insertions(+), 78 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 9c3df97fd..40fa67e1a 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -59,8 +59,13 @@ class BaseConsumer: self.lr_scheduler = None def setup(self) -> None: + for i in range(self.num_producers): + cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}") + if self.rank == 0: + cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model") launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) - plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2) # default config + + plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2) if ( self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config @@ -74,16 +79,10 @@ class BaseConsumer: self.tp_rank = dist.get_rank(self.plugin.tp_group) self.dp_size = dist.get_world_size(self.plugin.dp_group) - self.tp_size = dist.get_world_size(self.plugin.tp_group) self.buffer = [] self.recv_cnt = 0 - for i in range(self.num_producers): - cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}") - if self.dp_rank == 0 and self.tp_rank == 0: - group_name = f"sync_model_pp_stage_{self.plugin.stage_manager.stage}" - cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name=group_name) def state_dict(self) -> Dict[str, torch.Tensor]: raise NotImplementedError @@ -141,13 +140,12 @@ class BaseConsumer: print(f"Saved model checkpoint at step {step + 1} in folder {save_path}") if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1: + print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") torch.cuda.empty_cache() state_dict = self.state_dict() - if self.dp_rank == 0 and self.tp_rank == 0: - group_name = f"sync_model_pp_stage_{self.plugin.stage_manager.stage}" - print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") + if self.rank == 0: ray_broadcast_tensor_dict( - state_dict, src=self.num_producers, device=self.device, group_name=group_name + state_dict, src=self.num_producers, device=self.device, group_name="sync_model" ) del state_dict torch.cuda.empty_cache() @@ -193,9 +191,6 @@ class SimpleConsumer(BaseConsumer): self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3) self.accum_loss = torch.zeros(1, device=self.device) - def get_plugin(self): - return self.plugin - def setup(self): super().setup() self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index eb8e24c46..877ff98ec 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -75,7 +75,6 @@ class GRPOConsumer(BaseConsumer): ) path = model_config.pop("path") self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) - self.orig_state_dict_key = [k for k in self.policy_model.state_dict()] self.policy_model.train() self.policy_model.gradient_checkpointing_enable() self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6)) @@ -151,22 +150,6 @@ class GRPOConsumer(BaseConsumer): eta_min=0.1 * grpo_config.get("lr", 1e-6), ) - def get_device_mesh_mapping(self): - return { - "rank": self.rank, - "tp_rank": self.tp_rank, - "tp_size": self.tp_size, - "dp_size": self.dp_size, - "dp_rank": self.dp_rank, - "pp_size": self.booster.plugin.stage_manager.num_stages, - "pp_stage": self.booster.plugin.stage_manager.stage, - "is_last_stage": self.booster.plugin.stage_manager.is_last_stage(), - "world_size": self.world_size, - } - - def get_model_state_dict_keys(self): - return self.orig_state_dict_key - def setup(self): super().setup() if self.use_wandb and ( diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index a351dfc64..4e100ebd1 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -66,7 +66,7 @@ def launch_distributed( num_update_per_episode = num_samples // global_inference_batch_size num_recv_per_update = inference_batch_size // inference_microbatch_size - producer_procs = [] + procs = [] for i in range(num_producers): producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote( producer_idx=i, @@ -78,20 +78,18 @@ def launch_distributed( dataloaders_config=dataloaders_config, model_config=inference_model_config, generate_config=generate_config, - consumer_plugin_config=plugin_config, tokenizer_config=tokenizer_config, microbatch_size=inference_microbatch_size, backend=inference_backend, num_generations=num_generations, ) - producer_procs.append(producer) + procs.append(producer) generate_config_consumer = copy.deepcopy(generate_config) generate_config_consumer.update( dict( backend=inference_backend, ) ) - consumer_procs = [] for i in range(num_consumer_procs): consumer = core_consumer.options(num_gpus=1).remote( num_producers=num_producers, @@ -113,14 +111,6 @@ def launch_distributed( save_interval=save_interval, save_dir=save_dir, ) - consumer_procs.append(consumer) - # setup the consumer procs first - ray.get([p.setup.remote() for p in consumer_procs]) - # get the device mesh mapping from consumer procs - consumer_device_mesh_mapping = ray.get([p.get_device_mesh_mapping.remote() for p in consumer_procs]) - model_state_dict_keys = ray.get(consumer_procs[0].get_model_state_dict_keys.remote()) - # setup the producer procs - ray.get([p.setup.remote(consumer_device_mesh_mapping, model_state_dict_keys) for p in producer_procs]) - # loop - procs = producer_procs + consumer_procs + procs.append(consumer) + ray.get([p.setup.remote() for p in procs]) ray.get([p.loop.remote() for p in procs]) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index e9aff6f7b..fcb1a184a 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional import ray import ray.util.collective as cc @@ -26,7 +26,6 @@ class BaseProducer: dataloaders_config: Dict[str, Any], model_config: Dict[str, Any], generate_config: Dict[str, Any], - consumer_plugin_config: Dict[str, Any] = None, tokenizer_config: Optional[Dict[str, Any]] = None, microbatch_size: int = 1, backend: str = "transformers", @@ -44,7 +43,6 @@ class BaseProducer: self.model_config = model_config self.generate_config = generate_config self.tokenizer_config = tokenizer_config - self.consumer_plugin_config = consumer_plugin_config # init tokenizer if tokenizer_config is None: @@ -80,18 +78,9 @@ class BaseProducer: else: raise ValueError(f"Unexpected backend {backend}") - def setup(self, consumer_device_mesh_mapping: Dict[str, Any] = None, model_state_dict_keys: List = None) -> None: - self.consumer_device_mesh_mapping = consumer_device_mesh_mapping - self.model_state_dict_keys = model_state_dict_keys + def setup(self) -> None: cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}") - # cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model_pp_stage_0") - for i in range(self.num_consumer_procs): - device_mesh_mapping = self.consumer_device_mesh_mapping[i] - device_mesh_mapping["rank"] - # TODO: support ep, sp - if device_mesh_mapping["dp_rank"] == 0 and device_mesh_mapping["tp_rank"] == 0: - group_name = f"sync_model_pp_stage_{device_mesh_mapping['pp_stage']}" - cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=group_name) + cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model") def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: raise NotImplementedError @@ -136,27 +125,14 @@ class BaseProducer: ): self.model.llm.sleep() # revict KV_cache to avoid OOM # don't sync model for last iteration - torch.cuda.empty_cache() - state_dict = {} print( f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}" ) - for consumer_rank_id in range(self.num_consumer_procs): - device_mesh_mapping = self.consumer_device_mesh_mapping[consumer_rank_id] - device_mesh_mapping["rank"] - # TODO: support ep, sp - if device_mesh_mapping["dp_rank"] == 0 and device_mesh_mapping["tp_rank"] == 0: - group_name = f"sync_model_pp_stage_{device_mesh_mapping['pp_stage']}" - state_dict.update( - ray_broadcast_tensor_dict( - None, src=self.num_producers, device=self.device, group_name=group_name - ) - ) - # check model sync integrity - assert len(state_dict) == len( - self.model_state_dict_keys - ), f"state dict keys has {len(state_dict)} unique keys not equal original model with {len(self.model_state_dict_keys)} keys. Missing keys: {set(self.model_state_dict_keys)-set(state_dict.keys())}. Please kindly inform the developer." + torch.cuda.empty_cache() + state_dict = ray_broadcast_tensor_dict( + None, self.num_producers, device=self.device, group_name="sync_model" + ) self.load_state_dict(state_dict) del state_dict torch.cuda.empty_cache() @@ -190,7 +166,6 @@ class SimpleProducer(BaseProducer): dataloaders_config, model_config, generate_config, - consumer_plugin_config=None, tokenizer_config=None, microbatch_size=1, backend="transformers", @@ -206,7 +181,6 @@ class SimpleProducer(BaseProducer): dataloaders_config, model_config, generate_config, - consumer_plugin_config, tokenizer_config, microbatch_size, backend,