mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 13:05:26 +00:00
[Coati] first commit (#3283)
This commit is contained in:
4
applications/Chat/coati/models/__init__.py
Normal file
4
applications/Chat/coati/models/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .base import Actor, Critic, RewardModel
|
||||
from .loss import LogExpLoss, LogSigLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss
|
||||
|
||||
__all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss']
|
6
applications/Chat/coati/models/base/__init__.py
Normal file
6
applications/Chat/coati/models/base/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .actor import Actor
|
||||
from .critic import Critic
|
||||
from .lm import LM
|
||||
from .reward_model import RewardModel
|
||||
|
||||
__all__ = ['Actor', 'Critic', 'RewardModel', 'LM']
|
65
applications/Chat/coati/models/base/actor.py
Normal file
65
applications/Chat/coati/models/base/actor.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..generation import generate
|
||||
from ..lora import LoRAModule
|
||||
from ..utils import log_probs_from_logits
|
||||
|
||||
|
||||
class Actor(LoRAModule):
|
||||
"""
|
||||
Actor model base class.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Actor Model.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
|
||||
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
||||
self.model = model
|
||||
self.convert_to_lora()
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
return_action_mask: bool = True,
|
||||
**kwargs
|
||||
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
|
||||
sequences = generate(self.model, input_ids, **kwargs)
|
||||
attention_mask = None
|
||||
pad_token_id = kwargs.get('pad_token_id', None)
|
||||
if pad_token_id is not None:
|
||||
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
||||
if not return_action_mask:
|
||||
return sequences, attention_mask, None
|
||||
input_len = input_ids.size(1)
|
||||
eos_token_id = kwargs.get('eos_token_id', None)
|
||||
if eos_token_id is None:
|
||||
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
||||
else:
|
||||
# left padding may be applied, only mask action
|
||||
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
||||
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
||||
action_mask[:, :input_len] = False
|
||||
action_mask = action_mask[:, 1:]
|
||||
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
|
||||
|
||||
def forward(self,
|
||||
sequences: torch.LongTensor,
|
||||
num_actions: int,
|
||||
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Returns action log probs
|
||||
"""
|
||||
output = self.model(sequences, attention_mask=attention_mask)
|
||||
logits = output['logits']
|
||||
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||
return log_probs[:, -num_actions:]
|
||||
|
||||
def get_base_model(self):
|
||||
return self.model
|
54
applications/Chat/coati/models/base/critic.py
Normal file
54
applications/Chat/coati/models/base/critic.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..lora import LoRAModule
|
||||
from ..utils import masked_mean
|
||||
|
||||
|
||||
class Critic(LoRAModule):
|
||||
"""
|
||||
Critic model base class.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Critic model.
|
||||
value_head (nn.Module): Value head to get value.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
value_head: nn.Module,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
use_action_mask: bool = False,
|
||||
) -> None:
|
||||
|
||||
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
||||
self.model = model
|
||||
self.value_head = value_head
|
||||
self.use_action_mask = use_action_mask
|
||||
self.convert_to_lora()
|
||||
|
||||
def forward(self,
|
||||
sequences: torch.LongTensor,
|
||||
action_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
outputs = self.model(sequences, attention_mask=attention_mask)
|
||||
last_hidden_states = outputs['last_hidden_state']
|
||||
|
||||
values = self.value_head(last_hidden_states).squeeze(-1)
|
||||
|
||||
if action_mask is not None and self.use_action_mask:
|
||||
num_actions = action_mask.size(1)
|
||||
prompt_mask = attention_mask[:, :-num_actions]
|
||||
values = values[:, :-num_actions]
|
||||
value = masked_mean(values, prompt_mask, dim=1)
|
||||
return value
|
||||
|
||||
values = values[:, :-1]
|
||||
value = values.mean(dim=1)
|
||||
return value
|
30
applications/Chat/coati/models/base/lm.py
Normal file
30
applications/Chat/coati/models/base/lm.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..generation import generate
|
||||
from .actor import Actor
|
||||
|
||||
|
||||
class LM(Actor):
|
||||
"""
|
||||
Language model base class.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Language Model.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
|
||||
super().__init__(model=model, lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
||||
|
||||
def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Returns output log probs
|
||||
"""
|
||||
output = self.model(sequences, attention_mask=attention_mask)
|
||||
logits = output['logits']
|
||||
log_probs = F.log_softmax(logits, dim=-1)
|
||||
return log_probs
|
41
applications/Chat/coati/models/base/reward_model.py
Normal file
41
applications/Chat/coati/models/base/reward_model.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..lora import LoRAModule
|
||||
|
||||
|
||||
class RewardModel(LoRAModule):
|
||||
"""
|
||||
Reward model base class.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Reward model.
|
||||
value_head (nn.Module): Value head to get reward score.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: nn.Module,
|
||||
value_head: Optional[nn.Module] = None,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
||||
self.model = model
|
||||
self.convert_to_lora()
|
||||
|
||||
if value_head is not None:
|
||||
if value_head.out_features != 1:
|
||||
raise ValueError("The value head of reward model's output dim should be 1!")
|
||||
self.value_head = value_head
|
||||
else:
|
||||
self.value_head = nn.Linear(model.config.n_embd, 1)
|
||||
|
||||
def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
outputs = self.model(sequences, attention_mask=attention_mask)
|
||||
last_hidden_states = outputs['last_hidden_state']
|
||||
values = self.value_head(last_hidden_states)[:, :-1]
|
||||
value = values.mean(dim=1).squeeze(1) # ensure shape is (B)
|
||||
return value
|
6
applications/Chat/coati/models/bloom/__init__.py
Normal file
6
applications/Chat/coati/models/bloom/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .bloom_actor import BLOOMActor
|
||||
from .bloom_critic import BLOOMCritic
|
||||
from .bloom_lm import BLOOMLM
|
||||
from .bloom_rm import BLOOMRM
|
||||
|
||||
__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM', 'BLOOMLM']
|
35
applications/Chat/coati/models/bloom/bloom_actor.py
Normal file
35
applications/Chat/coati/models/bloom/bloom_actor.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
||||
|
||||
from ..base import Actor
|
||||
|
||||
|
||||
class BLOOMActor(Actor):
|
||||
"""
|
||||
BLOOM Actor model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (BloomConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = BloomForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = BloomForCausalLM(config)
|
||||
else:
|
||||
model = BloomForCausalLM(BloomConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
38
applications/Chat/coati/models/bloom/bloom_critic.py
Normal file
38
applications/Chat/coati/models/bloom/bloom_critic.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
||||
|
||||
from ..base import Critic
|
||||
|
||||
|
||||
class BLOOMCritic(Critic):
|
||||
"""
|
||||
BLOOM Critic model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (BloomConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
**kwargs) -> None:
|
||||
if pretrained is not None:
|
||||
model = BloomModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = BloomModel(config)
|
||||
else:
|
||||
model = BloomModel(BloomConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
35
applications/Chat/coati/models/bloom/bloom_lm.py
Normal file
35
applications/Chat/coati/models/bloom/bloom_lm.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
||||
|
||||
from ..base import LM
|
||||
|
||||
|
||||
class BLOOMLM(LM):
|
||||
"""
|
||||
BLOOM language model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (BloomConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = BloomForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = BloomForCausalLM(config)
|
||||
else:
|
||||
model = BloomForCausalLM(BloomConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
37
applications/Chat/coati/models/bloom/bloom_rm.py
Normal file
37
applications/Chat/coati/models/bloom/bloom_rm.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
||||
|
||||
from ..base import RewardModel
|
||||
|
||||
|
||||
class BLOOMRM(RewardModel):
|
||||
"""
|
||||
BLOOM Reward model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (BloomConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = BloomModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = BloomModel(config)
|
||||
else:
|
||||
model = BloomModel(BloomConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
4
applications/Chat/coati/models/deberta/__init__.py
Normal file
4
applications/Chat/coati/models/deberta/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .deberta_critic import DebertaCritic
|
||||
from .deberta_rm import DebertaRM
|
||||
|
||||
__all__ = ['DebertaCritic', 'DebertaRM']
|
36
applications/Chat/coati/models/deberta/deberta_critic.py
Normal file
36
applications/Chat/coati/models/deberta/deberta_critic.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers import DebertaV2Config, DebertaV2Model
|
||||
|
||||
from ..base import Critic
|
||||
|
||||
|
||||
class DebertaCritic(Critic):
|
||||
"""
|
||||
Deberta Critic model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (DebertaV2Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the LO-RA decomposition.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[DebertaV2Config] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = DebertaV2Model.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = DebertaV2Model(config)
|
||||
else:
|
||||
model = DebertaV2Model(DebertaV2Config())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
37
applications/Chat/coati/models/deberta/deberta_rm.py
Normal file
37
applications/Chat/coati/models/deberta/deberta_rm.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers import DebertaV2Config, DebertaV2Model
|
||||
|
||||
from ..base import RewardModel
|
||||
|
||||
|
||||
class DebertaRM(RewardModel):
|
||||
"""
|
||||
Deberta Reward model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (DebertaV2Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the LO-RA decomposition.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[DebertaV2Config] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = DebertaV2Model.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = DebertaV2Model(config)
|
||||
else:
|
||||
model = DebertaV2Model(DebertaV2Config())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
146
applications/Chat/coati/models/generation.py
Normal file
146
applications/Chat/coati/models/generation.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
try:
|
||||
from transformers.generation_logits_process import (
|
||||
LogitsProcessorList,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
except ImportError:
|
||||
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
|
||||
|
||||
|
||||
def prepare_logits_processor(top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None) -> LogitsProcessorList:
|
||||
processor_list = LogitsProcessorList()
|
||||
if temperature is not None and temperature != 1.0:
|
||||
processor_list.append(TemperatureLogitsWarper(temperature))
|
||||
if top_k is not None and top_k != 0:
|
||||
processor_list.append(TopKLogitsWarper(top_k))
|
||||
if top_p is not None and top_p < 1.0:
|
||||
processor_list.append(TopPLogitsWarper(top_p))
|
||||
return processor_list
|
||||
|
||||
|
||||
def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
# consider DP
|
||||
unfinished_sequences = unfinished_sequences.clone()
|
||||
dist.all_reduce(unfinished_sequences)
|
||||
return unfinished_sequences.max() == 0
|
||||
|
||||
|
||||
def sample(model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
max_length: int,
|
||||
early_stopping: bool = False,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
||||
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
||||
**model_kwargs) -> torch.Tensor:
|
||||
if input_ids.size(1) >= max_length:
|
||||
return input_ids
|
||||
|
||||
logits_processor = prepare_logits_processor(top_k, top_p, temperature)
|
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
|
||||
for _ in range(input_ids.size(1), max_length):
|
||||
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {
|
||||
'input_ids': input_ids
|
||||
}
|
||||
outputs = model(**model_inputs)
|
||||
|
||||
next_token_logits = outputs['logits'][:, -1, :]
|
||||
# pre-process distribution
|
||||
next_token_logits = logits_processor(input_ids, next_token_logits)
|
||||
# sample
|
||||
probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
|
||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if eos_token_id is not None:
|
||||
if pad_token_id is None:
|
||||
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||
|
||||
# update generated ids, model inputs for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
if update_model_kwargs_fn is not None:
|
||||
model_kwargs = update_model_kwargs_fn(outputs, **model_kwargs)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id is not None:
|
||||
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
|
||||
|
||||
# stop when each sentence is finished if early_stopping=True
|
||||
if early_stopping and _is_sequence_finished(unfinished_sequences):
|
||||
break
|
||||
|
||||
return input_ids
|
||||
|
||||
|
||||
def generate(model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
max_length: int,
|
||||
num_beams: int = 1,
|
||||
do_sample: bool = True,
|
||||
early_stopping: bool = False,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
||||
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
||||
**model_kwargs) -> torch.Tensor:
|
||||
"""Generate token sequence. The returned sequence is input_ids + generated_tokens.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model
|
||||
input_ids (torch.Tensor): input sequence
|
||||
max_length (int): max length of the returned sequence
|
||||
num_beams (int, optional): number of beams. Defaults to 1.
|
||||
do_sample (bool, optional): whether to do sample. Defaults to True.
|
||||
early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False.
|
||||
eos_token_id (Optional[int], optional): end of sequence token id. Defaults to None.
|
||||
pad_token_id (Optional[int], optional): pad token id. Defaults to None.
|
||||
top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
|
||||
top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None.
|
||||
temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None.
|
||||
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
|
||||
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
|
||||
"""
|
||||
is_greedy_gen_mode = ((num_beams == 1) and do_sample is False)
|
||||
is_sample_gen_mode = ((num_beams == 1) and do_sample is True)
|
||||
is_beam_gen_mode = ((num_beams > 1) and do_sample is False)
|
||||
if is_greedy_gen_mode:
|
||||
# run greedy search
|
||||
raise NotImplementedError
|
||||
elif is_sample_gen_mode:
|
||||
# run sample
|
||||
return sample(model,
|
||||
input_ids,
|
||||
max_length,
|
||||
early_stopping=early_stopping,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
prepare_inputs_fn=prepare_inputs_fn,
|
||||
update_model_kwargs_fn=update_model_kwargs_fn,
|
||||
**model_kwargs)
|
||||
elif is_beam_gen_mode:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError("Unsupported generation mode")
|
92
applications/Chat/coati/models/generation_utils.py
Normal file
92
applications/Chat/coati/models/generation_utils.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def gpt_prepare_inputs_fn(input_ids: torch.Tensor, past: Optional[torch.Tensor] = None, **kwargs) -> dict:
|
||||
token_type_ids = kwargs.get("token_type_ids", None)
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
else:
|
||||
position_ids = None
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
|
||||
|
||||
def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict:
|
||||
if "past_key_values" in outputs:
|
||||
model_kwargs["past"] = outputs["past_key_values"]
|
||||
else:
|
||||
model_kwargs["past"] = None
|
||||
|
||||
# update token_type_ids with last value
|
||||
if "token_type_ids" in model_kwargs:
|
||||
token_type_ids = model_kwargs["token_type_ids"]
|
||||
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
|
||||
|
||||
# update attention mask
|
||||
if "attention_mask" in model_kwargs:
|
||||
attention_mask = model_kwargs["attention_mask"]
|
||||
model_kwargs["attention_mask"] = torch.cat(
|
||||
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
|
||||
|
||||
return model_kwargs
|
||||
|
||||
|
||||
def opt_prepare_inputs_fn(input_ids: torch.Tensor,
|
||||
past: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs) -> dict:
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
if past:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
|
||||
def bloom_prepare_inputs_fn(input_ids: torch.Tensor,
|
||||
past: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs) -> dict:
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
if past:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"use_cache": use_cache,
|
||||
}
|
6
applications/Chat/coati/models/gpt/__init__.py
Normal file
6
applications/Chat/coati/models/gpt/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .gpt_actor import GPTActor
|
||||
from .gpt_critic import GPTCritic
|
||||
from .gpt_lm import GPTLM
|
||||
from .gpt_rm import GPTRM
|
||||
|
||||
__all__ = ['GPTActor', 'GPTCritic', 'GPTRM', 'GPTLM']
|
35
applications/Chat/coati/models/gpt/gpt_actor.py
Normal file
35
applications/Chat/coati/models/gpt/gpt_actor.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import Optional
|
||||
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
||||
|
||||
from ..base import Actor
|
||||
|
||||
|
||||
class GPTActor(Actor):
|
||||
"""
|
||||
GPT Actor model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (GPT2Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the LoRa layer.
|
||||
lora_train_bias (str): Bias training strategy for the LoRa layer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[GPT2Config] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = GPT2LMHeadModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = GPT2LMHeadModel(config)
|
||||
else:
|
||||
model = GPT2LMHeadModel(GPT2Config())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
37
applications/Chat/coati/models/gpt/gpt_critic.py
Normal file
37
applications/Chat/coati/models/gpt/gpt_critic.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
||||
|
||||
from ..base import Critic
|
||||
|
||||
|
||||
class GPTCritic(Critic):
|
||||
"""
|
||||
GPT Critic model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (GPT2Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the LO-RA decomposition.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[GPT2Config] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = GPT2Model.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = GPT2Model(config)
|
||||
else:
|
||||
model = GPT2Model(GPT2Config())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.n_embd, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
35
applications/Chat/coati/models/gpt/gpt_lm.py
Normal file
35
applications/Chat/coati/models/gpt/gpt_lm.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import Optional
|
||||
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
||||
|
||||
from ..base import LM
|
||||
|
||||
|
||||
class GPTLM(LM):
|
||||
"""
|
||||
GPT language model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (GPT2Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the LoRa layer.
|
||||
lora_train_bias (str): Bias training strategy for the LoRa layer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[GPT2Config] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = GPT2LMHeadModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = GPT2LMHeadModel(config)
|
||||
else:
|
||||
model = GPT2LMHeadModel(GPT2Config())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
39
applications/Chat/coati/models/gpt/gpt_rm.py
Normal file
39
applications/Chat/coati/models/gpt/gpt_rm.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
||||
|
||||
from ..base import RewardModel
|
||||
|
||||
|
||||
class GPTRM(RewardModel):
|
||||
"""
|
||||
GPT Reward model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (GPT2Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[GPT2Config] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = GPT2Model.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = GPT2Model(config)
|
||||
else:
|
||||
model = GPT2Model(GPT2Config())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.n_embd, 1)
|
||||
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.n_embd + 1))
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
6
applications/Chat/coati/models/llama/__init__.py
Normal file
6
applications/Chat/coati/models/llama/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .llama_actor import LlamaActor
|
||||
from .llama_critic import LlamaCritic
|
||||
from .llama_lm import LlamaLM
|
||||
from .llama_rm import LlamaRM
|
||||
|
||||
__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM', 'LlamaLM']
|
38
applications/Chat/coati/models/llama/llama_actor.py
Normal file
38
applications/Chat/coati/models/llama/llama_actor.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
|
||||
|
||||
from ..base import Actor
|
||||
|
||||
|
||||
class LlamaActor(Actor):
|
||||
"""
|
||||
Llama Actor model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (LlamaConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[LlamaConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
|
||||
if pretrained is not None:
|
||||
model = LlamaForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = LlamaForCausalLM(config)
|
||||
else:
|
||||
model = LlamaForCausalLM(LlamaConfig())
|
||||
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
42
applications/Chat/coati/models/llama/llama_critic.py
Normal file
42
applications/Chat/coati/models/llama/llama_critic.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
|
||||
|
||||
from ..base import Critic
|
||||
|
||||
|
||||
class LlamaCritic(Critic):
|
||||
"""
|
||||
Llama Critic model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (LlamaConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[LlamaConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
**kwargs) -> None:
|
||||
|
||||
if pretrained is not None:
|
||||
model = LlamaForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = LlamaForCausalLM(config)
|
||||
else:
|
||||
model = LlamaForCausalLM(LlamaConfig())
|
||||
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
40
applications/Chat/coati/models/llama/llama_lm.py
Normal file
40
applications/Chat/coati/models/llama/llama_lm.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import Optional
|
||||
|
||||
from transformers import LlamaConfig, LlamaForCausalLM
|
||||
|
||||
from ..base import LM
|
||||
|
||||
|
||||
class LlamaLM(LM):
|
||||
"""
|
||||
Llama language model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (LlamaConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[LlamaConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
|
||||
if pretrained is not None:
|
||||
model = LlamaForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = LlamaForCausalLM(config)
|
||||
else:
|
||||
model = LlamaForCausalLM(LlamaConfig())
|
||||
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
|
||||
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
|
40
applications/Chat/coati/models/llama/llama_rm.py
Normal file
40
applications/Chat/coati/models/llama/llama_rm.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
|
||||
|
||||
from ..base import RewardModel
|
||||
|
||||
|
||||
class LlamaRM(RewardModel):
|
||||
"""
|
||||
Llama Reward model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (LlamaConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[LlamaConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
|
||||
if pretrained is not None:
|
||||
model = LlamaModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = LlamaModel(config)
|
||||
else:
|
||||
model = LlamaModel(LlamaConfig())
|
||||
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
|
||||
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
129
applications/Chat/coati/models/lora.py
Normal file
129
applications/Chat/coati/models/lora.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import loralib as lora
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class LoraLinear(lora.LoRALayer, nn.Module):
|
||||
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight: nn.Parameter,
|
||||
bias: Optional[nn.Parameter],
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.,
|
||||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||
merge_weights: bool = True,
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
lora.LoRALayer.__init__(self,
|
||||
r=r,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
merge_weights=merge_weights)
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
|
||||
out_features, in_features = weight.shape
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
self.fan_in_fan_out = fan_in_fan_out
|
||||
# Actual trainable parameters
|
||||
if r > 0:
|
||||
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
|
||||
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
|
||||
self.scaling = self.lora_alpha / self.r
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
self.reset_parameters()
|
||||
if fan_in_fan_out:
|
||||
self.weight.data = self.weight.data.T
|
||||
|
||||
def reset_parameters(self):
|
||||
if hasattr(self, 'lora_A'):
|
||||
# initialize A the same way as the default for nn.Linear and B to zero
|
||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
|
||||
def T(w):
|
||||
return w.T if self.fan_in_fan_out else w
|
||||
|
||||
nn.Module.train(self, mode)
|
||||
if self.merge_weights and self.merged:
|
||||
# Make sure that the weights are not merged
|
||||
if self.r > 0:
|
||||
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
|
||||
self.merged = False
|
||||
|
||||
def eval(self):
|
||||
|
||||
def T(w):
|
||||
return w.T if self.fan_in_fan_out else w
|
||||
|
||||
nn.Module.eval(self)
|
||||
if self.merge_weights and not self.merged:
|
||||
# Merge the weights and mark it
|
||||
if self.r > 0:
|
||||
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
||||
delattr(self, 'lora_A')
|
||||
delattr(self, 'lora_B')
|
||||
self.merged = True
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
|
||||
def T(w):
|
||||
return w.T if self.fan_in_fan_out else w
|
||||
|
||||
if self.r > 0 and not self.merged:
|
||||
result = F.linear(x, T(self.weight), bias=self.bias)
|
||||
if self.r > 0:
|
||||
result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling
|
||||
return result
|
||||
else:
|
||||
return F.linear(x, T(self.weight), bias=self.bias)
|
||||
|
||||
|
||||
def lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
|
||||
assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})'
|
||||
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
|
||||
return lora_linear
|
||||
|
||||
|
||||
def convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, nn.Linear):
|
||||
setattr(module, name, lora_linear_wrapper(child, lora_rank))
|
||||
else:
|
||||
convert_to_lora_recursively(child, lora_rank)
|
||||
|
||||
|
||||
class LoRAModule(nn.Module):
|
||||
"""A LoRA module base class. All derived classes should call `convert_to_lora()` at the bottom of `__init__()`.
|
||||
This calss will convert all torch.nn.Linear layer to LoraLinear layer.
|
||||
|
||||
Args:
|
||||
lora_rank (int, optional): LoRA rank. 0 means LoRA is not applied. Defaults to 0.
|
||||
lora_train_bias (str, optional): Whether LoRA train biases.
|
||||
'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers.
|
||||
Defaults to 'none'.
|
||||
"""
|
||||
|
||||
def __init__(self, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
|
||||
super().__init__()
|
||||
self.lora_rank = lora_rank
|
||||
self.lora_train_bias = lora_train_bias
|
||||
|
||||
def convert_to_lora(self) -> None:
|
||||
if self.lora_rank <= 0:
|
||||
return
|
||||
convert_to_lora_recursively(self, self.lora_rank)
|
||||
lora.mark_only_lora_as_trainable(self, self.lora_train_bias)
|
117
applications/Chat/coati/models/loss.py
Normal file
117
applications/Chat/coati/models/loss.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .utils import masked_mean
|
||||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
"""
|
||||
GPT Language Model Loss
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
|
||||
class PolicyLoss(nn.Module):
|
||||
"""
|
||||
Policy Loss for PPO
|
||||
"""
|
||||
|
||||
def __init__(self, clip_eps: float = 0.2) -> None:
|
||||
super().__init__()
|
||||
self.clip_eps = clip_eps
|
||||
|
||||
def forward(self,
|
||||
log_probs: torch.Tensor,
|
||||
old_log_probs: torch.Tensor,
|
||||
advantages: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
ratio = (log_probs - old_log_probs).exp()
|
||||
surr1 = ratio * advantages
|
||||
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
||||
loss = -torch.min(surr1, surr2)
|
||||
if action_mask is not None:
|
||||
loss = masked_mean(loss, action_mask)
|
||||
loss = loss.mean()
|
||||
return loss
|
||||
|
||||
|
||||
class ValueLoss(nn.Module):
|
||||
"""
|
||||
Value Loss for PPO
|
||||
"""
|
||||
|
||||
def __init__(self, clip_eps: float = 0.4) -> None:
|
||||
super().__init__()
|
||||
self.clip_eps = clip_eps
|
||||
|
||||
def forward(self,
|
||||
values: torch.Tensor,
|
||||
old_values: torch.Tensor,
|
||||
reward: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
|
||||
surr1 = (values_clipped - reward)**2
|
||||
surr2 = (values - reward)**2
|
||||
loss = torch.max(surr1, surr2)
|
||||
loss = loss.mean()
|
||||
return loss
|
||||
|
||||
|
||||
class PPOPtxActorLoss(nn.Module):
|
||||
"""
|
||||
To Do:
|
||||
|
||||
PPO-ptx Actor Loss
|
||||
"""
|
||||
|
||||
def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None:
|
||||
super().__init__()
|
||||
self.pretrain_coef = pretrain_coef
|
||||
self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps)
|
||||
self.pretrain_loss_fn = pretrain_loss_fn
|
||||
|
||||
def forward(self,
|
||||
log_probs: torch.Tensor,
|
||||
old_log_probs: torch.Tensor,
|
||||
advantages: torch.Tensor,
|
||||
lm_logits: torch.Tensor,
|
||||
lm_input_ids: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask)
|
||||
lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids)
|
||||
return policy_loss + self.pretrain_coef * lm_loss
|
||||
|
||||
|
||||
class LogSigLoss(nn.Module):
|
||||
"""
|
||||
Pairwise Loss for Reward Model
|
||||
Details: https://arxiv.org/abs/2203.02155
|
||||
"""
|
||||
|
||||
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
|
||||
probs = torch.sigmoid(chosen_reward - reject_reward)
|
||||
log_probs = torch.log(probs)
|
||||
loss = -log_probs.mean()
|
||||
return loss
|
||||
|
||||
|
||||
class LogExpLoss(nn.Module):
|
||||
"""
|
||||
Pairwise Loss for Reward Model
|
||||
Details: https://arxiv.org/abs/2204.05862
|
||||
"""
|
||||
|
||||
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
|
||||
loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean()
|
||||
return loss
|
6
applications/Chat/coati/models/opt/__init__.py
Normal file
6
applications/Chat/coati/models/opt/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .opt_actor import OPTActor
|
||||
from .opt_critic import OPTCritic
|
||||
from .opt_lm import OPTLM
|
||||
from .opt_rm import OPTRM
|
||||
|
||||
__all__ = ['OPTActor', 'OPTCritic', 'OPTRM', 'OPTLM']
|
35
applications/Chat/coati/models/opt/opt_actor.py
Normal file
35
applications/Chat/coati/models/opt/opt_actor.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import Optional
|
||||
|
||||
from transformers.models.opt.configuration_opt import OPTConfig
|
||||
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
||||
|
||||
from ..base import Actor
|
||||
|
||||
|
||||
class OPTActor(Actor):
|
||||
"""
|
||||
OPT Actor model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (OPTConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[OPTConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = OPTForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = OPTForCausalLM(config)
|
||||
else:
|
||||
model = OPTForCausalLM(OPTConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
38
applications/Chat/coati/models/opt/opt_critic.py
Normal file
38
applications/Chat/coati/models/opt/opt_critic.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.models.opt.configuration_opt import OPTConfig
|
||||
from transformers.models.opt.modeling_opt import OPTModel
|
||||
|
||||
from ..base import Critic
|
||||
|
||||
|
||||
class OPTCritic(Critic):
|
||||
"""
|
||||
OPT Critic model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (OPTConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[OPTConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
**kwargs) -> None:
|
||||
if pretrained is not None:
|
||||
model = OPTModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = OPTModel(config)
|
||||
else:
|
||||
model = OPTModel(OPTConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
35
applications/Chat/coati/models/opt/opt_lm.py
Normal file
35
applications/Chat/coati/models/opt/opt_lm.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import Optional
|
||||
|
||||
from transformers.models.opt.configuration_opt import OPTConfig
|
||||
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
||||
|
||||
from ..base import LM
|
||||
|
||||
|
||||
class OPTLM(LM):
|
||||
"""
|
||||
OPT language model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (OPTConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[OPTConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = OPTForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = OPTForCausalLM(config)
|
||||
else:
|
||||
model = OPTForCausalLM(OPTConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
38
applications/Chat/coati/models/opt/opt_rm.py
Normal file
38
applications/Chat/coati/models/opt/opt_rm.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers import OPTConfig, OPTModel
|
||||
|
||||
from ..base import RewardModel
|
||||
|
||||
|
||||
class OPTRM(RewardModel):
|
||||
"""
|
||||
OPT Reward model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (OPTConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[OPTConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = OPTModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = OPTModel(config)
|
||||
else:
|
||||
model = OPTModel(OPTConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
|
||||
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.word_embed_proj_dim + 1))
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
92
applications/Chat/coati/models/utils.py
Normal file
92
applications/Chat/coati/models/utils.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
import loralib as lora
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def compute_approx_kl(log_probs: torch.Tensor,
|
||||
log_probs_base: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
Compute the approximate KL divergence between two distributions.
|
||||
Schulman blog: http://joschu.net/blog/kl-approx.html
|
||||
|
||||
Args:
|
||||
log_probs: Log probabilities of the new distribution.
|
||||
log_probs_base: Log probabilities of the base distribution.
|
||||
action_mask: Mask for actions.
|
||||
"""
|
||||
|
||||
log_ratio = log_probs - log_probs_base
|
||||
approx_kl = (log_ratio.exp() - 1) - log_ratio
|
||||
if action_mask is not None:
|
||||
approx_kl = masked_mean(approx_kl, action_mask, dim=1)
|
||||
return approx_kl
|
||||
approx_kl = approx_kl.mean(dim=1)
|
||||
return approx_kl
|
||||
|
||||
|
||||
def compute_reward(r: Union[torch.Tensor, float],
|
||||
kl_coef: float,
|
||||
log_probs: torch.Tensor,
|
||||
log_probs_base: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if kl_coef <= 0.0:
|
||||
return r
|
||||
kl = compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
|
||||
reward = r - kl_coef * kl
|
||||
return reward
|
||||
|
||||
|
||||
def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||
log_probs = F.log_softmax(logits, dim=-1)
|
||||
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
|
||||
return log_probs_labels.squeeze(-1)
|
||||
|
||||
|
||||
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
|
||||
tensor = tensor * mask
|
||||
tensor = tensor.sum(dim=dim)
|
||||
mask_sum = mask.sum(dim=dim)
|
||||
mean = tensor / (mask_sum + 1e-8)
|
||||
return mean
|
||||
|
||||
|
||||
def masked_normalize(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor:
|
||||
tensor = tensor * mask
|
||||
mean = masked_mean(tensor, mask, dim=dim)
|
||||
mean_centered = tensor - mean
|
||||
var = masked_mean(mean_centered**2, mask, dim=dim)
|
||||
return mean_centered * var.clamp(min=eps).rsqrt()
|
||||
|
||||
|
||||
def normalize(tensor: torch.Tensor, dim: int = 0, eps: float = 1e-8) -> torch.Tensor:
|
||||
mean = tensor.mean(dim)
|
||||
mean_centered = tensor - mean
|
||||
var = (mean_centered**2).mean(dim)
|
||||
norm = mean_centered * var.clamp(min=eps).rsqrt()
|
||||
return norm
|
||||
|
||||
|
||||
def convert_to_lora(model: nn.Module,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
lora_rank: int = 16,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.,
|
||||
fan_in_fan_out: bool = False,
|
||||
merge_weights: bool = True):
|
||||
if lora_rank > min(input_size, output_size):
|
||||
raise ValueError(f"LoRA rank {lora_rank} must be less or equal than {min(input_size, output_size)}")
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
module._modules[name] = lora.Linear(input_size,
|
||||
output_size,
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
fan_in_fan_out=fan_in_fan_out,
|
||||
merge_weights=merge_weights)
|
Reference in New Issue
Block a user