diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 40fa67e1a..9c3df97fd 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -59,13 +59,8 @@ class BaseConsumer: self.lr_scheduler = None def setup(self) -> None: - for i in range(self.num_producers): - cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}") - if self.rank == 0: - cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model") launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) - - plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2) + plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2) # default config if ( self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config @@ -79,10 +74,16 @@ class BaseConsumer: self.tp_rank = dist.get_rank(self.plugin.tp_group) self.dp_size = dist.get_world_size(self.plugin.dp_group) + self.tp_size = dist.get_world_size(self.plugin.tp_group) self.buffer = [] self.recv_cnt = 0 + for i in range(self.num_producers): + cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}") + if self.dp_rank == 0 and self.tp_rank == 0: + group_name = f"sync_model_pp_stage_{self.plugin.stage_manager.stage}" + cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name=group_name) def state_dict(self) -> Dict[str, torch.Tensor]: raise NotImplementedError @@ -140,12 +141,13 @@ class BaseConsumer: print(f"Saved model checkpoint at step {step + 1} in folder {save_path}") if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1: - print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") torch.cuda.empty_cache() state_dict = self.state_dict() - if self.rank == 0: + if self.dp_rank == 0 and self.tp_rank == 0: + group_name = f"sync_model_pp_stage_{self.plugin.stage_manager.stage}" + print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") ray_broadcast_tensor_dict( - state_dict, src=self.num_producers, device=self.device, group_name="sync_model" + state_dict, src=self.num_producers, device=self.device, group_name=group_name ) del state_dict torch.cuda.empty_cache() @@ -191,6 +193,9 @@ class SimpleConsumer(BaseConsumer): self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3) self.accum_loss = torch.zeros(1, device=self.device) + def get_plugin(self): + return self.plugin + def setup(self): super().setup() self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 877ff98ec..eb8e24c46 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -75,6 +75,7 @@ class GRPOConsumer(BaseConsumer): ) path = model_config.pop("path") self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) + self.orig_state_dict_key = [k for k in self.policy_model.state_dict()] self.policy_model.train() self.policy_model.gradient_checkpointing_enable() self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6)) @@ -150,6 +151,22 @@ class GRPOConsumer(BaseConsumer): eta_min=0.1 * grpo_config.get("lr", 1e-6), ) + def get_device_mesh_mapping(self): + return { + "rank": self.rank, + "tp_rank": self.tp_rank, + "tp_size": self.tp_size, + "dp_size": self.dp_size, + "dp_rank": self.dp_rank, + "pp_size": self.booster.plugin.stage_manager.num_stages, + "pp_stage": self.booster.plugin.stage_manager.stage, + "is_last_stage": self.booster.plugin.stage_manager.is_last_stage(), + "world_size": self.world_size, + } + + def get_model_state_dict_keys(self): + return self.orig_state_dict_key + def setup(self): super().setup() if self.use_wandb and ( diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 4e100ebd1..a351dfc64 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -66,7 +66,7 @@ def launch_distributed( num_update_per_episode = num_samples // global_inference_batch_size num_recv_per_update = inference_batch_size // inference_microbatch_size - procs = [] + producer_procs = [] for i in range(num_producers): producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote( producer_idx=i, @@ -78,18 +78,20 @@ def launch_distributed( dataloaders_config=dataloaders_config, model_config=inference_model_config, generate_config=generate_config, + consumer_plugin_config=plugin_config, tokenizer_config=tokenizer_config, microbatch_size=inference_microbatch_size, backend=inference_backend, num_generations=num_generations, ) - procs.append(producer) + producer_procs.append(producer) generate_config_consumer = copy.deepcopy(generate_config) generate_config_consumer.update( dict( backend=inference_backend, ) ) + consumer_procs = [] for i in range(num_consumer_procs): consumer = core_consumer.options(num_gpus=1).remote( num_producers=num_producers, @@ -111,6 +113,14 @@ def launch_distributed( save_interval=save_interval, save_dir=save_dir, ) - procs.append(consumer) - ray.get([p.setup.remote() for p in procs]) + consumer_procs.append(consumer) + # setup the consumer procs first + ray.get([p.setup.remote() for p in consumer_procs]) + # get the device mesh mapping from consumer procs + consumer_device_mesh_mapping = ray.get([p.get_device_mesh_mapping.remote() for p in consumer_procs]) + model_state_dict_keys = ray.get(consumer_procs[0].get_model_state_dict_keys.remote()) + # setup the producer procs + ray.get([p.setup.remote(consumer_device_mesh_mapping, model_state_dict_keys) for p in producer_procs]) + # loop + procs = producer_procs + consumer_procs ray.get([p.loop.remote() for p in procs]) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index fcb1a184a..e9aff6f7b 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import ray import ray.util.collective as cc @@ -26,6 +26,7 @@ class BaseProducer: dataloaders_config: Dict[str, Any], model_config: Dict[str, Any], generate_config: Dict[str, Any], + consumer_plugin_config: Dict[str, Any] = None, tokenizer_config: Optional[Dict[str, Any]] = None, microbatch_size: int = 1, backend: str = "transformers", @@ -43,6 +44,7 @@ class BaseProducer: self.model_config = model_config self.generate_config = generate_config self.tokenizer_config = tokenizer_config + self.consumer_plugin_config = consumer_plugin_config # init tokenizer if tokenizer_config is None: @@ -78,9 +80,18 @@ class BaseProducer: else: raise ValueError(f"Unexpected backend {backend}") - def setup(self) -> None: + def setup(self, consumer_device_mesh_mapping: Dict[str, Any] = None, model_state_dict_keys: List = None) -> None: + self.consumer_device_mesh_mapping = consumer_device_mesh_mapping + self.model_state_dict_keys = model_state_dict_keys cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}") - cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model") + # cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model_pp_stage_0") + for i in range(self.num_consumer_procs): + device_mesh_mapping = self.consumer_device_mesh_mapping[i] + device_mesh_mapping["rank"] + # TODO: support ep, sp + if device_mesh_mapping["dp_rank"] == 0 and device_mesh_mapping["tp_rank"] == 0: + group_name = f"sync_model_pp_stage_{device_mesh_mapping['pp_stage']}" + cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=group_name) def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: raise NotImplementedError @@ -125,14 +136,27 @@ class BaseProducer: ): self.model.llm.sleep() # revict KV_cache to avoid OOM # don't sync model for last iteration + torch.cuda.empty_cache() + state_dict = {} print( f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}" ) - torch.cuda.empty_cache() + for consumer_rank_id in range(self.num_consumer_procs): + device_mesh_mapping = self.consumer_device_mesh_mapping[consumer_rank_id] + device_mesh_mapping["rank"] + # TODO: support ep, sp + if device_mesh_mapping["dp_rank"] == 0 and device_mesh_mapping["tp_rank"] == 0: + group_name = f"sync_model_pp_stage_{device_mesh_mapping['pp_stage']}" + state_dict.update( + ray_broadcast_tensor_dict( + None, src=self.num_producers, device=self.device, group_name=group_name + ) + ) + # check model sync integrity + assert len(state_dict) == len( + self.model_state_dict_keys + ), f"state dict keys has {len(state_dict)} unique keys not equal original model with {len(self.model_state_dict_keys)} keys. Missing keys: {set(self.model_state_dict_keys)-set(state_dict.keys())}. Please kindly inform the developer." - state_dict = ray_broadcast_tensor_dict( - None, self.num_producers, device=self.device, group_name="sync_model" - ) self.load_state_dict(state_dict) del state_dict torch.cuda.empty_cache() @@ -166,6 +190,7 @@ class SimpleProducer(BaseProducer): dataloaders_config, model_config, generate_config, + consumer_plugin_config=None, tokenizer_config=None, microbatch_size=1, backend="transformers", @@ -181,6 +206,7 @@ class SimpleProducer(BaseProducer): dataloaders_config, model_config, generate_config, + consumer_plugin_config, tokenizer_config, microbatch_size, backend,