Merge pull request #6378 from hpcaitech/grpo-latest-rebase-fix-resume

[feat] fix resume training
This commit is contained in:
YeAnbang 2025-08-18 17:09:53 +08:00 committed by GitHub
commit 4ac2227488
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 98 additions and 10 deletions

View File

@ -1,4 +1,3 @@
import os
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
@ -7,11 +6,13 @@ import ray.util.collective as cc
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from coati.distributed.profiling_utils import CustomProfiler from coati.distributed.profiling_utils import CustomProfiler
from coati.utils import save_checkpoint
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -55,6 +56,7 @@ class BaseConsumer:
self.enable_profiling = enable_profiling self.enable_profiling = enable_profiling
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size" assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
self.num_microbatches = batch_size // minibatch_size self.num_microbatches = batch_size // minibatch_size
self.checkpoint_path = model_config.pop("checkpoint_path", None)
self.model_config = model_config self.model_config = model_config
self.plugin_config = plugin_config self.plugin_config = plugin_config
@ -62,9 +64,11 @@ class BaseConsumer:
self.device = get_current_device() self.device = get_current_device()
self.lr_scheduler = None self.lr_scheduler = None
self.n_behind = n_behind self.n_behind = n_behind
self.total_prompt_trained = 0 # for setting start index when resume training
def setup(self) -> None: def setup(self) -> None:
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
self.coordinator = DistCoordinator()
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2) plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
if ( if (
@ -143,6 +147,26 @@ class BaseConsumer:
return effective_group_to_raw_group_mapping return effective_group_to_raw_group_mapping
def loop(self) -> None: def loop(self) -> None:
self.profiler.enter("sync_model")
torch.cuda.empty_cache()
state_dict = self.state_dict()
if self.pp_size > 1:
if self.tp_rank == 0 and self.dp_rank == 0:
ray_broadcast_tensor_dict(
state_dict,
src=self.num_producers,
device=self.device,
group_name=f"sync_model_{self.pp_rank}",
)
else:
if self.rank == 0:
ray_broadcast_tensor_dict(
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
)
del state_dict
torch.cuda.empty_cache()
self.profiler.exit("sync_model")
print( print(
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}" f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
) )
@ -208,6 +232,7 @@ class BaseConsumer:
for k, v in raw_batch.items() for k, v in raw_batch.items()
} }
# [batch_size, num_generations] -> [batch_size] # [batch_size, num_generations] -> [batch_size]
self.total_prompt_trained += raw_batch["reward"].size(0)
reward = raw_batch["reward"][:, :, 0] reward = raw_batch["reward"][:, :, 0]
format_acc = raw_batch["format_acc"][:, :, 0] format_acc = raw_batch["format_acc"][:, :, 0]
ans_acc = raw_batch["ans_acc"][:, :, 0] ans_acc = raw_batch["ans_acc"][:, :, 0]
@ -285,10 +310,19 @@ class BaseConsumer:
if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode: if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode:
if self.rank == 0: if self.rank == 0:
print(f"Start saving policy model at step {step + 1}.") print(f"Start saving policy model at step {step + 1}.")
save_path = os.path.join(self.save_dir, f"modeling-episode-{episode}-step-{step + 1}") save_checkpoint(
self.booster.save_model(self.policy_model, save_path, shard=True) save_dir=self.save_dir,
booster=self.booster,
model=self.policy_model,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
epoch=episode,
step=step,
batch_size=int(self.total_prompt_trained / step),
coordinator=self.coordinator,
) # for setting start index when resuming training
if self.rank == 0: if self.rank == 0:
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}") print(f"Saved model checkpoint at step {step + 1} in folder {self.save_dir}")
if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and ( if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and (
episode != 0 or step >= self.n_behind episode != 0 or step >= self.n_behind

View File

@ -8,6 +8,7 @@ from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss from coati.distributed.loss import PolicyLoss
from coati.distributed.utils import entropy_from_logits, memory_efficient_logprob from coati.distributed.utils import entropy_from_logits, memory_efficient_logprob
from coati.trainer.utils import all_reduce_mean, all_reduce_sum from coati.trainer.utils import all_reduce_mean, all_reduce_sum
from coati.utils import load_checkpoint
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@ -157,6 +158,14 @@ class GRPOConsumer(BaseConsumer):
) )
if self.policy_loss_fn.beta > 0: if self.policy_loss_fn.beta > 0:
self.reference_model, *_ = self.booster.boost(self.reference_model) self.reference_model, *_ = self.booster.boost(self.reference_model)
if self.checkpoint_path is not None:
load_checkpoint(
self.checkpoint_path,
self.booster,
self.policy_model,
self.optimizer,
self.lr_scheduler,
)
self.plugin.logger.set_level("ERROR") self.plugin.logger.set_level("ERROR")
def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]: def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:

