mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-20 09:01:06 +00:00
[chat] refactor trainer (#3648)
* [chat] ppo trainer remove useless args * [chat] update examples * [chat] update benchmark * [chat] update examples * [chat] fix sft training with wandb * [chat] polish docstr
This commit is contained in:
@@ -15,7 +15,6 @@ class Trainer(ABC):
|
||||
Args:
|
||||
strategy (Strategy):the strategy to use for training
|
||||
max_epochs (int, defaults to 1): the number of epochs of training process
|
||||
tokenizer (Callable, optional): the tokenizer to use for tokenizing the input
|
||||
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||
generate_kwargs (dict, optional): the kwargs to use while model generating
|
||||
@@ -24,14 +23,12 @@ class Trainer(ABC):
|
||||
def __init__(self,
|
||||
strategy: Strategy,
|
||||
max_epochs: int = 1,
|
||||
tokenizer: Optional[Callable[[Any], dict]] = None,
|
||||
dataloader_pin_memory: bool = True,
|
||||
callbacks: List[Callback] = [],
|
||||
**generate_kwargs) -> None:
|
||||
super().__init__()
|
||||
self.strategy = strategy
|
||||
self.max_epochs = max_epochs
|
||||
self.tokenizer = tokenizer
|
||||
self.generate_kwargs = generate_kwargs
|
||||
self.dataloader_pin_memory = dataloader_pin_memory
|
||||
self.callbacks = callbacks
|
||||
|
@@ -4,7 +4,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from coati.experience_maker import Experience, NaiveExperienceMaker
|
||||
from coati.models.base import Actor, Critic
|
||||
from coati.models.loss import PolicyLoss, ValueLoss
|
||||
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
|
||||
from coati.replay_buffer import NaiveReplayBuffer
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
@@ -12,10 +12,12 @@ from torch.utils.data import DistributedSampler
|
||||
from tqdm import tqdm
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .base import Trainer
|
||||
from .callbacks import Callback
|
||||
from .strategies import Strategy
|
||||
from .utils import is_rank_0
|
||||
from .utils import is_rank_0, to_device
|
||||
|
||||
|
||||
class PPOTrainer(Trainer):
|
||||
@@ -38,11 +40,10 @@ class PPOTrainer(Trainer):
|
||||
vf_coef (float, defaults to 1.0): the coefficient of value loss
|
||||
ptx_coef (float, defaults to 0.9): the coefficient of ptx loss
|
||||
value_clip (float, defaults to 0.4): the clip coefficient of value loss
|
||||
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
|
||||
max_epochs (int, defaults to 1): the number of epochs of training process
|
||||
tokenizer (Callable, optional): the tokenizer to use for tokenizing the input
|
||||
sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer
|
||||
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
||||
offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process
|
||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||
generate_kwargs (dict, optional): the kwargs to use while model generating
|
||||
"""
|
||||
@@ -63,22 +64,21 @@ class PPOTrainer(Trainer):
|
||||
eps_clip: float = 0.2,
|
||||
vf_coef: float = 1.0,
|
||||
value_clip: float = 0.4,
|
||||
experience_batch_size: int = 8,
|
||||
max_epochs: int = 1,
|
||||
tokenizer: Optional[Callable[[Any], dict]] = None,
|
||||
sample_replay_buffer: bool = False,
|
||||
dataloader_pin_memory: bool = True,
|
||||
offload_inference_models: bool = True,
|
||||
callbacks: List[Callback] = [],
|
||||
**generate_kwargs) -> None:
|
||||
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
|
||||
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
||||
generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
|
||||
super().__init__(strategy, max_epochs, tokenizer, dataloader_pin_memory, callbacks, **generate_kwargs)
|
||||
super().__init__(strategy, max_epochs, dataloader_pin_memory, callbacks, **generate_kwargs)
|
||||
|
||||
self.experience_maker = experience_maker
|
||||
self.replay_buffer = replay_buffer
|
||||
self.experience_batch_size = experience_batch_size
|
||||
self.sample_replay_buffer = sample_replay_buffer
|
||||
self.offload_inference_models = offload_inference_models
|
||||
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
@@ -86,11 +86,13 @@ class PPOTrainer(Trainer):
|
||||
self.actor_loss_fn = PolicyLoss(eps_clip)
|
||||
self.critic_loss_fn = ValueLoss(value_clip)
|
||||
self.vf_coef = vf_coef
|
||||
self.ptx_loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
|
||||
self.ptx_loss_fn = GPTLMLoss()
|
||||
self.ptx_coef = ptx_coef
|
||||
self.actor_optim = actor_optim
|
||||
self.critic_optim = critic_optim
|
||||
|
||||
self.device = get_current_device()
|
||||
|
||||
def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
|
||||
if isinstance(inputs, Tensor):
|
||||
return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
|
||||
@@ -99,20 +101,15 @@ class PPOTrainer(Trainer):
|
||||
else:
|
||||
raise ValueError(f'Unsupported input type "{type(inputs)}"')
|
||||
|
||||
def _sample_prompts(self, prompts) -> list:
|
||||
indices = list(range(len(prompts)))
|
||||
sampled_indices = self.strategy.experience_sampler.choice(indices, self.experience_batch_size, replace=False)
|
||||
return [prompts[i] for i in sampled_indices]
|
||||
|
||||
def _learn(self):
|
||||
# replay buffer may be empty at first, we should rebuild at each training
|
||||
if not self.sample_replay_buffer:
|
||||
dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory)
|
||||
device = torch.cuda.current_device()
|
||||
if self.sample_replay_buffer:
|
||||
pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
|
||||
for _ in pbar:
|
||||
experience = self.replay_buffer.sample()
|
||||
experience.to_device(self.device)
|
||||
metrics = self.training_step(experience)
|
||||
pbar.set_postfix(metrics)
|
||||
else:
|
||||
@@ -123,7 +120,7 @@ class PPOTrainer(Trainer):
|
||||
pbar = tqdm(dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0())
|
||||
for experience in pbar:
|
||||
self._on_learn_batch_start()
|
||||
experience.to_device(device)
|
||||
experience.to_device(self.device)
|
||||
metrics = self.training_step(experience)
|
||||
self._on_learn_batch_end(metrics, experience)
|
||||
pbar.set_postfix(metrics)
|
||||
@@ -147,14 +144,17 @@ class PPOTrainer(Trainer):
|
||||
time += 1
|
||||
prompts = next(iter(self.prompt_dataloader))
|
||||
self._on_make_experience_start()
|
||||
self.experience_maker.initial_model.to(torch.cuda.current_device())
|
||||
self.experience_maker.reward_model.to(torch.cuda.current_device())
|
||||
if self.offload_inference_models:
|
||||
# TODO(ver217): this may be controlled by strategy if they are prepared by strategy
|
||||
self.experience_maker.initial_model.to(self.device)
|
||||
self.experience_maker.reward_model.to(self.device)
|
||||
experience = self._make_experience(prompts)
|
||||
self._on_make_experience_end(experience)
|
||||
self.replay_buffer.append(experience)
|
||||
if time % update_timesteps == 0:
|
||||
self.experience_maker.initial_model.to('cpu')
|
||||
self.experience_maker.reward_model.to('cpu')
|
||||
if self.offload_inference_models:
|
||||
self.experience_maker.initial_model.to('cpu')
|
||||
self.experience_maker.reward_model.to('cpu')
|
||||
self._learn()
|
||||
self.replay_buffer.clear()
|
||||
self._on_episode_end(episode)
|
||||
@@ -174,11 +174,10 @@ class PPOTrainer(Trainer):
|
||||
# ptx loss
|
||||
if self.ptx_coef != 0:
|
||||
batch = next(iter(self.pretrain_dataloader))
|
||||
ptx = batch['input_ids'].to(torch.cuda.current_device())
|
||||
label = batch['labels'].to(torch.cuda.current_device())[:, 1:]
|
||||
attention_mask = batch['attention_mask'].to(torch.cuda.current_device())
|
||||
ptx_log_probs = self.actor.get_base_model()(ptx, attention_mask=attention_mask)['logits'][..., :-1, :]
|
||||
ptx_loss = self.ptx_loss_fn(ptx_log_probs.view(-1, ptx_log_probs.size(-1)), label.view(-1))
|
||||
batch = to_device(batch, self.device)
|
||||
ptx_log_probs = self.actor.get_base_model()(batch['input_ids'],
|
||||
attention_mask=batch['attention_mask'])['logits']
|
||||
ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels'])
|
||||
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)
|
||||
|
||||
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import math
|
||||
import time
|
||||
from typing import Optional, List
|
||||
from typing import List, Optional
|
||||
|
||||
import loralib as lora
|
||||
import torch
|
||||
@@ -18,8 +18,8 @@ from transformers.trainer import get_scheduler
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .callbacks import Callback
|
||||
from .base import Trainer
|
||||
from .callbacks import Callback
|
||||
from .strategies import Strategy
|
||||
from .utils import is_rank_0
|
||||
|
||||
@@ -70,9 +70,10 @@ class SFTTrainer(Trainer):
|
||||
num_warmup_steps=math.ceil(max_steps * 0.03),
|
||||
num_training_steps=max_steps)
|
||||
|
||||
def fit(self, logger, log_interval=10):
|
||||
wandb.init(project="Coati", name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||
wandb.watch(self.model)
|
||||
def fit(self, logger, use_wandb: bool = False):
|
||||
if use_wandb:
|
||||
wandb.init(project="Coati", name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||
wandb.watch(self.model)
|
||||
total_loss = 0
|
||||
# epoch_bar = tqdm(range(self.epochs), desc='Epochs', disable=not is_rank_0())
|
||||
step_bar = tqdm(range(len(self.train_dataloader) // self.accimulation_steps * self.max_epochs),
|
||||
@@ -111,7 +112,7 @@ class SFTTrainer(Trainer):
|
||||
self.strategy.optimizer_step(self.optimizer)
|
||||
self.optimizer.zero_grad()
|
||||
self.scheduler.step()
|
||||
if is_rank_0():
|
||||
if is_rank_0() and use_wandb:
|
||||
wandb.log({
|
||||
"loss": total_loss / self.accimulation_steps,
|
||||
"lr": self.scheduler.get_last_lr()[0],
|
||||
|
@@ -1,14 +1,19 @@
|
||||
import torch.distributed as dist
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from coati.models.bloom import BLOOMActor, BLOOMCritic
|
||||
from coati.models.gpt import GPTActor, GPTCritic
|
||||
from coati.models.opt import OPTActor, OPTCritic
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import os
|
||||
import torch.distributed as dist
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
||||
def is_rank_0() -> bool:
|
||||
return not dist.is_initialized() or dist.get_rank() == 0
|
||||
|
||||
|
||||
def to_device(x: Any, device: torch.device) -> Any:
|
||||
|
||||
def _to(t: Any):
|
||||
if isinstance(t, torch.Tensor):
|
||||
return t.to(device)
|
||||
return t
|
||||
|
||||
return tree_map(_to, x)
|
||||
|
Reference in New Issue
Block a user