mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
Merge branch 'main' into feature/shardformer
This commit is contained in:
@@ -19,7 +19,7 @@ import torch
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .utils import is_rank_0, jload
|
||||
@@ -71,6 +71,42 @@ def _preprocess(sources: Sequence[str],
|
||||
return sequences_token["input_ids"], labels, sequences_token["attention_mask"]
|
||||
|
||||
|
||||
def _preprocess_chatglm(sources: Sequence[str],
|
||||
targets: Sequence[str],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_length: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preprocess the data by tokenizing.
|
||||
None for attention mask, ChatGLM will calculate attention mask according to input ids
|
||||
"""
|
||||
|
||||
labels = []
|
||||
input_ids = []
|
||||
for source, target in zip(sources, targets):
|
||||
source_id = tokenizer.encode(text=source, add_special_tokens=False)
|
||||
target_id = tokenizer.encode(text=target, add_special_tokens=False)
|
||||
input_id = tokenizer.build_inputs_with_special_tokens(source_id, target_id)
|
||||
# truncate
|
||||
sp_token_list = [tokenizer.gmask_token_id, tokenizer.bos_token_id]
|
||||
truncate_length = max(0, len(input_id) - max_length)
|
||||
input_id = input_id[truncate_length: ]
|
||||
if truncate_length == len(source_id) + 1:
|
||||
input_id = sp_token_list + input_id[1: ]
|
||||
elif truncate_length > len(source_id) + 1:
|
||||
input_id = sp_token_list + input_id[2: ]
|
||||
|
||||
context_length = input_id.index(tokenizer.bos_token_id)
|
||||
mask_position = context_length - 1
|
||||
label = [IGNORE_INDEX] * context_length + input_id[mask_position+1:]
|
||||
|
||||
pad_len = max_length - len(input_id)
|
||||
input_id = input_id + [tokenizer.pad_token_id] * pad_len
|
||||
input_ids.append(input_id)
|
||||
labels.append(label + [IGNORE_INDEX] * pad_len)
|
||||
return torch.tensor(input_ids), torch.tensor(labels), None
|
||||
|
||||
|
||||
class SFTDataset(Dataset):
|
||||
"""
|
||||
Dataset for sft model
|
||||
@@ -94,18 +130,25 @@ class SFTDataset(Dataset):
|
||||
data["completion"] + tokenizer.eos_token
|
||||
for data in tqdm(dataset, disable=not is_rank_0())
|
||||
]
|
||||
|
||||
self.input_ids, self.labels, self.attention_mask = \
|
||||
_preprocess(sources, targets, tokenizer, max_length)
|
||||
if isinstance(tokenizer, ChatGLMTokenizer):
|
||||
self.input_ids, self.labels, self.attention_mask = \
|
||||
_preprocess_chatglm(sources, targets, tokenizer, max_length)
|
||||
else:
|
||||
self.input_ids, self.labels, self.attention_mask = \
|
||||
_preprocess(sources, targets, tokenizer, max_length)
|
||||
|
||||
def __len__(self):
|
||||
length = self.input_ids.shape[0]
|
||||
return length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return dict(input_ids=self.input_ids[idx],
|
||||
labels=self.labels[idx],
|
||||
attention_mask=self.attention_mask[idx])
|
||||
if self.attention_mask is not None:
|
||||
return dict(input_ids=self.input_ids[idx],
|
||||
labels=self.labels[idx],
|
||||
attention_mask=self.attention_mask[idx])
|
||||
else:
|
||||
return dict(input_ids=self.input_ids[idx],
|
||||
labels=self.labels[idx])
|
||||
|
||||
|
||||
class SupervisedDataset(Dataset):
|
||||
@@ -137,14 +180,22 @@ class SupervisedDataset(Dataset):
|
||||
]
|
||||
|
||||
logger.info("Tokenizing inputs... This may take some time...")
|
||||
self.input_ids, self.labels, self.attention_mask = \
|
||||
_preprocess(sources, targets, tokenizer, max_length)
|
||||
if isinstance(tokenizer, ChatGLMTokenizer):
|
||||
self.input_ids, self.labels, self.attention_mask = \
|
||||
_preprocess_chatglm(sources, targets, tokenizer, max_length)
|
||||
else:
|
||||
self.input_ids, self.labels, self.attention_mask = \
|
||||
_preprocess(sources, targets, tokenizer, max_length)
|
||||
|
||||
def __len__(self):
|
||||
length = self.input_ids.shape[0]
|
||||
return length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return dict(input_ids=self.input_ids[idx],
|
||||
labels=self.labels[idx],
|
||||
attention_mask=self.attention_mask[idx])
|
||||
if self.attention_mask is not None:
|
||||
return dict(input_ids=self.input_ids[idx],
|
||||
labels=self.labels[idx],
|
||||
attention_mask=self.attention_mask[idx])
|
||||
else:
|
||||
return dict(input_ids=self.input_ids[idx],
|
||||
labels=self.labels[idx])
|
||||
|
3
applications/Chat/coati/models/chatglm/__init__.py
Normal file
3
applications/Chat/coati/models/chatglm/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .chatglm_actor import ChatGLMActor
|
||||
|
||||
__all__ = ['ChatGLMActor']
|
34
applications/Chat/coati/models/chatglm/chatglm_actor.py
Normal file
34
applications/Chat/coati/models/chatglm/chatglm_actor.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from .configuration_chatglm import ChatGLMConfig
|
||||
from .modeling_chatglm import ChatGLMForConditionalGeneration
|
||||
|
||||
from ..base import Actor
|
||||
|
||||
|
||||
class ChatGLMActor(Actor):
|
||||
"""
|
||||
ChatGLM Actor model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (ChatGLMConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
|
||||
do not support lora for now.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[ChatGLMConfig] = None,
|
||||
checkpoint: bool = False) -> None:
|
||||
if pretrained is not None:
|
||||
model = ChatGLMForConditionalGeneration.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = ChatGLMForConditionalGeneration(config)
|
||||
else:
|
||||
model = ChatGLMForConditionalGeneration(ChatGLMConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
super().__init__(model, lora_rank=0, lora_train_bias='none')
|
446
applications/Chat/coati/models/chatglm/chatglm_tokenizer.py
Normal file
446
applications/Chat/coati/models/chatglm/chatglm_tokenizer.py
Normal file
@@ -0,0 +1,446 @@
|
||||
"""
|
||||
This code is copied from https://huggingface.co/THUDM/chatglm-6b/blob/main/tokenization_chatglm.py
|
||||
"""
|
||||
"""Tokenization classes for ChatGLM."""
|
||||
from typing import List, Optional, Union
|
||||
import os
|
||||
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers.utils import logging, PaddingStrategy
|
||||
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
|
||||
from typing import Dict
|
||||
import sentencepiece as spm
|
||||
import numpy as np
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"THUDM/chatglm-6b": 2048,
|
||||
}
|
||||
|
||||
|
||||
class TextTokenizer:
|
||||
def __init__(self, model_path):
|
||||
self.sp = spm.SentencePieceProcessor()
|
||||
self.sp.Load(model_path)
|
||||
self.num_tokens = self.sp.vocab_size()
|
||||
|
||||
def encode(self, text):
|
||||
return self.sp.EncodeAsIds(text)
|
||||
|
||||
def decode(self, ids: List[int]):
|
||||
return self.sp.DecodeIds(ids)
|
||||
|
||||
def tokenize(self, text):
|
||||
return self.sp.EncodeAsPieces(text)
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
return self.sp.DecodePieces(tokens)
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
return [self.sp.PieceToId(token) for token in tokens]
|
||||
|
||||
def convert_token_to_id(self, token):
|
||||
return self.sp.PieceToId(token)
|
||||
|
||||
def convert_id_to_token(self, idx):
|
||||
return self.sp.IdToPiece(idx)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_tokens
|
||||
|
||||
|
||||
class SPTokenizer:
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
num_image_tokens=20000,
|
||||
max_blank_length=80,
|
||||
byte_fallback=True,
|
||||
):
|
||||
assert vocab_file is not None
|
||||
self.vocab_file = vocab_file
|
||||
self.num_image_tokens = num_image_tokens
|
||||
self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "<unused_0>", "<sop>", "<eop>", "<ENC>", "<dBLOCK>"]
|
||||
self.max_blank_length = max_blank_length
|
||||
self.byte_fallback = byte_fallback
|
||||
self.text_tokenizer = TextTokenizer(vocab_file)
|
||||
|
||||
def _get_text_tokenizer(self):
|
||||
return self.text_tokenizer
|
||||
|
||||
@staticmethod
|
||||
def get_blank_token(length: int):
|
||||
assert length >= 2
|
||||
return f"<|blank_{length}|>"
|
||||
|
||||
@staticmethod
|
||||
def get_tab_token():
|
||||
return f"<|tab|>"
|
||||
|
||||
@property
|
||||
def num_text_tokens(self):
|
||||
return self.text_tokenizer.num_tokens
|
||||
|
||||
@property
|
||||
def num_tokens(self):
|
||||
return self.num_image_tokens + self.num_text_tokens
|
||||
|
||||
@staticmethod
|
||||
def _encode_whitespaces(text: str, max_len: int = 80):
|
||||
text = text.replace("\t", SPTokenizer.get_tab_token())
|
||||
for i in range(max_len, 1, -1):
|
||||
text = text.replace(" " * i, SPTokenizer.get_blank_token(i))
|
||||
return text
|
||||
|
||||
def _preprocess(self, text: str, linebreak=True, whitespaces=True):
|
||||
if linebreak:
|
||||
text = text.replace("\n", "<n>")
|
||||
if whitespaces:
|
||||
text = self._encode_whitespaces(text, max_len=self.max_blank_length)
|
||||
return text
|
||||
|
||||
def encode(
|
||||
self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
|
||||
) -> List[int]:
|
||||
"""
|
||||
@param text: Text to encode.
|
||||
@param linebreak: Whether to encode newline (\n) in text.
|
||||
@param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.
|
||||
@param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.
|
||||
@param add_dummy_prefix: Whether to add dummy blank space in the beginning.
|
||||
"""
|
||||
text = self._preprocess(text, linebreak, whitespaces)
|
||||
if not add_dummy_prefix:
|
||||
text = "<n>" + text
|
||||
tmp = self._get_text_tokenizer().encode(text)
|
||||
tokens = [x + self.num_image_tokens for x in tmp]
|
||||
return tokens if add_dummy_prefix else tokens[2:]
|
||||
|
||||
def postprocess(self, text):
|
||||
text = text.replace("<n>", "\n")
|
||||
text = text.replace(SPTokenizer.get_tab_token(), "\t")
|
||||
for i in range(2, self.max_blank_length + 1):
|
||||
text = text.replace(self.get_blank_token(i), " " * i)
|
||||
return text
|
||||
|
||||
def decode(self, text_ids: List[int]) -> str:
|
||||
ids = [int(_id) - self.num_image_tokens for _id in text_ids]
|
||||
ids = [_id for _id in ids if _id >= 0]
|
||||
text = self._get_text_tokenizer().decode(ids)
|
||||
text = self.postprocess(text)
|
||||
return text
|
||||
|
||||
def decode_tokens(self, tokens: List[str]) -> str:
|
||||
text = self._get_text_tokenizer().convert_tokens_to_string(tokens)
|
||||
text = self.postprocess(text)
|
||||
return text
|
||||
|
||||
def tokenize(
|
||||
self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
|
||||
) -> List[str]:
|
||||
"""
|
||||
@param text: Text to encode.
|
||||
@param linebreak: Whether to encode newline (\n) in text.
|
||||
@param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.
|
||||
@param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.
|
||||
@param add_dummy_prefix: Whether to add dummy blank space in the beginning.
|
||||
"""
|
||||
text = self._preprocess(text, linebreak, whitespaces)
|
||||
if not add_dummy_prefix:
|
||||
text = "<n>" + text
|
||||
tokens = self._get_text_tokenizer().tokenize(text)
|
||||
return tokens if add_dummy_prefix else tokens[2:]
|
||||
|
||||
def __getitem__(self, x: Union[int, str]):
|
||||
if isinstance(x, int):
|
||||
if x < self.num_image_tokens:
|
||||
return "<image_{}>".format(x)
|
||||
else:
|
||||
return self.text_tokenizer.convert_id_to_token(x - self.num_image_tokens)
|
||||
elif isinstance(x, str):
|
||||
if x.startswith("<image_") and x.endswith(">") and x[7:-1].isdigit():
|
||||
return int(x[7:-1])
|
||||
else:
|
||||
return self.text_tokenizer.convert_token_to_id(x) + self.num_image_tokens
|
||||
else:
|
||||
raise ValueError("The key should be str or int.")
|
||||
|
||||
|
||||
class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
Construct a ChatGLM tokenizer. Based on byte-level Byte-Pair-Encoding.
|
||||
|
||||
Args:
|
||||
vocab_file (`str`):
|
||||
Path to the vocabulary file.
|
||||
"""
|
||||
|
||||
vocab_files_names = {"vocab_file": "ice_text.model"}
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
do_lower_case=False,
|
||||
remove_space=False,
|
||||
bos_token='<sop>',
|
||||
eos_token='<eop>',
|
||||
end_token='</s>',
|
||||
mask_token='[MASK]',
|
||||
gmask_token='[gMASK]',
|
||||
padding_side="left",
|
||||
pad_token="<pad>",
|
||||
unk_token="<unk>",
|
||||
num_image_tokens=20000,
|
||||
**kwargs
|
||||
) -> None:
|
||||
super().__init__(
|
||||
do_lower_case=do_lower_case,
|
||||
remove_space=remove_space,
|
||||
padding_side=padding_side,
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
end_token=end_token,
|
||||
mask_token=mask_token,
|
||||
gmask_token=gmask_token,
|
||||
pad_token=pad_token,
|
||||
unk_token=unk_token,
|
||||
num_image_tokens=num_image_tokens,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.do_lower_case = do_lower_case
|
||||
self.remove_space = remove_space
|
||||
self.vocab_file = vocab_file
|
||||
|
||||
self.bos_token = bos_token
|
||||
self.eos_token = eos_token
|
||||
self.end_token = end_token
|
||||
self.mask_token = mask_token
|
||||
self.gmask_token = gmask_token
|
||||
|
||||
self.sp_tokenizer = SPTokenizer(vocab_file, num_image_tokens=num_image_tokens)
|
||||
|
||||
""" Initialisation """
|
||||
|
||||
@property
|
||||
def gmask_token_id(self) -> Optional[int]:
|
||||
if self.gmask_token is None:
|
||||
return None
|
||||
return self.convert_tokens_to_ids(self.gmask_token)
|
||||
|
||||
@property
|
||||
def end_token_id(self) -> Optional[int]:
|
||||
"""
|
||||
`Optional[int]`: Id of the end of context token in the vocabulary. Returns `None` if the token has not been
|
||||
set.
|
||||
"""
|
||||
if self.end_token is None:
|
||||
return None
|
||||
return self.convert_tokens_to_ids(self.end_token)
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
""" Returns vocab size """
|
||||
return self.sp_tokenizer.num_tokens
|
||||
|
||||
def get_vocab(self):
|
||||
""" Returns vocab as a dict """
|
||||
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
|
||||
def preprocess_text(self, inputs):
|
||||
if self.remove_space:
|
||||
outputs = " ".join(inputs.strip().split())
|
||||
else:
|
||||
outputs = inputs
|
||||
|
||||
if self.do_lower_case:
|
||||
outputs = outputs.lower()
|
||||
|
||||
return outputs
|
||||
|
||||
def _tokenize(self, text, **kwargs):
|
||||
""" Returns a tokenized string. """
|
||||
text = self.preprocess_text(text)
|
||||
|
||||
seq = self.sp_tokenizer.tokenize(text)
|
||||
|
||||
return seq
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
return self.sp_tokenizer.decode_tokens(tokens)
|
||||
|
||||
def _decode(
|
||||
self,
|
||||
token_ids: Union[int, List[int]],
|
||||
**kwargs
|
||||
) -> str:
|
||||
if isinstance(token_ids, int):
|
||||
token_ids = [token_ids]
|
||||
if len(token_ids) == 0:
|
||||
return ""
|
||||
if self.pad_token_id in token_ids: # remove pad
|
||||
token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
|
||||
return super()._decode(token_ids, **kwargs)
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str) in an id using the vocab. """
|
||||
return self.sp_tokenizer[token]
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
return self.sp_tokenizer[index]
|
||||
|
||||
def save_vocabulary(self, save_directory, filename_prefix=None):
|
||||
"""
|
||||
Save the vocabulary and special tokens file to a directory.
|
||||
|
||||
Args:
|
||||
save_directory (`str`):
|
||||
The directory in which to save the vocabulary.
|
||||
filename_prefix (`str`, *optional*):
|
||||
An optional prefix to add to the named of the saved files.
|
||||
|
||||
Returns:
|
||||
`Tuple(str)`: Paths to the files saved.
|
||||
"""
|
||||
if os.path.isdir(save_directory):
|
||||
vocab_file = os.path.join(
|
||||
save_directory, self.vocab_files_names["vocab_file"]
|
||||
)
|
||||
else:
|
||||
vocab_file = save_directory
|
||||
|
||||
with open(self.vocab_file, 'rb') as fin:
|
||||
proto_str = fin.read()
|
||||
|
||||
with open(vocab_file, "wb") as writer:
|
||||
writer.write(proto_str)
|
||||
|
||||
return (vocab_file,)
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||
adding special tokens. A BERT sequence has the following format:
|
||||
|
||||
- single sequence: `[CLS] X [SEP]`
|
||||
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs to which the special tokens will be added.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
||||
"""
|
||||
gmask_id = self.sp_tokenizer[self.gmask_token]
|
||||
eos_id = self.sp_tokenizer[self.eos_token]
|
||||
token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]]
|
||||
if token_ids_1 is not None:
|
||||
token_ids_0 = token_ids_0 + token_ids_1
|
||||
return token_ids_0
|
||||
|
||||
def _pad(
|
||||
self,
|
||||
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
||||
max_length: Optional[int] = None,
|
||||
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
||||
|
||||
Args:
|
||||
encoded_inputs:
|
||||
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
||||
max_length: maximum length of the returned list and optionally padding length (see below).
|
||||
Will truncate by taking into account the special tokens.
|
||||
padding_strategy: PaddingStrategy to use for padding.
|
||||
|
||||
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
||||
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
||||
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
||||
The tokenizer padding sides are defined in self.padding_side:
|
||||
|
||||
- 'left': pads on the left of the sequences
|
||||
- 'right': pads on the right of the sequences
|
||||
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
||||
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
||||
`>= 7.5` (Volta).
|
||||
return_attention_mask:
|
||||
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
||||
"""
|
||||
# Load from model defaults
|
||||
bos_token_id = self.sp_tokenizer[self.bos_token]
|
||||
mask_token_id = self.sp_tokenizer[self.mask_token]
|
||||
gmask_token_id = self.sp_tokenizer[self.gmask_token]
|
||||
assert self.padding_side == "left"
|
||||
|
||||
required_input = encoded_inputs[self.model_input_names[0]]
|
||||
seq_length = len(required_input)
|
||||
|
||||
if padding_strategy == PaddingStrategy.LONGEST:
|
||||
max_length = len(required_input)
|
||||
|
||||
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
||||
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
||||
|
||||
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
||||
|
||||
# Initialize attention mask if not present.
|
||||
if max_length is not None:
|
||||
if "attention_mask" not in encoded_inputs:
|
||||
if bos_token_id in required_input:
|
||||
context_length = required_input.index(bos_token_id)
|
||||
else:
|
||||
context_length = seq_length
|
||||
attention_mask = np.ones((1, seq_length, seq_length))
|
||||
attention_mask = np.tril(attention_mask)
|
||||
attention_mask[:, :, :context_length] = 1
|
||||
attention_mask = np.bool_(attention_mask < 0.5)
|
||||
encoded_inputs["attention_mask"] = attention_mask
|
||||
|
||||
if "position_ids" not in encoded_inputs:
|
||||
if bos_token_id in required_input:
|
||||
context_length = required_input.index(bos_token_id)
|
||||
else:
|
||||
context_length = seq_length
|
||||
position_ids = np.arange(seq_length, dtype=np.int64)
|
||||
mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
|
||||
if mask_token in required_input:
|
||||
mask_position = required_input.index(mask_token)
|
||||
position_ids[context_length:] = mask_position
|
||||
block_position_ids = np.concatenate(
|
||||
[np.zeros(context_length, dtype=np.int64),
|
||||
np.arange(1, seq_length - context_length + 1, dtype=np.int64)])
|
||||
encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
|
||||
|
||||
if needs_to_be_padded:
|
||||
difference = max_length - len(required_input)
|
||||
|
||||
if "attention_mask" in encoded_inputs:
|
||||
encoded_inputs["attention_mask"] = np.pad(encoded_inputs["attention_mask"],
|
||||
pad_width=[(0, 0), (difference, 0), (difference, 0)],
|
||||
mode='constant', constant_values=True)
|
||||
if "token_type_ids" in encoded_inputs:
|
||||
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
|
||||
"token_type_ids"
|
||||
]
|
||||
if "special_tokens_mask" in encoded_inputs:
|
||||
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
|
||||
if "position_ids" in encoded_inputs:
|
||||
encoded_inputs["position_ids"] = np.pad(encoded_inputs["position_ids"],
|
||||
pad_width=[(0, 0), (difference, 0)])
|
||||
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
||||
|
||||
return encoded_inputs
|
107
applications/Chat/coati/models/chatglm/configuration_chatglm.py
Normal file
107
applications/Chat/coati/models/chatglm/configuration_chatglm.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
This code is copied from https://huggingface.co/THUDM/chatglm-6b/resolve/main/configuration_chatglm.py
|
||||
"""
|
||||
|
||||
""" ChatGLM model configuration """
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class ChatGLMConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`~ChatGLMModel`].
|
||||
It is used to instantiate an ChatGLM model according to the specified arguments, defining the model
|
||||
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
|
||||
the ChatGLM-6B [THUDM/ChatGLM-6B](https://huggingface.co/THUDM/chatglm-6b) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used
|
||||
to control the model outputs. Read the documentation from [`PretrainedConfig`]
|
||||
for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 150528):
|
||||
Vocabulary size of the ChatGLM-6B model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`~ChatGLMModel`] or
|
||||
[`~TFChatGLMModel`].
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the encoder layers and the pooler layer.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 28):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
inner_hidden_size (`int`, *optional*, defaults to 16384):
|
||||
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
max_sequence_length (`int`, *optional*, defaults to 512):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
|
||||
layernorm_epsilon (`float`, *optional*, defaults to 1e-5):
|
||||
The epsilon used by the layer normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether the model should return the last key/values attentions (not used by all models).
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from configuration_chatglm import ChatGLMConfig
|
||||
>>> from modeling_chatglm import ChatGLMModel
|
||||
|
||||
>>> # Initializing a ChatGLM-6B THUDM/ChatGLM-6B style configuration
|
||||
>>> configuration = ChatGLMConfig()
|
||||
|
||||
>>> # Initializing a model from the THUDM/ChatGLM-6B style configuration
|
||||
>>> model = ChatGLMModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```
|
||||
"""
|
||||
model_type = "chatglm"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=130528,
|
||||
hidden_size=4096,
|
||||
num_layers=28,
|
||||
num_attention_heads=32,
|
||||
layernorm_epsilon=1e-5,
|
||||
use_cache=True,
|
||||
bos_token_id=130004,
|
||||
eos_token_id=130005,
|
||||
mask_token_id=130000,
|
||||
gmask_token_id=130001,
|
||||
pad_token_id=3,
|
||||
max_sequence_length=2048,
|
||||
inner_hidden_size=16384,
|
||||
position_encoding_2d=True,
|
||||
quantization_bit=0,
|
||||
pre_seq_len=None,
|
||||
prefix_projection=False,
|
||||
**kwargs
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.max_sequence_length = max_sequence_length
|
||||
self.layernorm_epsilon = layernorm_epsilon
|
||||
self.inner_hidden_size = inner_hidden_size
|
||||
self.use_cache = use_cache
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.mask_token_id = mask_token_id
|
||||
self.gmask_token_id = gmask_token_id
|
||||
self.position_encoding_2d = position_encoding_2d
|
||||
self.quantization_bit = quantization_bit
|
||||
self.pre_seq_len = pre_seq_len
|
||||
self.prefix_projection = prefix_projection
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
**kwargs
|
||||
)
|
1439
applications/Chat/coati/models/chatglm/modeling_chatglm.py
Normal file
1439
applications/Chat/coati/models/chatglm/modeling_chatglm.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -52,9 +52,13 @@ class SFTTrainer(SLTrainer):
|
||||
for batch_id, batch in enumerate(self.train_dataloader):
|
||||
|
||||
batch = to_device(batch, torch.cuda.current_device())
|
||||
outputs = self.model(batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
labels=batch["labels"])
|
||||
if "attention_mask" in batch:
|
||||
outputs = self.model(batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
labels=batch["labels"])
|
||||
else:
|
||||
outputs = self.model(batch["input_ids"],
|
||||
labels=batch["labels"])
|
||||
|
||||
loss = outputs.loss
|
||||
loss = loss / self.accumulation_steps
|
||||
|
@@ -16,10 +16,9 @@
|
||||
"chat": {
|
||||
"GPT": [
|
||||
"language organization",
|
||||
"relevance",
|
||||
"naturalness",
|
||||
"engagingness",
|
||||
"reasonableness"
|
||||
"fidelity"
|
||||
],
|
||||
"Metrics": [
|
||||
"Distinct"
|
||||
@@ -27,7 +26,6 @@
|
||||
},
|
||||
"classification": {
|
||||
"GPT": [
|
||||
"language organization",
|
||||
"relevance",
|
||||
"correctness"
|
||||
],
|
||||
@@ -40,7 +38,6 @@
|
||||
},
|
||||
"closed_qa": {
|
||||
"GPT": [
|
||||
"language organization",
|
||||
"relevance",
|
||||
"correctness"
|
||||
],
|
||||
@@ -53,7 +50,6 @@
|
||||
},
|
||||
"extraction": {
|
||||
"GPT": [
|
||||
"language organization",
|
||||
"relevance",
|
||||
"correctness"
|
||||
],
|
||||
@@ -74,7 +70,20 @@
|
||||
"BLEU",
|
||||
"ROUGE",
|
||||
"BERTScore"
|
||||
]
|
||||
]
|
||||
},
|
||||
"logical_reasoning": {
|
||||
"GPT": [
|
||||
"correctness",
|
||||
"relevance",
|
||||
"reasonableness"
|
||||
],
|
||||
"Metrics": [
|
||||
"BLEU",
|
||||
"ROUGE",
|
||||
"BERTScore",
|
||||
"CHRF"
|
||||
]
|
||||
},
|
||||
"open_qa": {
|
||||
"GPT": [
|
||||
@@ -117,11 +126,79 @@
|
||||
"conciseness"
|
||||
],
|
||||
"Metrics": [
|
||||
"BLEU",
|
||||
"ROUGE",
|
||||
"BERTScore",
|
||||
"CHRF"
|
||||
]
|
||||
]
|
||||
},
|
||||
"Finance": {
|
||||
"GPT": [
|
||||
"relevance",
|
||||
"correctness"
|
||||
],
|
||||
"Metrics": [
|
||||
]
|
||||
},
|
||||
"Law": {
|
||||
"GPT": [
|
||||
"relevance",
|
||||
"correctness"
|
||||
],
|
||||
"Metrics": [
|
||||
]
|
||||
},
|
||||
"Education": {
|
||||
"GPT": [
|
||||
"relevance",
|
||||
"correctness"
|
||||
],
|
||||
"Metrics": [
|
||||
]
|
||||
},
|
||||
"Medical": {
|
||||
"GPT": [
|
||||
"relevance",
|
||||
"correctness"
|
||||
],
|
||||
"Metrics": [
|
||||
]
|
||||
},
|
||||
"STEM": {
|
||||
"GPT": [
|
||||
"relevance",
|
||||
"correctness"
|
||||
],
|
||||
"Metrics": [
|
||||
]
|
||||
},
|
||||
"SocialScience": {
|
||||
"GPT": [
|
||||
"relevance",
|
||||
"correctness"
|
||||
],
|
||||
"Metrics": [
|
||||
]
|
||||
},
|
||||
"Humanity": {
|
||||
"GPT": [
|
||||
"relevance",
|
||||
"correctness"
|
||||
],
|
||||
"Metrics": [
|
||||
]
|
||||
},
|
||||
"Other": {
|
||||
"GPT": [
|
||||
"relevance",
|
||||
"correctness"
|
||||
],
|
||||
"Metrics": [
|
||||
]
|
||||
},
|
||||
"ethics": {
|
||||
"GPT": [
|
||||
"relevance",
|
||||
"correctness"
|
||||
],
|
||||
"Metrics": [
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -26,10 +26,9 @@
|
||||
"chat": {
|
||||
"GPT": [
|
||||
"language organization",
|
||||
"relevance",
|
||||
"naturalness",
|
||||
"engagingness",
|
||||
"reasonableness"
|
||||
"fidelity"
|
||||
],
|
||||
"Metrics": [
|
||||
"Distinct"
|
||||
@@ -45,7 +44,6 @@
|
||||
},
|
||||
"classification": {
|
||||
"GPT": [
|
||||
"language organization",
|
||||
"relevance",
|
||||
"correctness"
|
||||
],
|
||||
@@ -63,7 +61,6 @@
|
||||
},
|
||||
"closed_qa": {
|
||||
"GPT": [
|
||||
"language organization",
|
||||
"relevance",
|
||||
"correctness"
|
||||
],
|
||||
@@ -81,7 +78,6 @@
|
||||
},
|
||||
"extraction": {
|
||||
"GPT": [
|
||||
"language organization",
|
||||
"relevance",
|
||||
"correctness"
|
||||
],
|
||||
@@ -114,6 +110,21 @@
|
||||
"data2text-informativeness"
|
||||
]
|
||||
},
|
||||
"logical_reasoning": {
|
||||
"GPT": [
|
||||
"correctness",
|
||||
"relevance",
|
||||
"reasonableness"
|
||||
],
|
||||
"Metrics": [
|
||||
"BLEU",
|
||||
"ROUGE",
|
||||
"BERTScore",
|
||||
"CHRF"
|
||||
],
|
||||
"UniEval": [
|
||||
]
|
||||
},
|
||||
"open_qa": {
|
||||
"GPT": [
|
||||
"language organization",
|
||||
@@ -176,12 +187,96 @@
|
||||
"CHRF"
|
||||
],
|
||||
"UniEval": [
|
||||
"summarization-coherence",
|
||||
"summarization-consistency",
|
||||
"summarization-fluency",
|
||||
"summarization-relevance",
|
||||
"data2text-naturalness",
|
||||
"data2text-informativeness"
|
||||
]
|
||||
},
|
||||
"Finance": {
|
||||
"GPT": [
|
||||
"relevance",
|
||||
"correctness"
|
||||
],
|
||||
"Metrics": [
|
||||
],
|
||||
"UniEval": [
|
||||
]
|
||||
},
|
||||
"Law": {
|
||||
"GPT": [
|
||||
"relevance",
|
||||
"correctness"
|
||||
],
|
||||
"Metrics": [
|
||||
],
|
||||
"UniEval": [
|
||||
]
|
||||
},
|
||||
"Education": {
|
||||
"GPT": [
|
||||
"relevance",
|
||||
"correctness"
|
||||
],
|
||||
"Metrics": [
|
||||
],
|
||||
"UniEval": [
|
||||
]
|
||||
},
|
||||
"Medical": {
|
||||
"GPT": [
|
||||
"relevance",
|
||||
"correctness"
|
||||
],
|
||||
"Metrics": [
|
||||
],
|
||||
"UniEval": [
|
||||
]
|
||||
},
|
||||
"STEM": {
|
||||
"GPT": [
|
||||
"relevance",
|
||||
"correctness"
|
||||
],
|
||||
"Metrics": [
|
||||
],
|
||||
"UniEval": [
|
||||
]
|
||||
},
|
||||
"SocialScience": {
|
||||
"GPT": [
|
||||
"relevance",
|
||||
"correctness"
|
||||
],
|
||||
"Metrics": [
|
||||
],
|
||||
"UniEval": [
|
||||
]
|
||||
},
|
||||
"Humanity": {
|
||||
"GPT": [
|
||||
"relevance",
|
||||
"correctness"
|
||||
],
|
||||
"Metrics": [
|
||||
],
|
||||
"UniEval": [
|
||||
]
|
||||
},
|
||||
"Other": {
|
||||
"GPT": [
|
||||
"relevance",
|
||||
"correctness"
|
||||
],
|
||||
"Metrics": [
|
||||
],
|
||||
"UniEval": [
|
||||
]
|
||||
},
|
||||
"ethics": {
|
||||
"GPT": [
|
||||
"relevance",
|
||||
"correctness"
|
||||
],
|
||||
"Metrics": [
|
||||
],
|
||||
"UniEval": [
|
||||
]
|
||||
}
|
||||
}
|
||||
|
@@ -26,14 +26,16 @@
|
||||
"relevance": "切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。",
|
||||
"naturalness": "自然(1-5):答案是否自然,并且符合问题给定的身份。",
|
||||
"engagingness": "参与感(1-5):答案是否对前面的对话内容做出了恰当的反应,是否理解对话的语境和背景。",
|
||||
"reasonableness": "合理性(1-5):答案是否能够与前面的对话内容形成逻辑上的衔接,是否符合常理,能否在这个上下文中合理存在。"
|
||||
"reasonableness": "合理性(1-5):答案是否能够与前面的对话内容形成逻辑上的衔接,是否符合常理,能否在这个上下文中合理存在。",
|
||||
"fidelity": "保真度(1-5):答案是否能够严格遵守角色的设定回答给定的请求。"
|
||||
},
|
||||
"CoT": {
|
||||
"language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:",
|
||||
"relevance": "1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。\n2. 阅读答案,确认答案是否直接回答了题目所问的问题。\n3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。\n4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。\n\n切题:",
|
||||
"naturalness": "1. 阅读题目,确定题目提供的身份信息。\n2. 检查答案内容是否符合题目给定的身份。\n3. 根据以上因素,对该回答的自然性进行打分,分数从1到5,其中1表示不自然,5表示非常自然,并符合问题给定的身份。\n\n自然:",
|
||||
"engagingness": "1. 阅读题目,确定对话的语境和背景。\n2. 检查答案是否充分理解对话的语境和背景,能否自然地融入到对话中而不显得突兀。\n3. 根据以上因素,对该回答的参与感进行打分,分数从1到5,其中1表示没有参与感,5表示非常有参与感,并且恰当地理解了对话的语境和背景。\n\n参与感:",
|
||||
"reasonableness": "1. 阅读题目,确定对话的主题以及问题期望的回答方向。\n2. 判断答案是否能够与前面的对话内容形成逻辑上的衔接,是否符合常理,能否在这个上下文中合理存在。\n3. 根据以上因素,对该回答的合理性进行打分,分数从1到5,其中1表示不合理,5表示非常合理,并且能够与前面的对话内容形成逻辑上的衔接,并符合常理。\n\n合理性:"
|
||||
"reasonableness": "1. 阅读题目,确定对话的主题以及问题期望的回答方向。\n2. 判断答案是否能够与前面的对话内容形成逻辑上的衔接,是否符合常理,能否在这个上下文中合理存在。\n3. 根据以上因素,对该回答的合理性进行打分,分数从1到5,其中1表示不合理,5表示非常合理,并且能够与前面的对话内容形成逻辑上的衔接,并符合常理。\n\n合理性:",
|
||||
"fidelity": "1. 仔细阅读问题,了解角色在问题中的设定和表现,包括职业、背景、观点、性格等方面。\n阅读题目的请求,确认回答请求时需要注意的细节。\n3. 对比提供的回答与该角色的设定,评估回答是否能够严格遵守角色的设定。\n4. 结合以上评估结果给出保真度的评分,范围从1到5分,其中1分表示回答与角色设定完全不符,5分表示回答完全符合角色设定且满足给定请求。\n\n保真度:"
|
||||
},
|
||||
"prompt": "你是一个好助手。请你为下面的“补全对话”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}"
|
||||
},
|
||||
|
@@ -26,14 +26,16 @@
|
||||
"relevance": "Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic.",
|
||||
"naturalness": "Naturalness (1-5): whether the answer is natural and fits the identity given by the question.",
|
||||
"engagingness": "Engagingness (1-5): whether the answer responds appropriately to the content of the preceding conversation and whether it understands the context and background of the conversation.",
|
||||
"reasonableness": "Reasonableness (1-5): Whether the answer can form a logical connection with the content of the previous dialogue, whether it is consistent with common sense, and whether it can reasonably exist in this context."
|
||||
"reasonableness": "Reasonableness (1-5): Whether the answer can form a logical connection with the content of the previous dialogue, whether it is consistent with common sense, and whether it can reasonably exist in this context.",
|
||||
"fidelity": "Fidelity (1-5): whether the answer is able to answer the given request in strict compliance with the role setting."
|
||||
},
|
||||
"CoT": {
|
||||
"language organization": "1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.\n2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.\n3. Determine if the answer is relevant to the question or topic and conveys a clear message.\n4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.\n5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.\n6. Evaluate the language organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good language organization and 1 indicates very poor language organization.\n\nLanguage organization:",
|
||||
"relevance": "1. Read the question to determine what the question asks and what aspects of the question need to be answered.\n2. Read the answers to make sure that they directly answer the question asked.\n3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.\n4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all.\n\nRelevance:",
|
||||
"naturalness": "1. Read the question and determine the identity information provided in the question.\n2. Check whether the content of the answer matches the identity given in the question.\n3. Based on the above factors, score the naturalness of the response on a scale from 1 to 5, where 1 means unnatural and 5 means very natural and in accordance with the identity given in the question.\n\nNaturalness:",
|
||||
"engagingness": "1. Read the questions to determine the context and background of the dialogue.\n2. Check that the answer fully understands the context and background of the conversation and that it fits naturally into the conversation without seeming abrupt.\n3. Based on the above factors, rate the response's engagement on a scale from 1 to 5, where 1 means not engaged and 5 means very engaged and appropriately understands the context and background of the conversation.\n\nEngagingness:",
|
||||
"reasonableness": "1. Read the question and determine the topic of the conversation and the direction the question expects the answer to go.\n2. Determine whether the answer can be logically connected to the preceding conversation, whether it makes common sense, and whether it can reasonably exist in this context.\n3. Based on the above factors, rate the reasonableness of the answer on a scale from 1 to 5, where 1 means unreasonable and 5 means very reasonable and able to form a logical connection with the preceding dialogue content and consistent with common sense.\n\nReasonableness:"
|
||||
"reasonableness": "1. Read the question and determine the topic of the conversation and the direction the question expects the answer to go.\n2. Determine whether the answer can be logically connected to the preceding conversation, whether it makes common sense, and whether it can reasonably exist in this context.\n3. Based on the above factors, rate the reasonableness of the answer on a scale from 1 to 5, where 1 means unreasonable and 5 means very reasonable and able to form a logical connection with the preceding dialogue content and consistent with common sense.\n\nReasonableness:",
|
||||
"fidelity": "1. Read the question carefully to understand how the character is set up and represented in the question, including aspects such as occupation, background, point of view, and personality.\n2. Read the question's request and confirm the details that need to be taken into account when answering the request.\n3. Compare the provided answer with the setting of the role and assess whether the answer can strictly adhere to the setting of the role.\n4. Combine the results of the above assessment to give a fidelity score ranging from 1 to 5, where a score of 1 means that the response does not match the persona at all, and a score of 5 means that the response fully complies with the persona and satisfies the given request.\n\nFidelity:"
|
||||
},
|
||||
"prompt": "You are a good assistant. Please rate the given answer to the \"chat\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}"
|
||||
},
|
||||
|
@@ -1,2 +1,3 @@
|
||||
pandas>=1.4.1
|
||||
sentencepiece
|
||||
colossalai==0.3.1
|
@@ -9,13 +9,15 @@ from coati.models.bloom import BLOOMActor
|
||||
from coati.models.gpt import GPTActor
|
||||
from coati.models.llama import LlamaActor
|
||||
from coati.models.opt import OPTActor
|
||||
from coati.models.chatglm import ChatGLMActor
|
||||
from coati.trainer import SFTTrainer
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||||
from datasets import load_dataset
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, AutoModel
|
||||
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
from transformers.trainer import get_scheduler
|
||||
|
||||
@@ -58,6 +60,8 @@ def train(args):
|
||||
model = LlamaActor(pretrained=args.pretrain,
|
||||
lora_rank=args.lora_rank,
|
||||
checkpoint=args.grad_checkpoint)
|
||||
elif args.model == 'chatglm':
|
||||
model = ChatGLMActor(pretrained=args.pretrain)
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
@@ -81,6 +85,9 @@ def train(args):
|
||||
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
|
||||
tokenizer.eos_token = '<\s>'
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
elif args.model == 'chatglm':
|
||||
tokenizer = ChatGLMTokenizer.from_pretrained(
|
||||
"THUDM/chatglm-6b" if args.tokenizer is None else args.tokenizer, trust_remote_code=True)
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
@@ -99,7 +106,6 @@ def train(args):
|
||||
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
|
||||
else:
|
||||
optim = Adam(model.parameters(), lr=args.lr)
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
# configure dataset
|
||||
@@ -185,7 +191,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--strategy',
|
||||
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'],
|
||||
default='colossalai_zero2')
|
||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
|
||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama', 'chatglm'], default='bloom')
|
||||
parser.add_argument('--tokenizer', type=str, default=None)
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--dataset', type=str, default=None)
|
||||
|
@@ -1 +1,2 @@
|
||||
pytest
|
||||
colossalai==0.3.1
|
@@ -2,7 +2,7 @@ transformers>=4.20.1
|
||||
tqdm
|
||||
datasets
|
||||
loralib
|
||||
colossalai>=0.2.4
|
||||
colossalai==0.3.1
|
||||
torch<2.0.0, >=1.12.1
|
||||
langchain
|
||||
tokenizers
|
||||
|
@@ -11,7 +11,7 @@ from coati.dataset.sft_dataset import IGNORE_INDEX, SFTDataset, SupervisedDatase
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, PreTrainedTokenizer
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
|
||||
SFT_DATASET = [
|
||||
{
|
||||
"instruction":
|
||||
@@ -80,6 +80,8 @@ def make_tokenizer(model: str):
|
||||
elif model == "llama":
|
||||
tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
elif model == "chatglm":
|
||||
tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model '{model}'")
|
||||
return tokenizer
|
||||
@@ -93,13 +95,19 @@ def check_content(input_ids_stripped: torch.Tensor, tokenizer: PreTrainedTokeniz
|
||||
elif model == "llama":
|
||||
assert input_ids_stripped[0] == tokenizer.bos_token_id
|
||||
input_ids_stripped = input_ids_stripped[1:]
|
||||
|
||||
elif model == "chatglm":
|
||||
assert input_ids_stripped[0] == tokenizer.bos_token_id
|
||||
assert input_ids_stripped[-1] == tokenizer.eos_token_id
|
||||
input_ids_stripped = input_ids_stripped[1:-1]
|
||||
assert torch.all(input_ids_stripped != tokenizer.pad_token_id)
|
||||
assert torch.all(input_ids_stripped != tokenizer.bos_token_id)
|
||||
assert torch.all(input_ids_stripped != tokenizer.eos_token_id)
|
||||
assert input_ids_stripped != tokenizer.sep_token_id
|
||||
assert input_ids_stripped != tokenizer.cls_token_id
|
||||
assert input_ids_stripped != tokenizer.mask_token_id
|
||||
if model == "chatglm":
|
||||
assert torch.all(input_ids_stripped != tokenizer.mask_token_id)
|
||||
else:
|
||||
assert input_ids_stripped != tokenizer.mask_token_id
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
|
||||
@@ -190,7 +198,8 @@ def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], ma
|
||||
assert torch.all(r_mask)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
|
||||
|
||||
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama", "chatglm"])
|
||||
@pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None])
|
||||
@pytest.mark.parametrize("max_dataset_size", [2])
|
||||
@pytest.mark.parametrize("max_length", [32, 1024])
|
||||
@@ -211,6 +220,19 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size:
|
||||
max_length=max_length)
|
||||
assert len(sft_dataset) == min(max_dataset_size, len(SFT_DATASET))
|
||||
|
||||
if isinstance(tokenizer, ChatGLMTokenizer):
|
||||
for i in range(max_dataset_size):
|
||||
assert isinstance(sft_dataset[i], dict)
|
||||
assert list(sft_dataset[i].keys()) == ["input_ids", "labels"]
|
||||
input_ids = sft_dataset[i]["input_ids"]
|
||||
labels = sft_dataset[i]["labels"]
|
||||
assert input_ids.shape == labels.shape == torch.Size([max_length])
|
||||
|
||||
ignore_mask = labels == IGNORE_INDEX
|
||||
assert input_ids.masked_select(torch.logical_not(ignore_mask))[0] == tokenizer.bos_token_id
|
||||
check_content(input_ids.masked_select(torch.logical_not(ignore_mask)), tokenizer, model)
|
||||
return
|
||||
|
||||
for i in range(max_dataset_size):
|
||||
assert isinstance(sft_dataset[i], dict)
|
||||
assert list(sft_dataset[i].keys()) == ["input_ids", "labels", "attention_mask"]
|
||||
@@ -238,4 +260,7 @@ if __name__ == "__main__":
|
||||
max_datasets_size=8,
|
||||
max_length=256)
|
||||
|
||||
test_prompt_dataset(model="opt", max_datasets_size=2, max_length=128)
|
||||
test_prompt_dataset(model="opt",
|
||||
max_datasets_size=2,
|
||||
max_length=128)
|
||||
|
||||
|
@@ -9,11 +9,12 @@ from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
|
||||
from coati.models.generation import generate
|
||||
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
|
||||
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
|
||||
from coati.models.chatglm import ChatGLMActor
|
||||
from coati.models.lora import LoraLinear, convert_to_lora_module
|
||||
from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
|
||||
from coati.models.opt import OPTRM, OPTActor, OPTCritic
|
||||
from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean
|
||||
|
||||
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seq_len", [32])
|
||||
@@ -24,8 +25,10 @@ from coati.models.utils import calc_action_log_probs, compute_reward, masked_mea
|
||||
lambda: GPTActor(),
|
||||
# HACK: skip llama due to long execution time
|
||||
# lambda: LlamaActor(),
|
||||
lambda: OPTActor()
|
||||
])
|
||||
lambda: OPTActor(),
|
||||
# lambda: ChatGLMActor(),
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize("generate_kwargs", [{
|
||||
"max_length": 64,
|
||||
"use_cache": True,
|
||||
@@ -115,11 +118,13 @@ def test_lora(lora_rank: int, num_dim: int, num_layers: int):
|
||||
lambda: (GPTActor(), GPTCritic(), GPTRM()),
|
||||
# HACK: skip llama due to long execution time
|
||||
# lambda: (LlamaActor(), LlamaCritic(), LlamaRM()),
|
||||
lambda: (OPTActor(), OPTCritic(), OPTRM()),
|
||||
])
|
||||
lambda: (OPTActor(), OPTCritic(), OPTRM()),
|
||||
lambda: (ChatGLMActor(), None, None),
|
||||
])
|
||||
@torch.no_grad()
|
||||
def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], batch_size: int, seq_len: int):
|
||||
|
||||
def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]],
|
||||
batch_size: int,
|
||||
seq_len: int):
|
||||
actor_input = {
|
||||
"input_ids": torch.randint(0, 100, (batch_size, seq_len)),
|
||||
"attention_mask": torch.randint(0, 2, (batch_size, seq_len))
|
||||
@@ -135,20 +140,30 @@ def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], b
|
||||
}
|
||||
|
||||
actor, critic, rm = models_maker()
|
||||
if isinstance(actor, ChatGLMActor):
|
||||
actor = actor.float()
|
||||
tokenizer = ChatGLMTokenizer.from_pretrained( "THUDM/chatglm-6b", trust_remote_code=True)
|
||||
chatglm_special_token = torch.tensor([tokenizer.gmask_token_id, tokenizer.bos_token_id]).repeat(batch_size, 1)
|
||||
actor_input ={
|
||||
"input_ids": torch.cat((torch.randint(0, 100, (batch_size, seq_len//2)), chatglm_special_token, torch.randint(0, 100, (batch_size, seq_len//2 - 2))), dim=1),
|
||||
"attention_mask": torch.randint(0, 2, (batch_size, 1, seq_len, seq_len))
|
||||
}
|
||||
assert isinstance(actor, Actor)
|
||||
base_actor_model = get_base_model(actor)
|
||||
assert isinstance(critic, Critic)
|
||||
base_critic_model = get_base_model(critic)
|
||||
assert isinstance(rm, RewardModel)
|
||||
base_rm_model = get_base_model(rm)
|
||||
|
||||
actor_output = actor(**actor_input)
|
||||
critic_output = critic(**critic_input)
|
||||
rm_output = rm(**rm_input)
|
||||
|
||||
assert actor_output.logits.shape[:2] == (batch_size, seq_len)
|
||||
assert critic_output.shape == (batch_size,)
|
||||
assert rm_output.shape == (batch_size,)
|
||||
|
||||
if critic:
|
||||
assert isinstance(critic, Critic)
|
||||
base_critic_model = get_base_model(critic)
|
||||
critic_output = critic(**critic_input)
|
||||
assert critic_output.shape == (batch_size, )
|
||||
|
||||
if rm:
|
||||
assert isinstance(rm, RewardModel)
|
||||
base_rm_model = get_base_model(rm)
|
||||
rm_output = rm(**rm_input)
|
||||
assert rm_output.shape == (batch_size, )
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [16])
|
||||
@@ -203,4 +218,4 @@ if __name__ == "__main__":
|
||||
|
||||
test_models(models_maker=lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), batch_size=8, seq_len=128)
|
||||
|
||||
test_loss(batch_size=8, seq_len=128, num_labels=100)
|
||||
test_loss(batch_size=8, seq_len=128, num_labels=100)
|
Reference in New Issue
Block a user