mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 09:29:47 +00:00
[chat] use official transformers and fix some issues (#4117)
* feat: remove on_learn_epoch fn as not used * revert: add _on_learn_epoch fn * feat: remove NaiveStrategy * test: update train_prompts tests * fix: remove prepare_llama_tokenizer_and_embedding * test: add lora arg * feat: remove roberta support in train_prompts due to runtime errs * feat: remove deberta & roberta in rm as not used * test: remove deberta and roberta tests * feat: remove deberta and roberta models as not used * fix: remove calls to roberta * fix: remove prepare_llama_tokenizer_and_embedding * chore: update transformers version * docs: update transformers version * fix: fix actor inference * fix: fix ci * feat: change llama pad token to unk * revert: revert ddp setup_distributed * fix: change llama pad token to unk * revert: undo unnecessary changes * fix: use pip to install transformers
This commit is contained in:
@@ -1,4 +0,0 @@
|
||||
from .deberta_critic import DebertaCritic
|
||||
from .deberta_rm import DebertaRM
|
||||
|
||||
__all__ = ['DebertaCritic', 'DebertaRM']
|
@@ -1,36 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers import DebertaV2Config, DebertaV2Model
|
||||
|
||||
from ..base import Critic
|
||||
|
||||
|
||||
class DebertaCritic(Critic):
|
||||
"""
|
||||
Deberta Critic model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (DebertaV2Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the LO-RA decomposition.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[DebertaV2Config] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = DebertaV2Model.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = DebertaV2Model(config)
|
||||
else:
|
||||
model = DebertaV2Model(DebertaV2Config())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
@@ -1,37 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers import DebertaV2Config, DebertaV2Model
|
||||
|
||||
from ..base import RewardModel
|
||||
|
||||
|
||||
class DebertaRM(RewardModel):
|
||||
"""
|
||||
Deberta Reward model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (DebertaV2Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the LO-RA decomposition.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[DebertaV2Config] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = DebertaV2Model.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = DebertaV2Model(config)
|
||||
else:
|
||||
model = DebertaV2Model(DebertaV2Config())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
@@ -1,5 +0,0 @@
|
||||
from .roberta_actor import RoBERTaActor
|
||||
from .roberta_critic import RoBERTaCritic
|
||||
from .roberta_rm import RoBERTaRM
|
||||
|
||||
__all__ = ['RoBERTaActor', 'RoBERTaCritic', 'RoBERTaRM']
|
@@ -1,35 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from transformers.models.roberta.configuration_roberta import RobertaConfig
|
||||
from transformers.models.roberta.modeling_roberta import RobertaForCausalLM
|
||||
|
||||
from ..base import Actor
|
||||
|
||||
class RoBERTaActor(Actor):
|
||||
"""
|
||||
RoBERTa Actor model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (RoBERTaConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[RobertaConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = RobertaForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = RobertaForCausalLM(config)
|
||||
else:
|
||||
model = RobertaForCausalLM(RobertaConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
@@ -1,38 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.models.roberta.configuration_roberta import RobertaConfig
|
||||
from transformers.models.roberta.modeling_roberta import RobertaModel
|
||||
|
||||
from ..base import Critic
|
||||
|
||||
|
||||
class RoBERTaCritic(Critic):
|
||||
"""
|
||||
RoBERTa Critic model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (RoBERTa Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[RobertaConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
**kwargs) -> None:
|
||||
if pretrained is not None:
|
||||
model = RobertaModel.from_pretrained(pretrained, add_pooling_layer=False)
|
||||
elif config is not None:
|
||||
model = RobertaModel(config)
|
||||
else:
|
||||
model = RobertaModel(RobertaConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
@@ -1,39 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers import RobertaConfig, RobertaModel
|
||||
|
||||
|
||||
from ..base import RewardModel
|
||||
|
||||
|
||||
class RoBERTaRM(RewardModel):
|
||||
"""
|
||||
RoBERTa Reward model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (RoBERTaConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[RobertaConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = RobertaModel.from_pretrained(pretrained, add_pooling_layer=False)
|
||||
elif config is not None:
|
||||
model = RobertaModel(config)
|
||||
else:
|
||||
model = RobertaModel(RobertaConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
value_head.weight.data.normal_(mean=0.0, std=1/(model.config.hidden_size + 1))
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
@@ -9,10 +9,8 @@ from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
|
||||
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
|
||||
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
|
||||
from coati.models.opt import OPTRM, OPTActor, OPTCritic
|
||||
from coati.models.roberta import RoBERTaActor, RoBERTaCritic, RoBERTaRM
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||||
from coati.utils import prepare_llama_tokenizer_and_embedding
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer, RobertaTokenizer
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
|
||||
|
||||
|
||||
def is_rank_0() -> bool:
|
||||
@@ -36,8 +34,6 @@ def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_ra
|
||||
actor = OPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||||
elif model == 'llama':
|
||||
actor = LlamaActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||||
elif model == 'roberta':
|
||||
actor = RoBERTaActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||||
else:
|
||||
raise ValueError(f'Unsupported actor model "{model}"')
|
||||
return actor
|
||||
@@ -52,8 +48,6 @@ def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_r
|
||||
critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
||||
elif model == 'llama':
|
||||
critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
||||
elif model == 'roberta':
|
||||
critic = RoBERTaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
||||
else:
|
||||
raise ValueError(f'Unsupported reward model "{model}"')
|
||||
return critic
|
||||
@@ -68,8 +62,6 @@ def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
|
||||
reward_model = OPTRM(pretrained=pretrained, config=config)
|
||||
elif model == 'llama':
|
||||
reward_model = LlamaRM(pretrained=pretrained, config=config)
|
||||
elif model == 'roberta':
|
||||
reward_model = RoBERTaRM(pretrained=pretrained, config=config)
|
||||
else:
|
||||
raise ValueError(f'Unsupported reward model "{model}"')
|
||||
return reward_model
|
||||
@@ -101,8 +93,6 @@ def get_tokenizer_from_args(model: str, **kwargs):
|
||||
elif model == 'llama':
|
||||
pretrain_path = kwargs["pretrain"]
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrain_path)
|
||||
elif model == 'roberta':
|
||||
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{model}"')
|
||||
|
||||
|
@@ -1,3 +0,0 @@
|
||||
from .tokenizer_utils import prepare_llama_tokenizer_and_embedding, smart_tokenizer_and_embedding_resize
|
||||
|
||||
__all__ = ['smart_tokenizer_and_embedding_resize', 'prepare_llama_tokenizer_and_embedding']
|
@@ -1,73 +0,0 @@
|
||||
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import transformers
|
||||
|
||||
DEFAULT_PAD_TOKEN = "[PAD]"
|
||||
DEFAULT_EOS_TOKEN = "</s>"
|
||||
DEFAULT_BOS_TOKEN = "</s>"
|
||||
DEFAULT_UNK_TOKEN = "</s>"
|
||||
|
||||
|
||||
def prepare_llama_tokenizer_and_embedding(
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
model: transformers.PreTrainedModel,
|
||||
special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN),
|
||||
):
|
||||
"""prepare llama tokenizer and embedding.
|
||||
|
||||
"""
|
||||
|
||||
if tokenizer.pad_token is None:
|
||||
smart_tokenizer_and_embedding_resize(
|
||||
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
)
|
||||
|
||||
tokenizer.add_special_tokens({
|
||||
"eos_token": DEFAULT_EOS_TOKEN,
|
||||
"bos_token": DEFAULT_BOS_TOKEN,
|
||||
"unk_token": DEFAULT_UNK_TOKEN,
|
||||
})
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def smart_tokenizer_and_embedding_resize(
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
model: transformers.PreTrainedModel,
|
||||
special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN),
|
||||
):
|
||||
"""Resize tokenizer and embedding.
|
||||
|
||||
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
||||
"""
|
||||
|
||||
if tokenizer.pad_token is None:
|
||||
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
||||
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
if num_new_tokens > 0:
|
||||
input_embeddings = model.get_input_embeddings().weight.data
|
||||
output_embeddings = model.get_output_embeddings().weight.data
|
||||
|
||||
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||
|
||||
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
||||
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
Reference in New Issue
Block a user