mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 18:09:06 +00:00
[chat] fix bugs and add unit tests (#4213)
* style: rename replay buffer Experience replay is typically for off policy algorithms. Use this name in PPO maybe misleading. * fix: fix wrong zero2 default arg * test: update experience tests * style: rename zero_pad fn * fix: defer init in CycledDataLoader * test: add benchmark test * style: rename internal fn of generation * style: rename internal fn of lora * fix: remove unused loss fn * fix: remove unused utils fn * refactor: remove generate_with_actor fn * fix: fix type annotation * test: add models tests * fix: skip llama due to long execution time * style: modify dataset * style: apply formatter * perf: update reward dataset * fix: fix wrong IGNORE_INDEX in sft dataset * fix: remove DataCollatorForSupervisedDataset * test: add dataset tests * style: apply formatter * style: rename test_ci to test_train * feat: add llama in inference * test: add inference tests * test: change test scripts directory * fix: update ci * fix: fix typo * fix: skip llama due to oom * fix: fix file mod * style: apply formatter * refactor: remove duplicated llama_gptq * style: apply formatter * to: update rm test * feat: add tokenizer arg * feat: add download model script * test: update train tests * fix: modify gemini load and save pretrained * test: update checkpoint io test * to: modify nproc_per_node * fix: do not remove existing dir * fix: modify save path * test: add random choice * fix: fix sft path * fix: enlarge nproc_per_node to avoid oom * fix: add num_retry * fix: make lora config of rm and critic consistent * fix: add warning about lora weights * fix: skip some gpt2 tests * fix: remove grad ckpt in rm and critic due to errors * refactor: directly use Actor in train_sft * test: add more arguments * fix: disable grad ckpt when using lora * fix: fix save_pretrained and related tests * test: enable zero2 tests * revert: remove useless fn * style: polish code * test: modify test args
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
from .base import Actor, Critic, RewardModel
|
||||
from .lora import LoRAModule, convert_to_lora_module
|
||||
from .loss import LogExpLoss, LogSigLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss
|
||||
from .loss import LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
|
||||
|
||||
__all__ = [
|
||||
'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss',
|
||||
'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'LogSigLoss', 'LogExpLoss',
|
||||
'LoRAModule', 'convert_to_lora_module'
|
||||
]
|
||||
|
@@ -14,7 +14,6 @@ class BLOOMCritic(Critic):
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (BloomConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
@@ -22,7 +21,6 @@ class BLOOMCritic(Critic):
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
**kwargs) -> None:
|
||||
@@ -32,7 +30,6 @@ class BLOOMCritic(Critic):
|
||||
model = BloomModel(config)
|
||||
else:
|
||||
model = BloomModel(BloomConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
||||
|
@@ -13,7 +13,6 @@ class BLOOMRM(RewardModel):
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (BloomConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
@@ -21,7 +20,6 @@ class BLOOMRM(RewardModel):
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
@@ -30,8 +28,7 @@ class BLOOMRM(RewardModel):
|
||||
model = BloomModel(config)
|
||||
else:
|
||||
model = BloomModel(BloomConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
||||
|
@@ -1,9 +1,9 @@
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .base import Actor
|
||||
|
||||
try:
|
||||
from transformers.generation_logits_process import (
|
||||
@@ -16,9 +16,9 @@ except ImportError:
|
||||
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
|
||||
|
||||
|
||||
def prepare_logits_processor(top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None) -> LogitsProcessorList:
|
||||
def _prepare_logits_processor(top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None) -> LogitsProcessorList:
|
||||
processor_list = LogitsProcessorList()
|
||||
if temperature is not None and temperature != 1.0:
|
||||
processor_list.append(TemperatureLogitsWarper(temperature))
|
||||
@@ -37,22 +37,22 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
|
||||
return unfinished_sequences.max() == 0
|
||||
|
||||
|
||||
def sample(model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
max_length: int,
|
||||
early_stopping: bool = False,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
||||
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
||||
**model_kwargs) -> torch.Tensor:
|
||||
def _sample(model: Actor,
|
||||
input_ids: torch.Tensor,
|
||||
max_length: int,
|
||||
early_stopping: bool = False,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
||||
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
||||
**model_kwargs) -> torch.Tensor:
|
||||
if input_ids.size(1) >= max_length:
|
||||
return input_ids
|
||||
|
||||
logits_processor = prepare_logits_processor(top_k, top_p, temperature)
|
||||
logits_processor = _prepare_logits_processor(top_k, top_p, temperature)
|
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
|
||||
for _ in range(input_ids.size(1), max_length):
|
||||
@@ -89,7 +89,8 @@ def sample(model: nn.Module,
|
||||
return input_ids
|
||||
|
||||
|
||||
def generate(model: nn.Module,
|
||||
@torch.no_grad()
|
||||
def generate(model: Actor,
|
||||
input_ids: torch.Tensor,
|
||||
max_length: int,
|
||||
num_beams: int = 1,
|
||||
@@ -128,51 +129,19 @@ def generate(model: nn.Module,
|
||||
raise NotImplementedError
|
||||
elif is_sample_gen_mode:
|
||||
# run sample
|
||||
return sample(model,
|
||||
input_ids,
|
||||
max_length,
|
||||
early_stopping=early_stopping,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
prepare_inputs_fn=prepare_inputs_fn,
|
||||
update_model_kwargs_fn=update_model_kwargs_fn,
|
||||
**model_kwargs)
|
||||
return _sample(model,
|
||||
input_ids,
|
||||
max_length,
|
||||
early_stopping=early_stopping,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
prepare_inputs_fn=prepare_inputs_fn,
|
||||
update_model_kwargs_fn=update_model_kwargs_fn,
|
||||
**model_kwargs)
|
||||
elif is_beam_gen_mode:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError("Unsupported generation mode")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_with_actor(
|
||||
actor_model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
return_action_mask: bool = True,
|
||||
**kwargs
|
||||
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
|
||||
"""Generate token sequence with actor model. Refer to `generate` for more details.
|
||||
"""
|
||||
# generate sequences
|
||||
sequences = generate(actor_model, input_ids, **kwargs)
|
||||
|
||||
# calculate auxiliary tensors
|
||||
attention_mask = None
|
||||
pad_token_id = kwargs.get('pad_token_id', None)
|
||||
if pad_token_id is not None:
|
||||
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
||||
if not return_action_mask:
|
||||
return sequences, attention_mask, None
|
||||
input_len = input_ids.size(1)
|
||||
eos_token_id = kwargs.get('eos_token_id', None)
|
||||
if eos_token_id is None:
|
||||
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
||||
else:
|
||||
# left padding may be applied, only mask action
|
||||
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
||||
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
||||
action_mask[:, :input_len] = False
|
||||
action_mask = action_mask[:, 1:]
|
||||
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
|
||||
|
@@ -14,7 +14,6 @@ class GPTCritic(Critic):
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (GPT2Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the LO-RA decomposition.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
@@ -22,7 +21,6 @@ class GPTCritic(Critic):
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[GPT2Config] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
**kwargs) -> None:
|
||||
@@ -32,7 +30,6 @@ class GPTCritic(Critic):
|
||||
model = GPT2Model(config)
|
||||
else:
|
||||
model = GPT2Model(GPT2Config())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.n_embd, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
||||
|
@@ -14,7 +14,6 @@ class GPTRM(RewardModel):
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (GPT2Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
@@ -22,7 +21,6 @@ class GPTRM(RewardModel):
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[GPT2Config] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
@@ -31,8 +29,6 @@ class GPTRM(RewardModel):
|
||||
model = GPT2Model(config)
|
||||
else:
|
||||
model = GPT2Model(GPT2Config())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.n_embd, 1)
|
||||
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.n_embd + 1))
|
||||
|
@@ -13,7 +13,6 @@ class LlamaCritic(Critic):
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (LlamaConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
@@ -21,7 +20,6 @@ class LlamaCritic(Critic):
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[LlamaConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
**kwargs) -> None:
|
||||
@@ -33,9 +31,5 @@ class LlamaCritic(Critic):
|
||||
else:
|
||||
model = LlamaModel(LlamaConfig())
|
||||
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
||||
|
@@ -13,7 +13,6 @@ class LlamaRM(RewardModel):
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (LlamaConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
@@ -21,7 +20,6 @@ class LlamaRM(RewardModel):
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[LlamaConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
|
||||
@@ -32,8 +30,6 @@ class LlamaRM(RewardModel):
|
||||
else:
|
||||
model = LlamaModel(LlamaConfig())
|
||||
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
|
||||
|
||||
|
@@ -98,18 +98,18 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
||||
return F.linear(x, T(self.weight), bias=self.bias)
|
||||
|
||||
|
||||
def lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
|
||||
def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
|
||||
assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})'
|
||||
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
|
||||
return lora_linear
|
||||
|
||||
|
||||
def convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
|
||||
def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, nn.Linear):
|
||||
setattr(module, name, lora_linear_wrapper(child, lora_rank))
|
||||
setattr(module, name, _lora_linear_wrapper(child, lora_rank))
|
||||
else:
|
||||
convert_to_lora_recursively(child, lora_rank)
|
||||
_convert_to_lora_recursively(child, lora_rank)
|
||||
|
||||
|
||||
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module:
|
||||
@@ -124,7 +124,7 @@ def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: s
|
||||
"""
|
||||
if lora_rank <= 0:
|
||||
return module
|
||||
convert_to_lora_recursively(module, lora_rank)
|
||||
_convert_to_lora_recursively(module, lora_rank)
|
||||
lora.mark_only_lora_as_trainable(module, lora_train_bias)
|
||||
return module
|
||||
|
||||
|
@@ -68,31 +68,6 @@ class ValueLoss(nn.Module):
|
||||
return 0.5 * loss
|
||||
|
||||
|
||||
class PPOPtxActorLoss(nn.Module):
|
||||
"""
|
||||
To Do:
|
||||
|
||||
PPO-ptx Actor Loss
|
||||
"""
|
||||
|
||||
def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None:
|
||||
super().__init__()
|
||||
self.pretrain_coef = pretrain_coef
|
||||
self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps)
|
||||
self.pretrain_loss_fn = pretrain_loss_fn
|
||||
|
||||
def forward(self,
|
||||
log_probs: torch.Tensor,
|
||||
old_log_probs: torch.Tensor,
|
||||
advantages: torch.Tensor,
|
||||
lm_logits: torch.Tensor,
|
||||
lm_input_ids: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask)
|
||||
lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids)
|
||||
return policy_loss + self.pretrain_coef * lm_loss
|
||||
|
||||
|
||||
class LogSigLoss(nn.Module):
|
||||
"""
|
||||
Pairwise Loss for Reward Model
|
||||
|
@@ -14,7 +14,6 @@ class OPTCritic(Critic):
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (OPTConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
@@ -22,7 +21,6 @@ class OPTCritic(Critic):
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[OPTConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
**kwargs) -> None:
|
||||
@@ -32,7 +30,6 @@ class OPTCritic(Critic):
|
||||
model = OPTModel(config)
|
||||
else:
|
||||
model = OPTModel(OPTConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
||||
|
@@ -13,7 +13,6 @@ class OPTRM(RewardModel):
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (OPTConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
@@ -21,7 +20,6 @@ class OPTRM(RewardModel):
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[OPTConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
@@ -30,8 +28,6 @@ class OPTRM(RewardModel):
|
||||
model = OPTModel(config)
|
||||
else:
|
||||
model = OPTModel(OPTConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
|
||||
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.word_embed_proj_dim + 1))
|
||||
|
@@ -1,14 +1,12 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
import loralib as lora
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def compute_approx_kl(log_probs: torch.Tensor,
|
||||
log_probs_base: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def _compute_approx_kl(log_probs: torch.Tensor,
|
||||
log_probs_base: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
Compute the approximate KL divergence between two distributions.
|
||||
Schulman blog: http://joschu.net/blog/kl-approx.html
|
||||
@@ -35,12 +33,12 @@ def compute_reward(r: Union[torch.Tensor, float],
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if kl_coef <= 0.0:
|
||||
return r
|
||||
kl = compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
|
||||
kl = _compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
|
||||
reward = r - kl_coef * kl
|
||||
return reward
|
||||
|
||||
|
||||
def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||
def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||
log_probs = F.log_softmax(logits, dim=-1)
|
||||
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
|
||||
return log_probs_labels.squeeze(-1)
|
||||
@@ -58,7 +56,7 @@ def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num
|
||||
torch.Tensor: Action log probs.
|
||||
"""
|
||||
logits = output['logits']
|
||||
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||
return log_probs[:, -num_actions:]
|
||||
|
||||
|
||||
@@ -68,41 +66,3 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
|
||||
mask_sum = mask.sum(dim=dim)
|
||||
mean = tensor / (mask_sum + 1e-8)
|
||||
return mean
|
||||
|
||||
|
||||
def masked_normalize(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor:
|
||||
tensor = tensor * mask
|
||||
mean = masked_mean(tensor, mask, dim=dim)
|
||||
mean_centered = tensor - mean
|
||||
var = masked_mean(mean_centered**2, mask, dim=dim)
|
||||
return mean_centered * var.clamp(min=eps).rsqrt()
|
||||
|
||||
|
||||
def normalize(tensor: torch.Tensor, dim: int = 0, eps: float = 1e-8) -> torch.Tensor:
|
||||
mean = tensor.mean(dim)
|
||||
mean_centered = tensor - mean
|
||||
var = (mean_centered**2).mean(dim)
|
||||
norm = mean_centered * var.clamp(min=eps).rsqrt()
|
||||
return norm
|
||||
|
||||
|
||||
def convert_to_lora(model: nn.Module,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
lora_rank: int = 16,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.,
|
||||
fan_in_fan_out: bool = False,
|
||||
merge_weights: bool = True):
|
||||
if lora_rank > min(input_size, output_size):
|
||||
raise ValueError(f"LoRA rank {lora_rank} must be less or equal than {min(input_size, output_size)}")
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
module._modules[name] = lora.Linear(input_size,
|
||||
output_size,
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
fan_in_fan_out=fan_in_fan_out,
|
||||
merge_weights=merge_weights)
|
||||
|
Reference in New Issue
Block a user