mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[ColossalChat] Update RLHF V2 (#5286)
* Add dpo. Fix sft, ppo, lora. Refactor all * fix and tested ppo * 2 nd round refactor * add ci tests * fix ci * fix ci * fix readme, style * fix readme style * fix style, fix benchmark * reproduce benchmark result, remove useless files * rename to ColossalChat * use new image * fix ci workflow * fix ci * use local model/tokenizer for ci tests * fix ci * fix ci * fix ci * fix ci timeout * fix rm progress bar. fix ci timeout * fix ci * fix ci typo * remove 3d plugin from ci temporary * test environment * cannot save optimizer * support chat template * fix readme * fix path * test ci locally * restore build_or_pr * fix ci data path * fix benchmark * fix ci, move ci tests to 3080, disable fast tokenizer * move ci to 85 * support flash attention 2 * add all-in-one data preparation script. Fix colossal-llama2-chat chat template * add hardware requirements * move ci test data * fix save_model, add unwrap * fix missing bos * fix missing bos; support grad accumulation with gemini * fix ci * fix ci * fix ci * fix llama2 chat template config * debug sft * debug sft * fix colossalai version requirement * fix ci * add sanity check to prevent NaN loss * fix requirements * add dummy data generation script * add dummy data generation script * add dummy data generation script * add dummy data generation script * update readme * update readme * update readme and ignore * fix logger bug * support parallel_output * modify data preparation logic * fix tokenization * update lr * fix inference * run pre-commit --------- Co-authored-by: Tong Li <tong.li352711588@gmail.com>
This commit is contained in:
30
applications/ColossalChat/examples/community/peft/README.md
Executable file
30
applications/ColossalChat/examples/community/peft/README.md
Executable file
@@ -0,0 +1,30 @@
|
||||
:warning: **This content may be outdated since the major update of Colossal Chat. We will update this content soon.**
|
||||
|
||||
# Add Peft support for SFT and Prompts model training
|
||||
|
||||
The original implementation just adopts the loralib and merges the layers into the final model. The huggingface peft is a better lora model implementation and can be easily training and distributed.
|
||||
|
||||
Since reward model is relative small, I just keep it as original one. I suggest train full model to get the proper reward/critic model.
|
||||
|
||||
# Preliminary installation
|
||||
|
||||
Since the current pypi peft package(0.2) has some bugs, please install the peft package using source.
|
||||
|
||||
```
|
||||
git clone https://github.com/huggingface/peft
|
||||
cd peft
|
||||
pip install .
|
||||
```
|
||||
|
||||
# Usage
|
||||
|
||||
For SFT training, just call train_peft_sft.py
|
||||
|
||||
Its arguments are almost identical to train_sft.py instead adding a new eval_dataset if you have an eval_dataset file. The data file is just a plain datafile, please check the format in the easy_dataset.py.
|
||||
|
||||
For stage-3 rlhf training, call train_peft_prompts.py.
|
||||
Its arguments are almost identical to train_prompts.py. The only difference is that I use text files to indicate the prompt and pretrained data file. The models are included in easy_models.py. Currently only bloom models are tested, but technically gpt2/opt/llama should be supported.
|
||||
|
||||
# Dataformat
|
||||
|
||||
Please refer the formats in test_sft.txt, test_prompts.txt, test_pretrained.txt.
|
240
applications/ColossalChat/examples/community/peft/easy_dataset.py
Executable file
240
applications/ColossalChat/examples/community/peft/easy_dataset.py
Executable file
@@ -0,0 +1,240 @@
|
||||
import copy
|
||||
import json
|
||||
from typing import Dict, Sequence
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
|
||||
def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer, max_length: int = 512) -> Dict:
|
||||
"""Tokenize a list of strings."""
|
||||
tokenized_list = [
|
||||
tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
)
|
||||
for text in strings
|
||||
]
|
||||
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
||||
input_ids_lens = labels_lens = [
|
||||
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
||||
]
|
||||
return dict(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
input_ids_lens=input_ids_lens,
|
||||
labels_lens=labels_lens,
|
||||
)
|
||||
|
||||
|
||||
def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: AutoTokenizer, max_length: int = 512) -> Dict:
|
||||
"""Preprocess the data by tokenizing."""
|
||||
examples = [s + t for s, t in zip(sources, targets)]
|
||||
examples_tokenized, sources_tokenized = [
|
||||
_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)
|
||||
]
|
||||
input_ids = examples_tokenized["input_ids"]
|
||||
labels = copy.deepcopy(input_ids)
|
||||
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
||||
label[:source_len] = IGNORE_INDEX
|
||||
return dict(input_ids=input_ids, labels=labels)
|
||||
|
||||
|
||||
class EasySupervisedDataset(Dataset):
|
||||
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 512) -> None:
|
||||
super(EasySupervisedDataset, self).__init__()
|
||||
with open(data_file, "r", encoding="UTF-8") as f:
|
||||
all_lines = f.readlines()
|
||||
# split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
|
||||
sources, targets = [], []
|
||||
for line in all_lines:
|
||||
if "回答:" in line:
|
||||
sep_index = line.index("回答:")
|
||||
sources.append(line[: sep_index + 3])
|
||||
targets.append(line[sep_index + 3 :] + tokenizer.eos_token)
|
||||
else:
|
||||
sources.append(line)
|
||||
targets.append("" + tokenizer.eos_token)
|
||||
data_dict = preprocess(sources, targets, tokenizer, max_length)
|
||||
|
||||
self.input_ids = data_dict["input_ids"]
|
||||
self.labels = data_dict["labels"]
|
||||
self.data_file = data_file
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_ids)
|
||||
|
||||
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
||||
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
||||
|
||||
def __repr__(self):
|
||||
return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})"
|
||||
|
||||
def __str__(self):
|
||||
return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})"
|
||||
|
||||
|
||||
class EasyPromptsDataset(Dataset):
|
||||
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 96) -> None:
|
||||
super(EasyPromptsDataset, self).__init__()
|
||||
with open(data_file, "r", encoding="UTF-8") as f:
|
||||
all_lines = f.readlines()
|
||||
all_lines = [line if "回答:" not in line else line[: line.index("回答:") + 3] for line in all_lines]
|
||||
self.prompts = [
|
||||
tokenizer(line, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True)[
|
||||
"input_ids"
|
||||
]
|
||||
.to(torch.cuda.current_device())
|
||||
.squeeze(0)
|
||||
for line in tqdm(all_lines)
|
||||
]
|
||||
self.data_file = data_file
|
||||
|
||||
def __len__(self):
|
||||
return len(self.prompts)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.prompts[idx]
|
||||
|
||||
def __repr__(self):
|
||||
return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})"
|
||||
|
||||
def __str__(self):
|
||||
return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})"
|
||||
|
||||
|
||||
class EasyRewardDataset(Dataset):
|
||||
def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None, max_length=512) -> None:
|
||||
super(EasyRewardDataset, self).__init__()
|
||||
self.chosen = []
|
||||
self.reject = []
|
||||
if special_token is None:
|
||||
self.end_token = tokenizer.eos_token
|
||||
else:
|
||||
self.end_token = special_token
|
||||
print(self.end_token)
|
||||
# read all lines in the train_file to a list
|
||||
with open(train_file, "r", encoding="UTF-8") as f:
|
||||
all_lines = f.readlines()
|
||||
for line in tqdm(all_lines):
|
||||
data = json.loads(line)
|
||||
prompt = "提问:" + data["prompt"] + " 回答:"
|
||||
|
||||
chosen = prompt + data["chosen"] + self.end_token
|
||||
chosen_token = tokenizer(
|
||||
chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
self.chosen.append(
|
||||
{"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
|
||||
)
|
||||
|
||||
reject = prompt + data["rejected"] + self.end_token
|
||||
reject_token = tokenizer(
|
||||
reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
self.reject.append(
|
||||
{"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
length = len(self.chosen)
|
||||
return length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return (
|
||||
self.chosen[idx]["input_ids"],
|
||||
self.chosen[idx]["attention_mask"],
|
||||
self.reject[idx]["input_ids"],
|
||||
self.reject[idx]["attention_mask"],
|
||||
)
|
||||
|
||||
# python representation of the object and the string representation of the object
|
||||
def __repr__(self):
|
||||
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
|
||||
|
||||
def __str__(self):
|
||||
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
|
||||
|
||||
|
||||
"""
|
||||
Easy SFT just accept a text file which can be read line by line. However the datasets will group texts together to max_length so LLM will learn the texts meaning better.
|
||||
If individual lines are not related, just set is_group_texts to False.
|
||||
"""
|
||||
|
||||
|
||||
class EasySFTDataset(Dataset):
|
||||
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_group_texts=True) -> None:
|
||||
super().__init__()
|
||||
# read the data_file line by line
|
||||
with open(data_file, "r", encoding="UTF-8") as f:
|
||||
# encode the text data line by line and put raw python list input_ids only to raw_input_ids list
|
||||
raw_input_ids = []
|
||||
for line in f:
|
||||
encoded_ids = tokenizer.encode(line)
|
||||
# if the encoded_ids is longer than max_length, then split it into several parts
|
||||
if len(encoded_ids) > max_length:
|
||||
for i in range(0, len(encoded_ids), max_length):
|
||||
raw_input_ids.append(encoded_ids[i : i + max_length])
|
||||
else:
|
||||
raw_input_ids.append(encoded_ids)
|
||||
|
||||
grouped_input_ids = []
|
||||
current_input_ids = []
|
||||
attention_mask = []
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
if is_group_texts:
|
||||
for input_ids in raw_input_ids:
|
||||
if len(current_input_ids) + len(input_ids) > max_length:
|
||||
# pad the current_input_ids to max_length with tokenizer.pad_token_id
|
||||
padded_length = max_length - len(current_input_ids)
|
||||
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
||||
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
|
||||
attention_mask.append(
|
||||
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
|
||||
)
|
||||
current_input_ids = []
|
||||
else:
|
||||
current_input_ids.extend(input_ids)
|
||||
if len(current_input_ids) > 0:
|
||||
padded_length = max_length - len(current_input_ids)
|
||||
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
||||
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
|
||||
attention_mask.append(
|
||||
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
|
||||
)
|
||||
else:
|
||||
# just append the raw_input_ids to max_length
|
||||
for input_ids in raw_input_ids:
|
||||
padded_length = max_length - len(input_ids)
|
||||
input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
||||
attention_mask.append(
|
||||
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
|
||||
)
|
||||
grouped_input_ids.append(torch.tensor(input_ids, dtype=torch.long))
|
||||
self.input_ids = grouped_input_ids
|
||||
self.labels = copy.deepcopy(self.input_ids)
|
||||
self.file_name = data_file
|
||||
self.attention_mask = attention_mask
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_ids)
|
||||
|
||||
# get item from dataset
|
||||
def __getitem__(self, idx):
|
||||
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
|
||||
|
||||
# generate the dataset description to be printed by print in python
|
||||
def __repr__(self):
|
||||
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
|
||||
|
||||
# generate the dataset description to be printed by print in python
|
||||
def __str__(self):
|
||||
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
|
93
applications/ColossalChat/examples/community/peft/easy_models.py
Executable file
93
applications/ColossalChat/examples/community/peft/easy_models.py
Executable file
@@ -0,0 +1,93 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from coati.models.generation import generate
|
||||
from coati.models.utils import log_probs_from_logits
|
||||
from peft import PeftModel
|
||||
from torch.nn.modules import Module
|
||||
from transformers import BloomConfig, BloomForCausalLM
|
||||
|
||||
|
||||
class Actor(Module):
|
||||
"""
|
||||
Actor model base class.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Actor Model.
|
||||
"""
|
||||
|
||||
def __init__(self, model: nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self, input_ids: torch.Tensor, return_action_mask: bool = True, **kwargs
|
||||
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
|
||||
sequences = generate(self.model, input_ids, **kwargs)
|
||||
attention_mask = None
|
||||
pad_token_id = kwargs.get("pad_token_id", None)
|
||||
if pad_token_id is not None:
|
||||
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
||||
if not return_action_mask:
|
||||
return sequences, attention_mask, None
|
||||
input_len = input_ids.size(1)
|
||||
eos_token_id = kwargs.get("eos_token_id", None)
|
||||
if eos_token_id is None:
|
||||
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
||||
else:
|
||||
# left padding may be applied, only mask action
|
||||
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
||||
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
||||
action_mask[:, :input_len] = False
|
||||
action_mask = action_mask[:, 1:]
|
||||
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len) :]
|
||||
|
||||
def forward(
|
||||
self, sequences: torch.LongTensor, num_actions: int, attention_mask: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""Returns action log probs"""
|
||||
output = self.model(sequences, attention_mask=attention_mask)
|
||||
logits = output["logits"]
|
||||
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||
return log_probs[:, -num_actions:]
|
||||
|
||||
def get_base_model(self):
|
||||
return self.model
|
||||
|
||||
|
||||
class BLOOMActor(Actor):
|
||||
"""
|
||||
BLOOM Actor model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (BloomConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_path: str = None,
|
||||
) -> None:
|
||||
if pretrained is not None:
|
||||
model = BloomForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = BloomForCausalLM(config)
|
||||
else:
|
||||
model = BloomForCausalLM(BloomConfig())
|
||||
if lora_path is not None:
|
||||
model = PeftModel.from_pretrained(model, lora_path)
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
super().__init__(model)
|
||||
|
||||
def print_trainable_parameters(self):
|
||||
self.get_base_model().print_trainable_parameters()
|
224
applications/ColossalChat/examples/community/peft/train_peft_prompts.py
Executable file
224
applications/ColossalChat/examples/community/peft/train_peft_prompts.py
Executable file
@@ -0,0 +1,224 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.dataset import DataCollatorForSupervisedDataset
|
||||
from coati.models.bloom import BLOOMRM, BLOOMCritic
|
||||
from coati.models.gpt import GPTRM, GPTCritic
|
||||
from coati.models.llama import LlamaCritic, LlamaRM
|
||||
from coati.models.opt import OPTRM, OPTCritic
|
||||
from coati.trainer import PPOTrainer
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||||
from easy_dataset import EasyPromptsDataset, EasySupervisedDataset
|
||||
from easy_models import BLOOMActor
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
|
||||
def main(args):
|
||||
# configure strategy
|
||||
if args.strategy == "ddp":
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == "colossalai_gemini":
|
||||
strategy = GeminiStrategy(
|
||||
placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5
|
||||
)
|
||||
elif args.strategy == "colossalai_zero2":
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
if args.rm_path is not None:
|
||||
state_dict = torch.load(args.rm_path, map_location="cpu")
|
||||
|
||||
# configure model
|
||||
if args.model == "bloom":
|
||||
# initial_model = BLOOMActor(pretrained=args.pretrain)
|
||||
print("Using peft lora to load Bloom model as initial_model")
|
||||
initial_model = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
|
||||
print("Using peft lora to load Bloom model as initial_model (Done)")
|
||||
else:
|
||||
raise ValueError(f'Unsupported actor model "{args.model}"')
|
||||
|
||||
if args.rm_model == None:
|
||||
rm_model_name = args.model
|
||||
else:
|
||||
rm_model_name = args.rm_model
|
||||
|
||||
if rm_model_name == "gpt2":
|
||||
reward_model = GPTRM(pretrained=args.rm_pretrain)
|
||||
elif rm_model_name == "bloom":
|
||||
print("load bloom reward model ", args.rm_pretrain)
|
||||
reward_model = BLOOMRM(pretrained=args.rm_pretrain)
|
||||
elif rm_model_name == "opt":
|
||||
reward_model = OPTRM(pretrained=args.rm_pretrain)
|
||||
elif rm_model_name == "llama":
|
||||
reward_model = LlamaRM(pretrained=args.rm_pretrain)
|
||||
else:
|
||||
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
|
||||
|
||||
if args.rm_path is not None:
|
||||
print("Loading reward model from", args.rm_path)
|
||||
reward_model.load_state_dict(state_dict)
|
||||
|
||||
if args.strategy != "colossalai_gemini":
|
||||
initial_model.to(torch.float16).to(torch.cuda.current_device())
|
||||
reward_model.to(torch.float16).to(torch.cuda.current_device())
|
||||
|
||||
with strategy.model_init_context():
|
||||
if args.model == "bloom":
|
||||
# actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||
print("Using peft lora to load Bloom model as Actor")
|
||||
actor = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
|
||||
print("Using peft lora to load Bloom model as Actor (Done)")
|
||||
else:
|
||||
raise ValueError(f'Unsupported actor model "{args.model}"')
|
||||
|
||||
if rm_model_name == "gpt2":
|
||||
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
elif rm_model_name == "bloom":
|
||||
print("load bloom critic ", args.rm_pretrain, " lora_rank ", args.lora_rank, " use_action_mask ", True)
|
||||
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
print("load bloom critic (Done) ")
|
||||
elif rm_model_name == "opt":
|
||||
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
elif rm_model_name == "llama":
|
||||
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
else:
|
||||
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
|
||||
|
||||
if args.rm_path is not None:
|
||||
print("Loading reward model from", args.rm_path)
|
||||
critic.load_state_dict(state_dict)
|
||||
del state_dict
|
||||
|
||||
if args.strategy != "colossalai_gemini":
|
||||
critic.to(torch.float16).to(torch.cuda.current_device())
|
||||
actor.to(torch.float16).to(torch.cuda.current_device())
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith("colossalai"):
|
||||
actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
|
||||
critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
|
||||
else:
|
||||
actor_optim = Adam(actor.parameters(), lr=1e-7)
|
||||
critic_optim = Adam(critic.parameters(), lr=1e-7)
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == "gpt2":
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(args.rm_pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == "bloom":
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.rm_pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == "opt":
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.rm_pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == "llama":
|
||||
tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
|
||||
tokenizer.eos_token = "</s>"
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
||||
|
||||
prompt_dataset = EasyPromptsDataset(args.prompt_path, tokenizer)
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
|
||||
else:
|
||||
prompt_sampler = None
|
||||
prompt_dataloader = DataLoader(
|
||||
prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.train_batch_size
|
||||
)
|
||||
|
||||
pretrain_dataset = EasySupervisedDataset(args.pretrain_dataset, tokenizer)
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
|
||||
else:
|
||||
pretrain_sampler = None
|
||||
pretrain_dataloader = DataLoader(
|
||||
pretrain_dataset,
|
||||
shuffle=(pretrain_sampler is None),
|
||||
sampler=pretrain_sampler,
|
||||
batch_size=args.ptx_batch_size,
|
||||
collate_fn=data_collator,
|
||||
)
|
||||
|
||||
def tokenize_fn(texts):
|
||||
# MUST padding to max length to ensure inputs of all ranks have the same length
|
||||
# Different length may lead to hang when using gemini, as different generation steps
|
||||
batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
|
||||
return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()}
|
||||
|
||||
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
|
||||
|
||||
# configure trainer
|
||||
trainer = PPOTrainer(
|
||||
strategy,
|
||||
actor,
|
||||
critic,
|
||||
reward_model,
|
||||
initial_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
kl_coef=args.kl_coef,
|
||||
ptx_coef=args.ptx_coef,
|
||||
train_batch_size=args.train_batch_size,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
tokenizer=tokenize_fn,
|
||||
max_length=512,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
prompt_dataloader=prompt_dataloader,
|
||||
pretrain_dataloader=pretrain_dataloader,
|
||||
num_episodes=args.num_episodes,
|
||||
num_update_steps=args.num_update_steps,
|
||||
num_collect_steps=args.num_collect_steps,
|
||||
)
|
||||
|
||||
# save model checkpoint after fitting
|
||||
trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if args.need_optim_ckpt:
|
||||
strategy.save_optimizer(
|
||||
actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--prompt_path", type=str, default=None, help="path to the prompt dataset")
|
||||
parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
|
||||
parser.add_argument(
|
||||
"--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp", help="strategy to use"
|
||||
)
|
||||
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--sft_lora_path", type=str, default=None)
|
||||
parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"])
|
||||
parser.add_argument("--rm_path", type=str, default=None)
|
||||
parser.add_argument("--rm_pretrain", type=str, default=None)
|
||||
parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
|
||||
parser.add_argument("--need_optim_ckpt", type=bool, default=False)
|
||||
parser.add_argument("--num_episodes", type=int, default=10)
|
||||
parser.add_argument("--num_collect_steps", type=int, default=10)
|
||||
parser.add_argument("--num_update_steps", type=int, default=5)
|
||||
parser.add_argument("--train_batch_size", type=int, default=2)
|
||||
parser.add_argument("--ptx_batch_size", type=int, default=1)
|
||||
parser.add_argument("--experience_batch_size", type=int, default=8)
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument("--kl_coef", type=float, default=0.1)
|
||||
parser.add_argument("--ptx_coef", type=float, default=0.9)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
185
applications/ColossalChat/examples/community/peft/train_peft_sft.py
Executable file
185
applications/ColossalChat/examples/community/peft/train_peft_sft.py
Executable file
@@ -0,0 +1,185 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.trainer import SFTTrainer
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||||
from easy_dataset import EasyDataset
|
||||
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BloomTokenizerFast
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor import ColoParameter
|
||||
|
||||
|
||||
def train(args):
|
||||
# configure strategy
|
||||
if args.strategy == "ddp":
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == "colossalai_gemini":
|
||||
strategy = GeminiStrategy(placement_policy="static")
|
||||
elif args.strategy == "colossalai_zero2":
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
# configure model
|
||||
with strategy.model_init_context():
|
||||
print("Warning: currently only bloom is tested, gpt2,llama and opt are not tested")
|
||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain).to(torch.cuda.current_device())
|
||||
# if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json
|
||||
if (
|
||||
os.path.exists(args.save_path)
|
||||
and os.path.exists(args.save_path + "/adapter_config.json")
|
||||
and os.path.exists(args.save_path + "/adapter_model.bin")
|
||||
):
|
||||
print("loading from saved peft model ", args.save_path)
|
||||
model = PeftModel.from_pretrained(model, args.save_path)
|
||||
else:
|
||||
# we'll use peft lora library to do the lora
|
||||
lora_rank = args.lora_rank if args.lora_rank > 0 else 32
|
||||
# config lora with rank of lora_rank
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM, inference_mode=False, r=lora_rank, lora_alpha=32, lora_dropout=0.1
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
model.print_trainable_parameters()
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == "gpt2":
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == "bloom":
|
||||
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == "opt":
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == "llama":
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.pretrain,
|
||||
padding_side="right",
|
||||
use_fast=False,
|
||||
)
|
||||
tokenizer.eos_token = "</s>"
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
if args.model == "llama" and args.strategy == "colossalai_gemini":
|
||||
# this is a hack to deal with the resized embedding
|
||||
# to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility
|
||||
for name, param in model.named_parameters():
|
||||
if not isinstance(param, ColoParameter):
|
||||
sub_module_name = ".".join(name.split(".")[:-1])
|
||||
weight_name = name.split(".")[-1]
|
||||
sub_module = model.get_submodule(sub_module_name)
|
||||
setattr(sub_module, weight_name, ColoParameter(param))
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith("colossalai"):
|
||||
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
|
||||
else:
|
||||
optim = Adam(model.parameters(), lr=args.lr)
|
||||
|
||||
logger = get_dist_logger()
|
||||
logger.set_level("WARNING")
|
||||
|
||||
# configure dataset
|
||||
law_dataset = EasyDataset(args.dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
|
||||
train_dataset = law_dataset
|
||||
print(train_dataset)
|
||||
eval_dataset = None
|
||||
if args.eval_dataset is not None:
|
||||
eval_dataset = EasyDataset(args.eval_dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
|
||||
data_collator = default_collate
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
train_sampler = DistributedSampler(
|
||||
train_dataset,
|
||||
shuffle=True,
|
||||
seed=42,
|
||||
drop_last=True,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size(),
|
||||
)
|
||||
if eval_dataset is not None:
|
||||
eval_sampler = DistributedSampler(
|
||||
eval_dataset,
|
||||
shuffle=False,
|
||||
seed=42,
|
||||
drop_last=False,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size(),
|
||||
)
|
||||
else:
|
||||
train_sampler = None
|
||||
eval_sampler = None
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
shuffle=(train_sampler is None),
|
||||
sampler=train_sampler,
|
||||
batch_size=args.batch_size,
|
||||
collate_fn=data_collator,
|
||||
pin_memory=True,
|
||||
)
|
||||
if eval_dataset is not None:
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset,
|
||||
shuffle=(eval_sampler is None),
|
||||
sampler=eval_sampler,
|
||||
batch_size=args.batch_size,
|
||||
collate_fn=data_collator,
|
||||
pin_memory=True,
|
||||
)
|
||||
else:
|
||||
eval_dataloader = None
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
strategy=strategy,
|
||||
optim=optim,
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
batch_size=args.batch_size,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
)
|
||||
|
||||
trainer.fit(logger=logger, log_interval=args.log_interval)
|
||||
|
||||
# save model checkpoint after fitting on only rank0
|
||||
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if args.need_optim_ckpt:
|
||||
strategy.save_optimizer(
|
||||
trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp")
|
||||
parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom")
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--dataset", type=str, default=None)
|
||||
parser.add_argument("--eval_dataset", type=str, default=None)
|
||||
parser.add_argument("--save_path", type=str, default="output")
|
||||
parser.add_argument("--need_optim_ckpt", type=bool, default=False)
|
||||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument("--log_interval", type=int, default=100, help="how many steps to log")
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--enable_peft_lora", action="store_true", default=False)
|
||||
parser.add_argument("--is_short_text", action="store_true", default=False)
|
||||
args = parser.parse_args()
|
||||
train(args)
|
Reference in New Issue
Block a user