Merge pull request #6378 from hpcaitech/grpo-latest-rebase-fix-resume

[feat] fix resume training
This commit is contained in:
YeAnbang 2025-08-18 17:09:53 +08:00 committed by GitHub
commit 4ac2227488
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 98 additions and 10 deletions

View File

@ -1,4 +1,3 @@
import os
from contextlib import nullcontext
from typing import Any, Dict, Optional
@ -7,11 +6,13 @@ import ray.util.collective as cc
import torch
import torch.distributed as dist
from coati.distributed.profiling_utils import CustomProfiler
from coati.utils import save_checkpoint
from tqdm import tqdm
from transformers import AutoModelForCausalLM
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.initialize import launch
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
@ -55,6 +56,7 @@ class BaseConsumer:
self.enable_profiling = enable_profiling
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
self.num_microbatches = batch_size // minibatch_size
self.checkpoint_path = model_config.pop("checkpoint_path", None)
self.model_config = model_config
self.plugin_config = plugin_config
@ -62,9 +64,11 @@ class BaseConsumer:
self.device = get_current_device()
self.lr_scheduler = None
self.n_behind = n_behind
self.total_prompt_trained = 0 # for setting start index when resume training
def setup(self) -> None:
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
self.coordinator = DistCoordinator()
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
if (
@ -143,6 +147,26 @@ class BaseConsumer:
return effective_group_to_raw_group_mapping
def loop(self) -> None:
self.profiler.enter("sync_model")
torch.cuda.empty_cache()
state_dict = self.state_dict()
if self.pp_size > 1:
if self.tp_rank == 0 and self.dp_rank == 0:
ray_broadcast_tensor_dict(
state_dict,
src=self.num_producers,
device=self.device,
group_name=f"sync_model_{self.pp_rank}",
)
else:
if self.rank == 0:
ray_broadcast_tensor_dict(
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
)
del state_dict
torch.cuda.empty_cache()
self.profiler.exit("sync_model")
print(
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
)
@ -208,6 +232,7 @@ class BaseConsumer:
for k, v in raw_batch.items()
}
# [batch_size, num_generations] -> [batch_size]
self.total_prompt_trained += raw_batch["reward"].size(0)
reward = raw_batch["reward"][:, :, 0]
format_acc = raw_batch["format_acc"][:, :, 0]
ans_acc = raw_batch["ans_acc"][:, :, 0]
@ -285,10 +310,19 @@ class BaseConsumer:
if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode:
if self.rank == 0:
print(f"Start saving policy model at step {step + 1}.")
save_path = os.path.join(self.save_dir, f"modeling-episode-{episode}-step-{step + 1}")
self.booster.save_model(self.policy_model, save_path, shard=True)
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.policy_model,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
epoch=episode,
step=step,
batch_size=int(self.total_prompt_trained / step),
coordinator=self.coordinator,
) # for setting start index when resuming training
if self.rank == 0:
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
print(f"Saved model checkpoint at step {step + 1} in folder {self.save_dir}")
if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and (
episode != 0 or step >= self.n_behind

View File

@ -8,6 +8,7 @@ from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss
from coati.distributed.utils import entropy_from_logits, memory_efficient_logprob
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
from coati.utils import load_checkpoint
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@ -157,6 +158,14 @@ class GRPOConsumer(BaseConsumer):
)
if self.policy_loss_fn.beta > 0:
self.reference_model, *_ = self.booster.boost(self.reference_model)
if self.checkpoint_path is not None:
load_checkpoint(
self.checkpoint_path,
self.booster,
self.policy_model,
self.optimizer,
self.lr_scheduler,
)
self.plugin.logger.set_level("ERROR")
def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:

View File

@ -8,10 +8,12 @@ import ray.util.collective as cc
import torch
import tqdm
import wandb
from coati.dataset import StatefulDistributedSampler
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
from coati.distributed.profiling_utils import CustomProfiler
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.utils import load_checkpoint
from ray.util.collective import allreduce
from ray.util.collective.types import Backend, ReduceOp
from torch.utils.data import DataLoader, DistributedSampler
@ -68,6 +70,7 @@ class BaseProducer:
self.profiler = CustomProfiler(f"P{self.producer_idx}", disabled=not enable_profiling)
self.train_dataset_config = train_dataset_config
self.checkpoint_path = model_config.pop("checkpoint_path", None)
self.model_config = model_config
self.generate_config = generate_config
self.tokenizer_config = tokenizer_config
@ -121,7 +124,7 @@ class BaseProducer:
self.train_dataloader = DataLoader(
self.train_dataset,
batch_size=microbatch_size,
sampler=DistributedSampler(
sampler=StatefulDistributedSampler(
self.train_dataset,
num_replicas=num_producers,
rank=producer_idx,
@ -133,6 +136,13 @@ class BaseProducer:
drop_last=True,
collate_fn=collate_fn_grpo,
)
if self.checkpoint_path is not None:
# resume training from checkpoint
start_epoch, start_step, sampler_start_idx = load_checkpoint(self.checkpoint_path, None, None, None, None)
self.train_dataloader.sampler.set_start_index(sampler_start_idx)
print(
f"[P{self.producer_idx}] Resume training from checkpoint {self.checkpoint_path}, start epoch {start_epoch}, start step {start_step}, sampler start index {sampler_start_idx}"
)
if grpo_config["reward_fn_type"] == "think_answer_tags":
self.evaluation_function = math_reward_fn
elif grpo_config["reward_fn_type"] == "boxed":
@ -203,6 +213,29 @@ class BaseProducer:
raise NotImplementedError
def loop(self) -> None:
torch.cuda.empty_cache()
self.profiler.enter("sync_model")
if self.consumer_pp_size > 1:
for pp_idx in range(self.consumer_pp_size):
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
)
if "consumer_global_step" in state_dict:
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict)
else:
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name="sync_model"
)
if "consumer_global_step" in state_dict:
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict)
self.profiler.exit("sync_model")
print(f"[P{self.producer_idx}] Sync initial model done.")
del state_dict
torch.cuda.empty_cache()
num_update_per_episode = len(self.train_dataloader) // self.num_microbatches
num_valid_microbatches = num_update_per_episode * self.num_microbatches

View File

@ -81,9 +81,12 @@ def load_checkpoint(
"""
# Update booster params states.
booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling"))
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
if model is not None:
booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling"))
if optimizer is not None:
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
if lr_scheduler is not None:
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
return (

View File

@ -18,6 +18,13 @@ os.environ["no_proxy"] = "127.0.0.1,localhost"
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
parser.add_argument(
"-cp",
"--checkpoint-path",
type=str,
default=None,
help="Path to the checkpoint to load the model from. If not provided, the model will be loaded from the model path.",
)
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
parser.add_argument(
"-ed",
@ -226,8 +233,10 @@ if __name__ == "__main__":
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock
inference_model_config = dict(path=args.model)
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
inference_model_config = dict(path=args.model, checkpoint_path=args.checkpoint_path)
train_model_config = dict(
path=args.model, use_flash_attention_2=True, use_cache=False, checkpoint_path=args.checkpoint_path
)
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
if args.backend == "transformers":