ColossalAI/applications/Chat/coati/ray/detached_trainer_base.py
Hongxin Liu b5f0566363
[chat] add distributed PPO trainer (#3740)
* Detached ppo (#9)

* run the base

* working on dist ppo

* sync

* detached trainer

* update detached trainer. no maker update function

* facing init problem

* 1 maker 1 trainer detached run. but no model update

* facing cuda problem

* fix save functions

* verified maker update

* nothing

* add ignore

* analyize loss issue

* remove some debug codes

* facing 2m1t stuck issue

* 2m1t verified

* do not use torchrun

* working on 2m2t

* working on 2m2t

* initialize strategy in ray actor env

* facing actor's init order issue

* facing ddp model update issue (need unwarp ddp)

* unwrap ddp actor

* checking 1m2t stuck problem

* nothing

* set timeout for trainer choosing. It solves the stuck problem!

* delete some debug output

* rename to sync with upstream

* rename to sync with upstream

* coati rename

* nothing

* I am going to detach the replaybuffer from trainer and make it a Ray Actor. Two benefits: 1. support TP trainer. 2. asynchronized buffer operations

* experience_maker_holder performs target-revolving _send_experience() instead of length comparison.

* move code to ray subfolder

* working on pipeline inference

* apply comments

* working on pipeline strategy. in progress.

* remove pipeline code. clean this branch

* update remote parameters by state_dict. no test

* nothing

* state_dict sharding transfer

* merge debug branch

* gemini _unwrap_model fix

* simplify code

* simplify code & fix LoRALinear AttributeError

* critic unwrapped state_dict

---------

Co-authored-by: csric <richcsr256@gmail.com>

* [chat] add perfomance evaluator and fix bugs (#10)

* [chat] add performance evaluator for ray

* [chat] refactor debug arg

* [chat] support hf config

* [chat] fix generation

* [chat] add 1mmt dummy example

* [chat] fix gemini ckpt

* split experience to send (#11)

Co-authored-by: csric <richcsr256@gmail.com>

* [chat] refactor trainer and maker (#12)

* [chat] refactor experience maker holder

* [chat] refactor model init

* [chat] refactor trainer args

* [chat] refactor model init

* [chat] refactor trainer

* [chat] refactor experience sending logic and training loop args (#13)

* [chat] refactor experience send logic

* [chat] refactor trainer

* [chat] refactor trainer

* [chat] refactor experience maker

* [chat] refactor pbar

* [chat] refactor example folder (#14)

* [chat] support quant (#15)

* [chat] add quant

* [chat] add quant example

* prompt example (#16)

* prompt example

* prompt load csv data

* remove legacy try

---------

Co-authored-by: csric <richcsr256@gmail.com>

* [chat] add mmmt dummy example and refactor experience sending (#17)

* [chat] add mmmt dummy example

* [chat] refactor naive strategy

* [chat] fix struck problem

* [chat] fix naive strategy

* [chat] optimize experience maker sending logic

* [chat] refactor sending assignment

* [chat] refactor performance evaluator (#18)

* Prompt Example & requires_grad state_dict & sharding state_dict (#19)

* prompt example

* prompt load csv data

* remove legacy try

* maker models require_grad set to False

* working on zero redundancy update

* mmmt_prompt example; naive strategy requires_grad state_dict & sharding; maker model requires_no_grad.

* remove legacy examples

* remove legacy examples

* remove replay buffer tp state. bad design

---------

Co-authored-by: csric <richcsr256@gmail.com>

* state_dict sending adapts to new unwrap function (#20)

* prompt example

* prompt load csv data

* remove legacy try

* maker models require_grad set to False

* working on zero redundancy update

* mmmt_prompt example; naive strategy requires_grad state_dict & sharding; maker model requires_no_grad.

* remove legacy examples

* remove legacy examples

* remove replay buffer tp state. bad design

* opt benchmark

* better script

* nothing

* [chat] strategy refactor unwrap model

* [chat] strategy refactor save model

* [chat] add docstr

* [chat] refactor trainer save model

* [chat] fix strategy typing

* [chat] refactor trainer save model

* [chat] update readme

* [chat] fix unit test

* working on lora reconstruction

* state_dict sending adapts to new unwrap function

* remove comments

---------

Co-authored-by: csric <richcsr256@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>

* [chat-ray] add readme (#21)

* add readme

* transparent graph

* add note background

---------

Co-authored-by: csric <richcsr256@gmail.com>

* [chat] get images from url (#22)

* Refactor/chat ray (#23)

* [chat] lora add todo

* [chat] remove unused pipeline strategy

* [chat] refactor example structure

* [chat] setup ci for ray

* [chat-ray] Support LoRA trainer. LoRA weights reconstruction. (#24)

* lora support prototype

* lora support

* 1mmt lora & remove useless code

---------

Co-authored-by: csric <richcsr256@gmail.com>

* [chat] fix test ci for ray

* [chat] fix test ci requirements for ray

* [chat] fix ray runtime env

* [chat] fix ray runtime env

* [chat] fix example ci docker args

* [chat] add debug info in trainer

* [chat] add nccl debug info

* [chat] skip ray test

* [doc] fix typo

---------

Co-authored-by: csric <59389055+CsRic@users.noreply.github.com>
Co-authored-by: csric <richcsr256@gmail.com>
2023-06-07 10:41:16 +08:00

180 lines
7.0 KiB
Python

import os
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import ray
import torch
from coati.experience_maker import Experience
from coati.replay_buffer.utils import BufferItem
from torch.utils.data import DataLoader
from tqdm import tqdm
from .callbacks import TrainerCallback
from .detached_replay_buffer import DetachedReplayBuffer
from .utils import is_rank_0
class DetachedTrainer(ABC):
'''
Base class for detached rlhf trainers.
'detach' means that the experience maker is detached compared to a normal Trainer.
Please set name attribute during init:
>>> trainer = DetachedTrainer.options(..., name = "xxx", ...).remote()
So an ExperienceMakerHolder can reach the detached_replay_buffer by Actor's name.
Args:
detached_strategy (DetachedStrategy): the strategy to use for training
detached_replay_buffer_ref (ObjectRef[DetachedReplayBuffer]): the replay buffer to use for training
data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
'''
def __init__(self,
experience_maker_holder_name_list: List[str],
train_batch_size: int = 8,
buffer_limit: int = 0,
dataloader_pin_memory: bool = True,
callbacks: List[TrainerCallback] = [],
debug: bool = False) -> None:
super().__init__()
self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit)
self.dataloader_pin_memory = dataloader_pin_memory
self.callbacks = callbacks
self.target_holder_name_list = experience_maker_holder_name_list
self.target_holder_list = []
self._is_target_holder_initialized = False
self._debug = debug
def update_target_holder_list(self):
# as the length of target_holder_list may be zero, we need to check it by a bool flag
if not self._is_target_holder_initialized:
for name in self.target_holder_name_list:
self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
self._is_target_holder_initialized = True
@abstractmethod
def _update_remote_makers(self, fully_update: bool = False, **kwargs):
pass
def sync_models_to_remote_makers(self, **kwargs):
self._update_remote_makers(fully_update=True, **kwargs)
@abstractmethod
def training_step(self, experience: Experience) -> Dict[str, Any]:
pass
def _learn(self, update_steps: int, train_epochs: int) -> None:
data = []
# warmup
pbar = tqdm(range(update_steps), desc=f'Train epoch [1/{train_epochs}]', disable=not is_rank_0())
self._on_epoch_start(0)
self._learn_epoch(pbar, data)
self._on_epoch_end(0)
# item is already a batch
dataloader = DataLoader(data,
batch_size=1,
shuffle=True,
pin_memory=self.dataloader_pin_memory,
collate_fn=lambda x: x[0])
for epoch in range(1, train_epochs):
pbar = tqdm(dataloader, desc=f'Train epoch [{epoch + 1}/{train_epochs}]', disable=not is_rank_0())
self._on_epoch_start(epoch)
self._learn_epoch(pbar, data)
self._on_epoch_end(epoch)
def _learn_epoch(self, pbar: tqdm, data: List[Experience]) -> None:
is_warmup = len(data) == 0
for x in pbar:
if self._debug:
print("[trainer] training step")
# sample a batch and then train to avoid waiting
experience = x if not is_warmup else self._buffer_sample()
experience.to_device(torch.cuda.current_device())
self._on_batch_start()
metrics = self.training_step(experience)
self._on_batch_end(metrics, experience)
if self._debug:
print("[trainer] step over")
experience.to_device("cpu")
if is_warmup:
data.append(experience)
pbar.set_postfix(metrics)
def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None:
self._on_fit_start()
for i in tqdm(range(total_steps // update_steps), desc='Trainer', disable=not is_rank_0()):
self._on_episode_start(i)
self._learn(update_steps, train_epochs)
self._on_update_start()
self._update_remote_makers()
self._on_update_end()
self._on_episode_end(i)
self._on_fit_end()
@ray.method(concurrency_group="buffer_length")
def buffer_get_length(self):
# called by ExperienceMakerHolder
if self._debug:
print("[trainer] telling length")
return self.detached_replay_buffer.get_length()
@ray.method(concurrency_group="buffer_append")
def buffer_append(self, experience: Experience):
# called by ExperienceMakerHolder
if self._debug:
print(f"[trainer] receiving exp.")
self.detached_replay_buffer.append(experience)
@ray.method(concurrency_group="buffer_append")
def buffer_extend(self, items: List[BufferItem]):
# called by ExperienceMakerHolder
if self._debug:
print(f"[trainer] receiving exp.")
self.detached_replay_buffer.extend(items)
@ray.method(concurrency_group="buffer_sample")
def _buffer_sample(self):
return self.detached_replay_buffer.sample()
def _on_fit_start(self) -> None:
for callback in self.callbacks:
callback.on_fit_start()
def _on_fit_end(self) -> None:
for callback in self.callbacks:
callback.on_fit_end()
def _on_episode_start(self, episode: int) -> None:
for callback in self.callbacks:
callback.on_episode_start(episode)
def _on_episode_end(self, episode: int) -> None:
for callback in self.callbacks:
callback.on_episode_end(episode)
def _on_epoch_start(self, epoch: int) -> None:
for callback in self.callbacks:
callback.on_epoch_start(epoch)
def _on_epoch_end(self, epoch: int) -> None:
for callback in self.callbacks:
callback.on_epoch_end(epoch)
def _on_batch_start(self) -> None:
for callback in self.callbacks:
callback.on_batch_start()
def _on_batch_end(self, metrics: dict, experience: Experience) -> None:
for callback in self.callbacks:
callback.on_batch_end(metrics, experience)
def _on_update_start(self) -> None:
for callback in self.callbacks:
callback.on_update_start()
def _on_update_end(self) -> None:
for callback in self.callbacks:
callback.on_update_end()