[chat]: update rm, add wandb and fix bugs (#4471)

* feat: modify forward fn of critic and reward model

* feat: modify calc_action_log_probs

* to: add wandb in sft and rm trainer

* feat: update train_sft

* feat: update train_rm

* style: modify type annotation and add warning

* feat: pass tokenizer to ppo trainer

* to: modify trainer base and maker base

* feat: add wandb in ppo trainer

* feat: pass tokenizer to generate

* test: update generate fn tests

* test: update train tests

* fix: remove action_mask

* feat: remove unused code

* fix: fix wrong ignore_index

* fix: fix mock tokenizer

* chore: update requirements

* revert: modify make_experience

* fix: fix inference

* fix: add padding side

* style: modify _on_learn_batch_end

* test: use mock tokenizer

* fix: use bf16 to avoid overflow

* fix: fix workflow

* [chat] fix gemini strategy

* [chat] fix

* sync: update colossalai strategy

* fix: fix args and model dtype

* fix: fix checkpoint test

* fix: fix requirements

* fix: fix missing import and wrong arg

* fix: temporarily skip gemini test in stage 3

* style: apply pre-commit

* fix: temporarily skip gemini test in stage 1&2

---------

Co-authored-by: Mingyan Jiang <1829166702@qq.com>
This commit is contained in:
Wenhao Chen
2023-09-20 15:53:58 +08:00
committed by GitHub
parent 07c2e3d09c
commit 7b9b86441f
36 changed files with 382 additions and 332 deletions

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import copy
from typing import Dict, Sequence, Tuple
from typing import Dict, Optional, Sequence, Tuple
import torch
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
@@ -57,6 +57,7 @@ def _preprocess(
sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
)
assert sequences_token["attention_mask"].dim() == 2, "seq2seq model should be preprocessed differently"
labels = copy.deepcopy(sequences_token["input_ids"])
for i in range(labels.shape[0]):
source_len = sources_token["attention_mask"][i].sum().item()
@@ -64,9 +65,10 @@ def _preprocess(
if tokenizer.padding_side == "right":
# |prompt|completion|eos|pad|
labels[i][:source_len] = IGNORE_INDEX
labels[i][-pad_len:] = IGNORE_INDEX
elif tokenizer.padding_side == "left":
# |pad|prompt|completion|eos|
labels[i][pad_len : pad_len + source_len] = IGNORE_INDEX
labels[i][: pad_len + source_len] = IGNORE_INDEX
else:
raise RuntimeError()
@@ -126,6 +128,8 @@ class SFTDataset(Dataset):
sources = [data["prompt"] for data in dataset]
targets = [data["completion"] + tokenizer.eos_token for data in tqdm(dataset, disable=not is_rank_0())]
logger.info("Tokenizing inputs... This may take some time...")
if isinstance(tokenizer, ChatGLMTokenizer):
self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
sources, targets, tokenizer, max_length
@@ -133,6 +137,8 @@ class SFTDataset(Dataset):
else:
self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
logger.info("Loaded dataset.")
def __len__(self):
length = self.input_ids.shape[0]
return length
@@ -148,7 +154,11 @@ class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(
self, data_path: str, tokenizer: PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512
self,
data_path: str,
tokenizer: PreTrainedTokenizer,
max_datasets_size: Optional[int] = None,
max_length: int = 512,
):
super().__init__()
logger.info("Loading data...")
@@ -175,6 +185,8 @@ class SupervisedDataset(Dataset):
else:
self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
logger.info("Loaded dataset.")
def __len__(self):
length = self.input_ids.shape[0]
return length

View File

@@ -1,4 +1,5 @@
import random
import warnings
from typing import List
import torch
@@ -30,9 +31,11 @@ class NaiveExperienceBuffer(ExperienceBuffer):
experience.to_device(torch.device("cpu"))
items = split_experience_batch(experience)
self.items.extend(items)
if self.limit > 0:
samples_to_remove = len(self.items) - self.limit
if samples_to_remove > 0:
warnings.warn(f"Experience buffer is full. Removing {samples_to_remove} samples.")
self.items = self.items[samples_to_remove:]
def clear(self) -> None:

View File

@@ -3,8 +3,7 @@ from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from coati.models.base import Actor
from coati.models.base import Actor, Critic, RewardModel
@dataclass
@@ -59,16 +58,13 @@ class Experience:
class ExperienceMaker(ABC):
def __init__(
self, actor: Actor, critic: nn.Module, reward_model: nn.Module, initial_model: Actor, kl_coef: float = 0.1
) -> None:
def __init__(self, actor: Actor, critic: Critic, reward_model: RewardModel, initial_model: Actor) -> None:
super().__init__()
self.actor = actor
self.critic = critic
self.reward_model = reward_model
self.initial_model = initial_model
self.kl_coef = kl_coef
@abstractmethod
def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
def make_experience(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs) -> Experience:
pass

View File

@@ -1,7 +1,9 @@
import torch
import torch.nn.functional as F
from coati.models.base import Actor, Critic, RewardModel
from coati.models.generation import generate
from coati.models.utils import calc_action_log_probs, compute_reward
from transformers import PreTrainedTokenizer
from .base import Experience, ExperienceMaker
@@ -11,6 +13,19 @@ class NaiveExperienceMaker(ExperienceMaker):
Naive experience maker.
"""
def __init__(
self,
actor: Actor,
critic: Critic,
reward_model: RewardModel,
initial_model: Actor,
tokenizer: PreTrainedTokenizer,
kl_coef: float = 0.1,
) -> None:
super().__init__(actor, critic, reward_model, initial_model)
self.tokenizer = tokenizer
self.kl_coef = kl_coef
@torch.no_grad()
def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
self.actor.eval()
@@ -19,16 +34,16 @@ class NaiveExperienceMaker(ExperienceMaker):
self.reward_model.eval()
# generate sequences
sequences = generate(self.actor, input_ids, **generate_kwargs)
sequences = generate(self.actor, input_ids, self.tokenizer, **generate_kwargs)
# calculate auxiliary tensors
attention_mask = None
pad_token_id = generate_kwargs.get("pad_token_id", None)
pad_token_id = self.tokenizer.pad_token_id
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
input_len = input_ids.size(1)
eos_token_id = generate_kwargs.get("eos_token_id", None)
eos_token_id = self.tokenizer.eos_token_id
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
@@ -40,11 +55,11 @@ class NaiveExperienceMaker(ExperienceMaker):
action_mask = action_mask[:, -(sequences.size(1) - input_len) :]
num_actions = action_mask.size(1)
actor_output = self.actor(sequences, attention_mask)
actor_output = self.actor(sequences, attention_mask)["logits"]
action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions)
base_model_output = self.initial_model(sequences, attention_mask)
base_model_output = self.initial_model(sequences, attention_mask)["logits"]
base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions)
value = self.critic(sequences, action_mask, attention_mask)
value = self.critic(sequences, attention_mask)
r = self.reward_model(sequences, attention_mask)
reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)

View File

@@ -25,7 +25,7 @@ class Actor(LoRAModule):
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
**model_kwargs, # HACK: `generate` method may pass more kwargs
**model_kwargs,
) -> torch.Tensor:
"""Returns model output."""
output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)

View File

@@ -1,10 +1,7 @@
from typing import Optional
import torch
import torch.nn as nn
from ..lora import LoRAModule
from ..utils import masked_mean
class Critic(LoRAModule):
@@ -19,37 +16,19 @@ class Critic(LoRAModule):
"""
def __init__(
self,
model: nn.Module,
value_head: nn.Module,
lora_rank: int = 0,
lora_train_bias: str = "none",
use_action_mask: bool = False,
self, model: nn.Module, value_head: nn.Module, lora_rank: int = 0, lora_train_bias: str = "none"
) -> None:
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
self.model = model
self.value_head = value_head
self.use_action_mask = use_action_mask
self.convert_to_lora()
def forward(
self,
sequences: torch.LongTensor,
action_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
def forward(self, sequences: torch.LongTensor, attention_mask: torch.Tensor) -> torch.Tensor:
outputs = self.model(sequences, attention_mask=attention_mask)
last_hidden_states = outputs["last_hidden_state"]
values = self.value_head(last_hidden_states).squeeze(-1)
if action_mask is not None and self.use_action_mask:
num_actions = action_mask.size(1)
prompt_mask = attention_mask[:, :-num_actions]
values = values[:, :-num_actions]
value = masked_mean(values, prompt_mask, dim=1)
return value
values = values[:, :-1]
value = values.mean(dim=1)
return value
sequence_lengths = torch.max(attention_mask * torch.arange(sequences.size(1), device=sequences.device), dim=1)[
0
]
sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), sequence_lengths]
values = self.value_head(sequence_hidden_states).squeeze(1) # ensure shape is (B, )
return values

View File

@@ -35,9 +35,12 @@ class RewardModel(LoRAModule):
else:
self.value_head = nn.Linear(model.config.n_embd, 1)
def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, sequences: torch.LongTensor, attention_mask: torch.Tensor) -> torch.Tensor:
outputs = self.model(sequences, attention_mask=attention_mask)
last_hidden_states = outputs["last_hidden_state"]
values = self.value_head(last_hidden_states)[:, :-1]
value = values.mean(dim=1).squeeze(1) # ensure shape is (B)
return value
sequence_lengths = torch.max(attention_mask * torch.arange(sequences.size(1), device=sequences.device), dim=1)[
0
]
sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), sequence_lengths]
values = self.value_head(sequence_hidden_states).squeeze(1) # ensure shape is (B, )
return values

