mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 18:09:06 +00:00
[chatgpt] add pre-trained model RoBERTa for RLHF stage 2 & 3 (#3223)
* Add RoBERTa for RLHF Stage 2 & 3 (test) RoBERTa for RLHF Stage 2 & 3 (still in testing) * Revert "Add RoBERTa for RLHF Stage 2 & 3 (test)" This reverts commit06741d894d
. * Add RoBERTa for RLHF stage 2 & 3 1. add roberta folder under model folder 2. add roberta option in train_reward_model.py 3. add some test in testci * add test for reward model training * Update test_ci.sh * Revert "Update test_ci.sh" This reverts commit 9c7352b81766f3177d31eeec0ec178a301df966a. * Add RoBERTa for RLHF Stage 2 & 3 (test) RoBERTa for RLHF Stage 2 & 3 (still in testing) * Revert "Add RoBERTa for RLHF Stage 2 & 3 (test)" This reverts commit06741d894d
. * Add RoBERTa for RLHF stage 2 & 3 1. add roberta folder under model folder 2. add roberta option in train_reward_model.py 3. add some test in testci * Update test_ci.sh * Revert "Update test_ci.sh" This reverts commit 9c7352b81766f3177d31eeec0ec178a301df966a. * update roberta with coati
This commit is contained in:
5
applications/Chat/coati/models/roberta/__init__.py
Normal file
5
applications/Chat/coati/models/roberta/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .roberta_actor import RoBERTaActor
|
||||
from .roberta_critic import RoBERTaCritic
|
||||
from .roberta_rm import RoBERTaRM
|
||||
|
||||
__all__ = ['RoBERTaActor', 'RoBERTaCritic', 'RoBERTaRM']
|
35
applications/Chat/coati/models/roberta/roberta_actor.py
Normal file
35
applications/Chat/coati/models/roberta/roberta_actor.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import Optional
|
||||
|
||||
from transformers.models.roberta.configuration_roberta import RobertaConfig
|
||||
from transformers.models.roberta.modeling_roberta import RobertaForCausalLM
|
||||
|
||||
from ..base import Actor
|
||||
|
||||
class RoBERTaActor(Actor):
|
||||
"""
|
||||
RoBERTa Actor model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (RoBERTaConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[RobertaConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = RobertaForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = RobertaForCausalLM(config)
|
||||
else:
|
||||
model = RobertaForCausalLM(RobertaConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
38
applications/Chat/coati/models/roberta/roberta_critic.py
Normal file
38
applications/Chat/coati/models/roberta/roberta_critic.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.models.roberta.configuration_roberta import RobertaConfig
|
||||
from transformers.models.roberta.modeling_roberta import RobertaModel
|
||||
|
||||
from ..base import Critic
|
||||
|
||||
|
||||
class RoBERTaCritic(Critic):
|
||||
"""
|
||||
RoBERTa Critic model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (RoBERTa Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[RobertaConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
**kwargs) -> None:
|
||||
if pretrained is not None:
|
||||
model = RobertaModel.from_pretrained(pretrained, add_pooling_layer=False)
|
||||
elif config is not None:
|
||||
model = RobertaModel(config)
|
||||
else:
|
||||
model = RobertaModel(RobertaConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
39
applications/Chat/coati/models/roberta/roberta_rm.py
Normal file
39
applications/Chat/coati/models/roberta/roberta_rm.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers import RobertaConfig, RobertaModel
|
||||
|
||||
|
||||
from ..base import RewardModel
|
||||
|
||||
|
||||
class RoBERTaRM(RewardModel):
|
||||
"""
|
||||
RoBERTa Reward model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (RoBERTaConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[RobertaConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = RobertaModel.from_pretrained(pretrained, add_pooling_layer=False)
|
||||
elif config is not None:
|
||||
model = RobertaModel(config)
|
||||
else:
|
||||
model = RobertaModel(RobertaConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
value_head.weight.data.normal_(mean=0.0, std=1/(model.config.hidden_size + 1))
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
Reference in New Issue
Block a user