diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index eb8e24c46..38d233e7d 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -151,19 +151,6 @@ class GRPOConsumer(BaseConsumer): eta_min=0.1 * grpo_config.get("lr", 1e-6), ) - def get_device_mesh_mapping(self): - return { - "rank": self.rank, - "tp_rank": self.tp_rank, - "tp_size": self.tp_size, - "dp_size": self.dp_size, - "dp_rank": self.dp_rank, - "pp_size": self.booster.plugin.stage_manager.num_stages, - "pp_stage": self.booster.plugin.stage_manager.stage, - "is_last_stage": self.booster.plugin.stage_manager.is_last_stage(), - "world_size": self.world_size, - } - def get_model_state_dict_keys(self): return self.orig_state_dict_key diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index a351dfc64..053b88657 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -116,11 +116,10 @@ def launch_distributed( consumer_procs.append(consumer) # setup the consumer procs first ray.get([p.setup.remote() for p in consumer_procs]) - # get the device mesh mapping from consumer procs - consumer_device_mesh_mapping = ray.get([p.get_device_mesh_mapping.remote() for p in consumer_procs]) + # get state dict key for checking syncing integrity model_state_dict_keys = ray.get(consumer_procs[0].get_model_state_dict_keys.remote()) # setup the producer procs - ray.get([p.setup.remote(consumer_device_mesh_mapping, model_state_dict_keys) for p in producer_procs]) + ray.get([p.setup.remote(model_state_dict_keys) for p in producer_procs]) # loop procs = producer_procs + consumer_procs ray.get([p.loop.remote() for p in procs]) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index e9aff6f7b..6755b247d 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -80,18 +80,12 @@ class BaseProducer: else: raise ValueError(f"Unexpected backend {backend}") - def setup(self, consumer_device_mesh_mapping: Dict[str, Any] = None, model_state_dict_keys: List = None) -> None: - self.consumer_device_mesh_mapping = consumer_device_mesh_mapping + def setup(self, model_state_dict_keys: List = None) -> None: self.model_state_dict_keys = model_state_dict_keys cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}") - # cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model_pp_stage_0") - for i in range(self.num_consumer_procs): - device_mesh_mapping = self.consumer_device_mesh_mapping[i] - device_mesh_mapping["rank"] - # TODO: support ep, sp - if device_mesh_mapping["dp_rank"] == 0 and device_mesh_mapping["tp_rank"] == 0: - group_name = f"sync_model_pp_stage_{device_mesh_mapping['pp_stage']}" - cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=group_name) + for pp_stage in range(self.consumer_plugin_config.get("pp_size", 1)): + group_name = f"sync_model_pp_stage_{pp_stage}" + cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=group_name) def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: raise NotImplementedError @@ -141,17 +135,13 @@ class BaseProducer: print( f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}" ) - for consumer_rank_id in range(self.num_consumer_procs): - device_mesh_mapping = self.consumer_device_mesh_mapping[consumer_rank_id] - device_mesh_mapping["rank"] - # TODO: support ep, sp - if device_mesh_mapping["dp_rank"] == 0 and device_mesh_mapping["tp_rank"] == 0: - group_name = f"sync_model_pp_stage_{device_mesh_mapping['pp_stage']}" - state_dict.update( - ray_broadcast_tensor_dict( - None, src=self.num_producers, device=self.device, group_name=group_name - ) + for pp_stage in range(self.consumer_plugin_config.get("pp_size", 1)): + group_name = f"sync_model_pp_stage_{pp_stage}" + state_dict.update( + ray_broadcast_tensor_dict( + None, src=self.num_producers, device=self.device, group_name=group_name ) + ) # check model sync integrity assert len(state_dict) == len( self.model_state_dict_keys