mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[ColossalChat] Update RLHF V2 (#5286)
* Add dpo. Fix sft, ppo, lora. Refactor all * fix and tested ppo * 2 nd round refactor * add ci tests * fix ci * fix ci * fix readme, style * fix readme style * fix style, fix benchmark * reproduce benchmark result, remove useless files * rename to ColossalChat * use new image * fix ci workflow * fix ci * use local model/tokenizer for ci tests * fix ci * fix ci * fix ci * fix ci timeout * fix rm progress bar. fix ci timeout * fix ci * fix ci typo * remove 3d plugin from ci temporary * test environment * cannot save optimizer * support chat template * fix readme * fix path * test ci locally * restore build_or_pr * fix ci data path * fix benchmark * fix ci, move ci tests to 3080, disable fast tokenizer * move ci to 85 * support flash attention 2 * add all-in-one data preparation script. Fix colossal-llama2-chat chat template * add hardware requirements * move ci test data * fix save_model, add unwrap * fix missing bos * fix missing bos; support grad accumulation with gemini * fix ci * fix ci * fix ci * fix llama2 chat template config * debug sft * debug sft * fix colossalai version requirement * fix ci * add sanity check to prevent NaN loss * fix requirements * add dummy data generation script * add dummy data generation script * add dummy data generation script * add dummy data generation script * update readme * update readme * update readme and ignore * fix logger bug * support parallel_output * modify data preparation logic * fix tokenization * update lr * fix inference * run pre-commit --------- Co-authored-by: Tong Li <tong.li352711588@gmail.com>
This commit is contained in:
240
applications/ColossalChat/examples/community/peft/easy_dataset.py
Executable file
240
applications/ColossalChat/examples/community/peft/easy_dataset.py
Executable file
@@ -0,0 +1,240 @@
|
||||
import copy
|
||||
import json
|
||||
from typing import Dict, Sequence
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
|
||||
def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer, max_length: int = 512) -> Dict:
|
||||
"""Tokenize a list of strings."""
|
||||
tokenized_list = [
|
||||
tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
)
|
||||
for text in strings
|
||||
]
|
||||
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
||||
input_ids_lens = labels_lens = [
|
||||
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
||||
]
|
||||
return dict(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
input_ids_lens=input_ids_lens,
|
||||
labels_lens=labels_lens,
|
||||
)
|
||||
|
||||
|
||||
def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: AutoTokenizer, max_length: int = 512) -> Dict:
|
||||
"""Preprocess the data by tokenizing."""
|
||||
examples = [s + t for s, t in zip(sources, targets)]
|
||||
examples_tokenized, sources_tokenized = [
|
||||
_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)
|
||||
]
|
||||
input_ids = examples_tokenized["input_ids"]
|
||||
labels = copy.deepcopy(input_ids)
|
||||
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
||||
label[:source_len] = IGNORE_INDEX
|
||||
return dict(input_ids=input_ids, labels=labels)
|
||||
|
||||
|
||||
class EasySupervisedDataset(Dataset):
|
||||
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 512) -> None:
|
||||
super(EasySupervisedDataset, self).__init__()
|
||||
with open(data_file, "r", encoding="UTF-8") as f:
|
||||
all_lines = f.readlines()
|
||||
# split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
|
||||
sources, targets = [], []
|
||||
for line in all_lines:
|
||||
if "回答:" in line:
|
||||
sep_index = line.index("回答:")
|
||||
sources.append(line[: sep_index + 3])
|
||||
targets.append(line[sep_index + 3 :] + tokenizer.eos_token)
|
||||
else:
|
||||
sources.append(line)
|
||||
targets.append("" + tokenizer.eos_token)
|
||||
data_dict = preprocess(sources, targets, tokenizer, max_length)
|
||||
|
||||
self.input_ids = data_dict["input_ids"]
|
||||
self.labels = data_dict["labels"]
|
||||
self.data_file = data_file
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_ids)
|
||||
|
||||
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
||||
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
||||
|
||||
def __repr__(self):
|
||||
return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})"
|
||||
|
||||
def __str__(self):
|
||||
return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})"
|
||||
|
||||
|
||||
class EasyPromptsDataset(Dataset):
|
||||
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 96) -> None:
|
||||
super(EasyPromptsDataset, self).__init__()
|
||||
with open(data_file, "r", encoding="UTF-8") as f:
|
||||
all_lines = f.readlines()
|
||||
all_lines = [line if "回答:" not in line else line[: line.index("回答:") + 3] for line in all_lines]
|
||||
self.prompts = [
|
||||
tokenizer(line, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True)[
|
||||
"input_ids"
|
||||
]
|
||||
.to(torch.cuda.current_device())
|
||||
.squeeze(0)
|
||||
for line in tqdm(all_lines)
|
||||
]
|
||||
self.data_file = data_file
|
||||
|
||||
def __len__(self):
|
||||
return len(self.prompts)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.prompts[idx]
|
||||
|
||||
def __repr__(self):
|
||||
return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})"
|
||||
|
||||
def __str__(self):
|
||||
return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})"
|
||||
|
||||
|
||||
class EasyRewardDataset(Dataset):
|
||||
def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None, max_length=512) -> None:
|
||||
super(EasyRewardDataset, self).__init__()
|
||||
self.chosen = []
|
||||
self.reject = []
|
||||
if special_token is None:
|
||||
self.end_token = tokenizer.eos_token
|
||||
else:
|
||||
self.end_token = special_token
|
||||
print(self.end_token)
|
||||
# read all lines in the train_file to a list
|
||||
with open(train_file, "r", encoding="UTF-8") as f:
|
||||
all_lines = f.readlines()
|
||||
for line in tqdm(all_lines):
|
||||
data = json.loads(line)
|
||||
prompt = "提问:" + data["prompt"] + " 回答:"
|
||||
|
||||
chosen = prompt + data["chosen"] + self.end_token
|
||||
chosen_token = tokenizer(
|
||||
chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
self.chosen.append(
|
||||
{"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
|
||||
)
|
||||
|
||||
reject = prompt + data["rejected"] + self.end_token
|
||||
reject_token = tokenizer(
|
||||
reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
self.reject.append(
|
||||
{"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
length = len(self.chosen)
|
||||
return length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return (
|
||||
self.chosen[idx]["input_ids"],
|
||||
self.chosen[idx]["attention_mask"],
|
||||
self.reject[idx]["input_ids"],
|
||||
self.reject[idx]["attention_mask"],
|
||||
)
|
||||
|
||||
# python representation of the object and the string representation of the object
|
||||
def __repr__(self):
|
||||
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
|
||||
|
||||
def __str__(self):
|
||||
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
|
||||
|
||||
|
||||
"""
|
||||
Easy SFT just accept a text file which can be read line by line. However the datasets will group texts together to max_length so LLM will learn the texts meaning better.
|
||||
If individual lines are not related, just set is_group_texts to False.
|
||||
"""
|
||||
|
||||
|
||||
class EasySFTDataset(Dataset):
|
||||
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_group_texts=True) -> None:
|
||||
super().__init__()
|
||||
# read the data_file line by line
|
||||
with open(data_file, "r", encoding="UTF-8") as f:
|
||||
# encode the text data line by line and put raw python list input_ids only to raw_input_ids list
|
||||
raw_input_ids = []
|
||||
for line in f:
|
||||
encoded_ids = tokenizer.encode(line)
|
||||
# if the encoded_ids is longer than max_length, then split it into several parts
|
||||
if len(encoded_ids) > max_length:
|
||||
for i in range(0, len(encoded_ids), max_length):
|
||||
raw_input_ids.append(encoded_ids[i : i + max_length])
|
||||
else:
|
||||
raw_input_ids.append(encoded_ids)
|
||||
|
||||
grouped_input_ids = []
|
||||
current_input_ids = []
|
||||
attention_mask = []
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
if is_group_texts:
|
||||
for input_ids in raw_input_ids:
|
||||
if len(current_input_ids) + len(input_ids) > max_length:
|
||||
# pad the current_input_ids to max_length with tokenizer.pad_token_id
|
||||
padded_length = max_length - len(current_input_ids)
|
||||
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
||||
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
|
||||
attention_mask.append(
|
||||
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
|
||||
)
|
||||
current_input_ids = []
|
||||
else:
|
||||
current_input_ids.extend(input_ids)
|
||||
if len(current_input_ids) > 0:
|
||||
padded_length = max_length - len(current_input_ids)
|
||||
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
||||
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
|
||||
attention_mask.append(
|
||||
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
|
||||
)
|
||||
else:
|
||||
# just append the raw_input_ids to max_length
|
||||
for input_ids in raw_input_ids:
|
||||
padded_length = max_length - len(input_ids)
|
||||
input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
||||
attention_mask.append(
|
||||
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
|
||||
)
|
||||
grouped_input_ids.append(torch.tensor(input_ids, dtype=torch.long))
|
||||
self.input_ids = grouped_input_ids
|
||||
self.labels = copy.deepcopy(self.input_ids)
|
||||
self.file_name = data_file
|
||||
self.attention_mask = attention_mask
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_ids)
|
||||
|
||||
# get item from dataset
|
||||
def __getitem__(self, idx):
|
||||
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
|
||||
|
||||
# generate the dataset description to be printed by print in python
|
||||
def __repr__(self):
|
||||
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
|
||||
|
||||
# generate the dataset description to be printed by print in python
|
||||
def __str__(self):
|
||||
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
|
Reference in New Issue
Block a user