diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index ba7d882c9..21da67161 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -1,4 +1,3 @@ -import os from contextlib import nullcontext from typing import Any, Dict, Optional @@ -7,11 +6,13 @@ import ray.util.collective as cc import torch import torch.distributed as dist from coati.distributed.profiling_utils import CustomProfiler +from coati.utils import save_checkpoint from tqdm import tqdm from transformers import AutoModelForCausalLM from colossalai.booster import Booster from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.cluster import DistCoordinator from colossalai.initialize import launch from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device @@ -55,6 +56,7 @@ class BaseConsumer: self.enable_profiling = enable_profiling assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size" self.num_microbatches = batch_size // minibatch_size + self.checkpoint_path = model_config.pop("checkpoint_path", None) self.model_config = model_config self.plugin_config = plugin_config @@ -62,9 +64,11 @@ class BaseConsumer: self.device = get_current_device() self.lr_scheduler = None self.n_behind = n_behind + self.total_prompt_trained = 0 # for setting start index when resume training def setup(self) -> None: launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) + self.coordinator = DistCoordinator() plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2) if ( @@ -143,6 +147,26 @@ class BaseConsumer: return effective_group_to_raw_group_mapping def loop(self) -> None: + self.profiler.enter("sync_model") + torch.cuda.empty_cache() + state_dict = self.state_dict() + if self.pp_size > 1: + if self.tp_rank == 0 and self.dp_rank == 0: + ray_broadcast_tensor_dict( + state_dict, + src=self.num_producers, + device=self.device, + group_name=f"sync_model_{self.pp_rank}", + ) + else: + if self.rank == 0: + ray_broadcast_tensor_dict( + state_dict, src=self.num_producers, device=self.device, group_name="sync_model" + ) + del state_dict + torch.cuda.empty_cache() + self.profiler.exit("sync_model") + print( f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}" ) @@ -208,6 +232,7 @@ class BaseConsumer: for k, v in raw_batch.items() } # [batch_size, num_generations] -> [batch_size] + self.total_prompt_trained += raw_batch["reward"].size(0) reward = raw_batch["reward"][:, :, 0] format_acc = raw_batch["format_acc"][:, :, 0] ans_acc = raw_batch["ans_acc"][:, :, 0] @@ -285,10 +310,19 @@ class BaseConsumer: if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode: if self.rank == 0: print(f"Start saving policy model at step {step + 1}.") - save_path = os.path.join(self.save_dir, f"modeling-episode-{episode}-step-{step + 1}") - self.booster.save_model(self.policy_model, save_path, shard=True) + save_checkpoint( + save_dir=self.save_dir, + booster=self.booster, + model=self.policy_model, + optimizer=self.optimizer, + lr_scheduler=self.lr_scheduler, + epoch=episode, + step=step, + batch_size=int(self.total_prompt_trained / step), + coordinator=self.coordinator, + ) # for setting start index when resuming training if self.rank == 0: - print(f"Saved model checkpoint at step {step + 1} in folder {save_path}") + print(f"Saved model checkpoint at step {step + 1} in folder {self.save_dir}") if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and ( episode != 0 or step >= self.n_behind diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 424d46098..ee72e0290 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -8,6 +8,7 @@ from coati.distributed.consumer import BaseConsumer from coati.distributed.loss import PolicyLoss from coati.distributed.utils import entropy_from_logits, memory_efficient_logprob from coati.trainer.utils import all_reduce_mean, all_reduce_sum +from coati.utils import load_checkpoint from transformers import AutoModelForCausalLM, AutoTokenizer from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR @@ -157,6 +158,14 @@ class GRPOConsumer(BaseConsumer): ) if self.policy_loss_fn.beta > 0: self.reference_model, *_ = self.booster.boost(self.reference_model) + if self.checkpoint_path is not None: + load_checkpoint( + self.checkpoint_path, + self.booster, + self.policy_model, + self.optimizer, + self.lr_scheduler, + ) self.plugin.logger.set_level("ERROR") def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]: diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 11fb5d3aa..fbec2319b 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -8,10 +8,12 @@ import ray.util.collective as cc import torch import tqdm import wandb +from coati.dataset import StatefulDistributedSampler from coati.dataset.loader import RawConversationDataset, collate_fn_grpo from coati.distributed.profiling_utils import CustomProfiler from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward +from coati.utils import load_checkpoint from ray.util.collective import allreduce from ray.util.collective.types import Backend, ReduceOp from torch.utils.data import DataLoader, DistributedSampler @@ -68,6 +70,7 @@ class BaseProducer: self.profiler = CustomProfiler(f"P{self.producer_idx}", disabled=not enable_profiling) self.train_dataset_config = train_dataset_config + self.checkpoint_path = model_config.pop("checkpoint_path", None) self.model_config = model_config self.generate_config = generate_config self.tokenizer_config = tokenizer_config @@ -121,7 +124,7 @@ class BaseProducer: self.train_dataloader = DataLoader( self.train_dataset, batch_size=microbatch_size, - sampler=DistributedSampler( + sampler=StatefulDistributedSampler( self.train_dataset, num_replicas=num_producers, rank=producer_idx, @@ -133,6 +136,13 @@ class BaseProducer: drop_last=True, collate_fn=collate_fn_grpo, ) + if self.checkpoint_path is not None: + # resume training from checkpoint + start_epoch, start_step, sampler_start_idx = load_checkpoint(self.checkpoint_path, None, None, None, None) + self.train_dataloader.sampler.set_start_index(sampler_start_idx) + print( + f"[P{self.producer_idx}] Resume training from checkpoint {self.checkpoint_path}, start epoch {start_epoch}, start step {start_step}, sampler start index {sampler_start_idx}" + ) if grpo_config["reward_fn_type"] == "think_answer_tags": self.evaluation_function = math_reward_fn elif grpo_config["reward_fn_type"] == "boxed": @@ -203,6 +213,29 @@ class BaseProducer: raise NotImplementedError def loop(self) -> None: + + torch.cuda.empty_cache() + self.profiler.enter("sync_model") + if self.consumer_pp_size > 1: + for pp_idx in range(self.consumer_pp_size): + state_dict = ray_broadcast_tensor_dict( + None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}" + ) + if "consumer_global_step" in state_dict: + self.consumer_global_step = state_dict.pop("consumer_global_step").item() + self.load_state_dict(state_dict) + else: + state_dict = ray_broadcast_tensor_dict( + None, self.num_producers, device=self.device, group_name="sync_model" + ) + if "consumer_global_step" in state_dict: + self.consumer_global_step = state_dict.pop("consumer_global_step").item() + self.load_state_dict(state_dict) + self.profiler.exit("sync_model") + print(f"[P{self.producer_idx}] Sync initial model done.") + del state_dict + torch.cuda.empty_cache() + num_update_per_episode = len(self.train_dataloader) // self.num_microbatches num_valid_microbatches = num_update_per_episode * self.num_microbatches diff --git a/applications/ColossalChat/coati/utils/ckpt_io.py b/applications/ColossalChat/coati/utils/ckpt_io.py index 5b804f0ac..9164c02d9 100755 --- a/applications/ColossalChat/coati/utils/ckpt_io.py +++ b/applications/ColossalChat/coati/utils/ckpt_io.py @@ -81,9 +81,12 @@ def load_checkpoint( """ # Update booster params states. - booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling")) - booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer")) - booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler")) + if model is not None: + booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling")) + if optimizer is not None: + booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer")) + if lr_scheduler is not None: + booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler")) running_states = load_json(file_path=os.path.join(load_dir, "running_states.json")) return ( diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 58148b67e..42ec582f6 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -18,6 +18,13 @@ os.environ["no_proxy"] = "127.0.0.1,localhost" if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") + parser.add_argument( + "-cp", + "--checkpoint-path", + type=str, + default=None, + help="Path to the checkpoint to load the model from. If not provided, the model will be loaded from the model path.", + ) parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") parser.add_argument( "-ed", @@ -226,8 +233,10 @@ if __name__ == "__main__": os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock - inference_model_config = dict(path=args.model) - train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False) + inference_model_config = dict(path=args.model, checkpoint_path=args.checkpoint_path) + train_model_config = dict( + path=args.model, use_flash_attention_2=True, use_cache=False, checkpoint_path=args.checkpoint_path + ) generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature) if args.backend == "transformers":