[ColossalChat] Update RLHF V2 (#5286)

* Add dpo. Fix sft, ppo, lora. Refactor all

* fix and tested ppo

* 2 nd round refactor

* add ci tests

* fix ci

* fix ci

* fix readme, style

* fix readme style

* fix style, fix benchmark

* reproduce benchmark result, remove useless files

* rename to ColossalChat

* use new image

* fix ci workflow

* fix ci

* use local model/tokenizer for ci tests

* fix ci

* fix ci

* fix ci

* fix ci timeout

* fix rm progress bar. fix ci timeout

* fix ci

* fix ci typo

* remove 3d plugin from ci temporary

* test environment

* cannot save optimizer

* support chat template

* fix readme

* fix path

* test ci locally

* restore build_or_pr

* fix ci data path

* fix benchmark

* fix ci, move ci tests to 3080, disable fast tokenizer

* move ci to 85

* support flash attention 2

* add all-in-one data preparation script. Fix colossal-llama2-chat chat template

* add hardware requirements

* move ci test data

* fix save_model, add unwrap

* fix missing bos

* fix missing bos; support grad accumulation with gemini

* fix ci

* fix ci

* fix ci

* fix llama2 chat template config

* debug sft

* debug sft

* fix colossalai version requirement

* fix ci

* add sanity check to prevent NaN loss

* fix requirements

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* update readme

* update readme

* update readme and ignore

* fix logger bug

* support parallel_output

* modify data preparation logic

* fix tokenization

* update lr

* fix inference

* run pre-commit

---------

Co-authored-by: Tong Li <tong.li352711588@gmail.com>
This commit is contained in:
YeAnbang
2024-03-29 14:12:29 +08:00
committed by GitHub
parent 36c4bb2893
commit df5e9c53cf
200 changed files with 8848 additions and 8049 deletions

View File

@@ -0,0 +1,30 @@
:warning: **This content may be outdated since the major update of Colossal Chat. We will update this content soon.**
# Add Peft support for SFT and Prompts model training
The original implementation just adopts the loralib and merges the layers into the final model. The huggingface peft is a better lora model implementation and can be easily training and distributed.
Since reward model is relative small, I just keep it as original one. I suggest train full model to get the proper reward/critic model.
# Preliminary installation
Since the current pypi peft package(0.2) has some bugs, please install the peft package using source.
```
git clone https://github.com/huggingface/peft
cd peft
pip install .
```
# Usage
For SFT training, just call train_peft_sft.py
Its arguments are almost identical to train_sft.py instead adding a new eval_dataset if you have an eval_dataset file. The data file is just a plain datafile, please check the format in the easy_dataset.py.
For stage-3 rlhf training, call train_peft_prompts.py.
Its arguments are almost identical to train_prompts.py. The only difference is that I use text files to indicate the prompt and pretrained data file. The models are included in easy_models.py. Currently only bloom models are tested, but technically gpt2/opt/llama should be supported.
# Dataformat
Please refer the formats in test_sft.txt, test_prompts.txt, test_pretrained.txt.

View File

@@ -0,0 +1,240 @@
import copy
import json
from typing import Dict, Sequence
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
IGNORE_INDEX = -100
def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer, max_length: int = 512) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=max_length,
truncation=True,
)
for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: AutoTokenizer, max_length: int = 512) -> Dict:
"""Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [
_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)
]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
label[:source_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=labels)
class EasySupervisedDataset(Dataset):
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 512) -> None:
super(EasySupervisedDataset, self).__init__()
with open(data_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines()
# split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
sources, targets = [], []
for line in all_lines:
if "回答:" in line:
sep_index = line.index("回答:")
sources.append(line[: sep_index + 3])
targets.append(line[sep_index + 3 :] + tokenizer.eos_token)
else:
sources.append(line)
targets.append("" + tokenizer.eos_token)
data_dict = preprocess(sources, targets, tokenizer, max_length)
self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]
self.data_file = data_file
def __len__(self):
return len(self.input_ids)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
def __repr__(self):
return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})"
def __str__(self):
return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})"
class EasyPromptsDataset(Dataset):
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 96) -> None:
super(EasyPromptsDataset, self).__init__()
with open(data_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines()
all_lines = [line if "回答:" not in line else line[: line.index("回答:") + 3] for line in all_lines]
self.prompts = [
tokenizer(line, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True)[
"input_ids"
]
.to(torch.cuda.current_device())
.squeeze(0)
for line in tqdm(all_lines)
]
self.data_file = data_file
def __len__(self):
return len(self.prompts)
def __getitem__(self, idx):
return self.prompts[idx]
def __repr__(self):
return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})"
def __str__(self):
return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})"
class EasyRewardDataset(Dataset):
def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None, max_length=512) -> None:
super(EasyRewardDataset, self).__init__()
self.chosen = []
self.reject = []
if special_token is None:
self.end_token = tokenizer.eos_token
else:
self.end_token = special_token
print(self.end_token)
# read all lines in the train_file to a list
with open(train_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines()
for line in tqdm(all_lines):
data = json.loads(line)
prompt = "提问:" + data["prompt"] + " 回答:"
chosen = prompt + data["chosen"] + self.end_token
chosen_token = tokenizer(
chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
)
self.chosen.append(
{"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
)
reject = prompt + data["rejected"] + self.end_token
reject_token = tokenizer(
reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
)
self.reject.append(
{"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
)
def __len__(self):
length = len(self.chosen)
return length
def __getitem__(self, idx):
return (
self.chosen[idx]["input_ids"],
self.chosen[idx]["attention_mask"],
self.reject[idx]["input_ids"],
self.reject[idx]["attention_mask"],
)
# python representation of the object and the string representation of the object
def __repr__(self):
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
def __str__(self):
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
"""
Easy SFT just accept a text file which can be read line by line. However the datasets will group texts together to max_length so LLM will learn the texts meaning better.
If individual lines are not related, just set is_group_texts to False.
"""
class EasySFTDataset(Dataset):
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_group_texts=True) -> None:
super().__init__()
# read the data_file line by line
with open(data_file, "r", encoding="UTF-8") as f:
# encode the text data line by line and put raw python list input_ids only to raw_input_ids list
raw_input_ids = []
for line in f:
encoded_ids = tokenizer.encode(line)
# if the encoded_ids is longer than max_length, then split it into several parts
if len(encoded_ids) > max_length:
for i in range(0, len(encoded_ids), max_length):
raw_input_ids.append(encoded_ids[i : i + max_length])
else:
raw_input_ids.append(encoded_ids)
grouped_input_ids = []
current_input_ids = []
attention_mask = []
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
if is_group_texts:
for input_ids in raw_input_ids:
if len(current_input_ids) + len(input_ids) > max_length:
# pad the current_input_ids to max_length with tokenizer.pad_token_id
padded_length = max_length - len(current_input_ids)
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
attention_mask.append(
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
)
current_input_ids = []
else:
current_input_ids.extend(input_ids)
if len(current_input_ids) > 0:
padded_length = max_length - len(current_input_ids)
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
attention_mask.append(
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
)
else:
# just append the raw_input_ids to max_length
for input_ids in raw_input_ids:
padded_length = max_length - len(input_ids)
input_ids.extend([tokenizer.pad_token_id] * padded_length)
attention_mask.append(
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
)
grouped_input_ids.append(torch.tensor(input_ids, dtype=torch.long))
self.input_ids = grouped_input_ids
self.labels = copy.deepcopy(self.input_ids)
self.file_name = data_file
self.attention_mask = attention_mask
def __len__(self):
return len(self.input_ids)
# get item from dataset
def __getitem__(self, idx):
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
# generate the dataset description to be printed by print in python
def __repr__(self):
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
# generate the dataset description to be printed by print in python
def __str__(self):
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"

View File

@@ -0,0 +1,93 @@
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from coati.models.generation import generate
from coati.models.utils import log_probs_from_logits
from peft import PeftModel
from torch.nn.modules import Module
from transformers import BloomConfig, BloomForCausalLM
class Actor(Module):
"""
Actor model base class.
Args:
model (nn.Module): Actor Model.
"""
def __init__(self, model: nn.Module) -> None:
super().__init__()
self.model = model
@torch.no_grad()
def generate(
self, input_ids: torch.Tensor, return_action_mask: bool = True, **kwargs
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
sequences = generate(self.model, input_ids, **kwargs)
attention_mask = None
pad_token_id = kwargs.get("pad_token_id", None)
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
if not return_action_mask:
return sequences, attention_mask, None
input_len = input_ids.size(1)
eos_token_id = kwargs.get("eos_token_id", None)
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len) :]
def forward(
self, sequences: torch.LongTensor, num_actions: int, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Returns action log probs"""
output = self.model(sequences, attention_mask=attention_mask)
logits = output["logits"]
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]
def get_base_model(self):
return self.model
class BLOOMActor(Actor):
"""
BLOOM Actor model.
Args:
pretrained (str): Pretrained model name or path.
config (BloomConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(
self,
pretrained: str = None,
config: Optional[BloomConfig] = None,
checkpoint: bool = False,
lora_path: str = None,
) -> None:
if pretrained is not None:
model = BloomForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = BloomForCausalLM(config)
else:
model = BloomForCausalLM(BloomConfig())
if lora_path is not None:
model = PeftModel.from_pretrained(model, lora_path)
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model)
def print_trainable_parameters(self):
self.get_base_model().print_trainable_parameters()

View File

@@ -0,0 +1,224 @@
import argparse
import torch
import torch.distributed as dist
from coati.dataset import DataCollatorForSupervisedDataset
from coati.models.bloom import BLOOMRM, BLOOMCritic
from coati.models.gpt import GPTRM, GPTCritic
from coati.models.llama import LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTCritic
from coati.trainer import PPOTrainer
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from easy_dataset import EasyPromptsDataset, EasySupervisedDataset
from easy_models import BLOOMActor
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
from colossalai.nn.optimizer import HybridAdam
def main(args):
# configure strategy
if args.strategy == "ddp":
strategy = DDPStrategy()
elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(
placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5
)
elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
if args.rm_path is not None:
state_dict = torch.load(args.rm_path, map_location="cpu")
# configure model
if args.model == "bloom":
# initial_model = BLOOMActor(pretrained=args.pretrain)
print("Using peft lora to load Bloom model as initial_model")
initial_model = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
print("Using peft lora to load Bloom model as initial_model (Done)")
else:
raise ValueError(f'Unsupported actor model "{args.model}"')
if args.rm_model == None:
rm_model_name = args.model
else:
rm_model_name = args.rm_model
if rm_model_name == "gpt2":
reward_model = GPTRM(pretrained=args.rm_pretrain)
elif rm_model_name == "bloom":
print("load bloom reward model ", args.rm_pretrain)
reward_model = BLOOMRM(pretrained=args.rm_pretrain)
elif rm_model_name == "opt":
reward_model = OPTRM(pretrained=args.rm_pretrain)
elif rm_model_name == "llama":
reward_model = LlamaRM(pretrained=args.rm_pretrain)
else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
if args.rm_path is not None:
print("Loading reward model from", args.rm_path)
reward_model.load_state_dict(state_dict)
if args.strategy != "colossalai_gemini":
initial_model.to(torch.float16).to(torch.cuda.current_device())
reward_model.to(torch.float16).to(torch.cuda.current_device())
with strategy.model_init_context():
if args.model == "bloom":
# actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
print("Using peft lora to load Bloom model as Actor")
actor = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
print("Using peft lora to load Bloom model as Actor (Done)")
else:
raise ValueError(f'Unsupported actor model "{args.model}"')
if rm_model_name == "gpt2":
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
elif rm_model_name == "bloom":
print("load bloom critic ", args.rm_pretrain, " lora_rank ", args.lora_rank, " use_action_mask ", True)
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
print("load bloom critic (Done) ")
elif rm_model_name == "opt":
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
elif rm_model_name == "llama":
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
if args.rm_path is not None:
print("Loading reward model from", args.rm_path)
critic.load_state_dict(state_dict)
del state_dict
if args.strategy != "colossalai_gemini":
critic.to(torch.float16).to(torch.cuda.current_device())
actor.to(torch.float16).to(torch.cuda.current_device())
# configure optimizer
if args.strategy.startswith("colossalai"):
actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
else:
actor_optim = Adam(actor.parameters(), lr=1e-7)
critic_optim = Adam(critic.parameters(), lr=1e-7)
# configure tokenizer
if args.model == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained(args.rm_pretrain)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained(args.rm_pretrain)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == "opt":
tokenizer = AutoTokenizer.from_pretrained(args.rm_pretrain)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == "llama":
tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
tokenizer.eos_token = "</s>"
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
prompt_dataset = EasyPromptsDataset(args.prompt_path, tokenizer)
if dist.is_initialized() and dist.get_world_size() > 1:
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
else:
prompt_sampler = None
prompt_dataloader = DataLoader(
prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.train_batch_size
)
pretrain_dataset = EasySupervisedDataset(args.pretrain_dataset, tokenizer)
if dist.is_initialized() and dist.get_world_size() > 1:
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
else:
pretrain_sampler = None
pretrain_dataloader = DataLoader(
pretrain_dataset,
shuffle=(pretrain_sampler is None),
sampler=pretrain_sampler,
batch_size=args.ptx_batch_size,
collate_fn=data_collator,
)
def tokenize_fn(texts):
# MUST padding to max length to ensure inputs of all ranks have the same length
# Different length may lead to hang when using gemini, as different generation steps
batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()}
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
# configure trainer
trainer = PPOTrainer(
strategy,
actor,
critic,
reward_model,
initial_model,
actor_optim,
critic_optim,
kl_coef=args.kl_coef,
ptx_coef=args.ptx_coef,
train_batch_size=args.train_batch_size,
experience_batch_size=args.experience_batch_size,
tokenizer=tokenize_fn,
max_length=512,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
trainer.fit(
prompt_dataloader=prompt_dataloader,
pretrain_dataloader=pretrain_dataloader,
num_episodes=args.num_episodes,
num_update_steps=args.num_update_steps,
num_collect_steps=args.num_collect_steps,
)
# save model checkpoint after fitting
trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(
actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--prompt_path", type=str, default=None, help="path to the prompt dataset")
parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
parser.add_argument(
"--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp", help="strategy to use"
)
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--sft_lora_path", type=str, default=None)
parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument("--rm_path", type=str, default=None)
parser.add_argument("--rm_pretrain", type=str, default=None)
parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
parser.add_argument("--need_optim_ckpt", type=bool, default=False)
parser.add_argument("--num_episodes", type=int, default=10)
parser.add_argument("--num_collect_steps", type=int, default=10)
parser.add_argument("--num_update_steps", type=int, default=5)
parser.add_argument("--train_batch_size", type=int, default=2)
parser.add_argument("--ptx_batch_size", type=int, default=1)
parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--kl_coef", type=float, default=0.1)
parser.add_argument("--ptx_coef", type=float, default=0.9)
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,185 @@
import argparse
import os
import torch
import torch.distributed as dist
from coati.trainer import SFTTrainer
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from easy_dataset import EasyDataset
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoModelForCausalLM, AutoTokenizer, BloomTokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ColoParameter
def train(args):
# configure strategy
if args.strategy == "ddp":
strategy = DDPStrategy()
elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="static")
elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model
with strategy.model_init_context():
print("Warning: currently only bloom is tested, gpt2,llama and opt are not tested")
model = AutoModelForCausalLM.from_pretrained(args.pretrain).to(torch.cuda.current_device())
# if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json
if (
os.path.exists(args.save_path)
and os.path.exists(args.save_path + "/adapter_config.json")
and os.path.exists(args.save_path + "/adapter_model.bin")
):
print("loading from saved peft model ", args.save_path)
model = PeftModel.from_pretrained(model, args.save_path)
else:
# we'll use peft lora library to do the lora
lora_rank = args.lora_rank if args.lora_rank > 0 else 32
# config lora with rank of lora_rank
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, inference_mode=False, r=lora_rank, lora_alpha=32, lora_dropout=0.1
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# configure tokenizer
if args.model == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
tokenizer.pad_token = tokenizer.eos_token
elif args.model == "opt":
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
tokenizer.pad_token = tokenizer.eos_token
elif args.model == "llama":
tokenizer = AutoTokenizer.from_pretrained(
args.pretrain,
padding_side="right",
use_fast=False,
)
tokenizer.eos_token = "</s>"
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
if args.model == "llama" and args.strategy == "colossalai_gemini":
# this is a hack to deal with the resized embedding
# to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility
for name, param in model.named_parameters():
if not isinstance(param, ColoParameter):
sub_module_name = ".".join(name.split(".")[:-1])
weight_name = name.split(".")[-1]
sub_module = model.get_submodule(sub_module_name)
setattr(sub_module, weight_name, ColoParameter(param))
# configure optimizer
if args.strategy.startswith("colossalai"):
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
else:
optim = Adam(model.parameters(), lr=args.lr)
logger = get_dist_logger()
logger.set_level("WARNING")
# configure dataset
law_dataset = EasyDataset(args.dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
train_dataset = law_dataset
print(train_dataset)
eval_dataset = None
if args.eval_dataset is not None:
eval_dataset = EasyDataset(args.eval_dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
data_collator = default_collate
if dist.is_initialized() and dist.get_world_size() > 1:
train_sampler = DistributedSampler(
train_dataset,
shuffle=True,
seed=42,
drop_last=True,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
)
if eval_dataset is not None:
eval_sampler = DistributedSampler(
eval_dataset,
shuffle=False,
seed=42,
drop_last=False,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
)
else:
train_sampler = None
eval_sampler = None
train_dataloader = DataLoader(
train_dataset,
shuffle=(train_sampler is None),
sampler=train_sampler,
batch_size=args.batch_size,
collate_fn=data_collator,
pin_memory=True,
)
if eval_dataset is not None:
eval_dataloader = DataLoader(
eval_dataset,
shuffle=(eval_sampler is None),
sampler=eval_sampler,
batch_size=args.batch_size,
collate_fn=data_collator,
pin_memory=True,
)
else:
eval_dataloader = None
trainer = SFTTrainer(
model=model,
strategy=strategy,
optim=optim,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
batch_size=args.batch_size,
max_epochs=args.max_epochs,
accumulation_steps=args.accumulation_steps,
)
trainer.fit(logger=logger, log_interval=args.log_interval)
# save model checkpoint after fitting on only rank0
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(
trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp")
parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom")
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--dataset", type=str, default=None)
parser.add_argument("--eval_dataset", type=str, default=None)
parser.add_argument("--save_path", type=str, default="output")
parser.add_argument("--need_optim_ckpt", type=bool, default=False)
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--log_interval", type=int, default=100, help="how many steps to log")
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--enable_peft_lora", action="store_true", default=False)
parser.add_argument("--is_short_text", action="store_true", default=False)
args = parser.parse_args()
train(args)