View File

@@ -2,6 +2,7 @@ from typing import Any, Callable, Optional
import torch
import torch.distributed as dist
from transformers import PreTrainedTokenizer
from .base import Actor
@@ -63,8 +64,8 @@ def _sample(
)
outputs = model(**model_inputs)
# NOTE: this is correct only in left padding mode
next_token_logits = outputs["logits"][:, -1, :]
# pre-process distribution
next_token_logits = logits_processor(input_ids, next_token_logits)
# sample
probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
@@ -72,8 +73,7 @@ def _sample(
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
assert pad_token_id is not None, "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs for next step
@@ -96,12 +96,11 @@ def _sample(
def generate(
model: Actor,
input_ids: torch.Tensor,
tokenizer: PreTrainedTokenizer,
max_length: int,
num_beams: int = 1,
do_sample: bool = True,
early_stopping: bool = False,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
@@ -118,14 +117,13 @@ def generate(
num_beams (int, optional): number of beams. Defaults to 1.
do_sample (bool, optional): whether to do sample. Defaults to True.
early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False.
eos_token_id (Optional[int], optional): end of sequence token id. Defaults to None.
pad_token_id (Optional[int], optional): pad token id. Defaults to None.
top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None.
temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None.
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
"""
assert tokenizer.padding_side == "left", "Current generation only supports left padding."
is_greedy_gen_mode = (num_beams == 1) and do_sample is False
is_sample_gen_mode = (num_beams == 1) and do_sample is True
is_beam_gen_mode = (num_beams > 1) and do_sample is False
@@ -139,8 +137,8 @@ def generate(
input_ids,
max_length,
early_stopping=early_stopping,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
top_k=top_k,
top_p=top_p,
temperature=temperature,

View File

@@ -13,6 +13,7 @@ class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
# NOTE: default ignore_index is -100, which is equal to IGNORE_INDEX in sft_dataset.py
self.loss = nn.CrossEntropyLoss()
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:

View File

@@ -46,18 +46,17 @@ def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.
return log_probs_labels.squeeze(-1)
def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
"""Calculate action log probs.
Args:
output (torch.Tensor): Output tensor of Actor.forward.
output (torch.Tensor): Output tensor of Actor.forward.logits.
sequences (torch.LongTensor): Input sequences.
num_actions (int): Number of actions.
Returns:
torch.Tensor: Action log probs.
"""
logits = output["logits"]
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]

View File

@@ -41,13 +41,13 @@ def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_ra
def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
if model == "gpt2":
critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
elif model == "bloom":
critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
elif model == "opt":
critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
elif model == "llama":
critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
else:
raise ValueError(f'Unsupported reward model "{model}"')
return critic

View File

@@ -7,11 +7,10 @@ import tqdm
from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import Experience
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from .callbacks import Callback
from .strategies import Strategy
from .utils import CycledDataLoader, is_rank_0
from .utils import is_rank_0
class SLTrainer(ABC):
@@ -47,11 +46,11 @@ class SLTrainer(ABC):
raise NotImplementedError()
def _before_fit(self):
self.no_epoch_bar = False
raise NotImplementedError()
def fit(self, *args, **kwargs):
self._before_fit(*args, **kwargs)
for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0() or self.no_epoch_bar):
for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0()):
self._train(epoch)
self._eval(epoch)
@@ -123,9 +122,9 @@ class OnPolicyTrainer(ABC):
for callback in self.callbacks:
callback.on_learn_batch_start()
def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
def _on_learn_batch_end(self, experience: Experience) -> None:
for callback in self.callbacks:
callback.on_learn_batch_end(metrics, experience)
callback.on_learn_batch_end(experience)
@abstractmethod
def _make_experience(self, collect_step: int):
@@ -153,27 +152,26 @@ class OnPolicyTrainer(ABC):
self._learn(update_step)
self._on_learn_epoch_end(update_step)
def _before_fit(self, *args, **kwargs):
raise NotImplementedError()
def fit(
self,
prompt_dataloader: DataLoader,
pretrain_dataloader: DataLoader,
num_episodes: int,
num_collect_steps: int,
num_update_steps: int,
*args,
**kwargs,
):
"""
The main training loop of on-policy rl trainers.
Args:
prompt_dataloader (DataLoader): the dataloader to use for prompt data
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
num_episodes (int): the number of episodes to train
num_collect_steps (int): the number of collect steps per episode
num_update_steps (int): the number of update steps per episode
"""
self.prompt_dataloader = CycledDataLoader(prompt_dataloader)
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)
self._before_fit(*args, **kwargs)
with self._fit_ctx():
for episode in tqdm.trange(num_episodes, desc="Episodes", disable=not is_rank_0()):
with self._episode_ctx(episode):

View File

@@ -35,5 +35,5 @@ class Callback(ABC):
def on_learn_batch_start(self) -> None:
pass
def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
def on_learn_batch_end(self, experience: Experience) -> None:
pass

View File

@@ -137,7 +137,7 @@ class PerformanceEvaluator(Callback):
return
self.learn_timer.start()
def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
def on_learn_batch_end(self, experience: Experience) -> None:
if self.disable:
return
self.learn_timer.end()

View File

@@ -1,27 +1,26 @@
from typing import Dict, List
from typing import Dict, List, Optional
import torch.nn as nn
from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models.base import Actor, Critic, get_base_model
from coati.models.base import Actor, Critic, RewardModel, get_base_model
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
from coati.models.utils import calc_action_log_probs
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DistributedSampler
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from transformers import PreTrainedTokenizerBase
from colossalai.utils import get_current_device
from .base import OnPolicyTrainer
from .callbacks import Callback
from .strategies import GeminiStrategy, Strategy
from .utils import is_rank_0, to_device
from .utils import CycledDataLoader, is_rank_0, to_device
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
unwrapper_model = strategy.unwrap_model(actor)
hf_model = get_base_model(unwrapper_model)
unwrapped_model = strategy.unwrap_model(actor)
hf_model = get_base_model(unwrapped_model)
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
if "prepare_inputs_fn" not in generate_kwargs and hasattr(hf_model, "prepare_inputs_for_generation"):
@@ -41,7 +40,7 @@ class PPOTrainer(OnPolicyTrainer):
strategy (Strategy): the strategy to use for training
actor (Actor): the actor model in ppo algorithm
critic (Critic): the critic model in ppo algorithm
reward_model (nn.Module): the reward model in rlhf algorithm to make reward of sentences
reward_model (RewardModel): the reward model in rlhf algorithm to make reward of sentences
initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor
actor_optim (Optimizer): the optimizer to use for actor model
critic_optim (Optimizer): the optimizer to use for critic model
@@ -65,10 +64,11 @@ class PPOTrainer(OnPolicyTrainer):
strategy: Strategy,
actor: Actor,
critic: Critic,
reward_model: nn.Module,
reward_model: RewardModel,
initial_model: Actor,
actor_optim: Optimizer,
critic_optim: Optimizer,
tokenizer: PreTrainedTokenizerBase,
kl_coef: float = 0.1,
ptx_coef: float = 0.9,
train_batch_size: int = 8,
@@ -90,11 +90,11 @@ class PPOTrainer(OnPolicyTrainer):
super().__init__(strategy, data_buffer, sample_buffer, dataloader_pin_memory, callbacks)
self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
self.offload_inference_models = offload_inference_models
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer, kl_coef)
self.actor = actor
self.critic = critic
self.tokenizer = tokenizer
self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip)
@@ -104,58 +104,81 @@ class PPOTrainer(OnPolicyTrainer):
self.actor_optim = actor_optim
self.critic_optim = critic_optim
self.offload_inference_models = offload_inference_models
self.device = get_current_device()
def _before_fit(
self,
prompt_dataloader: DataLoader,
pretrain_dataloader: DataLoader,
log_dir: Optional[str] = None,
use_wandb: bool = False,
):
"""
Args:
prompt_dataloader (DataLoader): the dataloader to use for prompt data
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
"""
self.prompt_dataloader = CycledDataLoader(prompt_dataloader)
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)
self.writer = None
if use_wandb and is_rank_0():
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
import wandb
wandb.init(project="Coati-ppo", sync_tensorboard=True)
if log_dir is not None and is_rank_0():
import os
import time
from torch.utils.tensorboard import SummaryWriter
log_dir = os.path.join(log_dir, "ppo")
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir)
def _make_experience(self, collect_step: int) -> Experience:
prompts = self.prompt_dataloader.next()
if self.offload_inference_models:
# TODO(ver217): this may be controlled by strategy if they are prepared by strategy
self.experience_maker.initial_model.to(self.device)
self.experience_maker.reward_model.to(self.device)
if isinstance(prompts, Tensor):
return self.experience_maker.make_experience(prompts, **self.generate_kwargs)
elif isinstance(prompts, dict):
return self.experience_maker.make_experience(**prompts, **self.generate_kwargs)
else:
raise ValueError(f'Unsupported input type "{type(prompts)}"')
assert isinstance(prompts, dict), f'Unsupported input type "{type(prompts)}"'
return self.experience_maker.make_experience(**prompts, **self.generate_kwargs)
def _training_step(self, experience: Experience) -> Dict[str, float]:
def _training_step(self, experience: Experience):
self.actor.train()
self.critic.train()
# policy loss
num_actions = experience.action_mask.size(1)
actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask)
action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions)
num_actions = experience.action_log_probs.size(1)
actor_logits = self.actor(experience.sequences, experience.attention_mask)["logits"]
action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)
actor_loss = self.actor_loss_fn(
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
)
actor_loss = (1 - self.ptx_coef) * actor_loss
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
# ptx loss
if self.ptx_coef != 0:
batch = self.pretrain_dataloader.next()
batch = to_device(batch, self.device)
ptx_log_probs = self.actor(batch["input_ids"], attention_mask=batch["attention_mask"])["logits"]
ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch["labels"])
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)
ptx_log_probs = self.actor(batch["input_ids"], batch["attention_mask"])["logits"]
ptx_loss = self.ptx_coef * self.ptx_loss_fn(ptx_log_probs, batch["labels"])
self.strategy.backward(ptx_loss, self.actor, self.actor_optim)
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
self.strategy.optimizer_step(self.actor_optim)
self.actor_optim.zero_grad()
# value loss
values = self.critic(
experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
)
critic_loss = self.critic_loss_fn(
values, experience.values, experience.reward, action_mask=experience.action_mask
)
values = self.critic(experience.sequences, attention_mask=experience.attention_mask)
critic_loss = self.critic_loss_fn(values, experience.values, experience.reward)
critic_loss = critic_loss * self.vf_coef
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
self.strategy.optimizer_step(self.critic_optim)
self.critic_optim.zero_grad()
return {"reward": experience.reward.mean().item()}
def _learn(self, update_step: int):
if self.offload_inference_models:
self.experience_maker.initial_model.to("cpu")
@@ -166,8 +189,8 @@ class PPOTrainer(OnPolicyTrainer):
experience = self.data_buffer.sample()
self._on_learn_batch_start()
experience.to_device(self.device)
metrics = self._training_step(experience)
self._on_learn_batch_end(metrics, experience)
self._training_step(experience)
self._on_learn_batch_end(experience)
else:
if isinstance(self.dataloader.sampler, DistributedSampler):
self.dataloader.sampler.set_epoch(update_step)
@@ -175,6 +198,5 @@ class PPOTrainer(OnPolicyTrainer):
for experience in pbar:
self._on_learn_batch_start()
experience.to_device(self.device)
metrics = self._training_step(experience)
self._on_learn_batch_end(metrics, experience)
pbar.set_postfix(metrics)
self._training_step(experience)
self._on_learn_batch_end(experience)

View File

@@ -1,7 +1,5 @@
from datetime import datetime
from typing import Callable
from typing import Callable, Optional
import pandas as pd
import torch
import tqdm
from torch.optim import Optimizer
@@ -40,10 +38,12 @@ class RewardModelTrainer(SLTrainer):
self.loss_fn = loss_fn
self.scheduler = lr_scheduler
self.num_train_step = 0
def _eval(self, epoch):
if self.eval_dataloader is not None:
self.model.eval()
dist, on, cnt = 0, 0, 0
dist, num_correct, num_samples = 0, 0, 0
with torch.no_grad():
for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader:
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
@@ -52,27 +52,21 @@ class RewardModelTrainer(SLTrainer):
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
for i in range(len(chosen_reward)):
cnt += 1
if chosen_reward[i] > reject_reward[i]:
on += 1
num_samples += chosen_ids.size(0)
num_correct += (chosen_reward > reject_reward).sum().item()
dist += (chosen_reward - reject_reward).mean().item()
self.dist = dist / len(self.eval_dataloader)
self.acc = on / cnt
self.acc = num_correct / num_samples
if is_rank_0():
log = pd.DataFrame(
[[(epoch + 1) * len(self.train_dataloader), self.loss.item(), self.dist, self.acc]],
columns=["step", "loss", "dist", "acc"],
)
log.to_csv("log.csv", mode="a", header=False, index=False)
if self.writer:
self.writer.add_scalar("eval/dist", self.dist, epoch)
self.writer.add_scalar("eval/acc", self.acc, epoch)
def _train(self, epoch):
self.model.train()
step_bar = tqdm.trange(
len(self.train_dataloader), desc="Train step of epoch %d" % epoch, disable=not is_rank_0()
len(self.train_dataloader), desc=f"Epoch {epoch + 1}/{self.max_epochs}", disable=not is_rank_0()
)
cnt = 0
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
@@ -80,26 +74,50 @@ class RewardModelTrainer(SLTrainer):
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
self.loss = self.loss_fn(chosen_reward, reject_reward)
self.strategy.backward(self.loss, self.model, self.optimizer)
loss = self.loss_fn(chosen_reward, reject_reward)
self.strategy.backward(loss, self.model, self.optimizer)
self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad()
cnt += 1
if cnt % 100 == 0:
if self.writer:
self.writer.add_scalar("train/loss", loss.item(), self.num_train_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
self.writer.add_scalar("train/dist", (chosen_reward - reject_reward).mean().item(), self.num_train_step)
self.writer.add_scalar(
"train/acc", (chosen_reward > reject_reward).float().mean().item(), self.num_train_step
)
self.num_train_step += 1
if self.num_train_step % 100 == 0:
self.scheduler.step()
step_bar.update()
step_bar.close()
def _before_fit(self, train_dataloader: DataLoader, valid_dataloader: DataLoader, eval_dataloader: DataLoader):
def _before_fit(
self,
train_dataloader: DataLoader,
eval_dataloader: DataLoader,
log_dir: Optional[str] = None,
use_wandb: bool = False,
):
"""
Args:
train_dataloader (DataLoader): the dataloader to use for training
valid_dataloader (DataLoader): the dataloader to use for validation
eval_dataloader (DataLoader): the dataloader to use for evaluation
"""
super()._before_fit()
self.datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
self.train_dataloader = train_dataloader
self.valid_dataloader = valid_dataloader
self.eval_dataloader = eval_dataloader
self.writer = None
if use_wandb and is_rank_0():
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
import wandb
wandb.init(project="Coati-rm", sync_tensorboard=True)
if log_dir is not None and is_rank_0():
import os
import time
from torch.utils.tensorboard import SummaryWriter
log_dir = os.path.join(log_dir, "rm")
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir)

View File

@@ -1,10 +1,8 @@
import time
from typing import Optional
import torch
import torch.distributed as dist
import tqdm
import wandb
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
@@ -48,38 +46,34 @@ class SFTTrainer(SLTrainer):
self.accumulation_steps = accumulation_steps
self.scheduler = lr_scheduler
self.num_train_step = 0
self.num_eval_step = 0
def _train(self, epoch: int):
self.model.train()
for batch_id, batch in enumerate(self.train_dataloader):
step_bar = tqdm.trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
if "attention_mask" in batch:
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
else:
outputs = self.model(batch["input_ids"], labels=batch["labels"])
loss = outputs.loss
loss = loss / self.accumulation_steps
self.strategy.backward(loss, self.model, self.optimizer)
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
loss = outputs.loss / self.accumulation_steps
self.total_loss += loss.item()
self.strategy.backward(loss, self.model, self.optimizer)
# gradient accumulation
if (batch_id + 1) % self.accumulation_steps == 0:
if (i + 1) % self.accumulation_steps == 0:
self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad()
self.scheduler.step()
if is_rank_0() and self.use_wandb:
wandb.log(
{
"loss": self.total_loss / self.accumulation_steps,
"lr": self.scheduler.get_last_lr()[0],
"epoch": epoch,
"batch_id": batch_id,
}
)
if self.writer:
self.writer.add_scalar("train/loss", self.total_loss, self.num_train_step)
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
self.num_train_step += 1
self.total_loss = 0
self.step_bar.update()
step_bar.update()
step_bar.close()
def _eval(self, epoch: int):
if self.eval_dataloader is not None:
@@ -91,20 +85,21 @@ class SFTTrainer(SLTrainer):
outputs = self.model(
batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]
)
loss = outputs.loss
loss_sum += loss.item()
loss_sum += outputs.loss.item()
num_seen += batch["input_ids"].size(0)
loss_mean = loss_sum / num_seen
if dist.get_rank() == 0:
self.logger.info(f"Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}")
if self.writer:
self.writer.add_scalar("eval/loss", loss_mean, self.num_eval_step)
self.num_eval_step += 1
def _before_fit(
self,
train_dataloader: DataLoader,
eval_dataloader: Optional[DataLoader] = None,
logger: Optional[DistributedLogger] = None,
log_dir: Optional[str] = None,
use_wandb: bool = False,
):
"""
@@ -116,15 +111,20 @@ class SFTTrainer(SLTrainer):
self.eval_dataloader = eval_dataloader
self.logger = logger
self.use_wandb = use_wandb
if use_wandb:
wandb.init(project="Coati", name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
wandb.watch(self.model)
self.writer = None
if use_wandb and is_rank_0():
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
import wandb
wandb.init(project="Coati-sft", sync_tensorboard=True)
if log_dir is not None and is_rank_0():
import os
import time
from torch.utils.tensorboard import SummaryWriter
log_dir = os.path.join(log_dir, "sft")
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir)
self.total_loss = 0
self.no_epoch_bar = True
self.step_bar = tqdm.trange(
len(self.train_dataloader) // self.accumulation_steps * self.max_epochs,
desc=f"steps",
disable=not is_rank_0(),
)

View File

@@ -1,17 +1,13 @@
import warnings
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
import colossalai
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
from colossalai.booster.plugin.gemini_plugin import GeminiModel
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
from .ddp import DDPStrategy
@@ -65,14 +61,11 @@ class LowLevelZeroStrategy(DDPStrategy):
assert precision in ("fp32", "fp16"), f'Unsupported precision "{precision}"'
plugin_initializer = lambda: LowLevelZeroPlugin(
# zero_config
stage=stage,
precision=precision,
# zero_optim_config
reduce_bucket_size_in_m=reduce_bucket_size,
overlap_communication=overlap_communication,
cpu_offload=(placement_policy == "cpu"),
# optim_config
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
@@ -136,7 +129,7 @@ class GeminiStrategy(DDPStrategy):
self,
seed: int = 42,
shard_init: bool = False, # only for stage 3
placement_policy: str = "cuda",
placement_policy: str = "auto",
pin_memory: bool = True, # only for stage 3
force_outputs_fp32: bool = False, # only for stage 3
search_range_m: int = 32, # only for stage 3
@@ -153,8 +146,6 @@ class GeminiStrategy(DDPStrategy):
max_norm: float = 0.0,
norm_type: float = 2.0,
) -> None:
assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"'
# TODO(ver217): support shard_init when using from_pretrained()
if shard_init:
warnings.warn(
@@ -167,8 +158,7 @@ class GeminiStrategy(DDPStrategy):
# NOTE: dist should be initialized before calling get_current_device()
plugin_initializer = lambda: GeminiPlugin(
# gemini_config
device=get_current_device(),
chunk_init_device=get_current_device(),
placement_policy=placement_policy,
precision="fp16",
pin_memory=pin_memory,
@@ -177,9 +167,7 @@ class GeminiStrategy(DDPStrategy):
search_range_m=search_range_m,
hidden_dim=hidden_dim,
min_chunk_size_m=min_chunk_size_m,
# zero_optim_config
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
# optim_config
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
@@ -200,15 +188,8 @@ class GeminiStrategy(DDPStrategy):
colossalai.launch_from_torch({}, seed=self.seed)
def model_init_context(self):
world_size = dist.get_world_size()
shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
return ColoInitContext(
device=get_current_device(), dtype=torch.half, default_pg=shard_pg, default_dist_spec=default_dist_spec
)
return LazyInitContext(default_device=get_current_device())
def unwrap_model(self, model: nn.Module) -> nn.Module:
assert isinstance(model, GeminiModel)
ddp_model = model.unwrap()
assert isinstance(ddp_model, GeminiDDP)
return ddp_model.module
assert isinstance(model, GeminiDDP)
return model.module