mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-14 06:05:26 +00:00
* Detached ppo (#9) * run the base * working on dist ppo * sync * detached trainer * update detached trainer. no maker update function * facing init problem * 1 maker 1 trainer detached run. but no model update * facing cuda problem * fix save functions * verified maker update * nothing * add ignore * analyize loss issue * remove some debug codes * facing 2m1t stuck issue * 2m1t verified * do not use torchrun * working on 2m2t * working on 2m2t * initialize strategy in ray actor env * facing actor's init order issue * facing ddp model update issue (need unwarp ddp) * unwrap ddp actor * checking 1m2t stuck problem * nothing * set timeout for trainer choosing. It solves the stuck problem! * delete some debug output * rename to sync with upstream * rename to sync with upstream * coati rename * nothing * I am going to detach the replaybuffer from trainer and make it a Ray Actor. Two benefits: 1. support TP trainer. 2. asynchronized buffer operations * experience_maker_holder performs target-revolving _send_experience() instead of length comparison. * move code to ray subfolder * working on pipeline inference * apply comments * working on pipeline strategy. in progress. * remove pipeline code. clean this branch * update remote parameters by state_dict. no test * nothing * state_dict sharding transfer * merge debug branch * gemini _unwrap_model fix * simplify code * simplify code & fix LoRALinear AttributeError * critic unwrapped state_dict --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] add perfomance evaluator and fix bugs (#10) * [chat] add performance evaluator for ray * [chat] refactor debug arg * [chat] support hf config * [chat] fix generation * [chat] add 1mmt dummy example * [chat] fix gemini ckpt * split experience to send (#11) Co-authored-by: csric <richcsr256@gmail.com> * [chat] refactor trainer and maker (#12) * [chat] refactor experience maker holder * [chat] refactor model init * [chat] refactor trainer args * [chat] refactor model init * [chat] refactor trainer * [chat] refactor experience sending logic and training loop args (#13) * [chat] refactor experience send logic * [chat] refactor trainer * [chat] refactor trainer * [chat] refactor experience maker * [chat] refactor pbar * [chat] refactor example folder (#14) * [chat] support quant (#15) * [chat] add quant * [chat] add quant example * prompt example (#16) * prompt example * prompt load csv data * remove legacy try --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] add mmmt dummy example and refactor experience sending (#17) * [chat] add mmmt dummy example * [chat] refactor naive strategy * [chat] fix struck problem * [chat] fix naive strategy * [chat] optimize experience maker sending logic * [chat] refactor sending assignment * [chat] refactor performance evaluator (#18) * Prompt Example & requires_grad state_dict & sharding state_dict (#19) * prompt example * prompt load csv data * remove legacy try * maker models require_grad set to False * working on zero redundancy update * mmmt_prompt example; naive strategy requires_grad state_dict & sharding; maker model requires_no_grad. * remove legacy examples * remove legacy examples * remove replay buffer tp state. bad design --------- Co-authored-by: csric <richcsr256@gmail.com> * state_dict sending adapts to new unwrap function (#20) * prompt example * prompt load csv data * remove legacy try * maker models require_grad set to False * working on zero redundancy update * mmmt_prompt example; naive strategy requires_grad state_dict & sharding; maker model requires_no_grad. * remove legacy examples * remove legacy examples * remove replay buffer tp state. bad design * opt benchmark * better script * nothing * [chat] strategy refactor unwrap model * [chat] strategy refactor save model * [chat] add docstr * [chat] refactor trainer save model * [chat] fix strategy typing * [chat] refactor trainer save model * [chat] update readme * [chat] fix unit test * working on lora reconstruction * state_dict sending adapts to new unwrap function * remove comments --------- Co-authored-by: csric <richcsr256@gmail.com> Co-authored-by: ver217 <lhx0217@gmail.com> * [chat-ray] add readme (#21) * add readme * transparent graph * add note background --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] get images from url (#22) * Refactor/chat ray (#23) * [chat] lora add todo * [chat] remove unused pipeline strategy * [chat] refactor example structure * [chat] setup ci for ray * [chat-ray] Support LoRA trainer. LoRA weights reconstruction. (#24) * lora support prototype * lora support * 1mmt lora & remove useless code --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] fix test ci for ray * [chat] fix test ci requirements for ray * [chat] fix ray runtime env * [chat] fix ray runtime env * [chat] fix example ci docker args * [chat] add debug info in trainer * [chat] add nccl debug info * [chat] skip ray test * [doc] fix typo --------- Co-authored-by: csric <59389055+CsRic@users.noreply.github.com> Co-authored-by: csric <richcsr256@gmail.com>
213 lines
8.4 KiB
Python
213 lines
8.4 KiB
Python
from time import time
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from coati.experience_maker import Experience
|
|
|
|
from .base import MakerCallback, TrainerCallback
|
|
|
|
|
|
def get_world_size() -> int:
|
|
if dist.is_initialized():
|
|
return dist.get_world_size()
|
|
return 1
|
|
|
|
|
|
def print_rank_0(*args, **kwargs) -> None:
|
|
if not dist.is_initialized() or dist.get_rank() == 0:
|
|
print(*args, **kwargs)
|
|
|
|
|
|
@torch.no_grad()
|
|
def all_reduce_mean(x: float, world_size: int) -> float:
|
|
if world_size == 1:
|
|
return x
|
|
tensor = torch.tensor([x], device=torch.cuda.current_device())
|
|
dist.all_reduce(tensor)
|
|
tensor = tensor / world_size
|
|
return tensor.item()
|
|
|
|
|
|
class Timer:
|
|
|
|
def __init__(self) -> None:
|
|
self.start_time: Optional[float] = None
|
|
self.duration: float = 0.
|
|
|
|
def start(self) -> None:
|
|
self.start_time = time()
|
|
|
|
def end(self) -> None:
|
|
self.duration += time() - self.start_time
|
|
|
|
def reset(self) -> None:
|
|
self.duration = 0.
|
|
|
|
|
|
class ExperienceMakerPerformanceEvaluator(MakerCallback):
|
|
|
|
def __init__(self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int,
|
|
reward_model_num_params: int) -> None:
|
|
super().__init__()
|
|
self.world_size = get_world_size()
|
|
self.actor_num_params = actor_num_params
|
|
self.critic_num_params = critic_num_params
|
|
self.initial_model_num_params = initial_model_num_params
|
|
self.reward_model_num_params = reward_model_num_params
|
|
|
|
self.batch_timer = Timer()
|
|
self.send_timer = Timer()
|
|
self.make_experience_timer = Timer()
|
|
self.total_samples: int = 0
|
|
self.make_experience_flop: int = 0
|
|
|
|
print_rank_0(
|
|
f'ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}'
|
|
)
|
|
|
|
def on_make_experience_start(self) -> None:
|
|
self.make_experience_timer.start()
|
|
|
|
def on_make_experience_end(self, experience: Experience) -> None:
|
|
self.make_experience_timer.end()
|
|
|
|
batch_size, seq_len = experience.sequences.shape
|
|
|
|
self.total_samples += batch_size
|
|
|
|
# actor generate
|
|
num_actions = experience.action_mask.size(1)
|
|
input_len = seq_len - num_actions
|
|
total_seq_len = (input_len + seq_len - 1) * num_actions / 2
|
|
self.make_experience_flop += self.actor_num_params * batch_size * total_seq_len * 2
|
|
# actor forward
|
|
self.make_experience_flop += self.actor_num_params * batch_size * seq_len * 2
|
|
# critic forward
|
|
self.make_experience_flop += self.critic_num_params * batch_size * seq_len * 2
|
|
# initial model forward
|
|
self.make_experience_flop += self.initial_model_num_params * batch_size * seq_len * 2
|
|
# reward model forward
|
|
self.make_experience_flop += self.reward_model_num_params * batch_size * seq_len * 2
|
|
|
|
def on_send_start(self) -> None:
|
|
self.send_timer.start()
|
|
|
|
def on_send_end(self) -> None:
|
|
self.send_timer.end()
|
|
|
|
def on_batch_start(self) -> None:
|
|
self.batch_timer.start()
|
|
|
|
def on_batch_end(self) -> None:
|
|
self.batch_timer.end()
|
|
|
|
def on_loop_end(self) -> None:
|
|
avg_make_experience_duration = all_reduce_mean(self.make_experience_timer.duration, self.world_size)
|
|
avg_overall_duration = all_reduce_mean(self.batch_timer.duration, self.world_size)
|
|
avg_send_duration = all_reduce_mean(self.send_timer.duration, self.world_size)
|
|
|
|
avg_throughput = self.total_samples * self.world_size / (avg_overall_duration + 1e-12)
|
|
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
|
|
avg_time_per_sample = (avg_overall_duration + 1e-12) / (self.total_samples * self.world_size)
|
|
avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / \
|
|
(self.total_samples * self.world_size)
|
|
avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size)
|
|
|
|
print_rank_0(
|
|
'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' +
|
|
f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n' +
|
|
f'Sample time (overall): {avg_time_per_sample:.3f} s\n' +
|
|
f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
|
+
|
|
f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
|
)
|
|
|
|
|
|
class TrainerPerformanceEvaluator(TrainerCallback):
|
|
|
|
def __init__(self,
|
|
actor_num_params: int,
|
|
critic_num_params: int,
|
|
enable_grad_checkpoint: bool = False,
|
|
ignore_first_episodes: int = 1) -> None:
|
|
super().__init__()
|
|
self.world_size = get_world_size()
|
|
self.actor_num_params = actor_num_params
|
|
self.critic_num_params = critic_num_params
|
|
self.enable_grad_checkpoint = enable_grad_checkpoint
|
|
self.ignore_first_episodes = ignore_first_episodes
|
|
self.ignore_this_episode = False
|
|
|
|
self.episode_timer = Timer()
|
|
self.batch_timer = Timer()
|
|
self.update_timer = Timer()
|
|
self.total_samples: int = 0
|
|
self.learn_flop: int = 0
|
|
|
|
print_rank_0(
|
|
f'Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}'
|
|
)
|
|
|
|
def on_episode_start(self, episodes: int) -> None:
|
|
self.ignore_this_episode = episodes < self.ignore_first_episodes
|
|
if self.ignore_this_episode:
|
|
return
|
|
self.episode_timer.start()
|
|
|
|
def on_episode_end(self, episodes: int) -> None:
|
|
if self.ignore_this_episode:
|
|
return
|
|
self.episode_timer.end()
|
|
|
|
def on_batch_start(self) -> None:
|
|
if self.ignore_this_episode:
|
|
return
|
|
self.batch_timer.start()
|
|
|
|
def on_batch_end(self, metrics: dict, experience: Experience) -> None:
|
|
if self.ignore_this_episode:
|
|
return
|
|
self.batch_timer.end()
|
|
|
|
batch_size, seq_len = experience.sequences.shape
|
|
|
|
self.total_samples += batch_size
|
|
|
|
# actor forward-backward, 3 means forward(1) + backward(2)
|
|
self.learn_flop += self.actor_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
|
|
# critic forward-backward
|
|
self.learn_flop += self.critic_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
|
|
|
|
def on_update_start(self) -> None:
|
|
if self.ignore_this_episode:
|
|
return
|
|
self.update_timer.start()
|
|
|
|
def on_update_end(self) -> None:
|
|
if self.ignore_this_episode:
|
|
return
|
|
self.update_timer.end()
|
|
|
|
def on_fit_end(self) -> None:
|
|
if self.total_samples == 0:
|
|
print_rank_0('No samples are collected, skip trainer performance evaluation')
|
|
return
|
|
avg_train_duration = all_reduce_mean(self.batch_timer.duration, self.world_size)
|
|
avg_update_duration = all_reduce_mean(self.update_timer.duration, self.world_size)
|
|
avg_episode_duration = all_reduce_mean(self.episode_timer.duration, self.world_size)
|
|
|
|
avg_throughput = self.total_samples * self.world_size / (avg_episode_duration + 1e-12)
|
|
avg_learn_tflops = self.learn_flop / 1e12 / (avg_train_duration + 1e-12)
|
|
avg_time_per_sample = (avg_episode_duration + 1e-12) / (self.total_samples * self.world_size)
|
|
avg_train_time_per_sample = (avg_train_duration + 1e-12) / (self.total_samples * self.world_size)
|
|
avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size)
|
|
|
|
print_rank_0(
|
|
'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' +
|
|
f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n' +
|
|
f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
|
+
|
|
f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
|
)
|