mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-10 20:32:40 +00:00
* Detached ppo (#9) * run the base * working on dist ppo * sync * detached trainer * update detached trainer. no maker update function * facing init problem * 1 maker 1 trainer detached run. but no model update * facing cuda problem * fix save functions * verified maker update * nothing * add ignore * analyize loss issue * remove some debug codes * facing 2m1t stuck issue * 2m1t verified * do not use torchrun * working on 2m2t * working on 2m2t * initialize strategy in ray actor env * facing actor's init order issue * facing ddp model update issue (need unwarp ddp) * unwrap ddp actor * checking 1m2t stuck problem * nothing * set timeout for trainer choosing. It solves the stuck problem! * delete some debug output * rename to sync with upstream * rename to sync with upstream * coati rename * nothing * I am going to detach the replaybuffer from trainer and make it a Ray Actor. Two benefits: 1. support TP trainer. 2. asynchronized buffer operations * experience_maker_holder performs target-revolving _send_experience() instead of length comparison. * move code to ray subfolder * working on pipeline inference * apply comments * working on pipeline strategy. in progress. * remove pipeline code. clean this branch * update remote parameters by state_dict. no test * nothing * state_dict sharding transfer * merge debug branch * gemini _unwrap_model fix * simplify code * simplify code & fix LoRALinear AttributeError * critic unwrapped state_dict --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] add perfomance evaluator and fix bugs (#10) * [chat] add performance evaluator for ray * [chat] refactor debug arg * [chat] support hf config * [chat] fix generation * [chat] add 1mmt dummy example * [chat] fix gemini ckpt * split experience to send (#11) Co-authored-by: csric <richcsr256@gmail.com> * [chat] refactor trainer and maker (#12) * [chat] refactor experience maker holder * [chat] refactor model init * [chat] refactor trainer args * [chat] refactor model init * [chat] refactor trainer * [chat] refactor experience sending logic and training loop args (#13) * [chat] refactor experience send logic * [chat] refactor trainer * [chat] refactor trainer * [chat] refactor experience maker * [chat] refactor pbar * [chat] refactor example folder (#14) * [chat] support quant (#15) * [chat] add quant * [chat] add quant example * prompt example (#16) * prompt example * prompt load csv data * remove legacy try --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] add mmmt dummy example and refactor experience sending (#17) * [chat] add mmmt dummy example * [chat] refactor naive strategy * [chat] fix struck problem * [chat] fix naive strategy * [chat] optimize experience maker sending logic * [chat] refactor sending assignment * [chat] refactor performance evaluator (#18) * Prompt Example & requires_grad state_dict & sharding state_dict (#19) * prompt example * prompt load csv data * remove legacy try * maker models require_grad set to False * working on zero redundancy update * mmmt_prompt example; naive strategy requires_grad state_dict & sharding; maker model requires_no_grad. * remove legacy examples * remove legacy examples * remove replay buffer tp state. bad design --------- Co-authored-by: csric <richcsr256@gmail.com> * state_dict sending adapts to new unwrap function (#20) * prompt example * prompt load csv data * remove legacy try * maker models require_grad set to False * working on zero redundancy update * mmmt_prompt example; naive strategy requires_grad state_dict & sharding; maker model requires_no_grad. * remove legacy examples * remove legacy examples * remove replay buffer tp state. bad design * opt benchmark * better script * nothing * [chat] strategy refactor unwrap model * [chat] strategy refactor save model * [chat] add docstr * [chat] refactor trainer save model * [chat] fix strategy typing * [chat] refactor trainer save model * [chat] update readme * [chat] fix unit test * working on lora reconstruction * state_dict sending adapts to new unwrap function * remove comments --------- Co-authored-by: csric <richcsr256@gmail.com> Co-authored-by: ver217 <lhx0217@gmail.com> * [chat-ray] add readme (#21) * add readme * transparent graph * add note background --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] get images from url (#22) * Refactor/chat ray (#23) * [chat] lora add todo * [chat] remove unused pipeline strategy * [chat] refactor example structure * [chat] setup ci for ray * [chat-ray] Support LoRA trainer. LoRA weights reconstruction. (#24) * lora support prototype * lora support * 1mmt lora & remove useless code --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] fix test ci for ray * [chat] fix test ci requirements for ray * [chat] fix ray runtime env * [chat] fix ray runtime env * [chat] fix example ci docker args * [chat] add debug info in trainer * [chat] add nccl debug info * [chat] skip ray test * [doc] fix typo --------- Co-authored-by: csric <59389055+CsRic@users.noreply.github.com> Co-authored-by: csric <richcsr256@gmail.com>
201 lines
8.9 KiB
Python
201 lines
8.9 KiB
Python
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
|
|
import ray
|
|
import torch
|
|
from coati.experience_maker import Experience, NaiveExperienceMaker
|
|
from coati.models.base import Actor, Critic
|
|
from coati.models.loss import PolicyLoss, ValueLoss
|
|
from coati.trainer.callbacks import Callback
|
|
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy, Strategy
|
|
from torch.optim import Adam
|
|
|
|
from colossalai.nn.optimizer import HybridAdam
|
|
|
|
from .callbacks import TrainerCallback, TrainerPerformanceEvaluator
|
|
from .detached_trainer_base import DetachedTrainer
|
|
from .lora_constructor import LoRAConstructor
|
|
from .utils import (
|
|
get_actor_from_args,
|
|
get_critic_from_args,
|
|
get_model_numel,
|
|
get_rank,
|
|
get_strategy_from_args,
|
|
is_rank_0,
|
|
set_dist_env,
|
|
state_dict_to,
|
|
)
|
|
|
|
|
|
@ray.remote(concurrency_groups={
|
|
"buffer_length": 1,
|
|
"buffer_append": 1,
|
|
"buffer_sample": 1,
|
|
"model_io": 1,
|
|
"compute": 1
|
|
})
|
|
class DetachedPPOTrainer(DetachedTrainer):
|
|
'''
|
|
Detached Trainer for PPO algorithm
|
|
Args:
|
|
strategy (Strategy): the strategy to use for training
|
|
model (str) : for actor / critic init
|
|
pretrained (str) : for actor / critic init
|
|
lora_rank (int) : for actor / critic init
|
|
train_batch_size (int, defaults to 8): the batch size to use for training
|
|
train_batch_size (int, defaults to 8): the batch size to use for training
|
|
buffer_limit (int, defaults to 0): the max_size limitation of replay buffer
|
|
buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
|
|
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
|
|
value_clip (float, defaults to 0.4): the clip coefficient of value loss
|
|
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
|
|
max_epochs (int, defaults to 1): the number of epochs of training process
|
|
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
|
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
|
generate_kwargs (dict, optional): the kwargs to use while model generating
|
|
'''
|
|
|
|
def __init__(
|
|
self,
|
|
experience_maker_holder_name_list: List[str],
|
|
strategy_fn: Callable[[], Strategy],
|
|
model_fn: Callable[[], Tuple[Actor, Critic]],
|
|
env_info: Dict[str, str] = None,
|
|
train_batch_size: int = 8,
|
|
buffer_limit: int = 0,
|
|
eps_clip: float = 0.2,
|
|
value_clip: float = 0.4,
|
|
dataloader_pin_memory: bool = True,
|
|
callbacks: List[TrainerCallback] = [],
|
|
eval_performance: bool = False,
|
|
debug: bool = False,
|
|
update_lora_weights: bool = False,
|
|
) -> None:
|
|
# set environment variables
|
|
if env_info:
|
|
set_dist_env(env_info=env_info)
|
|
# configure strategy
|
|
self.strategy = strategy_fn()
|
|
# configure models, loss and optimizers
|
|
with self.strategy.model_init_context():
|
|
self.actor, self.critic = model_fn()
|
|
|
|
if eval_performance:
|
|
actor_numel = get_model_numel(self.actor)
|
|
critic_numel = get_model_numel(self.critic)
|
|
evaluator = TrainerPerformanceEvaluator(actor_numel, critic_numel)
|
|
callbacks = callbacks + [evaluator]
|
|
|
|
if isinstance(self.strategy, ColossalAIStrategy):
|
|
self.actor_optim = HybridAdam(self.actor.parameters(), lr=1e-7)
|
|
self.critic_optim = HybridAdam(self.critic.parameters(), lr=1e-7)
|
|
else:
|
|
self.actor_optim = Adam(self.actor.parameters(), lr=1e-7)
|
|
self.critic_optim = Adam(self.critic.parameters(), lr=1e-7)
|
|
|
|
(self.actor, self.actor_optim), (self.critic, self.critic_optim) = \
|
|
self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim))
|
|
|
|
# configure trainer
|
|
self.actor_loss_fn = PolicyLoss(eps_clip)
|
|
self.critic_loss_fn = ValueLoss(value_clip)
|
|
|
|
super().__init__(experience_maker_holder_name_list,
|
|
train_batch_size=train_batch_size,
|
|
buffer_limit=buffer_limit,
|
|
dataloader_pin_memory=dataloader_pin_memory,
|
|
callbacks=callbacks,
|
|
debug=debug)
|
|
if self._debug:
|
|
print(f'[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}')
|
|
|
|
self._update_lora_weights = update_lora_weights
|
|
|
|
@ray.method(concurrency_group="model_io")
|
|
@torch.no_grad()
|
|
def _update_remote_makers(self, fully_update: bool = False, **config):
|
|
# TODO: balance duties
|
|
if not fully_update:
|
|
config['requires_grad_only'] = True
|
|
self.update_target_holder_list()
|
|
# mark start, ensure order
|
|
tasks = []
|
|
for target_holder in self.target_holder_list:
|
|
tasks.append(target_holder.update_experience_maker.remote(chunk_start=True, fully_update=fully_update))
|
|
ray.get(tasks)
|
|
# sending loop
|
|
tasks = []
|
|
|
|
for state_dict_shard in self._get_model_state_dict_shard(self.actor, fully_update=fully_update, **config):
|
|
for target_holder in self.target_holder_list:
|
|
tasks.append(
|
|
target_holder.update_experience_maker.remote(
|
|
new_actor_state_dict=state_dict_shard,
|
|
new_actor_lora_config_dict=self._get_model_lora_config_dict(self.actor),
|
|
fully_update=fully_update))
|
|
# sending loop
|
|
for state_dict_shard in self._get_model_state_dict_shard(self.critic, fully_update=fully_update, **config):
|
|
for target_holder in self.target_holder_list:
|
|
tasks.append(
|
|
target_holder.update_experience_maker.remote(
|
|
new_critic_state_dict=state_dict_shard,
|
|
new_critic_lora_config_dict=self._get_model_lora_config_dict(self.critic),
|
|
fully_update=fully_update))
|
|
ray.get(tasks)
|
|
# mark end
|
|
for target_holder in self.target_holder_list:
|
|
target_holder.update_experience_maker.remote(chunk_end=True, fully_update=fully_update)
|
|
|
|
@ray.method(concurrency_group="compute")
|
|
def training_step(self, experience: Experience) -> Dict[str, float]:
|
|
self.actor.train()
|
|
self.critic.train()
|
|
|
|
num_actions = experience.action_mask.size(1)
|
|
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
|
|
actor_loss = self.actor_loss_fn(action_log_probs,
|
|
experience.action_log_probs,
|
|
experience.advantages,
|
|
action_mask=experience.action_mask)
|
|
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
|
|
self.strategy.optimizer_step(self.actor_optim)
|
|
self.actor_optim.zero_grad()
|
|
|
|
values = self.critic(experience.sequences,
|
|
action_mask=experience.action_mask,
|
|
attention_mask=experience.attention_mask)
|
|
critic_loss = self.critic_loss_fn(values,
|
|
experience.values,
|
|
experience.reward,
|
|
action_mask=experience.action_mask)
|
|
|
|
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
|
|
self.strategy.optimizer_step(self.critic_optim)
|
|
self.critic_optim.zero_grad()
|
|
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
|
|
|
|
def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None:
|
|
self.strategy.save_model(self.actor, path, only_rank0)
|
|
|
|
def strategy_save_critic(self, path: str, only_rank0: bool = False) -> None:
|
|
self.strategy.save_model(self.critic, path, only_rank0)
|
|
|
|
def strategy_save_actor_optim(self, path: str, only_rank0: bool = False) -> None:
|
|
self.strategy.save_optimizer(self.actor_optim, path, only_rank0)
|
|
|
|
def strategy_save_critic_optim(self, path: str, only_rank0: bool = False) -> None:
|
|
self.strategy.save_optimizer(self.critic_optim, path, only_rank0)
|
|
|
|
def _get_model_state_dict_shard(self, model: torch.nn.Module, fully_update=False, **config):
|
|
for state_dict in self.strategy.get_model_state_dict_shard(model, **config):
|
|
if not self._update_lora_weights or fully_update:
|
|
yield state_dict_to(state_dict)
|
|
else:
|
|
state_dict_lora, _ = LoRAConstructor.filter_state_dict_lora(state_dict)
|
|
yield state_dict_to(state_dict_lora)
|
|
|
|
def _get_model_lora_config_dict(self, model: torch.nn.Module):
|
|
if not self._update_lora_weights:
|
|
return None
|
|
unwrapped_model = self.strategy.unwrap_model(model)
|
|
return LoRAConstructor.extract_lora_config(unwrapped_model)
|