diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 40fa67e1a..1cebcb40e 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -59,10 +59,6 @@ class BaseConsumer: self.lr_scheduler = None def setup(self) -> None: - for i in range(self.num_producers): - cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}") - if self.rank == 0: - cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model") launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2) @@ -77,8 +73,24 @@ class BaseConsumer: self.booster = Booster(plugin=self.plugin) self.dp_rank = dist.get_rank(self.plugin.dp_group) self.tp_rank = dist.get_rank(self.plugin.tp_group) + self.pp_rank = dist.get_rank(self.plugin.pp_group) self.dp_size = dist.get_world_size(self.plugin.dp_group) + self.tp_size = dist.get_world_size(self.plugin.tp_group) + self.pp_size = dist.get_world_size(self.plugin.pp_group) + + # Init Hybrid ray process group + for i in range(self.num_producers): + cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}") + if self.pp_size > 1: + # use hybrid tp + pp + if self.tp_rank == 0 and self.dp_rank == 0: + cc.init_collective_group( + self.num_producers + 1, self.num_producers, group_name=f"sync_model_{self.pp_rank}" + ) + else: + if self.rank == 0: + cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model") self.buffer = [] @@ -140,13 +152,27 @@ class BaseConsumer: print(f"Saved model checkpoint at step {step + 1} in folder {save_path}") if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1: - print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") + if self.pp_size > 1: + print( + f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}" + ) + else: + print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") torch.cuda.empty_cache() state_dict = self.state_dict() - if self.rank == 0: - ray_broadcast_tensor_dict( - state_dict, src=self.num_producers, device=self.device, group_name="sync_model" - ) + if self.pp_size > 1: + if self.tp_rank == 0 and self.dp_rank == 0: + ray_broadcast_tensor_dict( + state_dict, + src=self.num_producers, + device=self.device, + group_name=f"sync_model_{self.pp_rank}", + ) + else: + if self.rank == 0: + ray_broadcast_tensor_dict( + state_dict, src=self.num_producers, device=self.device, group_name="sync_model" + ) del state_dict torch.cuda.empty_cache() diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 4e100ebd1..a346d1d4f 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -57,7 +57,7 @@ def launch_distributed( else: core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer) - train_dp_size = get_dp_size_fast(num_producers, plugin_config) + train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config) assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0 dataset_path = dataset_config["path"] @@ -82,6 +82,7 @@ def launch_distributed( microbatch_size=inference_microbatch_size, backend=inference_backend, num_generations=num_generations, + consumer_plugin_config=plugin_config, ) procs.append(producer) generate_config_consumer = copy.deepcopy(generate_config) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index fcb1a184a..a2d675870 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -29,6 +29,7 @@ class BaseProducer: tokenizer_config: Optional[Dict[str, Any]] = None, microbatch_size: int = 1, backend: str = "transformers", + consumer_plugin_config: Dict[str, Any] = None, ): self.producer_idx = producer_idx self.num_producers = num_producers @@ -78,9 +79,15 @@ class BaseProducer: else: raise ValueError(f"Unexpected backend {backend}") + self.consumer_pp_size = consumer_plugin_config["pp_size"] # consumer pp size + def setup(self) -> None: cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}") - cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model") + if self.consumer_pp_size > 1: + for i in range(self.consumer_pp_size): + cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}") + else: + cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model") def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: raise NotImplementedError @@ -125,15 +132,25 @@ class BaseProducer: ): self.model.llm.sleep() # revict KV_cache to avoid OOM # don't sync model for last iteration - print( - f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}" - ) torch.cuda.empty_cache() - state_dict = ray_broadcast_tensor_dict( - None, self.num_producers, device=self.device, group_name="sync_model" - ) - self.load_state_dict(state_dict) + if self.consumer_pp_size > 1: + for pp_idx in range(self.consumer_pp_size): + print( + f"[P{self.producer_idx}] Sync model PP stage {pp_idx} episode {episode} step {(i + 1) // self.num_microbatches - 1}" + ) + state_dict = ray_broadcast_tensor_dict( + None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}" + ) + self.load_state_dict(state_dict) + else: + print( + f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}" + ) + state_dict = ray_broadcast_tensor_dict( + None, self.num_producers, device=self.device, group_name="sync_model" + ) + self.load_state_dict(state_dict) del state_dict torch.cuda.empty_cache() if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get( @@ -170,6 +187,7 @@ class SimpleProducer(BaseProducer): microbatch_size=1, backend="transformers", num_generations: int = 8, + consumer_plugin_config=None, ): super().__init__( producer_idx, @@ -184,6 +202,7 @@ class SimpleProducer(BaseProducer): tokenizer_config, microbatch_size, backend, + consumer_plugin_config, ) self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 39584750c..788e60c2e 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -58,7 +58,7 @@ if __name__ == "__main__": "--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional" ) parser.add_argument( - "--master_port", type=int, default=29506, help="Master port for multi-node distributed training, Optional" + "--master_port", type=int, default=29505, help="Master port for multi-node distributed training, Optional" ) # Sampling parameters @@ -223,7 +223,7 @@ if __name__ == "__main__": "zero_stage": 2, }, # for zero # plugin_config={ - # "tp_size": 2, + # "tp_size": 1, # "pp_size": 2, # "microbatch_size": max( # 1, args.train_microbatch_size // 2