mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 18:09:06 +00:00
[chat]: update rm, add wandb and fix bugs (#4471)
* feat: modify forward fn of critic and reward model * feat: modify calc_action_log_probs * to: add wandb in sft and rm trainer * feat: update train_sft * feat: update train_rm * style: modify type annotation and add warning * feat: pass tokenizer to ppo trainer * to: modify trainer base and maker base * feat: add wandb in ppo trainer * feat: pass tokenizer to generate * test: update generate fn tests * test: update train tests * fix: remove action_mask * feat: remove unused code * fix: fix wrong ignore_index * fix: fix mock tokenizer * chore: update requirements * revert: modify make_experience * fix: fix inference * fix: add padding side * style: modify _on_learn_batch_end * test: use mock tokenizer * fix: use bf16 to avoid overflow * fix: fix workflow * [chat] fix gemini strategy * [chat] fix * sync: update colossalai strategy * fix: fix args and model dtype * fix: fix checkpoint test * fix: fix requirements * fix: fix missing import and wrong arg * fix: temporarily skip gemini test in stage 3 * style: apply pre-commit * fix: temporarily skip gemini test in stage 1&2 --------- Co-authored-by: Mingyan Jiang <1829166702@qq.com>
This commit is contained in:
@@ -25,7 +25,7 @@ class Actor(LoRAModule):
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**model_kwargs, # HACK: `generate` method may pass more kwargs
|
||||
**model_kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Returns model output."""
|
||||
output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
|
||||
|
@@ -1,10 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..lora import LoRAModule
|
||||
from ..utils import masked_mean
|
||||
|
||||
|
||||
class Critic(LoRAModule):
|
||||
@@ -19,37 +16,19 @@ class Critic(LoRAModule):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
value_head: nn.Module,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = "none",
|
||||
use_action_mask: bool = False,
|
||||
self, model: nn.Module, value_head: nn.Module, lora_rank: int = 0, lora_train_bias: str = "none"
|
||||
) -> None:
|
||||
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
||||
self.model = model
|
||||
self.value_head = value_head
|
||||
self.use_action_mask = use_action_mask
|
||||
self.convert_to_lora()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sequences: torch.LongTensor,
|
||||
action_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
def forward(self, sequences: torch.LongTensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
outputs = self.model(sequences, attention_mask=attention_mask)
|
||||
last_hidden_states = outputs["last_hidden_state"]
|
||||
|
||||
values = self.value_head(last_hidden_states).squeeze(-1)
|
||||
|
||||
if action_mask is not None and self.use_action_mask:
|
||||
num_actions = action_mask.size(1)
|
||||
prompt_mask = attention_mask[:, :-num_actions]
|
||||
values = values[:, :-num_actions]
|
||||
value = masked_mean(values, prompt_mask, dim=1)
|
||||
return value
|
||||
|
||||
values = values[:, :-1]
|
||||
value = values.mean(dim=1)
|
||||
return value
|
||||
sequence_lengths = torch.max(attention_mask * torch.arange(sequences.size(1), device=sequences.device), dim=1)[
|
||||
0
|
||||
]
|
||||
sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), sequence_lengths]
|
||||
values = self.value_head(sequence_hidden_states).squeeze(1) # ensure shape is (B, )
|
||||
return values
|
||||
|
@@ -35,9 +35,12 @@ class RewardModel(LoRAModule):
|
||||
else:
|
||||
self.value_head = nn.Linear(model.config.n_embd, 1)
|
||||
|
||||
def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def forward(self, sequences: torch.LongTensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
outputs = self.model(sequences, attention_mask=attention_mask)
|
||||
last_hidden_states = outputs["last_hidden_state"]
|
||||
values = self.value_head(last_hidden_states)[:, :-1]
|
||||
value = values.mean(dim=1).squeeze(1) # ensure shape is (B)
|
||||
return value
|
||||
sequence_lengths = torch.max(attention_mask * torch.arange(sequences.size(1), device=sequences.device), dim=1)[
|
||||
0
|
||||
]
|
||||
sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), sequence_lengths]
|
||||
values = self.value_head(sequence_hidden_states).squeeze(1) # ensure shape is (B, )
|
||||
return values
|
||||
|
@@ -2,6 +2,7 @@ from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from .base import Actor
|
||||
|
||||
@@ -63,8 +64,8 @@ def _sample(
|
||||
)
|
||||
outputs = model(**model_inputs)
|
||||
|
||||
# NOTE: this is correct only in left padding mode
|
||||
next_token_logits = outputs["logits"][:, -1, :]
|
||||
# pre-process distribution
|
||||
next_token_logits = logits_processor(input_ids, next_token_logits)
|
||||
# sample
|
||||
probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
|
||||
@@ -72,8 +73,7 @@ def _sample(
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if eos_token_id is not None:
|
||||
if pad_token_id is None:
|
||||
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
||||
assert pad_token_id is not None, "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||
|
||||
# update generated ids, model inputs for next step
|
||||
@@ -96,12 +96,11 @@ def _sample(
|
||||
def generate(
|
||||
model: Actor,
|
||||
input_ids: torch.Tensor,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_length: int,
|
||||
num_beams: int = 1,
|
||||
do_sample: bool = True,
|
||||
early_stopping: bool = False,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
@@ -118,14 +117,13 @@ def generate(
|
||||
num_beams (int, optional): number of beams. Defaults to 1.
|
||||
do_sample (bool, optional): whether to do sample. Defaults to True.
|
||||
early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False.
|
||||
eos_token_id (Optional[int], optional): end of sequence token id. Defaults to None.
|
||||
pad_token_id (Optional[int], optional): pad token id. Defaults to None.
|
||||
top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
|
||||
top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None.
|
||||
temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None.
|
||||
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
|
||||
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
|
||||
"""
|
||||
assert tokenizer.padding_side == "left", "Current generation only supports left padding."
|
||||
is_greedy_gen_mode = (num_beams == 1) and do_sample is False
|
||||
is_sample_gen_mode = (num_beams == 1) and do_sample is True
|
||||
is_beam_gen_mode = (num_beams > 1) and do_sample is False
|
||||
@@ -139,8 +137,8 @@ def generate(
|
||||
input_ids,
|
||||
max_length,
|
||||
early_stopping=early_stopping,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
|
@@ -13,6 +13,7 @@ class GPTLMLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# NOTE: default ignore_index is -100, which is equal to IGNORE_INDEX in sft_dataset.py
|
||||
self.loss = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||
|
@@ -46,18 +46,17 @@ def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.
|
||||
return log_probs_labels.squeeze(-1)
|
||||
|
||||
|
||||
def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
|
||||
def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
|
||||
"""Calculate action log probs.
|
||||
|
||||
Args:
|
||||
output (torch.Tensor): Output tensor of Actor.forward.
|
||||
output (torch.Tensor): Output tensor of Actor.forward.logits.
|
||||
sequences (torch.LongTensor): Input sequences.
|
||||
num_actions (int): Number of actions.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Action log probs.
|
||||
"""
|
||||
logits = output["logits"]
|
||||
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||
return log_probs[:, -num_actions:]
|
||||
|
||||
|
Reference in New Issue
Block a user