View File

@ -8,10 +8,12 @@ import ray.util.collective as cc
import torch import torch
import tqdm import tqdm
import wandb import wandb
from coati.dataset import StatefulDistributedSampler
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
from coati.distributed.profiling_utils import CustomProfiler from coati.distributed.profiling_utils import CustomProfiler
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.utils import load_checkpoint
from ray.util.collective import allreduce from ray.util.collective import allreduce
from ray.util.collective.types import Backend, ReduceOp from ray.util.collective.types import Backend, ReduceOp
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
@ -68,6 +70,7 @@ class BaseProducer:
self.profiler = CustomProfiler(f"P{self.producer_idx}", disabled=not enable_profiling) self.profiler = CustomProfiler(f"P{self.producer_idx}", disabled=not enable_profiling)
self.train_dataset_config = train_dataset_config self.train_dataset_config = train_dataset_config
self.checkpoint_path = model_config.pop("checkpoint_path", None)
self.model_config = model_config self.model_config = model_config
self.generate_config = generate_config self.generate_config = generate_config
self.tokenizer_config = tokenizer_config self.tokenizer_config = tokenizer_config
@ -121,7 +124,7 @@ class BaseProducer:
self.train_dataloader = DataLoader( self.train_dataloader = DataLoader(
self.train_dataset, self.train_dataset,
batch_size=microbatch_size, batch_size=microbatch_size,
sampler=DistributedSampler( sampler=StatefulDistributedSampler(
self.train_dataset, self.train_dataset,
num_replicas=num_producers, num_replicas=num_producers,
rank=producer_idx, rank=producer_idx,
@ -133,6 +136,13 @@ class BaseProducer:
drop_last=True, drop_last=True,
collate_fn=collate_fn_grpo, collate_fn=collate_fn_grpo,
) )
if self.checkpoint_path is not None:
# resume training from checkpoint
start_epoch, start_step, sampler_start_idx = load_checkpoint(self.checkpoint_path, None, None, None, None)
self.train_dataloader.sampler.set_start_index(sampler_start_idx)
print(
f"[P{self.producer_idx}] Resume training from checkpoint {self.checkpoint_path}, start epoch {start_epoch}, start step {start_step}, sampler start index {sampler_start_idx}"
)
if grpo_config["reward_fn_type"] == "think_answer_tags": if grpo_config["reward_fn_type"] == "think_answer_tags":
self.evaluation_function = math_reward_fn self.evaluation_function = math_reward_fn
elif grpo_config["reward_fn_type"] == "boxed": elif grpo_config["reward_fn_type"] == "boxed":
@ -203,6 +213,29 @@ class BaseProducer:
raise NotImplementedError raise NotImplementedError
def loop(self) -> None: def loop(self) -> None:
torch.cuda.empty_cache()
self.profiler.enter("sync_model")
if self.consumer_pp_size > 1:
for pp_idx in range(self.consumer_pp_size):
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
)
if "consumer_global_step" in state_dict:
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict)
else:
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name="sync_model"
)
if "consumer_global_step" in state_dict:
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict)
self.profiler.exit("sync_model")
print(f"[P{self.producer_idx}] Sync initial model done.")
del state_dict
torch.cuda.empty_cache()
num_update_per_episode = len(self.train_dataloader) // self.num_microbatches num_update_per_episode = len(self.train_dataloader) // self.num_microbatches
num_valid_microbatches = num_update_per_episode * self.num_microbatches num_valid_microbatches = num_update_per_episode * self.num_microbatches

View File

@ -81,8 +81,11 @@ def load_checkpoint(
""" """
# Update booster params states. # Update booster params states.
if model is not None:
booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling")) booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling"))
if optimizer is not None:
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer")) booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
if lr_scheduler is not None:
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler")) booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
running_states = load_json(file_path=os.path.join(load_dir, "running_states.json")) running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))

View File

@ -18,6 +18,13 @@ os.environ["no_proxy"] = "127.0.0.1,localhost"
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
parser.add_argument(
"-cp",
"--checkpoint-path",
type=str,
default=None,
help="Path to the checkpoint to load the model from. If not provided, the model will be loaded from the model path.",
)
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
parser.add_argument( parser.add_argument(
"-ed", "-ed",
@ -226,8 +233,10 @@ if __name__ == "__main__":
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock
inference_model_config = dict(path=args.model) inference_model_config = dict(path=args.model, checkpoint_path=args.checkpoint_path)
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False) train_model_config = dict(
path=args.model, use_flash_attention_2=True, use_cache=False, checkpoint_path=args.checkpoint_path
)
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature) generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
if args.backend == "transformers": if args.backend == "transformers":