ColossalAI/applications/ColossalChat/examples/community/peft/easy_dataset.py
YeAnbang df5e9c53cf
[ColossalChat] Update RLHF V2 (#5286)
* Add dpo. Fix sft, ppo, lora. Refactor all

* fix and tested ppo

* 2 nd round refactor

* add ci tests

* fix ci

* fix ci

* fix readme, style

* fix readme style

* fix style, fix benchmark

* reproduce benchmark result, remove useless files

* rename to ColossalChat

* use new image

* fix ci workflow

* fix ci

* use local model/tokenizer for ci tests

* fix ci

* fix ci

* fix ci

* fix ci timeout

* fix rm progress bar. fix ci timeout

* fix ci

* fix ci typo

* remove 3d plugin from ci temporary

* test environment

* cannot save optimizer

* support chat template

* fix readme

* fix path

* test ci locally

* restore build_or_pr

* fix ci data path

* fix benchmark

* fix ci, move ci tests to 3080, disable fast tokenizer

* move ci to 85

* support flash attention 2

* add all-in-one data preparation script. Fix colossal-llama2-chat chat template

* add hardware requirements

* move ci test data

* fix save_model, add unwrap

* fix missing bos

* fix missing bos; support grad accumulation with gemini

* fix ci

* fix ci

* fix ci

* fix llama2 chat template config

* debug sft

* debug sft

* fix colossalai version requirement

* fix ci

* add sanity check to prevent NaN loss

* fix requirements

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* update readme

* update readme

* update readme and ignore

* fix logger bug

* support parallel_output

* modify data preparation logic

* fix tokenization

* update lr

* fix inference

* run pre-commit

---------

Co-authored-by: Tong Li <tong.li352711588@gmail.com>
2024-03-29 14:12:29 +08:00

241 lines
9.9 KiB
Python
Executable File

import copy
import json
from typing import Dict, Sequence
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
IGNORE_INDEX = -100
def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer, max_length: int = 512) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=max_length,
truncation=True,
)
for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: AutoTokenizer, max_length: int = 512) -> Dict:
"""Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [
_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)
]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
label[:source_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=labels)
class EasySupervisedDataset(Dataset):
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 512) -> None:
super(EasySupervisedDataset, self).__init__()
with open(data_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines()
# split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
sources, targets = [], []
for line in all_lines:
if "回答:" in line:
sep_index = line.index("回答:")
sources.append(line[: sep_index + 3])
targets.append(line[sep_index + 3 :] + tokenizer.eos_token)
else:
sources.append(line)
targets.append("" + tokenizer.eos_token)
data_dict = preprocess(sources, targets, tokenizer, max_length)
self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]
self.data_file = data_file
def __len__(self):
return len(self.input_ids)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
def __repr__(self):
return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})"
def __str__(self):
return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})"
class EasyPromptsDataset(Dataset):
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 96) -> None:
super(EasyPromptsDataset, self).__init__()
with open(data_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines()
all_lines = [line if "回答:" not in line else line[: line.index("回答:") + 3] for line in all_lines]
self.prompts = [
tokenizer(line, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True)[
"input_ids"
]
.to(torch.cuda.current_device())
.squeeze(0)
for line in tqdm(all_lines)
]
self.data_file = data_file
def __len__(self):
return len(self.prompts)
def __getitem__(self, idx):
return self.prompts[idx]
def __repr__(self):
return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})"
def __str__(self):
return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})"
class EasyRewardDataset(Dataset):
def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None, max_length=512) -> None:
super(EasyRewardDataset, self).__init__()
self.chosen = []
self.reject = []
if special_token is None:
self.end_token = tokenizer.eos_token
else:
self.end_token = special_token
print(self.end_token)
# read all lines in the train_file to a list
with open(train_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines()
for line in tqdm(all_lines):
data = json.loads(line)
prompt = "提问:" + data["prompt"] + " 回答:"
chosen = prompt + data["chosen"] + self.end_token
chosen_token = tokenizer(
chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
)
self.chosen.append(
{"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
)
reject = prompt + data["rejected"] + self.end_token
reject_token = tokenizer(
reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
)
self.reject.append(
{"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
)
def __len__(self):
length = len(self.chosen)
return length
def __getitem__(self, idx):
return (
self.chosen[idx]["input_ids"],
self.chosen[idx]["attention_mask"],
self.reject[idx]["input_ids"],
self.reject[idx]["attention_mask"],
)
# python representation of the object and the string representation of the object
def __repr__(self):
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
def __str__(self):
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
"""
Easy SFT just accept a text file which can be read line by line. However the datasets will group texts together to max_length so LLM will learn the texts meaning better.
If individual lines are not related, just set is_group_texts to False.
"""
class EasySFTDataset(Dataset):
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_group_texts=True) -> None:
super().__init__()
# read the data_file line by line
with open(data_file, "r", encoding="UTF-8") as f:
# encode the text data line by line and put raw python list input_ids only to raw_input_ids list
raw_input_ids = []
for line in f:
encoded_ids = tokenizer.encode(line)
# if the encoded_ids is longer than max_length, then split it into several parts
if len(encoded_ids) > max_length:
for i in range(0, len(encoded_ids), max_length):
raw_input_ids.append(encoded_ids[i : i + max_length])
else:
raw_input_ids.append(encoded_ids)
grouped_input_ids = []
current_input_ids = []
attention_mask = []
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
if is_group_texts:
for input_ids in raw_input_ids:
if len(current_input_ids) + len(input_ids) > max_length:
# pad the current_input_ids to max_length with tokenizer.pad_token_id
padded_length = max_length - len(current_input_ids)
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
attention_mask.append(
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
)
current_input_ids = []
else:
current_input_ids.extend(input_ids)
if len(current_input_ids) > 0:
padded_length = max_length - len(current_input_ids)
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
attention_mask.append(
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
)
else:
# just append the raw_input_ids to max_length
for input_ids in raw_input_ids:
padded_length = max_length - len(input_ids)
input_ids.extend([tokenizer.pad_token_id] * padded_length)
attention_mask.append(
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
)
grouped_input_ids.append(torch.tensor(input_ids, dtype=torch.long))
self.input_ids = grouped_input_ids
self.labels = copy.deepcopy(self.input_ids)
self.file_name = data_file
self.attention_mask = attention_mask
def __len__(self):
return len(self.input_ids)
# get item from dataset
def __getitem__(self, idx):
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
# generate the dataset description to be printed by print in python
def __repr__(self):
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
# generate the dataset description to be printed by print in python
def __str__(self):
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"