ColossalAI/applications/Chat/coati/ray/detached_trainer_ppo.py
Hongxin Liu b5f0566363
[chat] add distributed PPO trainer (#3740)
* Detached ppo (#9)

* run the base

* working on dist ppo

* sync

* detached trainer

* update detached trainer. no maker update function

* facing init problem

* 1 maker 1 trainer detached run. but no model update

* facing cuda problem

* fix save functions

* verified maker update

* nothing

* add ignore

* analyize loss issue

* remove some debug codes

* facing 2m1t stuck issue

* 2m1t verified

* do not use torchrun

* working on 2m2t

* working on 2m2t

* initialize strategy in ray actor env

* facing actor's init order issue

* facing ddp model update issue (need unwarp ddp)

* unwrap ddp actor

* checking 1m2t stuck problem

* nothing

* set timeout for trainer choosing. It solves the stuck problem!

* delete some debug output

* rename to sync with upstream

* rename to sync with upstream

* coati rename

* nothing

* I am going to detach the replaybuffer from trainer and make it a Ray Actor. Two benefits: 1. support TP trainer. 2. asynchronized buffer operations

* experience_maker_holder performs target-revolving _send_experience() instead of length comparison.

* move code to ray subfolder

* working on pipeline inference

* apply comments

* working on pipeline strategy. in progress.

* remove pipeline code. clean this branch

* update remote parameters by state_dict. no test

* nothing

* state_dict sharding transfer

* merge debug branch

* gemini _unwrap_model fix

* simplify code

* simplify code & fix LoRALinear AttributeError

* critic unwrapped state_dict

---------

Co-authored-by: csric <richcsr256@gmail.com>

* [chat] add perfomance evaluator and fix bugs (#10)

* [chat] add performance evaluator for ray

* [chat] refactor debug arg

* [chat] support hf config

* [chat] fix generation

* [chat] add 1mmt dummy example

* [chat] fix gemini ckpt

* split experience to send (#11)

Co-authored-by: csric <richcsr256@gmail.com>

* [chat] refactor trainer and maker (#12)

* [chat] refactor experience maker holder

* [chat] refactor model init

* [chat] refactor trainer args

* [chat] refactor model init

* [chat] refactor trainer

* [chat] refactor experience sending logic and training loop args (#13)

* [chat] refactor experience send logic

* [chat] refactor trainer

* [chat] refactor trainer

* [chat] refactor experience maker

* [chat] refactor pbar

* [chat] refactor example folder (#14)

* [chat] support quant (#15)

* [chat] add quant

* [chat] add quant example

* prompt example (#16)

* prompt example

* prompt load csv data

* remove legacy try

---------

Co-authored-by: csric <richcsr256@gmail.com>

* [chat] add mmmt dummy example and refactor experience sending (#17)

* [chat] add mmmt dummy example

* [chat] refactor naive strategy

* [chat] fix struck problem

* [chat] fix naive strategy

* [chat] optimize experience maker sending logic

* [chat] refactor sending assignment

* [chat] refactor performance evaluator (#18)

* Prompt Example & requires_grad state_dict & sharding state_dict (#19)

* prompt example

* prompt load csv data

* remove legacy try

* maker models require_grad set to False

* working on zero redundancy update

* mmmt_prompt example; naive strategy requires_grad state_dict & sharding; maker model requires_no_grad.

* remove legacy examples

* remove legacy examples

* remove replay buffer tp state. bad design

---------

Co-authored-by: csric <richcsr256@gmail.com>

* state_dict sending adapts to new unwrap function (#20)

* prompt example

* prompt load csv data

* remove legacy try

* maker models require_grad set to False

* working on zero redundancy update

* mmmt_prompt example; naive strategy requires_grad state_dict & sharding; maker model requires_no_grad.

* remove legacy examples

* remove legacy examples

* remove replay buffer tp state. bad design

* opt benchmark

* better script

* nothing

* [chat] strategy refactor unwrap model

* [chat] strategy refactor save model

* [chat] add docstr

* [chat] refactor trainer save model

* [chat] fix strategy typing

* [chat] refactor trainer save model

* [chat] update readme

* [chat] fix unit test

* working on lora reconstruction

* state_dict sending adapts to new unwrap function

* remove comments

---------

Co-authored-by: csric <richcsr256@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>

* [chat-ray] add readme (#21)

* add readme

* transparent graph

* add note background

---------

Co-authored-by: csric <richcsr256@gmail.com>

* [chat] get images from url (#22)

* Refactor/chat ray (#23)

* [chat] lora add todo

* [chat] remove unused pipeline strategy

* [chat] refactor example structure

* [chat] setup ci for ray

* [chat-ray] Support LoRA trainer. LoRA weights reconstruction. (#24)

* lora support prototype

* lora support

* 1mmt lora & remove useless code

---------

Co-authored-by: csric <richcsr256@gmail.com>

* [chat] fix test ci for ray

* [chat] fix test ci requirements for ray

* [chat] fix ray runtime env

* [chat] fix ray runtime env

* [chat] fix example ci docker args

* [chat] add debug info in trainer

* [chat] add nccl debug info

* [chat] skip ray test

* [doc] fix typo

---------

Co-authored-by: csric <59389055+CsRic@users.noreply.github.com>
Co-authored-by: csric <richcsr256@gmail.com>
2023-06-07 10:41:16 +08:00

201 lines
8.9 KiB
Python

from typing import Any, Callable, Dict, List, Optional, Tuple
import ray
import torch
from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models.base import Actor, Critic
from coati.models.loss import PolicyLoss, ValueLoss
from coati.trainer.callbacks import Callback
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy, Strategy
from torch.optim import Adam
from colossalai.nn.optimizer import HybridAdam
from .callbacks import TrainerCallback, TrainerPerformanceEvaluator
from .detached_trainer_base import DetachedTrainer
from .lora_constructor import LoRAConstructor
from .utils import (
get_actor_from_args,
get_critic_from_args,
get_model_numel,
get_rank,
get_strategy_from_args,
is_rank_0,
set_dist_env,
state_dict_to,
)
@ray.remote(concurrency_groups={
"buffer_length": 1,
"buffer_append": 1,
"buffer_sample": 1,
"model_io": 1,
"compute": 1
})
class DetachedPPOTrainer(DetachedTrainer):
'''
Detached Trainer for PPO algorithm
Args:
strategy (Strategy): the strategy to use for training
model (str) : for actor / critic init
pretrained (str) : for actor / critic init
lora_rank (int) : for actor / critic init
train_batch_size (int, defaults to 8): the batch size to use for training
train_batch_size (int, defaults to 8): the batch size to use for training
buffer_limit (int, defaults to 0): the max_size limitation of replay buffer
buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
value_clip (float, defaults to 0.4): the clip coefficient of value loss
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
max_epochs (int, defaults to 1): the number of epochs of training process
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
'''
def __init__(
self,
experience_maker_holder_name_list: List[str],
strategy_fn: Callable[[], Strategy],
model_fn: Callable[[], Tuple[Actor, Critic]],
env_info: Dict[str, str] = None,
train_batch_size: int = 8,
buffer_limit: int = 0,
eps_clip: float = 0.2,
value_clip: float = 0.4,
dataloader_pin_memory: bool = True,
callbacks: List[TrainerCallback] = [],
eval_performance: bool = False,
debug: bool = False,
update_lora_weights: bool = False,
) -> None:
# set environment variables
if env_info:
set_dist_env(env_info=env_info)
# configure strategy
self.strategy = strategy_fn()
# configure models, loss and optimizers
with self.strategy.model_init_context():
self.actor, self.critic = model_fn()
if eval_performance:
actor_numel = get_model_numel(self.actor)
critic_numel = get_model_numel(self.critic)
evaluator = TrainerPerformanceEvaluator(actor_numel, critic_numel)
callbacks = callbacks + [evaluator]
if isinstance(self.strategy, ColossalAIStrategy):
self.actor_optim = HybridAdam(self.actor.parameters(), lr=1e-7)
self.critic_optim = HybridAdam(self.critic.parameters(), lr=1e-7)
else:
self.actor_optim = Adam(self.actor.parameters(), lr=1e-7)
self.critic_optim = Adam(self.critic.parameters(), lr=1e-7)
(self.actor, self.actor_optim), (self.critic, self.critic_optim) = \
self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim))
# configure trainer
self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip)
super().__init__(experience_maker_holder_name_list,
train_batch_size=train_batch_size,
buffer_limit=buffer_limit,
dataloader_pin_memory=dataloader_pin_memory,
callbacks=callbacks,
debug=debug)
if self._debug:
print(f'[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}')
self._update_lora_weights = update_lora_weights
@ray.method(concurrency_group="model_io")
@torch.no_grad()
def _update_remote_makers(self, fully_update: bool = False, **config):
# TODO: balance duties
if not fully_update:
config['requires_grad_only'] = True
self.update_target_holder_list()
# mark start, ensure order
tasks = []
for target_holder in self.target_holder_list:
tasks.append(target_holder.update_experience_maker.remote(chunk_start=True, fully_update=fully_update))
ray.get(tasks)
# sending loop
tasks = []
for state_dict_shard in self._get_model_state_dict_shard(self.actor, fully_update=fully_update, **config):
for target_holder in self.target_holder_list:
tasks.append(
target_holder.update_experience_maker.remote(
new_actor_state_dict=state_dict_shard,
new_actor_lora_config_dict=self._get_model_lora_config_dict(self.actor),
fully_update=fully_update))
# sending loop
for state_dict_shard in self._get_model_state_dict_shard(self.critic, fully_update=fully_update, **config):
for target_holder in self.target_holder_list:
tasks.append(
target_holder.update_experience_maker.remote(
new_critic_state_dict=state_dict_shard,
new_critic_lora_config_dict=self._get_model_lora_config_dict(self.critic),
fully_update=fully_update))
ray.get(tasks)
# mark end
for target_holder in self.target_holder_list:
target_holder.update_experience_maker.remote(chunk_end=True, fully_update=fully_update)
@ray.method(concurrency_group="compute")
def training_step(self, experience: Experience) -> Dict[str, float]:
self.actor.train()
self.critic.train()
num_actions = experience.action_mask.size(1)
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
actor_loss = self.actor_loss_fn(action_log_probs,
experience.action_log_probs,
experience.advantages,
action_mask=experience.action_mask)
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
self.strategy.optimizer_step(self.actor_optim)
self.actor_optim.zero_grad()
values = self.critic(experience.sequences,
action_mask=experience.action_mask,
attention_mask=experience.attention_mask)
critic_loss = self.critic_loss_fn(values,
experience.values,
experience.reward,
action_mask=experience.action_mask)
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
self.strategy.optimizer_step(self.critic_optim)
self.critic_optim.zero_grad()
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None:
self.strategy.save_model(self.actor, path, only_rank0)
def strategy_save_critic(self, path: str, only_rank0: bool = False) -> None:
self.strategy.save_model(self.critic, path, only_rank0)
def strategy_save_actor_optim(self, path: str, only_rank0: bool = False) -> None:
self.strategy.save_optimizer(self.actor_optim, path, only_rank0)
def strategy_save_critic_optim(self, path: str, only_rank0: bool = False) -> None:
self.strategy.save_optimizer(self.critic_optim, path, only_rank0)
def _get_model_state_dict_shard(self, model: torch.nn.Module, fully_update=False, **config):
for state_dict in self.strategy.get_model_state_dict_shard(model, **config):
if not self._update_lora_weights or fully_update:
yield state_dict_to(state_dict)
else:
state_dict_lora, _ = LoRAConstructor.filter_state_dict_lora(state_dict)
yield state_dict_to(state_dict_lora)
def _get_model_lora_config_dict(self, model: torch.nn.Module):
if not self._update_lora_weights:
return None
unwrapped_model = self.strategy.unwrap_model(model)
return LoRAConstructor.extract_lora_config(unwrapped_model)