mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 18:09:06 +00:00
[chat] refactor actor class (#3968)
* refactor: separate log_probs fn from Actor forward fn * refactor: separate generate fn from Actor class * feat: update unwrap_model and get_base_model * unwrap_model returns model not wrapped by Strategy * get_base_model returns HF model for Actor, Critic and RewardModel * feat: simplify Strategy.prepare * style: remove get_base_model method of Actor * perf: tokenize text in batches * refactor: move calc_action_log_probs to utils of model * test: update test with new forward fn * style: rename forward fn args * fix: do not unwrap model in save_model fn of naive strategy * test: add gemini test for train_prompts * fix: fix _set_default_generate_kwargs
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from typing import Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from .actor import Actor
|
||||
@@ -5,10 +7,10 @@ from .critic import Critic
|
||||
from .reward_model import RewardModel
|
||||
|
||||
|
||||
def get_base_model(model: nn.Module) -> nn.Module:
|
||||
def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module:
|
||||
"""Get the base model of our wrapper classes.
|
||||
For Actor, it's base model is ``actor.model`` and it's usually a ``transformers.PreTrainedModel``.
|
||||
For Critic and RewardModel, it's base model is itself.
|
||||
For Actor, Critic and RewardModel, return ``model.model``,
|
||||
it's usually a ``transformers.PreTrainedModel``.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model to get base model from
|
||||
@@ -16,9 +18,9 @@ def get_base_model(model: nn.Module) -> nn.Module:
|
||||
Returns:
|
||||
nn.Module: the base model
|
||||
"""
|
||||
if isinstance(model, Actor):
|
||||
return model.get_base_model()
|
||||
return model
|
||||
assert isinstance(model, (Actor, Critic, RewardModel)), \
|
||||
f'Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first.'
|
||||
return model.model
|
||||
|
||||
|
||||
__all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model']
|
||||
|
@@ -1,12 +1,9 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..generation import generate
|
||||
from ..lora import LoRAModule
|
||||
from ..utils import log_probs_from_logits
|
||||
|
||||
|
||||
class Actor(LoRAModule):
|
||||
@@ -24,42 +21,16 @@ class Actor(LoRAModule):
|
||||
self.model = model
|
||||
self.convert_to_lora()
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
return_action_mask: bool = True,
|
||||
**kwargs
|
||||
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
|
||||
sequences = generate(self.model, input_ids, **kwargs)
|
||||
attention_mask = None
|
||||
pad_token_id = kwargs.get('pad_token_id', None)
|
||||
if pad_token_id is not None:
|
||||
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
||||
if not return_action_mask:
|
||||
return sequences, attention_mask, None
|
||||
input_len = input_ids.size(1)
|
||||
eos_token_id = kwargs.get('eos_token_id', None)
|
||||
if eos_token_id is None:
|
||||
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
||||
else:
|
||||
# left padding may be applied, only mask action
|
||||
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
||||
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
||||
action_mask[:, :input_len] = False
|
||||
action_mask = action_mask[:, 1:]
|
||||
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
|
||||
|
||||
def forward(self,
|
||||
sequences: torch.LongTensor,
|
||||
num_actions: int,
|
||||
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Returns action log probs
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**model_kwargs, # HACK: `generate` method may pass more kwargs
|
||||
) -> torch.Tensor:
|
||||
"""Returns model output.
|
||||
"""
|
||||
output = self.model(sequences, attention_mask=attention_mask)
|
||||
logits = output['logits']
|
||||
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||
return log_probs[:, -num_actions:]
|
||||
|
||||
def get_base_model(self):
|
||||
return self.model
|
||||
output = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
**model_kwargs
|
||||
)
|
||||
return output
|
||||
|
@@ -1,8 +1,10 @@
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
try:
|
||||
from transformers.generation_logits_process import (
|
||||
@@ -55,9 +57,8 @@ def sample(model: nn.Module,
|
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
|
||||
for _ in range(input_ids.size(1), max_length):
|
||||
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {
|
||||
'input_ids': input_ids
|
||||
}
|
||||
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) \
|
||||
if prepare_inputs_fn is not None else {'input_ids': input_ids}
|
||||
outputs = model(**model_inputs)
|
||||
|
||||
next_token_logits = outputs['logits'][:, -1, :]
|
||||
@@ -144,3 +145,35 @@ def generate(model: nn.Module,
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError("Unsupported generation mode")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_with_actor(actor_model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
return_action_mask: bool = True,
|
||||
**kwargs
|
||||
) -> Union[Tuple[torch.LongTensor, torch.LongTensor],
|
||||
Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
|
||||
"""Generate token sequence with actor model. Refer to `generate` for more details.
|
||||
"""
|
||||
# generate sequences
|
||||
sequences = generate(actor_model, input_ids, **kwargs)
|
||||
|
||||
# calculate auxiliary tensors
|
||||
attention_mask = None
|
||||
pad_token_id = kwargs.get('pad_token_id', None)
|
||||
if pad_token_id is not None:
|
||||
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
||||
if not return_action_mask:
|
||||
return sequences, attention_mask, None
|
||||
input_len = input_ids.size(1)
|
||||
eos_token_id = kwargs.get('eos_token_id', None)
|
||||
if eos_token_id is None:
|
||||
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
||||
else:
|
||||
# left padding may be applied, only mask action
|
||||
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
||||
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
||||
action_mask[:, :input_len] = False
|
||||
action_mask = action_mask[:, 1:]
|
||||
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
|
||||
|
@@ -46,6 +46,25 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
|
||||
return log_probs_labels.squeeze(-1)
|
||||
|
||||
|
||||
def calc_action_log_probs(output: torch.Tensor,
|
||||
sequences: torch.LongTensor,
|
||||
num_actions: int
|
||||
) -> torch.Tensor:
|
||||
"""Calculate action log probs.
|
||||
|
||||
Args:
|
||||
output (torch.Tensor): Output tensor of Actor.forward.
|
||||
sequences (torch.LongTensor): Input sequences.
|
||||
num_actions (int): Number of actions.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Action log probs.
|
||||
"""
|
||||
logits = output['logits']
|
||||
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||
return log_probs[:, -num_actions:]
|
||||
|
||||
|
||||
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
|
||||
tensor = tensor * mask
|
||||
tensor = tensor.sum(dim=dim)
|
||||
|
Reference in New Issue
Block a user