mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 07:00:37 +00:00
[ColossalChat] Update RLHF V2 (#5286)
* Add dpo. Fix sft, ppo, lora. Refactor all * fix and tested ppo * 2 nd round refactor * add ci tests * fix ci * fix ci * fix readme, style * fix readme style * fix style, fix benchmark * reproduce benchmark result, remove useless files * rename to ColossalChat * use new image * fix ci workflow * fix ci * use local model/tokenizer for ci tests * fix ci * fix ci * fix ci * fix ci timeout * fix rm progress bar. fix ci timeout * fix ci * fix ci typo * remove 3d plugin from ci temporary * test environment * cannot save optimizer * support chat template * fix readme * fix path * test ci locally * restore build_or_pr * fix ci data path * fix benchmark * fix ci, move ci tests to 3080, disable fast tokenizer * move ci to 85 * support flash attention 2 * add all-in-one data preparation script. Fix colossal-llama2-chat chat template * add hardware requirements * move ci test data * fix save_model, add unwrap * fix missing bos * fix missing bos; support grad accumulation with gemini * fix ci * fix ci * fix ci * fix llama2 chat template config * debug sft * debug sft * fix colossalai version requirement * fix ci * add sanity check to prevent NaN loss * fix requirements * add dummy data generation script * add dummy data generation script * add dummy data generation script * add dummy data generation script * update readme * update readme * update readme and ignore * fix logger bug * support parallel_output * modify data preparation logic * fix tokenization * update lr * fix inference * run pre-commit --------- Co-authored-by: Tong Li <tong.li352711588@gmail.com>
This commit is contained in:
179
applications/ColossalChat/coati/ray/detached_trainer_base.py
Executable file
179
applications/ColossalChat/coati/ray/detached_trainer_base.py
Executable file
@@ -0,0 +1,179 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from coati.experience_buffer.utils import BufferItem
|
||||
from coati.experience_maker import Experience
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from .callbacks import TrainerCallback
|
||||
from .detached_replay_buffer import DetachedReplayBuffer
|
||||
from .utils import is_rank_0
|
||||
|
||||
|
||||
class DetachedTrainer(ABC):
|
||||
"""
|
||||
Base class for detached rlhf trainers.
|
||||
'detach' means that the experience maker is detached compared to a normal Trainer.
|
||||
Please set name attribute during init:
|
||||
>>> trainer = DetachedTrainer.options(..., name = "xxx", ...).remote()
|
||||
So an ExperienceMakerHolder can reach the detached_replay_buffer by Actor's name.
|
||||
Args:
|
||||
detached_strategy (DetachedStrategy): the strategy to use for training
|
||||
detached_replay_buffer_ref (ObjectRef[DetachedReplayBuffer]): the replay buffer to use for training
|
||||
data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||
generate_kwargs (dict, optional): the kwargs to use while model generating
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experience_maker_holder_name_list: List[str],
|
||||
train_batch_size: int = 8,
|
||||
buffer_limit: int = 0,
|
||||
dataloader_pin_memory: bool = True,
|
||||
callbacks: List[TrainerCallback] = [],
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit)
|
||||
self.dataloader_pin_memory = dataloader_pin_memory
|
||||
self.callbacks = callbacks
|
||||
self.target_holder_name_list = experience_maker_holder_name_list
|
||||
self.target_holder_list = []
|
||||
self._is_target_holder_initialized = False
|
||||
self._debug = debug
|
||||
|
||||
def update_target_holder_list(self):
|
||||
# as the length of target_holder_list may be zero, we need to check it by a bool flag
|
||||
if not self._is_target_holder_initialized:
|
||||
for name in self.target_holder_name_list:
|
||||
self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
|
||||
self._is_target_holder_initialized = True
|
||||
|
||||
@abstractmethod
|
||||
def _update_remote_makers(self, fully_update: bool = False, **kwargs):
|
||||
pass
|
||||
|
||||
def sync_models_to_remote_makers(self, **kwargs):
|
||||
self._update_remote_makers(fully_update=True, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def training_step(self, experience: Experience) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
def _learn(self, update_steps: int, train_epochs: int) -> None:
|
||||
data = []
|
||||
# warmup
|
||||
pbar = tqdm(range(update_steps), desc=f"Train epoch [1/{train_epochs}]", disable=not is_rank_0())
|
||||
self._on_epoch_start(0)
|
||||
self._learn_epoch(pbar, data)
|
||||
self._on_epoch_end(0)
|
||||
# item is already a batch
|
||||
dataloader = DataLoader(
|
||||
data, batch_size=1, shuffle=True, pin_memory=self.dataloader_pin_memory, collate_fn=lambda x: x[0]
|
||||
)
|
||||
for epoch in range(1, train_epochs):
|
||||
pbar = tqdm(dataloader, desc=f"Train epoch [{epoch + 1}/{train_epochs}]", disable=not is_rank_0())
|
||||
self._on_epoch_start(epoch)
|
||||
self._learn_epoch(pbar, data)
|
||||
self._on_epoch_end(epoch)
|
||||
|
||||
def _learn_epoch(self, pbar: tqdm, data: List[Experience]) -> None:
|
||||
is_warmup = len(data) == 0
|
||||
for x in pbar:
|
||||
if self._debug:
|
||||
print("[trainer] training step")
|
||||
# sample a batch and then train to avoid waiting
|
||||
experience = x if not is_warmup else self._buffer_sample()
|
||||
experience.to_device(torch.cuda.current_device())
|
||||
self._on_batch_start()
|
||||
metrics = self.training_step(experience)
|
||||
self._on_batch_end(metrics, experience)
|
||||
|
||||
if self._debug:
|
||||
print("[trainer] step over")
|
||||
experience.to_device("cpu")
|
||||
if is_warmup:
|
||||
data.append(experience)
|
||||
pbar.set_postfix(metrics)
|
||||
|
||||
def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None:
|
||||
self._on_fit_start()
|
||||
for i in tqdm(range(total_steps // update_steps), desc="Trainer", disable=not is_rank_0()):
|
||||
self._on_episode_start(i)
|
||||
self._learn(update_steps, train_epochs)
|
||||
self._on_update_start()
|
||||
self._update_remote_makers()
|
||||
self._on_update_end()
|
||||
self._on_episode_end(i)
|
||||
self._on_fit_end()
|
||||
|
||||
@ray.method(concurrency_group="buffer_length")
|
||||
def buffer_get_length(self):
|
||||
# called by ExperienceMakerHolder
|
||||
if self._debug:
|
||||
print("[trainer] telling length")
|
||||
return self.detached_replay_buffer.get_length()
|
||||
|
||||
@ray.method(concurrency_group="buffer_append")
|
||||
def buffer_append(self, experience: Experience):
|
||||
# called by ExperienceMakerHolder
|
||||
if self._debug:
|
||||
print(f"[trainer] receiving exp.")
|
||||
self.detached_replay_buffer.append(experience)
|
||||
|
||||
@ray.method(concurrency_group="buffer_append")
|
||||
def buffer_extend(self, items: List[BufferItem]):
|
||||
# called by ExperienceMakerHolder
|
||||
if self._debug:
|
||||
print(f"[trainer] receiving exp.")
|
||||
self.detached_replay_buffer.extend(items)
|
||||
|
||||
@ray.method(concurrency_group="buffer_sample")
|
||||
def _buffer_sample(self):
|
||||
return self.detached_replay_buffer.sample()
|
||||
|
||||
def _on_fit_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_fit_start()
|
||||
|
||||
def _on_fit_end(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_fit_end()
|
||||
|
||||
def _on_episode_start(self, episode: int) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_episode_start(episode)
|
||||
|
||||
def _on_episode_end(self, episode: int) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_episode_end(episode)
|
||||
|
||||
def _on_epoch_start(self, epoch: int) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_epoch_start(epoch)
|
||||
|
||||
def _on_epoch_end(self, epoch: int) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_epoch_end(epoch)
|
||||
|
||||
def _on_batch_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_batch_start()
|
||||
|
||||
def _on_batch_end(self, metrics: dict, experience: Experience) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_batch_end(metrics, experience)
|
||||
|
||||
def _on_update_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_update_start()
|
||||
|
||||
def _on_update_end(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_update_end()
|
Reference in New Issue
Block a user