mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-28 12:43:17 +00:00
Merge pull request #6378 from hpcaitech/grpo-latest-rebase-fix-resume
[feat] fix resume training
This commit is contained in:
commit
4ac2227488
@ -1,4 +1,3 @@
|
|||||||
import os
|
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
@ -7,11 +6,13 @@ import ray.util.collective as cc
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from coati.distributed.profiling_utils import CustomProfiler
|
from coati.distributed.profiling_utils import CustomProfiler
|
||||||
|
from coati.utils import save_checkpoint
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin import HybridParallelPlugin
|
from colossalai.booster.plugin import HybridParallelPlugin
|
||||||
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
@ -55,6 +56,7 @@ class BaseConsumer:
|
|||||||
self.enable_profiling = enable_profiling
|
self.enable_profiling = enable_profiling
|
||||||
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
|
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
|
||||||
self.num_microbatches = batch_size // minibatch_size
|
self.num_microbatches = batch_size // minibatch_size
|
||||||
|
self.checkpoint_path = model_config.pop("checkpoint_path", None)
|
||||||
|
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.plugin_config = plugin_config
|
self.plugin_config = plugin_config
|
||||||
@ -62,9 +64,11 @@ class BaseConsumer:
|
|||||||
self.device = get_current_device()
|
self.device = get_current_device()
|
||||||
self.lr_scheduler = None
|
self.lr_scheduler = None
|
||||||
self.n_behind = n_behind
|
self.n_behind = n_behind
|
||||||
|
self.total_prompt_trained = 0 # for setting start index when resume training
|
||||||
|
|
||||||
def setup(self) -> None:
|
def setup(self) -> None:
|
||||||
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
|
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
|
||||||
|
self.coordinator = DistCoordinator()
|
||||||
|
|
||||||
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
|
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
|
||||||
if (
|
if (
|
||||||
@ -143,6 +147,26 @@ class BaseConsumer:
|
|||||||
return effective_group_to_raw_group_mapping
|
return effective_group_to_raw_group_mapping
|
||||||
|
|
||||||
def loop(self) -> None:
|
def loop(self) -> None:
|
||||||
|
self.profiler.enter("sync_model")
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
state_dict = self.state_dict()
|
||||||
|
if self.pp_size > 1:
|
||||||
|
if self.tp_rank == 0 and self.dp_rank == 0:
|
||||||
|
ray_broadcast_tensor_dict(
|
||||||
|
state_dict,
|
||||||
|
src=self.num_producers,
|
||||||
|
device=self.device,
|
||||||
|
group_name=f"sync_model_{self.pp_rank}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.rank == 0:
|
||||||
|
ray_broadcast_tensor_dict(
|
||||||
|
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
|
||||||
|
)
|
||||||
|
del state_dict
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
self.profiler.exit("sync_model")
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
|
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
|
||||||
)
|
)
|
||||||
@ -208,6 +232,7 @@ class BaseConsumer:
|
|||||||
for k, v in raw_batch.items()
|
for k, v in raw_batch.items()
|
||||||
}
|
}
|
||||||
# [batch_size, num_generations] -> [batch_size]
|
# [batch_size, num_generations] -> [batch_size]
|
||||||
|
self.total_prompt_trained += raw_batch["reward"].size(0)
|
||||||
reward = raw_batch["reward"][:, :, 0]
|
reward = raw_batch["reward"][:, :, 0]
|
||||||
format_acc = raw_batch["format_acc"][:, :, 0]
|
format_acc = raw_batch["format_acc"][:, :, 0]
|
||||||
ans_acc = raw_batch["ans_acc"][:, :, 0]
|
ans_acc = raw_batch["ans_acc"][:, :, 0]
|
||||||
@ -285,10 +310,19 @@ class BaseConsumer:
|
|||||||
if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode:
|
if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode:
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
print(f"Start saving policy model at step {step + 1}.")
|
print(f"Start saving policy model at step {step + 1}.")
|
||||||
save_path = os.path.join(self.save_dir, f"modeling-episode-{episode}-step-{step + 1}")
|
save_checkpoint(
|
||||||
self.booster.save_model(self.policy_model, save_path, shard=True)
|
save_dir=self.save_dir,
|
||||||
|
booster=self.booster,
|
||||||
|
model=self.policy_model,
|
||||||
|
optimizer=self.optimizer,
|
||||||
|
lr_scheduler=self.lr_scheduler,
|
||||||
|
epoch=episode,
|
||||||
|
step=step,
|
||||||
|
batch_size=int(self.total_prompt_trained / step),
|
||||||
|
coordinator=self.coordinator,
|
||||||
|
) # for setting start index when resuming training
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
|
print(f"Saved model checkpoint at step {step + 1} in folder {self.save_dir}")
|
||||||
|
|
||||||
if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and (
|
if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and (
|
||||||
episode != 0 or step >= self.n_behind
|
episode != 0 or step >= self.n_behind
|
||||||
|
@ -8,6 +8,7 @@ from coati.distributed.consumer import BaseConsumer
|
|||||||
from coati.distributed.loss import PolicyLoss
|
from coati.distributed.loss import PolicyLoss
|
||||||
from coati.distributed.utils import entropy_from_logits, memory_efficient_logprob
|
from coati.distributed.utils import entropy_from_logits, memory_efficient_logprob
|
||||||
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
||||||
|
from coati.utils import load_checkpoint
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
@ -157,6 +158,14 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
)
|
)
|
||||||
if self.policy_loss_fn.beta > 0:
|
if self.policy_loss_fn.beta > 0:
|
||||||
self.reference_model, *_ = self.booster.boost(self.reference_model)
|
self.reference_model, *_ = self.booster.boost(self.reference_model)
|
||||||
|
if self.checkpoint_path is not None:
|
||||||
|
load_checkpoint(
|
||||||
|
self.checkpoint_path,
|
||||||
|
self.booster,
|
||||||
|
self.policy_model,
|
||||||
|
self.optimizer,
|
||||||
|
self.lr_scheduler,
|
||||||
|
)
|
||||||
self.plugin.logger.set_level("ERROR")
|
self.plugin.logger.set_level("ERROR")
|
||||||
|
|
||||||
def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
|
def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
|
||||||
|
@ -8,10 +8,12 @@ import ray.util.collective as cc
|
|||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
import wandb
|
import wandb
|
||||||
|
from coati.dataset import StatefulDistributedSampler
|
||||||
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
|
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
|
||||||
from coati.distributed.profiling_utils import CustomProfiler
|
from coati.distributed.profiling_utils import CustomProfiler
|
||||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
|
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
|
||||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||||
|
from coati.utils import load_checkpoint
|
||||||
from ray.util.collective import allreduce
|
from ray.util.collective import allreduce
|
||||||
from ray.util.collective.types import Backend, ReduceOp
|
from ray.util.collective.types import Backend, ReduceOp
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
@ -68,6 +70,7 @@ class BaseProducer:
|
|||||||
self.profiler = CustomProfiler(f"P{self.producer_idx}", disabled=not enable_profiling)
|
self.profiler = CustomProfiler(f"P{self.producer_idx}", disabled=not enable_profiling)
|
||||||
|
|
||||||
self.train_dataset_config = train_dataset_config
|
self.train_dataset_config = train_dataset_config
|
||||||
|
self.checkpoint_path = model_config.pop("checkpoint_path", None)
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.generate_config = generate_config
|
self.generate_config = generate_config
|
||||||
self.tokenizer_config = tokenizer_config
|
self.tokenizer_config = tokenizer_config
|
||||||
@ -121,7 +124,7 @@ class BaseProducer:
|
|||||||
self.train_dataloader = DataLoader(
|
self.train_dataloader = DataLoader(
|
||||||
self.train_dataset,
|
self.train_dataset,
|
||||||
batch_size=microbatch_size,
|
batch_size=microbatch_size,
|
||||||
sampler=DistributedSampler(
|
sampler=StatefulDistributedSampler(
|
||||||
self.train_dataset,
|
self.train_dataset,
|
||||||
num_replicas=num_producers,
|
num_replicas=num_producers,
|
||||||
rank=producer_idx,
|
rank=producer_idx,
|
||||||
@ -133,6 +136,13 @@ class BaseProducer:
|
|||||||
drop_last=True,
|
drop_last=True,
|
||||||
collate_fn=collate_fn_grpo,
|
collate_fn=collate_fn_grpo,
|
||||||
)
|
)
|
||||||
|
if self.checkpoint_path is not None:
|
||||||
|
# resume training from checkpoint
|
||||||
|
start_epoch, start_step, sampler_start_idx = load_checkpoint(self.checkpoint_path, None, None, None, None)
|
||||||
|
self.train_dataloader.sampler.set_start_index(sampler_start_idx)
|
||||||
|
print(
|
||||||
|
f"[P{self.producer_idx}] Resume training from checkpoint {self.checkpoint_path}, start epoch {start_epoch}, start step {start_step}, sampler start index {sampler_start_idx}"
|
||||||
|
)
|
||||||
if grpo_config["reward_fn_type"] == "think_answer_tags":
|
if grpo_config["reward_fn_type"] == "think_answer_tags":
|
||||||
self.evaluation_function = math_reward_fn
|
self.evaluation_function = math_reward_fn
|
||||||
elif grpo_config["reward_fn_type"] == "boxed":
|
elif grpo_config["reward_fn_type"] == "boxed":
|
||||||
@ -203,6 +213,29 @@ class BaseProducer:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def loop(self) -> None:
|
def loop(self) -> None:
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
self.profiler.enter("sync_model")
|
||||||
|
if self.consumer_pp_size > 1:
|
||||||
|
for pp_idx in range(self.consumer_pp_size):
|
||||||
|
state_dict = ray_broadcast_tensor_dict(
|
||||||
|
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
|
||||||
|
)
|
||||||
|
if "consumer_global_step" in state_dict:
|
||||||
|
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
|
||||||
|
self.load_state_dict(state_dict)
|
||||||
|
else:
|
||||||
|
state_dict = ray_broadcast_tensor_dict(
|
||||||
|
None, self.num_producers, device=self.device, group_name="sync_model"
|
||||||
|
)
|
||||||
|
if "consumer_global_step" in state_dict:
|
||||||
|
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
|
||||||
|
self.load_state_dict(state_dict)
|
||||||
|
self.profiler.exit("sync_model")
|
||||||
|
print(f"[P{self.producer_idx}] Sync initial model done.")
|
||||||
|
del state_dict
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
num_update_per_episode = len(self.train_dataloader) // self.num_microbatches
|
num_update_per_episode = len(self.train_dataloader) // self.num_microbatches
|
||||||
num_valid_microbatches = num_update_per_episode * self.num_microbatches
|
num_valid_microbatches = num_update_per_episode * self.num_microbatches
|
||||||
|
|
||||||
|
@ -81,9 +81,12 @@ def load_checkpoint(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Update booster params states.
|
# Update booster params states.
|
||||||
booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling"))
|
if model is not None:
|
||||||
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
|
booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling"))
|
||||||
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
|
if optimizer is not None:
|
||||||
|
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
|
||||||
|
if lr_scheduler is not None:
|
||||||
|
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
|
||||||
|
|
||||||
running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
|
running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
|
||||||
return (
|
return (
|
||||||
|
@ -18,6 +18,13 @@ os.environ["no_proxy"] = "127.0.0.1,localhost"
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
||||||
|
parser.add_argument(
|
||||||
|
"-cp",
|
||||||
|
"--checkpoint-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to the checkpoint to load the model from. If not provided, the model will be loaded from the model path.",
|
||||||
|
)
|
||||||
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-ed",
|
"-ed",
|
||||||
@ -226,8 +233,10 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock
|
||||||
|
|
||||||
inference_model_config = dict(path=args.model)
|
inference_model_config = dict(path=args.model, checkpoint_path=args.checkpoint_path)
|
||||||
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
|
train_model_config = dict(
|
||||||
|
path=args.model, use_flash_attention_2=True, use_cache=False, checkpoint_path=args.checkpoint_path
|
||||||
|
)
|
||||||
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
|
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
|
||||||
|
|
||||||
if args.backend == "transformers":
|
if args.backend == "transformers":
|
||||||
|
Loading…
Reference in New Issue
Block a user