mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-04 14:38:10 +00:00
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686
.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* Support overall loss, update KTO logging
* [Docs] clarify launch port
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Hotfix] README link (#5966)
* update ignore
* update readme
* run style
* update readme
* [Hotfix] Avoid fused RMSnorm import error without apex (#5985)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Chat] fix readme (#5989)
* fix readme
* fix readme, tokenization fully tested
* fix readme, tokenization fully tested
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix sync condition (#6000)
* [plugin] add cast inputs option for zero (#6003)
* [pre-commit.ci] pre-commit autoupdate (#5995)
updates:
- [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](https://github.com/psf/black-pre-commit-mirror/compare/24.4.2...24.8.0)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [misc] Bypass the huggingface bug to solve the mask mismatch problem (#5991)
* [Feature] Zigzag Ring attention (#5905)
* halfway
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* unified cross entropy func for all shardformer models
* remove redundant lines
* add basic ring attn; debug cross entropy
* fwd bwd logic complete
* fwd bwd logic complete; add experimental triton rescale
* precision tests passed
* precision tests passed
* fix typos and remove misc files
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* add sp_mode to benchmark; fix varlen interface
* update softmax_lse shape by new interface
* change tester name
* remove buffer clone; support packed seq layout
* add varlen tests
* fix typo
* all tests passed
* add dkv_group; fix mask
* remove debug statements
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [misc] update compatibility (#6008)
* [misc] update compatibility
* [misc] update requirements
* [devops] disable requirements cache
* [test] fix torch ddp test
* [test] fix rerun on address in use
* [test] fix lazy init
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix the merge
* fix the merge
* overlap kv comm with output rescale (#6017)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* fix the merge
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix the merge
* fix
* fix
* fix the merge
* fix
* [misc] Use dist logger in plugins (#6011)
* use dist logger in plugins
* remove trash
* print on rank 0
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* fix
* fix
* fix
* fix
* fix the merge
* fix
* fix
* fix
* fix
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
412 lines
18 KiB
Python
Executable File
412 lines
18 KiB
Python
Executable File
"""
|
|
PPO trainer
|
|
"""
|
|
|
|
import os
|
|
from typing import Dict, List, Optional
|
|
|
|
import torch
|
|
import wandb
|
|
from coati.experience_buffer import NaiveExperienceBuffer
|
|
from coati.experience_maker import Experience, NaiveExperienceMaker
|
|
from coati.models import Critic, RewardModel
|
|
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
|
|
from coati.models.utils import calc_action_log_probs
|
|
from coati.trainer.callbacks import Callback
|
|
from coati.trainer.utils import all_reduce_mean
|
|
from coati.utils import AccumulativeMeanMeter, save_checkpoint
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import _LRScheduler
|
|
from torch.utils.data import DataLoader, DistributedSampler
|
|
from tqdm import tqdm
|
|
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
|
|
|
from colossalai.booster import Booster
|
|
from colossalai.booster.plugin import GeminiPlugin
|
|
from colossalai.cluster import DistCoordinator
|
|
from colossalai.utils import get_current_device
|
|
|
|
from .base import OLTrainer
|
|
from .utils import CycledDataLoader, is_rank_0, to_device
|
|
|
|
|
|
def _set_default_generate_kwargs(actor: PreTrainedModel) -> Dict:
|
|
"""
|
|
Set default keyword arguments for generation based on the actor model.
|
|
|
|
Args:
|
|
actor (PreTrainedModel): The actor model.
|
|
|
|
Returns:
|
|
Dict: A dictionary containing the default keyword arguments for generation.
|
|
"""
|
|
unwrapped_model = actor.unwrap()
|
|
new_kwargs = {}
|
|
# use huggingface models method directly
|
|
if hasattr(unwrapped_model, "prepare_inputs_for_generation"):
|
|
new_kwargs["prepare_inputs_fn"] = unwrapped_model.prepare_inputs_for_generation
|
|
|
|
if hasattr(unwrapped_model, "_update_model_kwargs_for_generation"):
|
|
new_kwargs["update_model_kwargs_fn"] = unwrapped_model._update_model_kwargs_for_generation
|
|
return new_kwargs
|
|
|
|
|
|
class PPOTrainer(OLTrainer):
|
|
"""
|
|
Trainer for PPO algorithm.
|
|
|
|
Args:
|
|
strategy (Booster): the strategy to use for training
|
|
actor (Actor): the actor model in ppo algorithm
|
|
critic (Critic): the critic model in ppo algorithm
|
|
reward_model (RewardModel): the reward model in rlhf algorithm to make reward of sentences
|
|
initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor
|
|
actor_optim (Optimizer): the optimizer to use for actor model
|
|
critic_optim (Optimizer): the optimizer to use for critic model
|
|
kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
|
|
train_batch_size (int, defaults to 8): the batch size to use for training
|
|
buffer_limit (int, defaults to 0): the max_size limitation of buffer
|
|
buffer_cpu_offload (bool, defaults to True): whether to offload buffer to cpu
|
|
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
|
|
vf_coef (float, defaults to 1.0): the coefficient of value loss
|
|
ptx_coef (float, defaults to 0.9): the coefficient of ptx loss
|
|
value_clip (float, defaults to 0.4): the clip coefficient of value loss
|
|
sample_buffer (bool, defaults to False): whether to sample from buffer
|
|
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
|
offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process
|
|
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
|
generate_kwargs (dict, optional): the kwargs to use while model generating
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
actor_booster: Booster,
|
|
critic_booster: Booster,
|
|
actor: PreTrainedModel,
|
|
critic: Critic,
|
|
reward_model: RewardModel,
|
|
initial_model: PreTrainedModel,
|
|
actor_optim: Optimizer,
|
|
critic_optim: Optimizer,
|
|
actor_lr_scheduler: _LRScheduler,
|
|
critic_lr_scheduler: _LRScheduler,
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
kl_coef: float = 0.1,
|
|
ptx_coef: float = 0.9,
|
|
train_batch_size: int = 8,
|
|
buffer_limit: int = 0,
|
|
buffer_cpu_offload: bool = True,
|
|
eps_clip: float = 0.2,
|
|
vf_coef: float = 1.0,
|
|
value_clip: float = 0.2,
|
|
sample_buffer: bool = False,
|
|
dataloader_pin_memory: bool = True,
|
|
offload_inference_models: bool = True,
|
|
apply_loss_mask: bool = True,
|
|
accumulation_steps: int = 1,
|
|
save_interval: int = 0,
|
|
save_dir: str = None,
|
|
use_tp: bool = False,
|
|
coordinator: DistCoordinator = None,
|
|
callbacks: List[Callback] = [],
|
|
**generate_kwargs,
|
|
) -> None:
|
|
if isinstance(actor_booster, GeminiPlugin):
|
|
assert not offload_inference_models, "GeminiPlugin is not compatible with manual model.to('cpu')"
|
|
|
|
data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
|
super().__init__(
|
|
actor_booster, critic_booster, data_buffer, sample_buffer, dataloader_pin_memory, callbacks=callbacks
|
|
)
|
|
self.generate_kwargs = _set_default_generate_kwargs(actor)
|
|
self.generate_kwargs.update(generate_kwargs)
|
|
|
|
self.actor = actor
|
|
self.critic = critic
|
|
self.actor_booster = actor_booster
|
|
self.critic_booster = critic_booster
|
|
self.actor_scheduler = actor_lr_scheduler
|
|
self.critic_scheduler = critic_lr_scheduler
|
|
self.tokenizer = tokenizer
|
|
self.experience_maker = NaiveExperienceMaker(
|
|
self.actor, self.critic, reward_model, initial_model, self.tokenizer, kl_coef
|
|
)
|
|
self.train_batch_size = train_batch_size
|
|
|
|
self.actor_loss_fn = PolicyLoss(eps_clip)
|
|
self.critic_loss_fn = ValueLoss(value_clip)
|
|
self.vf_coef = vf_coef
|
|
self.ptx_loss_fn = GPTLMLoss()
|
|
self.ptx_coef = ptx_coef
|
|
self.actor_optim = actor_optim
|
|
self.critic_optim = critic_optim
|
|
self.save_interval = save_interval
|
|
self.apply_loss_mask = apply_loss_mask
|
|
self.coordinator = coordinator
|
|
self.actor_save_dir = os.path.join(save_dir, "actor")
|
|
self.critic_save_dir = os.path.join(save_dir, "critic")
|
|
self.num_train_step = 0
|
|
self.accumulation_steps = accumulation_steps
|
|
self.use_tp = use_tp
|
|
self.accumulative_meter = AccumulativeMeanMeter()
|
|
self.offload_inference_models = offload_inference_models
|
|
self.device = get_current_device()
|
|
|
|
def _before_fit(
|
|
self,
|
|
prompt_dataloader: DataLoader,
|
|
pretrain_dataloader: Optional[DataLoader] = None,
|
|
log_dir: Optional[str] = None,
|
|
use_wandb: bool = False,
|
|
):
|
|
"""
|
|
Args:
|
|
prompt_dataloader (DataLoader): the dataloader to use for prompt data
|
|
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
|
|
"""
|
|
self.prompt_dataloader = CycledDataLoader(prompt_dataloader)
|
|
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader) if pretrain_dataloader is not None else None
|
|
|
|
self.writer = None
|
|
if use_wandb and is_rank_0():
|
|
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
|
|
import wandb
|
|
|
|
self.wandb_run = wandb.init(project="Coati-ppo", sync_tensorboard=True)
|
|
if log_dir is not None and is_rank_0():
|
|
import os
|
|
import time
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
log_dir = os.path.join(log_dir, "ppo")
|
|
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
|
|
self.writer = SummaryWriter(log_dir=log_dir)
|
|
|
|
def _setup_update_phrase_dataload(self):
|
|
"""
|
|
why not use distributed_dataloader?
|
|
if tp is used, input on each rank is the same and we use the same dataloader to feed same experience to all ranks
|
|
if tp is not used, input on each rank is different and we expect different experiences to be fed to each rank
|
|
"""
|
|
self.dataloader = DataLoader(
|
|
self.data_buffer,
|
|
batch_size=self.train_batch_size,
|
|
shuffle=True,
|
|
drop_last=True,
|
|
pin_memory=self.dataloader_pin_memory,
|
|
collate_fn=self.data_buffer.collate_fn,
|
|
)
|
|
|
|
def _make_experience(self, collect_step: int) -> Experience:
|
|
"""
|
|
Make experience
|
|
"""
|
|
prompts = self.prompt_dataloader.next()
|
|
if self.offload_inference_models:
|
|
# TODO(ver217): this may be controlled by strategy if they are prepared by strategy
|
|
self.experience_maker.initial_model.to(self.device)
|
|
self.experience_maker.reward_model.to(self.device)
|
|
return self.experience_maker.make_experience(
|
|
input_ids=prompts["input_ids"].to(get_current_device()),
|
|
attention_mask=prompts["attention_mask"].to(get_current_device()),
|
|
**self.generate_kwargs,
|
|
)
|
|
|
|
def _training_step(self, experience: Experience):
|
|
"""
|
|
Args:
|
|
experience:
|
|
sequences: [batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>
|
|
"""
|
|
self.num_train_step += 1
|
|
self.actor.train()
|
|
self.critic.train()
|
|
num_actions = experience.action_log_probs.size(1)
|
|
# policy loss
|
|
|
|
actor_logits = self.actor(input_ids=experience.sequences, attention_mask=experience.attention_mask)[
|
|
"logits"
|
|
] # [batch size, prompt_length + response_length]
|
|
action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)
|
|
|
|
actor_loss, to_skip, max_ratio = self.actor_loss_fn(
|
|
action_log_probs,
|
|
experience.action_log_probs,
|
|
experience.advantages,
|
|
action_mask=experience.action_mask if self.apply_loss_mask else None,
|
|
)
|
|
actor_loss = (1 - self.ptx_coef) * actor_loss
|
|
if not to_skip:
|
|
self.actor_booster.backward(loss=actor_loss, optimizer=self.actor_optim)
|
|
|
|
# ptx loss
|
|
if self.ptx_coef != 0:
|
|
batch = self.pretrain_dataloader.next()
|
|
batch = to_device(batch, self.device)
|
|
outputs = self.actor(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
|
|
ptx_loss = outputs.loss
|
|
ptx_loss = self.ptx_coef * ptx_loss
|
|
self.actor_booster.backward(loss=ptx_loss, optimizer=self.actor_optim)
|
|
|
|
# value loss
|
|
values = self.critic(
|
|
input_ids=experience.sequences, attention_mask=experience.attention_mask
|
|
) # [batch size, prompt_length + response_length]
|
|
critic_loss = self.critic_loss_fn(
|
|
values[:, -num_actions:],
|
|
experience.values,
|
|
experience.advantages,
|
|
action_mask=experience.action_mask if self.apply_loss_mask else None,
|
|
)
|
|
critic_loss = critic_loss * self.vf_coef
|
|
self.critic_booster.backward(loss=critic_loss, optimizer=self.critic_optim)
|
|
|
|
# sync
|
|
actor_loss_mean = all_reduce_mean(tensor=actor_loss)
|
|
critic_loss_mean = all_reduce_mean(tensor=critic_loss)
|
|
max_ratio_mean = all_reduce_mean(tensor=max_ratio)
|
|
reward_mean = all_reduce_mean(tensor=experience.reward.mean())
|
|
value_mean = all_reduce_mean(tensor=experience.values.mean())
|
|
advantages_mean = all_reduce_mean(tensor=experience.advantages.mean())
|
|
kl_mean = all_reduce_mean(tensor=experience.kl.mean())
|
|
if self.ptx_coef != 0:
|
|
ptx_loss_mean = all_reduce_mean(tensor=ptx_loss)
|
|
|
|
self.accumulative_meter.add("actor_loss", actor_loss_mean.to(torch.float16).mean().item())
|
|
self.accumulative_meter.add("critic_loss", critic_loss_mean.to(torch.float16).mean().item())
|
|
self.accumulative_meter.add("max_ratio", max_ratio_mean.to(torch.float16).item())
|
|
self.accumulative_meter.add("reward", reward_mean.to(torch.float16).mean().item())
|
|
self.accumulative_meter.add("value", value_mean.to(torch.float16).mean().item())
|
|
self.accumulative_meter.add("advantages", advantages_mean.to(torch.float16).item())
|
|
self.accumulative_meter.add("skip_ratio", 1.0 if to_skip else 0.0)
|
|
self.accumulative_meter.add("kl", kl_mean.to(torch.float16).item())
|
|
if self.ptx_coef != 0:
|
|
self.accumulative_meter.add("ptx_loss", ptx_loss_mean.to(torch.float16).mean().item())
|
|
|
|
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
|
|
self.actor_optim.step()
|
|
self.critic_optim.step()
|
|
self.actor_optim.zero_grad()
|
|
self.critic_optim.zero_grad()
|
|
self.actor_scheduler.step()
|
|
self.critic_scheduler.step()
|
|
|
|
# preparing logging model output and corresponding rewards.
|
|
if self.num_train_step % 10 == 1:
|
|
response_text = self.experience_maker.tokenizer.batch_decode(
|
|
experience.sequences, skip_special_tokens=True
|
|
)
|
|
for i in range(len(response_text)):
|
|
response_text[i] = response_text[i] + f"\n\nReward: {experience.reward[i]}"
|
|
|
|
if self.writer and is_rank_0() and "wandb_run" in self.__dict__:
|
|
# log output to wandb
|
|
my_table = wandb.Table(
|
|
columns=[f"sample response {i}" for i in range(len(response_text))], data=[response_text]
|
|
)
|
|
try:
|
|
self.wandb_run.log({"sample_response": my_table})
|
|
except OSError as e:
|
|
self.coordinator.print_on_master(e)
|
|
elif self.writer and is_rank_0():
|
|
for line in response_text:
|
|
self.coordinator.print_on_master(line)
|
|
|
|
if self.writer and is_rank_0():
|
|
self.writer.add_scalar("train/max_ratio", self.accumulative_meter.get("max_ratio"), self.num_train_step)
|
|
self.writer.add_scalar(
|
|
"train/skip_ratio", self.accumulative_meter.get("skip_ratio"), self.num_train_step
|
|
)
|
|
self.writer.add_scalar(
|
|
"train/actor_loss", self.accumulative_meter.get("actor_loss"), self.num_train_step
|
|
)
|
|
self.writer.add_scalar("train/lr_actor", self.actor_optim.param_groups[0]["lr"], self.num_train_step)
|
|
self.writer.add_scalar("train/lr_critic", self.critic_optim.param_groups[0]["lr"], self.num_train_step)
|
|
self.writer.add_scalar(
|
|
"train/critic_loss", self.accumulative_meter.get("critic_loss"), self.num_train_step
|
|
)
|
|
if self.ptx_coef != 0:
|
|
self.writer.add_scalar(
|
|
"train/ptx_loss", self.accumulative_meter.get("ptx_loss"), self.num_train_step
|
|
)
|
|
self.writer.add_scalar("reward", self.accumulative_meter.get("reward"), self.num_train_step)
|
|
self.writer.add_scalar("approx_kl", self.accumulative_meter.get("kl"), self.num_train_step)
|
|
self.writer.add_scalar("value", self.accumulative_meter.get("value"), self.num_train_step)
|
|
self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), self.num_train_step)
|
|
self.accumulative_meter.reset()
|
|
|
|
def _learn(self, update_step: int):
|
|
"""
|
|
Perform the learning step of the PPO algorithm.
|
|
|
|
Args:
|
|
update_step (int): The current update step.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
if self.offload_inference_models:
|
|
self.experience_maker.initial_model.to("cpu")
|
|
self.experience_maker.reward_model.to("cpu")
|
|
|
|
# buffer may be empty at first, we should rebuild at each training
|
|
if self.sample_buffer:
|
|
experience = self.data_buffer.sample()
|
|
self._on_learn_batch_start()
|
|
experience.to_device(self.device)
|
|
self._training_step(experience)
|
|
self._on_learn_batch_end(experience)
|
|
else:
|
|
if isinstance(self.dataloader.sampler, DistributedSampler):
|
|
self.dataloader.sampler.set_epoch(update_step)
|
|
pbar = tqdm(self.dataloader, desc=f"Train epoch [{update_step + 1}]", disable=not is_rank_0())
|
|
for experience in pbar:
|
|
self._on_learn_batch_start()
|
|
experience.to_device(self.device)
|
|
self._training_step(experience)
|
|
self._on_learn_batch_end(experience)
|
|
|
|
def _save_checkpoint(self, episode: int = 0):
|
|
"""
|
|
Save the actor and critic checkpoints with running states.
|
|
|
|
Args:
|
|
episode (int): The current episode number.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
|
|
self.coordinator.print_on_master("\nStart saving actor checkpoint with running states")
|
|
save_checkpoint(
|
|
save_dir=self.actor_save_dir,
|
|
booster=self.actor_booster,
|
|
model=self.actor,
|
|
optimizer=self.actor_optim,
|
|
lr_scheduler=self.actor_scheduler,
|
|
epoch=0,
|
|
step=episode + 1,
|
|
batch_size=self.train_batch_size,
|
|
coordinator=self.coordinator,
|
|
)
|
|
self.coordinator.print_on_master(
|
|
f"Saved actor checkpoint at episode {(episode + 1)} at folder {self.actor_save_dir}"
|
|
)
|
|
|
|
self.coordinator.print_on_master("\nStart saving critic checkpoint with running states")
|
|
save_checkpoint(
|
|
save_dir=self.critic_save_dir,
|
|
booster=self.critic_booster,
|
|
model=self.critic,
|
|
optimizer=self.critic_optim,
|
|
lr_scheduler=self.critic_scheduler,
|
|
epoch=0,
|
|
step=episode + 1,
|
|
batch_size=self.train_batch_size,
|
|
coordinator=self.coordinator,
|
|
)
|
|
self.coordinator.print_on_master(
|
|
f"Saved critic checkpoint at episode {(episode + 1)} at folder {self.critic_save_dir}"
|
|
)
|