diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index cd6c3cfa4..1cebcb40e 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -81,18 +81,16 @@ class BaseConsumer: # Init Hybrid ray process group for i in range(self.num_producers): - cc.init_collective_group(self.world_size + 1, self.rank + 1, backend="hccl", group_name=f"sync_data_{i}") + cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}") if self.pp_size > 1: # use hybrid tp + pp if self.tp_rank == 0 and self.dp_rank == 0: cc.init_collective_group( - self.num_producers + 1, self.num_producers, backend="hccl", group_name=f"sync_model_{self.pp_rank}" + self.num_producers + 1, self.num_producers, group_name=f"sync_model_{self.pp_rank}" ) else: if self.rank == 0: - cc.init_collective_group( - self.num_producers + 1, self.num_producers, backend="hccl", group_name="sync_model" - ) + cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model") self.buffer = [] @@ -154,7 +152,12 @@ class BaseConsumer: print(f"Saved model checkpoint at step {step + 1} in folder {save_path}") if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1: - print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") + if self.pp_size > 1: + print( + f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}" + ) + else: + print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") torch.cuda.empty_cache() state_dict = self.state_dict() if self.pp_size > 1: diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 948f3bc50..945bb05dc 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -82,16 +82,12 @@ class BaseProducer: self.consumer_pp_size = consumer_plugin_config["pp_size"] # consumer pp size def setup(self) -> None: - cc.init_collective_group( - 1 + self.num_consumer_procs, 0, backend="hccl", group_name=f"sync_data_{self.producer_idx}" - ) + cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}") if self.consumer_pp_size > 1: for i in range(self.consumer_pp_size): - cc.init_collective_group( - self.num_producers + 1, self.producer_idx, backend="hccl", group_name=f"sync_model_{i}" - ) + cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}") else: - cc.init_collective_group(self.num_producers + 1, self.producer_idx, backend="hccl", group_name="sync_model") + cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model") def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: raise NotImplementedError @@ -142,8 +138,8 @@ class BaseProducer: torch.cuda.empty_cache() if self.consumer_pp_size > 1: - # TODO: loop load for i in range(self.consumer_pp_size): + print(f"[P{self.producer_idx}] Sync model PP stage {i}") state_dict = ray_broadcast_tensor_dict( None, self.num_producers, device=self.device, group_name=f"sync_model_{i}" ) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index a3ed00f88..788e60c2e 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -58,7 +58,7 @@ if __name__ == "__main__": "--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional" ) parser.add_argument( - "--master_port", type=int, default=29506, help="Master port for multi-node distributed training, Optional" + "--master_port", type=int, default=29505, help="Master port for multi-node distributed training, Optional" ) # Sampling parameters @@ -129,7 +129,7 @@ if __name__ == "__main__": args.top_k = -1 inference_model_config = dict(path=args.model) - train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False, attn_implementation="eager") + train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False) generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature) if args.backend == "transformers": @@ -155,7 +155,7 @@ if __name__ == "__main__": enforce_eager=True, enable_chunked_prefill=True, max_model_len=args.max_new_tokens + args.max_prompt_tokens, - tensor_parallel_size=2, + tensor_parallel_size=1, ) ) generate_config.update( @@ -223,7 +223,7 @@ if __name__ == "__main__": "zero_stage": 2, }, # for zero # plugin_config={ - # "tp_size": 2, + # "tp_size": 1, # "pp_size": 2, # "microbatch_size": max( # 1, args.train_microbatch_size // 2