mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-27 20:22:09 +00:00
Merge pull request #6378 from hpcaitech/grpo-latest-rebase-fix-resume
[feat] fix resume training
This commit is contained in:
commit
4ac2227488
@ -1,4 +1,3 @@
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@ -7,11 +6,13 @@ import ray.util.collective as cc
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.distributed.profiling_utils import CustomProfiler
|
||||
from coati.utils import save_checkpoint
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import HybridParallelPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
@ -55,6 +56,7 @@ class BaseConsumer:
|
||||
self.enable_profiling = enable_profiling
|
||||
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
|
||||
self.num_microbatches = batch_size // minibatch_size
|
||||
self.checkpoint_path = model_config.pop("checkpoint_path", None)
|
||||
|
||||
self.model_config = model_config
|
||||
self.plugin_config = plugin_config
|
||||
@ -62,9 +64,11 @@ class BaseConsumer:
|
||||
self.device = get_current_device()
|
||||
self.lr_scheduler = None
|
||||
self.n_behind = n_behind
|
||||
self.total_prompt_trained = 0 # for setting start index when resume training
|
||||
|
||||
def setup(self) -> None:
|
||||
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
|
||||
self.coordinator = DistCoordinator()
|
||||
|
||||
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
|
||||
if (
|
||||
@ -143,6 +147,26 @@ class BaseConsumer:
|
||||
return effective_group_to_raw_group_mapping
|
||||
|
||||
def loop(self) -> None:
|
||||
self.profiler.enter("sync_model")
|
||||
torch.cuda.empty_cache()
|
||||
state_dict = self.state_dict()
|
||||
if self.pp_size > 1:
|
||||
if self.tp_rank == 0 and self.dp_rank == 0:
|
||||
ray_broadcast_tensor_dict(
|
||||
state_dict,
|
||||
src=self.num_producers,
|
||||
device=self.device,
|
||||
group_name=f"sync_model_{self.pp_rank}",
|
||||
)
|
||||
else:
|
||||
if self.rank == 0:
|
||||
ray_broadcast_tensor_dict(
|
||||
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
|
||||
)
|
||||
del state_dict
|
||||
torch.cuda.empty_cache()
|
||||
self.profiler.exit("sync_model")
|
||||
|
||||
print(
|
||||
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
|
||||
)
|
||||
@ -208,6 +232,7 @@ class BaseConsumer:
|
||||
for k, v in raw_batch.items()
|
||||
}
|
||||
# [batch_size, num_generations] -> [batch_size]
|
||||
self.total_prompt_trained += raw_batch["reward"].size(0)
|
||||
reward = raw_batch["reward"][:, :, 0]
|
||||
format_acc = raw_batch["format_acc"][:, :, 0]
|
||||
ans_acc = raw_batch["ans_acc"][:, :, 0]
|
||||
@ -285,10 +310,19 @@ class BaseConsumer:
|
||||
if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode:
|
||||
if self.rank == 0:
|
||||
print(f"Start saving policy model at step {step + 1}.")
|
||||
save_path = os.path.join(self.save_dir, f"modeling-episode-{episode}-step-{step + 1}")
|
||||
self.booster.save_model(self.policy_model, save_path, shard=True)
|
||||
save_checkpoint(
|
||||
save_dir=self.save_dir,
|
||||
booster=self.booster,
|
||||
model=self.policy_model,
|
||||
optimizer=self.optimizer,
|
||||
lr_scheduler=self.lr_scheduler,
|
||||
epoch=episode,
|
||||
step=step,
|
||||
batch_size=int(self.total_prompt_trained / step),
|
||||
coordinator=self.coordinator,
|
||||
) # for setting start index when resuming training
|
||||
if self.rank == 0:
|
||||
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
|
||||
print(f"Saved model checkpoint at step {step + 1} in folder {self.save_dir}")
|
||||
|
||||
if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and (
|
||||
episode != 0 or step >= self.n_behind
|
||||
|
@ -8,6 +8,7 @@ from coati.distributed.consumer import BaseConsumer
|
||||
from coati.distributed.loss import PolicyLoss
|
||||
from coati.distributed.utils import entropy_from_logits, memory_efficient_logprob
|
||||
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
||||
from coati.utils import load_checkpoint
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
@ -157,6 +158,14 @@ class GRPOConsumer(BaseConsumer):
|
||||
)
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
self.reference_model, *_ = self.booster.boost(self.reference_model)
|
||||
if self.checkpoint_path is not None:
|
||||
load_checkpoint(
|
||||
self.checkpoint_path,
|
||||
self.booster,
|
||||
self.policy_model,
|
||||
self.optimizer,
|
||||
self.lr_scheduler,
|
||||
)
|
||||
self.plugin.logger.set_level("ERROR")
|
||||
|
||||
def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
|
||||
|
@ -8,10 +8,12 @@ import ray.util.collective as cc
|
||||
import torch
|
||||
import tqdm
|
||||
import wandb
|
||||
from coati.dataset import StatefulDistributedSampler
|
||||
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
|
||||
from coati.distributed.profiling_utils import CustomProfiler
|
||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
|
||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||
from coati.utils import load_checkpoint
|
||||
from ray.util.collective import allreduce
|
||||
from ray.util.collective.types import Backend, ReduceOp
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
@ -68,6 +70,7 @@ class BaseProducer:
|
||||
self.profiler = CustomProfiler(f"P{self.producer_idx}", disabled=not enable_profiling)
|
||||
|
||||
self.train_dataset_config = train_dataset_config
|
||||
self.checkpoint_path = model_config.pop("checkpoint_path", None)
|
||||
self.model_config = model_config
|
||||
self.generate_config = generate_config
|
||||
self.tokenizer_config = tokenizer_config
|
||||
@ -121,7 +124,7 @@ class BaseProducer:
|
||||
self.train_dataloader = DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=microbatch_size,
|
||||
sampler=DistributedSampler(
|
||||
sampler=StatefulDistributedSampler(
|
||||
self.train_dataset,
|
||||
num_replicas=num_producers,
|
||||
rank=producer_idx,
|
||||
@ -133,6 +136,13 @@ class BaseProducer:
|
||||
drop_last=True,
|
||||
collate_fn=collate_fn_grpo,
|
||||
)
|
||||
if self.checkpoint_path is not None:
|
||||
# resume training from checkpoint
|
||||
start_epoch, start_step, sampler_start_idx = load_checkpoint(self.checkpoint_path, None, None, None, None)
|
||||
self.train_dataloader.sampler.set_start_index(sampler_start_idx)
|
||||
print(
|
||||
f"[P{self.producer_idx}] Resume training from checkpoint {self.checkpoint_path}, start epoch {start_epoch}, start step {start_step}, sampler start index {sampler_start_idx}"
|
||||
)
|
||||
if grpo_config["reward_fn_type"] == "think_answer_tags":
|
||||
self.evaluation_function = math_reward_fn
|
||||
elif grpo_config["reward_fn_type"] == "boxed":
|
||||
@ -203,6 +213,29 @@ class BaseProducer:
|
||||
raise NotImplementedError
|
||||
|
||||
def loop(self) -> None:
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
self.profiler.enter("sync_model")
|
||||
if self.consumer_pp_size > 1:
|
||||
for pp_idx in range(self.consumer_pp_size):
|
||||
state_dict = ray_broadcast_tensor_dict(
|
||||
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
|
||||
)
|
||||
if "consumer_global_step" in state_dict:
|
||||
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
|
||||
self.load_state_dict(state_dict)
|
||||
else:
|
||||
state_dict = ray_broadcast_tensor_dict(
|
||||
None, self.num_producers, device=self.device, group_name="sync_model"
|
||||
)
|
||||
if "consumer_global_step" in state_dict:
|
||||
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
|
||||
self.load_state_dict(state_dict)
|
||||
self.profiler.exit("sync_model")
|
||||
print(f"[P{self.producer_idx}] Sync initial model done.")
|
||||
del state_dict
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
num_update_per_episode = len(self.train_dataloader) // self.num_microbatches
|
||||
num_valid_microbatches = num_update_per_episode * self.num_microbatches
|
||||
|
||||
|
@ -81,9 +81,12 @@ def load_checkpoint(
|
||||
"""
|
||||
|
||||
# Update booster params states.
|
||||
booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling"))
|
||||
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
|
||||
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
|
||||
if model is not None:
|
||||
booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling"))
|
||||
if optimizer is not None:
|
||||
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
|
||||
if lr_scheduler is not None:
|
||||
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
|
||||
|
||||
running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
|
||||
return (
|
||||
|
@ -18,6 +18,13 @@ os.environ["no_proxy"] = "127.0.0.1,localhost"
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
||||
parser.add_argument(
|
||||
"-cp",
|
||||
"--checkpoint-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the checkpoint to load the model from. If not provided, the model will be loaded from the model path.",
|
||||
)
|
||||
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
||||
parser.add_argument(
|
||||
"-ed",
|
||||
@ -226,8 +233,10 @@ if __name__ == "__main__":
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock
|
||||
|
||||
inference_model_config = dict(path=args.model)
|
||||
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
|
||||
inference_model_config = dict(path=args.model, checkpoint_path=args.checkpoint_path)
|
||||
train_model_config = dict(
|
||||
path=args.model, use_flash_attention_2=True, use_cache=False, checkpoint_path=args.checkpoint_path
|
||||
)
|
||||
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
|
||||
|
||||
if args.backend == "transformers":
|
||||
|
Loading…
Reference in New Issue
Block a user