[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -3,7 +3,6 @@ import json
from typing import Dict, Sequence
import torch
from datasets import load_dataset
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
@@ -20,7 +19,8 @@ def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer, max_length: i
padding="longest",
max_length=max_length,
truncation=True,
) for text in strings
)
for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
@@ -48,18 +48,17 @@ def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: AutoTo
class EasySupervisedDataset(Dataset):
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 512) -> None:
super(EasySupervisedDataset, self).__init__()
with open(data_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines()
#split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
# split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
sources, targets = [], []
for line in all_lines:
if "回答:" in line:
sep_index = line.index("回答:")
sources.append(line[:sep_index + 3])
targets.append(line[sep_index + 3:] + tokenizer.eos_token)
sources.append(line[: sep_index + 3])
targets.append(line[sep_index + 3 :] + tokenizer.eos_token)
else:
sources.append(line)
targets.append("" + tokenizer.eos_token)
@@ -83,15 +82,17 @@ class EasySupervisedDataset(Dataset):
class EasyPromptsDataset(Dataset):
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 96) -> None:
super(EasyPromptsDataset, self).__init__()
with open(data_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines()
all_lines = [line if "回答:" not in line else line[:line.index("回答:") + 3] for line in all_lines]
all_lines = [line if "回答:" not in line else line[: line.index("回答:") + 3] for line in all_lines]
self.prompts = [
tokenizer(line, return_tensors='pt', max_length=max_length, padding='max_length',
truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0)
tokenizer(line, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True)[
"input_ids"
]
.to(torch.cuda.current_device())
.squeeze(0)
for line in tqdm(all_lines)
]
self.data_file = data_file
@@ -110,7 +111,6 @@ class EasyPromptsDataset(Dataset):
class EasyRewardDataset(Dataset):
def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None, max_length=512) -> None:
super(EasyRewardDataset, self).__init__()
self.chosen = []
@@ -120,44 +120,42 @@ class EasyRewardDataset(Dataset):
else:
self.end_token = special_token
print(self.end_token)
#read all lines in the train_file to a list
# read all lines in the train_file to a list
with open(train_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines()
for line in tqdm(all_lines):
data = json.loads(line)
prompt = "提问:" + data['prompt'] + " 回答:"
prompt = "提问:" + data["prompt"] + " 回答:"
chosen = prompt + data['chosen'] + self.end_token
chosen_token = tokenizer(chosen,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.chosen.append({
"input_ids": chosen_token['input_ids'],
"attention_mask": chosen_token['attention_mask']
})
chosen = prompt + data["chosen"] + self.end_token
chosen_token = tokenizer(
chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
)
self.chosen.append(
{"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
)
reject = prompt + data['rejected'] + self.end_token
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.reject.append({
"input_ids": reject_token['input_ids'],
"attention_mask": reject_token['attention_mask']
})
reject = prompt + data["rejected"] + self.end_token
reject_token = tokenizer(
reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
)
self.reject.append(
{"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
)
def __len__(self):
length = len(self.chosen)
return length
def __getitem__(self, idx):
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
"input_ids"], self.reject[idx]["attention_mask"]
return (
self.chosen[idx]["input_ids"],
self.chosen[idx]["attention_mask"],
self.reject[idx]["input_ids"],
self.reject[idx]["attention_mask"],
)
#python representation of the object and the string representation of the object
# python representation of the object and the string representation of the object
def __repr__(self):
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
@@ -165,26 +163,25 @@ class EasyRewardDataset(Dataset):
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
'''
"""
Easy SFT just accept a text file which can be read line by line. However the datasets will group texts together to max_length so LLM will learn the texts meaning better.
If individual lines are not related, just set is_group_texts to False.
'''
"""
class EasySFTDataset(Dataset):
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_group_texts=True) -> None:
super().__init__()
#read the data_file line by line
# read the data_file line by line
with open(data_file, "r", encoding="UTF-8") as f:
#encode the text data line by line and put raw python list input_ids only to raw_input_ids list
# encode the text data line by line and put raw python list input_ids only to raw_input_ids list
raw_input_ids = []
for line in f:
encoded_ids = tokenizer.encode(line)
#if the encoded_ids is longer than max_length, then split it into several parts
# if the encoded_ids is longer than max_length, then split it into several parts
if len(encoded_ids) > max_length:
for i in range(0, len(encoded_ids), max_length):
raw_input_ids.append(encoded_ids[i:i + max_length])
raw_input_ids.append(encoded_ids[i : i + max_length])
else:
raw_input_ids.append(encoded_ids)
@@ -196,12 +193,13 @@ class EasySFTDataset(Dataset):
if is_group_texts:
for input_ids in raw_input_ids:
if len(current_input_ids) + len(input_ids) > max_length:
#pad the current_input_ids to max_length with tokenizer.pad_token_id
# pad the current_input_ids to max_length with tokenizer.pad_token_id
padded_length = max_length - len(current_input_ids)
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
attention_mask.append(
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
)
current_input_ids = []
else:
current_input_ids.extend(input_ids)
@@ -210,14 +208,16 @@ class EasySFTDataset(Dataset):
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
attention_mask.append(
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
)
else:
#just append the raw_input_ids to max_length
# just append the raw_input_ids to max_length
for input_ids in raw_input_ids:
padded_length = max_length - len(input_ids)
input_ids.extend([tokenizer.pad_token_id] * padded_length)
attention_mask.append(
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
)
grouped_input_ids.append(torch.tensor(input_ids, dtype=torch.long))
self.input_ids = grouped_input_ids
self.labels = copy.deepcopy(self.input_ids)
@@ -227,14 +227,14 @@ class EasySFTDataset(Dataset):
def __len__(self):
return len(self.input_ids)
#get item from dataset
# get item from dataset
def __getitem__(self, idx):
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
#generate the dataset description to be printed by print in python
# generate the dataset description to be printed by print in python
def __repr__(self):
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
#generate the dataset description to be printed by print in python
# generate the dataset description to be printed by print in python
def __str__(self):
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"

View File

@@ -4,7 +4,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from coati.models.generation import generate
from coati.models.utils import log_probs_from_logits, masked_mean
from coati.models.utils import log_probs_from_logits
from peft import PeftModel
from torch.nn.modules import Module
from transformers import BloomConfig, BloomForCausalLM
@@ -24,38 +24,33 @@ class Actor(Module):
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
return_action_mask: bool = True,
**kwargs
self, input_ids: torch.Tensor, return_action_mask: bool = True, **kwargs
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
sequences = generate(self.model, input_ids, **kwargs)
attention_mask = None
pad_token_id = kwargs.get('pad_token_id', None)
pad_token_id = kwargs.get("pad_token_id", None)
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
if not return_action_mask:
return sequences, attention_mask, None
input_len = input_ids.size(1)
eos_token_id = kwargs.get('eos_token_id', None)
eos_token_id = kwargs.get("eos_token_id", None)
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len) :]
def forward(self,
sequences: torch.LongTensor,
num_actions: int,
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Returns action log probs
"""
def forward(
self, sequences: torch.LongTensor, num_actions: int, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Returns action log probs"""
output = self.model(sequences, attention_mask=attention_mask)
logits = output['logits']
logits = output["logits"]
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]
@@ -75,11 +70,13 @@ class BLOOMActor(Actor):
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: str = None,
config: Optional[BloomConfig] = None,
checkpoint: bool = False,
lora_path: str = None) -> None:
def __init__(
self,
pretrained: str = None,
config: Optional[BloomConfig] = None,
checkpoint: bool = False,
lora_path: str = None,
) -> None:
if pretrained is not None:
model = BloomForCausalLM.from_pretrained(pretrained)
elif config is not None:

View File

@@ -1,18 +1,16 @@
import argparse
import pandas as pd
import torch
import torch.distributed as dist
from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset
from coati.dataset import DataCollatorForSupervisedDataset
from coati.models.bloom import BLOOMRM, BLOOMCritic
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.models.gpt import GPTRM, GPTCritic
from coati.models.llama import LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTCritic
from coati.trainer import PPOTrainer
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from easy_dataset import EasyPromptsDataset, EasySupervisedDataset
from easy_models import BLOOMActor
from peft import PeftModel
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
@@ -23,24 +21,24 @@ from colossalai.nn.optimizer import HybridAdam
def main(args):
# configure strategy
if args.strategy == 'ddp':
if args.strategy == "ddp":
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
strategy = GeminiStrategy(placement_policy='cpu', initial_scale=2**5)
elif args.strategy == 'colossalai_zero2':
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
if args.rm_path is not None:
state_dict = torch.load(args.rm_path, map_location='cpu')
state_dict = torch.load(args.rm_path, map_location="cpu")
# configure model
if args.model == 'bloom':
if args.model == "bloom":
# initial_model = BLOOMActor(pretrained=args.pretrain)
print('Using peft lora to load Bloom model as initial_model')
print("Using peft lora to load Bloom model as initial_model")
initial_model = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
print('Using peft lora to load Bloom model as initial_model (Done)')
print("Using peft lora to load Bloom model as initial_model (Done)")
else:
raise ValueError(f'Unsupported actor model "{args.model}"')
@@ -49,59 +47,59 @@ def main(args):
else:
rm_model_name = args.rm_model
if rm_model_name == 'gpt2':
if rm_model_name == "gpt2":
reward_model = GPTRM(pretrained=args.rm_pretrain)
elif rm_model_name == 'bloom':
elif rm_model_name == "bloom":
print("load bloom reward model ", args.rm_pretrain)
reward_model = BLOOMRM(pretrained=args.rm_pretrain)
elif rm_model_name == 'opt':
elif rm_model_name == "opt":
reward_model = OPTRM(pretrained=args.rm_pretrain)
elif rm_model_name == 'llama':
elif rm_model_name == "llama":
reward_model = LlamaRM(pretrained=args.rm_pretrain)
else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
if args.rm_path is not None:
print('Loading reward model from', args.rm_path)
print("Loading reward model from", args.rm_path)
reward_model.load_state_dict(state_dict)
if args.strategy != 'colossalai_gemini':
if args.strategy != "colossalai_gemini":
initial_model.to(torch.float16).to(torch.cuda.current_device())
reward_model.to(torch.float16).to(torch.cuda.current_device())
with strategy.model_init_context():
if args.model == 'bloom':
if args.model == "bloom":
# actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
print('Using peft lora to load Bloom model as Actor')
print("Using peft lora to load Bloom model as Actor")
actor = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
print('Using peft lora to load Bloom model as Actor (Done)')
print("Using peft lora to load Bloom model as Actor (Done)")
else:
raise ValueError(f'Unsupported actor model "{args.model}"')
if rm_model_name == 'gpt2':
if rm_model_name == "gpt2":
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
elif rm_model_name == 'bloom':
elif rm_model_name == "bloom":
print("load bloom critic ", args.rm_pretrain, " lora_rank ", args.lora_rank, " use_action_mask ", True)
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
print("load bloom critic (Done) ")
elif rm_model_name == 'opt':
elif rm_model_name == "opt":
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
elif rm_model_name == 'llama':
elif rm_model_name == "llama":
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
if args.rm_path is not None:
print('Loading reward model from', args.rm_path)
print("Loading reward model from", args.rm_path)
critic.load_state_dict(state_dict)
del state_dict
if args.strategy != 'colossalai_gemini':
if args.strategy != "colossalai_gemini":
critic.to(torch.float16).to(torch.cuda.current_device())
actor.to(torch.float16).to(torch.cuda.current_device())
# configure optimizer
if args.strategy.startswith('colossalai'):
if args.strategy.startswith("colossalai"):
actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
else:
@@ -109,18 +107,18 @@ def main(args):
critic_optim = Adam(critic.parameters(), lr=1e-7)
# configure tokenizer
if args.model == 'gpt2':
if args.model == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained(args.rm_pretrain)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained(args.rm_pretrain)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
elif args.model == "opt":
tokenizer = AutoTokenizer.from_pretrained(args.rm_pretrain)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'llama':
elif args.model == "llama":
tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
tokenizer.eos_token = '<\s>'
tokenizer.eos_token = "<\s>"
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
@@ -132,26 +130,27 @@ def main(args):
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
else:
prompt_sampler = None
prompt_dataloader = DataLoader(prompt_dataset,
shuffle=(prompt_sampler is None),
sampler=prompt_sampler,
batch_size=args.train_batch_size)
prompt_dataloader = DataLoader(
prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.train_batch_size
)
pretrain_dataset = EasySupervisedDataset(args.pretrain_dataset, tokenizer)
if dist.is_initialized() and dist.get_world_size() > 1:
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
else:
pretrain_sampler = None
pretrain_dataloader = DataLoader(pretrain_dataset,
shuffle=(pretrain_sampler is None),
sampler=pretrain_sampler,
batch_size=args.ptx_batch_size,
collate_fn=data_collator)
pretrain_dataloader = DataLoader(
pretrain_dataset,
shuffle=(pretrain_sampler is None),
sampler=pretrain_sampler,
batch_size=args.ptx_batch_size,
collate_fn=data_collator,
)
def tokenize_fn(texts):
# MUST padding to max length to ensure inputs of all ranks have the same length
# Different length may lead to hang when using gemini, as different generation steps
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()}
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
@@ -178,45 +177,46 @@ def main(args):
eos_token_id=tokenizer.eos_token_id,
)
trainer.fit(prompt_dataloader=prompt_dataloader,
pretrain_dataloader=pretrain_dataloader,
num_episodes=args.num_episodes,
num_update_steps=args.num_update_steps,
num_collect_steps=args.num_collect_steps)
trainer.fit(
prompt_dataloader=prompt_dataloader,
pretrain_dataloader=pretrain_dataloader,
num_episodes=args.num_episodes,
num_update_steps=args.num_update_steps,
num_collect_steps=args.num_collect_steps,
)
# save model checkpoint after fitting
trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(actor_optim,
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
only_rank0=False)
strategy.save_optimizer(
actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False
)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--prompt_path', type=str, default=None, help='path to the prompt dataset')
parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset')
parser.add_argument('--strategy',
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='ddp',
help='strategy to use')
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--sft_lora_path', type=str, default=None)
parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama'])
parser.add_argument('--rm_path', type=str, default=None)
parser.add_argument('--rm_pretrain', type=str, default=None)
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
parser.add_argument('--num_episodes', type=int, default=10)
parser.add_argument('--num_collect_steps', type=int, default=10)
parser.add_argument('--num_update_steps', type=int, default=5)
parser.add_argument('--train_batch_size', type=int, default=2)
parser.add_argument('--ptx_batch_size', type=int, default=1)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--kl_coef', type=float, default=0.1)
parser.add_argument('--ptx_coef', type=float, default=0.9)
parser.add_argument("--prompt_path", type=str, default=None, help="path to the prompt dataset")
parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
parser.add_argument(
"--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp", help="strategy to use"
)
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--sft_lora_path", type=str, default=None)
parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument("--rm_path", type=str, default=None)
parser.add_argument("--rm_pretrain", type=str, default=None)
parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
parser.add_argument("--need_optim_ckpt", type=bool, default=False)
parser.add_argument("--num_episodes", type=int, default=10)
parser.add_argument("--num_collect_steps", type=int, default=10)
parser.add_argument("--num_update_steps", type=int, default=5)
parser.add_argument("--train_batch_size", type=int, default=2)
parser.add_argument("--ptx_batch_size", type=int, default=1)
parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--kl_coef", type=float, default=0.1)
parser.add_argument("--ptx_coef", type=float, default=0.9)
args = parser.parse_args()
main(args)

View File

@@ -1,18 +1,10 @@
import argparse
import os
import loralib as lora
import torch
import torch.distributed as dist
from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
from coati.models.base import RewardModel
from coati.models.bloom import BLOOMLM
from coati.models.gpt import GPTLM
from coati.models.llama import LlamaLM
from coati.models.opt import OPTLM
from coati.trainer import SFTTrainer
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from datasets import load_dataset
from easy_dataset import EasyDataset
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
from torch.optim import Adam
@@ -29,75 +21,76 @@ from colossalai.tensor import ColoParameter
def train(args):
# configure strategy
if args.strategy == 'ddp':
if args.strategy == "ddp":
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
strategy = GeminiStrategy(placement_policy='cuda')
elif args.strategy == 'colossalai_zero2':
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="cuda")
elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model
with strategy.model_init_context():
print('Warning: currently only bloom is tested, gpt2,llama and opt are not tested')
print("Warning: currently only bloom is tested, gpt2,llama and opt are not tested")
model = AutoModelForCausalLM.from_pretrained(args.pretrain).to(torch.cuda.current_device())
# if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json
if os.path.exists(args.save_path) and os.path.exists(args.save_path + '/adapter_config.json') \
and os.path.exists(args.save_path + '/adapter_model.bin'):
if (
os.path.exists(args.save_path)
and os.path.exists(args.save_path + "/adapter_config.json")
and os.path.exists(args.save_path + "/adapter_model.bin")
):
print("loading from saved peft model ", args.save_path)
model = PeftModel.from_pretrained(model, args.save_path)
else:
# we'll use peft lora library to do the lora
lora_rank = args.lora_rank if args.lora_rank > 0 else 32
# config lora with rank of lora_rank
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=lora_rank,
lora_alpha=32,
lora_dropout=0.1)
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, inference_mode=False, r=lora_rank, lora_alpha=32, lora_dropout=0.1
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# configure tokenizer
if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
if args.model == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
elif args.model == "opt":
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'llama':
elif args.model == "llama":
tokenizer = AutoTokenizer.from_pretrained(
args.pretrain,
padding_side="right",
use_fast=False,
)
tokenizer.eos_token = '<\s>'
tokenizer.eos_token = "<\s>"
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
if args.model == 'llama' and args.strategy == 'colossalai_gemini':
if args.model == "llama" and args.strategy == "colossalai_gemini":
# this is a hack to deal with the resized embedding
# to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility
for name, param in model.named_parameters():
if not isinstance(param, ColoParameter):
sub_module_name = '.'.join(name.split('.')[:-1])
weight_name = name.split('.')[-1]
sub_module_name = ".".join(name.split(".")[:-1])
weight_name = name.split(".")[-1]
sub_module = model.get_submodule(sub_module_name)
setattr(sub_module, weight_name, ColoParameter(param))
# configure optimizer
if args.strategy.startswith('colossalai'):
if args.strategy.startswith("colossalai"):
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
else:
optim = Adam(model.parameters(), lr=args.lr)
logger = get_dist_logger()
logger.set_level('WARNING')
logger.set_level("WARNING")
# configure dataset
law_dataset = EasyDataset(args.dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
@@ -108,47 +101,57 @@ def train(args):
eval_dataset = EasyDataset(args.eval_dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
data_collator = default_collate
if dist.is_initialized() and dist.get_world_size() > 1:
train_sampler = DistributedSampler(train_dataset,
shuffle=True,
seed=42,
drop_last=True,
rank=dist.get_rank(),
num_replicas=dist.get_world_size())
train_sampler = DistributedSampler(
train_dataset,
shuffle=True,
seed=42,
drop_last=True,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
)
if eval_dataset is not None:
eval_sampler = DistributedSampler(eval_dataset,
shuffle=False,
seed=42,
drop_last=False,
rank=dist.get_rank(),
num_replicas=dist.get_world_size())
eval_sampler = DistributedSampler(
eval_dataset,
shuffle=False,
seed=42,
drop_last=False,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
)
else:
train_sampler = None
eval_sampler = None
train_dataloader = DataLoader(train_dataset,
shuffle=(train_sampler is None),
sampler=train_sampler,
batch_size=args.batch_size,
collate_fn=data_collator,
pin_memory=True)
train_dataloader = DataLoader(
train_dataset,
shuffle=(train_sampler is None),
sampler=train_sampler,
batch_size=args.batch_size,
collate_fn=data_collator,
pin_memory=True,
)
if eval_dataset is not None:
eval_dataloader = DataLoader(eval_dataset,
shuffle=(eval_sampler is None),
sampler=eval_sampler,
batch_size=args.batch_size,
collate_fn=data_collator,
pin_memory=True)
eval_dataloader = DataLoader(
eval_dataset,
shuffle=(eval_sampler is None),
sampler=eval_sampler,
batch_size=args.batch_size,
collate_fn=data_collator,
pin_memory=True,
)
else:
eval_dataloader = None
trainer = SFTTrainer(model=model,
strategy=strategy,
optim=optim,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
batch_size=args.batch_size,
max_epochs=args.max_epochs,
accumulation_steps=args.accumulation_steps)
trainer = SFTTrainer(
model=model,
strategy=strategy,
optim=optim,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
batch_size=args.batch_size,
max_epochs=args.max_epochs,
accumulation_steps=args.accumulation_steps,
)
trainer.fit(logger=logger, log_interval=args.log_interval)
@@ -156,29 +159,27 @@ def train(args):
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(trainer.optimizer,
'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()),
only_rank0=False)
strategy.save_optimizer(
trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--strategy',
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='ddp')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--dataset', type=str, default=None)
parser.add_argument('--eval_dataset', type=str, default=None)
parser.add_argument('--save_path', type=str, default='output')
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
parser.add_argument('--max_epochs', type=int, default=3)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log")
parser.add_argument('--lr', type=float, default=5e-6)
parser.add_argument('--accumulation_steps', type=int, default=8)
parser.add_argument('--enable_peft_lora', action='store_true', default=False)
parser.add_argument("--is_short_text", action='store_true', default=False)
parser.add_argument("--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp")
parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom")
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--dataset", type=str, default=None)
parser.add_argument("--eval_dataset", type=str, default=None)
parser.add_argument("--save_path", type=str, default="output")
parser.add_argument("--need_optim_ckpt", type=bool, default=False)
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--log_interval", type=int, default=100, help="how many steps to log")
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--enable_peft_lora", action="store_true", default=False)
parser.add_argument("--is_short_text", action="store_true", default=False)
args = parser.parse_args()
train(args)

View File

@@ -6,16 +6,25 @@ from ray.job_submission import JobSubmissionClient
def main(api_server_endpoint="http://127.0.0.1:8265"):
client = JobSubmissionClient(api_server_endpoint)
client.submit_job(
entrypoint=
"python experimental/ray/train_prompts_on_ray.py --strategy colossalai_zero2 --prompt_csv_url https://huggingface.co/datasets/fka/awesome-chatgpt-prompts/resolve/main/prompts.csv",
entrypoint="python experimental/ray/train_prompts_on_ray.py --strategy colossalai_zero2 --prompt_csv_url https://huggingface.co/datasets/fka/awesome-chatgpt-prompts/resolve/main/prompts.csv",
runtime_env={
"working_dir":
"applications/Chat",
"working_dir": "applications/Chat",
"pip": [
"torch==1.13.1", "transformers>=4.20.1", "datasets", "loralib", "colossalai>=0.2.4", "langchain",
"tokenizers", "fastapi", "sse_starlette", "wandb", "sentencepiece", "gpustat"
]
})
"torch==1.13.1",
"transformers>=4.20.1",
"datasets",
"loralib",
"colossalai>=0.2.4",
"langchain",
"tokenizers",
"fastapi",
"sse_starlette",
"wandb",
"sentencepiece",
"gpustat",
],
},
)
if __name__ == "__main__":

View File

@@ -26,9 +26,14 @@ from colossalai.nn.optimizer import HybridAdam
class ExperienceCompositionRefs:
def __init__(self, sequences_attention_mask_action_mask_ref: ray.ObjectRef, action_log_probs_ref: ray.ObjectRef,
base_action_log_probs_ref: ray.ObjectRef, value_ref: ray.ObjectRef, r_ref: ray.ObjectRef) -> None:
def __init__(
self,
sequences_attention_mask_action_mask_ref: ray.ObjectRef,
action_log_probs_ref: ray.ObjectRef,
base_action_log_probs_ref: ray.ObjectRef,
value_ref: ray.ObjectRef,
r_ref: ray.ObjectRef,
) -> None:
self.sequences_attention_mask_action_mask_ref = sequences_attention_mask_action_mask_ref
self.action_log_probs_ref = action_log_probs_ref
self.base_action_log_probs_ref = base_action_log_probs_ref
@@ -37,14 +42,14 @@ class ExperienceCompositionRefs:
class ExperienceMaker:
def __init__(self, kl_coef) -> None:
self.kl_coef = kl_coef
@torch.no_grad()
def make_experience(self, experiment_computation_refs: ExperienceCompositionRefs):
sequences, attention_mask, action_mask = ray.get(
experiment_computation_refs.sequences_attention_mask_action_mask_ref)
experiment_computation_refs.sequences_attention_mask_action_mask_ref
)
action_log_probs = ray.get(experiment_computation_refs.action_log_probs_ref)
base_action_log_probs = ray.get(experiment_computation_refs.base_action_log_probs_ref)
r = ray.get(experiment_computation_refs.r_ref)
@@ -58,11 +63,10 @@ class ExperienceMaker:
class DistributedTorchRayActor:
def __init__(self, world_size, rank, local_rank, master_addr, master_port):
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S')
logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
)
self._model = None
self._world_size = world_size
self._rank = rank
@@ -82,7 +86,7 @@ class DistributedTorchRayActor:
@staticmethod
def _get_free_port():
with socket.socket() as sock:
sock.bind(('', 0))
sock.bind(("", 0))
return sock.getsockname()[1]
def get_master_addr_port(self):
@@ -90,7 +94,6 @@ class DistributedTorchRayActor:
class BasePPORole(DistributedTorchRayActor):
def add_experience_maker(self, kl_coef: float = 0.1):
self._experience_maker = ExperienceMaker(kl_coef)
@@ -99,12 +102,12 @@ class BasePPORole(DistributedTorchRayActor):
def _init_strategy(self, strategy: str):
# configure strategy
if strategy == 'ddp':
if strategy == "ddp":
self._strategy = DDPStrategy()
elif strategy == 'colossalai_gemini':
self._strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
elif strategy == 'colossalai_zero2':
self._strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
elif strategy == "colossalai_gemini":
self._strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
elif strategy == "colossalai_zero2":
self._strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
raise ValueError(f'Unsupported strategy "{strategy}"')
@@ -124,11 +127,9 @@ class BasePPORole(DistributedTorchRayActor):
def _load_model_from_pretrained(self, model_class: Type[LoRAModule], pretrain: str):
raise NotImplementedError()
def init_model_from_pretrained(self,
strategy: str,
model_class: Type[LoRAModule],
pretrain: str,
has_optimizer=False):
def init_model_from_pretrained(
self, strategy: str, model_class: Type[LoRAModule], pretrain: str, has_optimizer=False
):
self._init_strategy(strategy)
self._load_model_from_pretrained(model_class, pretrain)
self._prepare_model_with_strategy(has_optimizer)
@@ -138,7 +139,6 @@ class BasePPORole(DistributedTorchRayActor):
class TrainablePPORole(BasePPORole):
def _load_model_from_pretrained(self, model_class, pretrain):
with self._strategy.model_init_context():
self._model = model_class(pretrain).to(torch.cuda.current_device())
@@ -161,38 +161,39 @@ class TrainablePPORole(BasePPORole):
@ray.remote(num_gpus=1)
class RayPPOActor(TrainablePPORole):
def set_loss_function(self, eps_clip: float):
self._actor_loss_fn = PolicyLoss(eps_clip)
def load_tokenizer_from_pretrained(self, model_type: str, pretrained):
if model_type == 'gpt2':
if model_type == "gpt2":
self._model_tokenizer = GPT2Tokenizer.from_pretrained(pretrained)
self._model_tokenizer.pad_token = self._model_tokenizer.eos_token
elif model_type == 'bloom':
elif model_type == "bloom":
self._model_tokenizer = BloomTokenizerFast.from_pretrained(pretrained)
self._model_tokenizer.pad_token = self._model_tokenizer.eos_token
elif model_type == 'opt':
elif model_type == "opt":
self._model_tokenizer = AutoTokenizer.from_pretrained(pretrained)
else:
raise ValueError(f'Unsupported model "{model_type}"')
# Set tokenize function for sequence generation
def _text_input_tokenize_fn(texts):
batch = self._model_tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True)
batch = self._model_tokenizer(texts, return_tensors="pt", max_length=96, padding=True, truncation=True)
return {k: v.cuda() for k, v in batch.items()}
self._sample_tokenize_function = _text_input_tokenize_fn
def setup_generate_kwargs(self, generate_kwargs: dict):
from coati.trainer.ppo import _set_default_generate_kwargs
self._generate_kwargs = _set_default_generate_kwargs(self._strategy, generate_kwargs, self._model)
self._generate_kwargs['pad_token_id'] = self._model_tokenizer.pad_token_id
self._generate_kwargs['eos_token_id'] = self._model_tokenizer.eos_token_id
self._generate_kwargs["pad_token_id"] = self._model_tokenizer.pad_token_id
self._generate_kwargs["eos_token_id"] = self._model_tokenizer.eos_token_id
def load_csv_prompt_file_from_url_to_sampler(self, prompt_url):
import pandas as pd
prompts = pd.read_csv(prompt_url)['prompt']
prompts = pd.read_csv(prompt_url)["prompt"]
self._sampler = self._strategy.setup_sampler(prompts)
def _generate(self, input_ids, **generate_kwargs):
@@ -214,10 +215,9 @@ class RayPPOActor(TrainablePPORole):
def _training_step(self, experience):
num_actions = experience.action_mask.size(1)
action_log_probs = self._model(experience.sequences, num_actions, attention_mask=experience.attention_mask)
actor_loss = self._actor_loss_fn(action_log_probs,
experience.action_log_probs,
experience.advantages,
action_mask=experience.action_mask)
actor_loss = self._actor_loss_fn(
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
)
self._strategy.backward(actor_loss, self._model, self._optimizer)
self._strategy.optimizer_step(self._optimizer)
self._optimizer.zero_grad()
@@ -229,17 +229,18 @@ class RayPPOActor(TrainablePPORole):
self._strategy.save_model(self._model, save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
if should_save_optimizer:
self._strategy.save_optimizer(self._optimizer,
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
only_rank0=False)
self._strategy.save_optimizer(
self._optimizer,
"actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()),
only_rank0=False,
)
def generate_answer(self, prompt, max_length=30, num_return_sequences=5):
encoded_input = self._model_tokenizer(prompt, return_tensors='pt')
encoded_input = self._model_tokenizer(prompt, return_tensors="pt")
input_ids = {k: v.cuda() for k, v in encoded_input.items()}
sequence, _ = self._model.generate(**input_ids,
max_length=max_length,
return_action_mask=False,
num_return_sequences=num_return_sequences)
sequence, _ = self._model.generate(
**input_ids, max_length=max_length, return_action_mask=False, num_return_sequences=num_return_sequences
)
token_list = list(sequence.data[0])
output = " ".join([self._model_tokenizer.decode(token) for token in token_list])
return output
@@ -247,18 +248,16 @@ class RayPPOActor(TrainablePPORole):
@ray.remote(num_gpus=1)
class RayPPOCritic(TrainablePPORole):
def set_loss_function(self, value_clip: float):
self._critic_loss_fn = ValueLoss(value_clip)
def _training_step(self, experience):
values = self._model(experience.sequences,
action_mask=experience.action_mask,
attention_mask=experience.attention_mask)
critic_loss = self._critic_loss_fn(values,
experience.values,
experience.reward,
action_mask=experience.action_mask)
values = self._model(
experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
)
critic_loss = self._critic_loss_fn(
values, experience.values, experience.reward, action_mask=experience.action_mask
)
self._strategy.backward(critic_loss, self._model, self._optimizer)
self._strategy.optimizer_step(self._optimizer)
self._optimizer.zero_grad()
@@ -272,12 +271,12 @@ class RayPPOCritic(TrainablePPORole):
@ray.remote(num_gpus=1)
class RayPPORewardModel(BasePPORole):
def _load_model_from_pretrained(self, model_class, pretrain):
with self._strategy.model_init_context():
critic = model_class(pretrained=pretrain).to(torch.cuda.current_device())
self._model = RewardModel(deepcopy(critic.model),
deepcopy(critic.value_head)).to(torch.cuda.current_device())
self._model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(
torch.cuda.current_device()
)
@torch.no_grad()
def calculate_r(self, sequence_attention_action_mask):
@@ -287,7 +286,6 @@ class RayPPORewardModel(BasePPORole):
@ray.remote(num_gpus=1)
class RayPPOInitialModel(BasePPORole):
def _load_model_from_pretrained(self, model_class, pretrain):
with self._strategy.model_init_context():
self._model = model_class(pretrain).to(torch.cuda.current_device())
@@ -300,8 +298,8 @@ class RayPPOInitialModel(BasePPORole):
class PPORayActorGroup:
"""
A group of ray actors
Functions start with 'async' should return list of object refs
A group of ray actors
Functions start with 'async' should return list of object refs
"""
def __init__(self, num_nodes, num_gpus_per_node, ray_actor_type: Type[BasePPORole]) -> None:
@@ -319,8 +317,9 @@ class PPORayActorGroup:
pg = placement_group(bundles, strategy="STRICT_SPREAD")
ray.get(pg.ready())
if pg:
master_actor = self.ray_actor_type.options(scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg, placement_group_bundle_index=0)).remote(world_size, 0, 0, None, None)
master_actor = self.ray_actor_type.options(
scheduling_strategy=PlacementGroupSchedulingStrategy(placement_group=pg, placement_group_bundle_index=0)
).remote(world_size, 0, 0, None, None)
else:
master_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, 0, 0, None, None)
self._actor_handlers = [master_actor]
@@ -331,16 +330,20 @@ class PPORayActorGroup:
for rank in range(1, world_size):
local_rank = rank % self._num_gpus_per_node
if pg:
worker_actor = self.ray_actor_type.options(scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg, placement_group_bundle_index=rank // self._num_gpus_per_node)).remote(
world_size, rank, local_rank, master_addr, master_port)
worker_actor = self.ray_actor_type.options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg, placement_group_bundle_index=rank // self._num_gpus_per_node
)
).remote(world_size, rank, local_rank, master_addr, master_port)
else:
worker_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, rank, local_rank,
master_addr, master_port)
worker_actor = self.ray_actor_type.options(num_gpus=1).remote(
world_size, rank, local_rank, master_addr, master_port
)
self._actor_handlers.append(worker_actor)
def async_init_model_from_pretrained(self, strategy: str, model_class: Type[LoRAModule], pretrain: str,
has_optimizer: bool):
def async_init_model_from_pretrained(
self, strategy: str, model_class: Type[LoRAModule], pretrain: str, has_optimizer: bool
):
return [
actor.init_model_from_pretrained.remote(strategy, model_class, pretrain, has_optimizer)
for actor in self._actor_handlers
@@ -348,7 +351,6 @@ class PPORayActorGroup:
class TrainableModelRayActorGroup(PPORayActorGroup):
def async_learn_on_experiences(self, experience_refs):
num_actors = len(self._actor_handlers)
learn_result_refs = []
@@ -359,7 +361,6 @@ class TrainableModelRayActorGroup(PPORayActorGroup):
class PPOActorRayActorGroup(TrainableModelRayActorGroup):
def __init__(self, num_nodes, num_gpus_per_node) -> None:
super().__init__(num_nodes, num_gpus_per_node, RayPPOActor)
@@ -381,7 +382,8 @@ class PPOActorRayActorGroup(TrainableModelRayActorGroup):
action_log_probs_refs = []
for i in range(len(sequences_attention_mask_action_mask_refs)):
action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_action_log_probs.remote(
sequences_attention_mask_action_mask_refs[i])
sequences_attention_mask_action_mask_refs[i]
)
action_log_probs_refs.append(action_log_probs_ref)
return action_log_probs_refs
@@ -393,7 +395,6 @@ class PPOActorRayActorGroup(TrainableModelRayActorGroup):
class PPOCriticRayActorGroup(TrainableModelRayActorGroup):
def __init__(self, num_nodes, num_gpus_per_node) -> None:
super().__init__(num_nodes, num_gpus_per_node, RayPPOCritic)
@@ -402,7 +403,8 @@ class PPOCriticRayActorGroup(TrainableModelRayActorGroup):
value_refs = []
for i in range(len(sequences_attention_mask_action_mask_refs)):
value_ref = self._actor_handlers[i % num_actors].calculate_value.remote(
sequences_attention_mask_action_mask_refs[i])
sequences_attention_mask_action_mask_refs[i]
)
value_refs.append(value_ref)
return value_refs
@@ -411,7 +413,6 @@ class PPOCriticRayActorGroup(TrainableModelRayActorGroup):
class PPOInitialRayActorGroup(PPORayActorGroup):
def __init__(self, num_nodes, num_gpus_per_node) -> None:
super().__init__(num_nodes, num_gpus_per_node, RayPPOInitialModel)
@@ -420,13 +421,13 @@ class PPOInitialRayActorGroup(PPORayActorGroup):
base_action_log_probs_refs = []
for i in range(len(sequences_attention_mask_action_mask_refs)):
base_action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_base_action_log_probs.remote(
sequences_attention_mask_action_mask_refs[i])
sequences_attention_mask_action_mask_refs[i]
)
base_action_log_probs_refs.append(base_action_log_probs_ref)
return base_action_log_probs_refs
class PPORewardRayActorGroup(PPORayActorGroup):
def __init__(self, num_nodes, num_gpus_per_node) -> None:
super().__init__(num_nodes, num_gpus_per_node, RayPPORewardModel)
@@ -435,20 +436,21 @@ class PPORewardRayActorGroup(PPORayActorGroup):
r_refs = []
for i in range(len(sequences_attention_mask_action_mask_refs)):
r_ref = self._actor_handlers[i % num_actors].calculate_r.remote(
sequences_attention_mask_action_mask_refs[i])
sequences_attention_mask_action_mask_refs[i]
)
r_refs.append(r_ref)
return r_refs
def main(args):
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S')
if args.model == 'gpt2':
logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
)
if args.model == "gpt2":
actor_model_class, critic_model_class = GPTActor, GPTCritic
elif args.model == 'bloom':
elif args.model == "bloom":
actor_model_class, critic_model_class = BLOOMActor, BLOOMCritic
elif args.model == 'opt':
elif args.model == "opt":
actor_model_class, critic_model_class = OPTActor, OPTCritic
else:
raise ValueError(f'Unsupported model "{args.model}"')
@@ -462,13 +464,14 @@ def main(args):
logging.info("Actors created")
# Prepare model for training
generate_kwargs = {'max_length': 128, 'do_sample': True, 'temperature': 1.0, 'top_k': 50}
generate_kwargs = {"max_length": 128, "do_sample": True, "temperature": 1.0, "top_k": 50}
ray.get(
actor_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, True) +
critic_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, True) +
initial_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, False) +
reward_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, False) +
actor_group.async_prepare_for_sequence_generation(args.model, args.pretrain, generate_kwargs))
actor_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, True)
+ critic_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, True)
+ initial_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, False)
+ reward_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, False)
+ actor_group.async_prepare_for_sequence_generation(args.model, args.pretrain, generate_kwargs)
)
logging.info("Models prepared for training")
# Prepare models for training
@@ -483,8 +486,12 @@ def main(args):
# Start training
logging.info("Training start")
# Set all models to eval and add experience maker
all_ray_actors = actor_group._actor_handlers + critic_group._actor_handlers + \
initial_group._actor_handlers + reward_group._actor_handlers
all_ray_actors = (
actor_group._actor_handlers
+ critic_group._actor_handlers
+ initial_group._actor_handlers
+ reward_group._actor_handlers
)
num_ray_actors = len(all_ray_actors)
ray.get([ray_actor.eval.remote() for ray_actor in all_ray_actors])
ray.get([ray_actor.add_experience_maker.remote() for ray_actor in all_ray_actors])
@@ -497,18 +504,28 @@ def main(args):
time += 1
# Experience queueing stage
sequences_attention_mask_action_mask_refs = actor_group.async_sample_prompts_and_make_sequence(
experience_batch_size)
experience_batch_size
)
base_action_log_probs_refs = initial_group.async_calculate_base_action_log_probs(
sequences_attention_mask_action_mask_refs)
sequences_attention_mask_action_mask_refs
)
values_refs = critic_group.async_calculate_value(sequences_attention_mask_action_mask_refs)
r_refs = reward_group.async_calculate_r(sequences_attention_mask_action_mask_refs)
action_log_probs_refs = actor_group.async_calculate_action_log_probs(
sequences_attention_mask_action_mask_refs)
experience_composition_refs.extend([
ExperienceCompositionRefs(sequences_attention_mask_action_mask_refs[i], action_log_probs_refs[i],
base_action_log_probs_refs[i], values_refs[i], r_refs[i])
for i in range(len(sequences_attention_mask_action_mask_refs))
])
sequences_attention_mask_action_mask_refs
)
experience_composition_refs.extend(
[
ExperienceCompositionRefs(
sequences_attention_mask_action_mask_refs[i],
action_log_probs_refs[i],
base_action_log_probs_refs[i],
values_refs[i],
r_refs[i],
)
for i in range(len(sequences_attention_mask_action_mask_refs))
]
)
# Learning stage
if time % update_timesteps == 0:
experience_refs = []
@@ -519,8 +536,9 @@ def main(args):
experience_refs.append(selected_ray_actor.make_experience.remote(exp_composition_ref))
# backward
ray.get(
actor_group.async_learn_on_experiences(experience_refs) +
critic_group.async_learn_on_experiences(experience_refs))
actor_group.async_learn_on_experiences(experience_refs)
+ critic_group.async_learn_on_experiences(experience_refs)
)
# clear refs queue
experience_composition_refs.clear()
logging.info("Training finished")
@@ -528,26 +546,24 @@ def main(args):
actor_group.save_checkpoint(args.save_path, args.need_optim_ckpt)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--prompt_csv_url', type=str)
parser.add_argument('--strategy',
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='ddp')
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
parser.add_argument('--pretrain', type=str, default='gpt2')
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
parser.add_argument('--num_episodes', type=int, default=10)
parser.add_argument('--max_timesteps', type=int, default=10)
parser.add_argument('--update_timesteps', type=int, default=10)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--num_actor_nodes', type=int, help='num of nodes to use to host actor model', default=1)
parser.add_argument('--num_critic_nodes', type=int, help='num of nodes to use to host critic model', default=1)
parser.add_argument('--num_initial_nodes', type=int, help='num of nodes to use to host initial model', default=1)
parser.add_argument('--num_reward_nodes', type=int, help='num of nodes to use to host reward model', default=1)
parser.add_argument('--num_gpus_per_node', type=int, help='num of gpus on a ray node', default=1)
parser.add_argument("--prompt_csv_url", type=str)
parser.add_argument("--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp")
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt"])
parser.add_argument("--pretrain", type=str, default="gpt2")
parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts.pt")
parser.add_argument("--need_optim_ckpt", type=bool, default=False)
parser.add_argument("--num_episodes", type=int, default=10)
parser.add_argument("--max_timesteps", type=int, default=10)
parser.add_argument("--update_timesteps", type=int, default=10)
parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument("--num_actor_nodes", type=int, help="num of nodes to use to host actor model", default=1)
parser.add_argument("--num_critic_nodes", type=int, help="num of nodes to use to host critic model", default=1)
parser.add_argument("--num_initial_nodes", type=int, help="num of nodes to use to host initial model", default=1)
parser.add_argument("--num_reward_nodes", type=int, help="num of nodes to use to host reward model", default=1)
parser.add_argument("--num_gpus_per_node", type=int, help="num of gpus on a ray node", default=1)
args = parser.parse_args()
ray.init()
main(args)

View File

@@ -22,7 +22,7 @@ class HFRepoFiles:
file_path = hf_hub_download(self.repo_id, file, local_dir=dir_path)
def download_all(self):
file_path = snapshot_download(self.repo_id)
snapshot_download(self.repo_id)
def test_init(model: str, dir_path: str):
@@ -31,19 +31,19 @@ def test_init(model: str, dir_path: str):
actor = GPTActor(config=config)
critic = GPTCritic(config=config)
reward_model = GPTRM(config=config)
tokenizer = GPT2Tokenizer.from_pretrained(dir_path)
GPT2Tokenizer.from_pretrained(dir_path)
elif model == "bloom":
config = BloomConfig.from_pretrained(dir_path)
actor = BLOOMActor(config=config)
critic = BLOOMCritic(config=config)
reward_model = BLOOMRM(config=config)
tokenizer = BloomTokenizerFast.from_pretrained(dir_path)
BloomTokenizerFast.from_pretrained(dir_path)
elif model == "opt":
config = AutoConfig.from_pretrained(dir_path)
actor = OPTActor(config=config)
critic = OPTCritic(config=config)
reward_model = OPTRM(config=config)
tokenizer = AutoTokenizer.from_pretrained(dir_path)
AutoTokenizer.from_pretrained(dir_path)
else:
raise NotImplementedError(f"Model {model} not implemented")
@@ -59,17 +59,12 @@ if __name__ == "__main__":
exit(0)
repo_list = {
"gpt2": HFRepoFiles(
repo_id="gpt2",
files=["config.json", "tokenizer.json", "vocab.json", "merges.txt"]
),
"gpt2": HFRepoFiles(repo_id="gpt2", files=["config.json", "tokenizer.json", "vocab.json", "merges.txt"]),
"bloom": HFRepoFiles(
repo_id="bigscience/bloom-560m",
files=["config.json", "tokenizer.json", "tokenizer_config.json"]
repo_id="bigscience/bloom-560m", files=["config.json", "tokenizer.json", "tokenizer_config.json"]
),
"opt": HFRepoFiles(
repo_id="facebook/opt-350m",
files=["config.json", "tokenizer_config.json", "vocab.json", "merges.txt"]
repo_id="facebook/opt-350m", files=["config.json", "tokenizer_config.json", "vocab.json", "merges.txt"]
),
}

View File

@@ -31,9 +31,11 @@ def generate_alpaca():
def generate_sharegpt():
# ShareGPT data requires less processing.
conversation_dataset = []
dataset = load_dataset("anon8231489123/ShareGPT_Vicuna_unfiltered",
data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json",
split="train")
dataset = load_dataset(
"anon8231489123/ShareGPT_Vicuna_unfiltered",
data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json",
split="train",
)
conversations = dataset["conversations"]
@@ -43,23 +45,24 @@ def generate_sharegpt():
del conv["markdown"]
del conv["text"]
conversation = dict(type="conversation",
language="Multilingual",
dataset="ShareGPT",
conversations=conversations[idx])
conversation = dict(
type="conversation", language="Multilingual", dataset="ShareGPT", conversations=conversations[idx]
)
conversation_dataset.append(conversation)
return conversation_dataset
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--dataset',
type=str,
default="All",
choices=["Alpaca", "ShareGPT", "All"],
help="which dataset to convert, All will combine Alpaca and ShareGPT")
parser.add_argument('--save_path', type=str, default="dataset.json", help="path to save the converted dataset")
parser.add_argument(
"--dataset",
type=str,
default="All",
choices=["Alpaca", "ShareGPT", "All"],
help="which dataset to convert, All will combine Alpaca and ShareGPT",
)
parser.add_argument("--save_path", type=str, default="dataset.json", help="path to save the converted dataset")
args = parser.parse_args()
conversation_dataset = []
@@ -75,5 +78,5 @@ if __name__ == '__main__':
for idx, sample in enumerate(conversation_dataset):
sample["id"] = idx + 1
with open(args.save_path, mode='w') as f:
with open(args.save_path, mode="w") as f:
json.dump(conversation_dataset, f, indent=4, default=str, ensure_ascii=False)

View File

@@ -6,7 +6,7 @@ random.seed(42)
def sample(args):
with open(args.dataset_path, mode='r') as f:
with open(args.dataset_path, mode="r") as f:
dataset_list = json.load(f)
sampled_dataset = [
@@ -14,18 +14,14 @@ def sample(args):
for idx, sample in enumerate(random.sample(dataset_list, args.sample_size))
]
with open(args.save_path, mode='w') as f:
json.dump(sampled_dataset, f, indent=4,
default=str, ensure_ascii=False)
with open(args.save_path, mode="w") as f:
json.dump(sampled_dataset, f, indent=4, default=str, ensure_ascii=False)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_path', type=str, default=None,
required=True, help="path to the pretrain dataset")
parser.add_argument('--save_path', type=str, default='prompt.json',
help="path to save the prompt dataset")
parser.add_argument('--sample_size', type=int,
default=16384, help="size of the prompt dataset")
parser.add_argument("--dataset_path", type=str, default=None, required=True, help="path to the pretrain dataset")
parser.add_argument("--save_path", type=str, default="prompt.json", help="path to save the prompt dataset")
parser.add_argument("--sample_size", type=int, default=16384, help="size of the prompt dataset")
args = parser.parse_args()
sample(args)

View File

@@ -11,13 +11,13 @@ from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, Llama
def eval(args):
# configure model
if args.model == 'gpt2':
if args.model == "gpt2":
actor = GPTActor(pretrained=args.pretrain)
elif args.model == 'bloom':
elif args.model == "bloom":
actor = BLOOMActor(pretrained=args.pretrain)
elif args.model == 'opt':
elif args.model == "opt":
actor = OPTActor(pretrained=args.pretrain)
elif args.model == 'llama':
elif args.model == "llama":
actor = LlamaActor(pretrained=args.pretrain)
else:
raise ValueError(f'Unsupported model "{args.model}"')
@@ -28,45 +28,38 @@ def eval(args):
actor.load_state_dict(state_dict)
# configure tokenizer
if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
if args.model == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
elif args.model == "opt":
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'llama':
elif args.model == "llama":
tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
tokenizer.eos_token = '<\s>'
tokenizer.eos_token = "<\s>"
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
actor.eval()
input_ids = tokenizer.encode(args.input,
return_tensors='pt')\
.to(torch.cuda.current_device())
outputs = generate(actor,
input_ids,
max_length=args.max_length,
do_sample=True,
top_k=50,
top_p=0.95,
num_return_sequences=1)
output = tokenizer.batch_decode(outputs[0],
skip_special_tokens=True)
input_ids = tokenizer.encode(args.input, return_tensors="pt").to(torch.cuda.current_device())
outputs = generate(
actor, input_ids, max_length=args.max_length, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1
)
output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
print(f"[Output]: {''.join(output)}")
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
# We suggest to use the pretrained model from HuggingFace, use pretrain to configure model
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--model_path', type=str, default=None)
parser.add_argument('--input', type=str, default='Question: How are you ? Answer:')
parser.add_argument('--max_length', type=int, default=100)
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--model_path", type=str, default=None)
parser.add_argument("--input", type=str, default="Question: How are you ? Answer:")
parser.add_argument("--max_length", type=int, default=100)
args = parser.parse_args()
eval(args)

View File

@@ -5,7 +5,6 @@ from functools import partial
import pandas as pd
import ray
import torch
from coati.quant import llama_load_quant, low_resource_init
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
from coati.ray.experience_maker_holder import ExperienceMakerHolder
@@ -23,13 +22,13 @@ from transformers.modeling_utils import no_init_weights
def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
s.bind(("", 0))
return s.getsockname()[1]
def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(('8.8.8.8', 80))
s.connect(("8.8.8.8", 80))
return s.getsockname()[0]
@@ -37,22 +36,25 @@ def main(args):
master_addr = str(get_local_ip())
# trainer_env_info
trainer_port = str(get_free_port())
env_info_trainers = [{
'local_rank': '0',
'rank': str(rank),
'world_size': str(args.num_trainers),
'master_port': trainer_port,
'master_addr': master_addr
} for rank in range(args.num_trainers)]
env_info_trainers = [
{
"local_rank": "0",
"rank": str(rank),
"world_size": str(args.num_trainers),
"master_port": trainer_port,
"master_addr": master_addr,
}
for rank in range(args.num_trainers)
]
# maker_env_info
maker_port = str(get_free_port())
env_info_maker = {
'local_rank': '0',
'rank': '0',
'world_size': '1',
'master_port': maker_port,
'master_addr': master_addr
"local_rank": "0",
"rank": "0",
"world_size": "1",
"master_port": maker_port,
"master_addr": master_addr,
}
# configure tokenizer
@@ -75,27 +77,33 @@ def main(args):
eval_performance=True,
debug=args.debug,
update_lora_weights=not (args.lora_rank == 0),
) for i, env_info_trainer in enumerate(env_info_trainers)
)
for i, env_info_trainer in enumerate(env_info_trainers)
]
def model_fn():
actor = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
critic = get_critic_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
reward_model = get_reward_model_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
if args.initial_model_quant_ckpt is not None and args.model == 'llama':
if args.initial_model_quant_ckpt is not None and args.model == "llama":
# quantize initial model
actor_cfg = AutoConfig.from_pretrained(args.pretrain)
with low_resource_init(), no_init_weights():
initial_model = get_actor_from_args(args.model, config=actor_cfg)
initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits,
args.quant_group_size).cuda().requires_grad_(False)
initial_model.model = (
llama_load_quant(
initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
)
.cuda()
.requires_grad_(False)
)
else:
initial_model = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
return actor, critic, reward_model, initial_model
# configure Experience Maker
experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=[f'trainer{i}' for i in range(args.num_trainers)],
detached_trainer_name_list=[f"trainer{i}" for i in range(args.num_trainers)],
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
model_fn=model_fn,
env_info=env_info_maker,
@@ -130,12 +138,11 @@ def main(args):
dataset_size = args.experience_batch_size * 4
def build_dataloader():
def tokenize_fn(texts):
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
return {k: v.cuda() for k, v in batch.items()}
dataset = pd.read_csv(args.prompt_path)['prompt']
dataset = pd.read_csv(args.prompt_path)["prompt"]
dataloader = DataLoader(dataset=dataset, batch_size=dataset_size, shuffle=True, collate_fn=tokenize_fn)
return dataloader
@@ -144,32 +151,31 @@ def main(args):
ray.get(wait_tasks)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--prompt_path', type=str, default=None)
parser.add_argument('--num_trainers', type=int, default=1)
parser.add_argument('--trainer_strategy',
choices=[
'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
'colossalai_zero2_cpu'
],
default='ddp')
parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--critic_pretrain', type=str, default=None)
parser.add_argument('--experience_steps', type=int, default=4)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--train_epochs', type=int, default=1)
parser.add_argument('--update_steps', type=int, default=2)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--prompt_path", type=str, default=None)
parser.add_argument("--num_trainers", type=int, default=1)
parser.add_argument(
"--trainer_strategy",
choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
default="ddp",
)
parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--critic_pretrain", type=str, default=None)
parser.add_argument("--experience_steps", type=int, default=4)
parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument("--train_epochs", type=int, default=1)
parser.add_argument("--update_steps", type=int, default=2)
parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--initial_model_quant_ckpt', type=str, default=None)
parser.add_argument('--quant_bits', type=int, default=4)
parser.add_argument('--quant_group_size', type=int, default=128)
parser.add_argument('--debug', action='store_true')
parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
parser.add_argument("--quant_bits", type=int, default=4)
parser.add_argument("--quant_group_size", type=int, default=128)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
main(args)

View File

@@ -5,7 +5,6 @@ from functools import partial
import pandas as pd
import ray
import torch
from coati.quant import llama_load_quant, low_resource_init
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
from coati.ray.experience_maker_holder import ExperienceMakerHolder
@@ -23,13 +22,13 @@ from transformers.modeling_utils import no_init_weights
def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
s.bind(("", 0))
return s.getsockname()[1]
def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(('8.8.8.8', 80))
s.connect(("8.8.8.8", 80))
return s.getsockname()[0]
@@ -37,23 +36,29 @@ def main(args):
master_addr = str(get_local_ip())
# trainer_env_info
trainer_port = str(get_free_port())
env_info_trainers = [{
'local_rank': '0',
'rank': str(rank),
'world_size': str(args.num_trainers),
'master_port': trainer_port,
'master_addr': master_addr
} for rank in range(args.num_trainers)]
env_info_trainers = [
{
"local_rank": "0",
"rank": str(rank),
"world_size": str(args.num_trainers),
"master_port": trainer_port,
"master_addr": master_addr,
}
for rank in range(args.num_trainers)
]
# maker_env_info
maker_port = str(get_free_port())
env_info_makers = [{
'local_rank': '0',
'rank': str(rank),
'world_size': str(args.num_makers),
'master_port': maker_port,
'master_addr': master_addr
} for rank in range(args.num_makers)]
env_info_makers = [
{
"local_rank": "0",
"rank": str(rank),
"world_size": str(args.num_makers),
"master_port": maker_port,
"master_addr": master_addr,
}
for rank in range(args.num_makers)
]
# configure tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
@@ -63,13 +68,18 @@ def main(args):
actor = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
critic = get_critic_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
reward_model = get_reward_model_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
if args.initial_model_quant_ckpt is not None and args.model == 'llama':
if args.initial_model_quant_ckpt is not None and args.model == "llama":
# quantize initial model
actor_cfg = AutoConfig.from_pretrained(args.pretrain)
with low_resource_init(), no_init_weights():
initial_model = get_actor_from_args(args.model, config=actor_cfg)
initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits,
args.quant_group_size).cuda().requires_grad_(False)
initial_model.model = (
llama_load_quant(
initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
)
.cuda()
.requires_grad_(False)
)
else:
initial_model = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
return actor, critic, reward_model, initial_model
@@ -78,7 +88,7 @@ def main(args):
experience_holder_refs = [
ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=[
f'trainer{x}'
f"trainer{x}"
for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False)
],
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
@@ -87,8 +97,8 @@ def main(args):
kl_coef=0.1,
debug=args.debug,
update_lora_weights=not (args.lora_rank == 0),
# sync_models_from_trainers=True,
# generation kwargs:
# sync_models_from_trainers=True,
# generation kwargs:
max_length=512,
do_sample=True,
temperature=1.0,
@@ -128,12 +138,11 @@ def main(args):
dataset_size = args.experience_batch_size * 4
def build_dataloader():
def tokenize_fn(texts):
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
return {k: v.cuda() for k, v in batch.items()}
dataset = pd.read_csv(args.prompt_path)['prompt']
dataset = pd.read_csv(args.prompt_path)["prompt"]
dataloader = DataLoader(dataset=dataset, batch_size=dataset_size, shuffle=True, collate_fn=tokenize_fn)
return dataloader
@@ -148,39 +157,44 @@ def main(args):
for experience_holder_ref in experience_holder_refs:
wait_tasks.append(experience_holder_ref.workingloop.remote(build_dataloader, num_steps=args.experience_steps))
total_steps = args.experience_batch_size * args.experience_steps * \
args.num_makers // (args.num_trainers * args.train_batch_size)
total_steps = (
args.experience_batch_size
* args.experience_steps
* args.num_makers
// (args.num_trainers * args.train_batch_size)
)
for trainer_ref in trainer_refs:
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
ray.get(wait_tasks)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--prompt_path', type=str, default=None)
parser.add_argument('--num_makers', type=int, default=1)
parser.add_argument('--num_trainers', type=int, default=1)
parser.add_argument("--prompt_path", type=str, default=None)
parser.add_argument("--num_makers", type=int, default=1)
parser.add_argument("--num_trainers", type=int, default=1)
parser.add_argument(
'--trainer_strategy',
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu', 'colossalai_zero2_cpu'],
default='ddp')
parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--critic_pretrain', type=str, default=None)
parser.add_argument('--experience_steps', type=int, default=4)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--train_epochs', type=int, default=1)
parser.add_argument('--update_steps', type=int, default=2)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
"--trainer_strategy",
choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
default="ddp",
)
parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--critic_pretrain", type=str, default=None)
parser.add_argument("--experience_steps", type=int, default=4)
parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument("--train_epochs", type=int, default=1)
parser.add_argument("--update_steps", type=int, default=2)
parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--initial_model_quant_ckpt', type=str, default=None)
parser.add_argument('--quant_bits', type=int, default=4)
parser.add_argument('--quant_group_size', type=int, default=128)
parser.add_argument('--debug', action='store_true')
parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
parser.add_argument("--quant_bits", type=int, default=4)
parser.add_argument("--quant_group_size", type=int, default=128)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})

View File

@@ -1,3 +1,3 @@
pandas>=1.4.1
sentencepiece
colossalai==0.3.1
colossalai==0.3.1

View File

@@ -20,28 +20,28 @@ from colossalai.nn.optimizer import HybridAdam
def main(args):
# configure strategy
if args.strategy == 'ddp':
if args.strategy == "ddp":
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
elif args.strategy == 'colossalai_zero2':
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
if args.rm_path is not None:
warnings.warn('LoRA weights should be merged with the model weights')
state_dict = torch.load(args.rm_path, map_location='cpu')
warnings.warn("LoRA weights should be merged with the model weights")
state_dict = torch.load(args.rm_path, map_location="cpu")
with strategy.model_init_context():
# configure model
if args.model == 'gpt2':
if args.model == "gpt2":
initial_model = GPTActor(pretrained=args.pretrain)
elif args.model == 'bloom':
elif args.model == "bloom":
initial_model = BLOOMActor(pretrained=args.pretrain)
elif args.model == 'opt':
elif args.model == "opt":
initial_model = OPTActor(pretrained=args.pretrain)
elif args.model == 'llama':
elif args.model == "llama":
initial_model = LlamaActor(pretrained=args.pretrain)
else:
raise ValueError(f'Unsupported actor model "{args.model}"')
@@ -51,13 +51,13 @@ def main(args):
else:
rm_model_name = args.rm_model
if rm_model_name == 'gpt2':
if rm_model_name == "gpt2":
reward_model = GPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
elif rm_model_name == 'bloom':
elif rm_model_name == "bloom":
reward_model = BLOOMRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
elif rm_model_name == 'opt':
elif rm_model_name == "opt":
reward_model = OPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
elif rm_model_name == 'llama':
elif rm_model_name == "llama":
reward_model = LlamaRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
@@ -68,24 +68,24 @@ def main(args):
initial_model.to(torch.float16).to(torch.cuda.current_device())
reward_model.to(torch.float16).to(torch.cuda.current_device())
if args.model == 'gpt2':
if args.model == "gpt2":
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
elif args.model == 'bloom':
elif args.model == "bloom":
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
elif args.model == 'opt':
elif args.model == "opt":
actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
elif args.model == 'llama':
elif args.model == "llama":
actor = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
else:
raise ValueError(f'Unsupported actor model "{args.model}"')
if rm_model_name == 'gpt2':
if rm_model_name == "gpt2":
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
elif rm_model_name == 'bloom':
elif rm_model_name == "bloom":
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
elif rm_model_name == 'opt':
elif rm_model_name == "opt":
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
elif rm_model_name == 'llama':
elif rm_model_name == "llama":
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
@@ -94,12 +94,12 @@ def main(args):
critic.load_state_dict(state_dict, strict=False)
del state_dict
if args.strategy != 'colossalai_gemini':
if args.strategy != "colossalai_gemini":
critic.to(torch.float16).to(torch.cuda.current_device())
actor.to(torch.float16).to(torch.cuda.current_device())
# configure optimizer
if args.strategy.startswith('colossalai'):
if args.strategy.startswith("colossalai"):
actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
else:
@@ -107,22 +107,22 @@ def main(args):
critic_optim = Adam(critic.parameters(), lr=1e-7)
# configure tokenizer
if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained(
'gpt2' if args.tokenizer is None else args.tokenizer)
if args.model == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained(
'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer)
"bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained(
"facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
elif args.model == "opt":
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'llama':
elif args.model == "llama":
tokenizer = LlamaTokenizer.from_pretrained(
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
tokenizer.eos_token = '<\s>'
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
)
tokenizer.eos_token = "<\s>"
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
@@ -132,27 +132,25 @@ def main(args):
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
else:
prompt_sampler = None
prompt_dataloader = DataLoader(prompt_dataset,
shuffle=(prompt_sampler is None),
sampler=prompt_sampler,
batch_size=args.experience_batch_size)
prompt_dataloader = DataLoader(
prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.experience_batch_size
)
pretrain_dataset = SupervisedDataset(tokenizer=tokenizer,
data_path=args.pretrain_dataset,
max_datasets_size=16384,
max_length=args.max_input_len)
pretrain_dataset = SupervisedDataset(
tokenizer=tokenizer, data_path=args.pretrain_dataset, max_datasets_size=16384, max_length=args.max_input_len
)
if dist.is_initialized() and dist.get_world_size() > 1:
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
else:
pretrain_sampler = None
pretrain_dataloader = DataLoader(pretrain_dataset,
shuffle=(pretrain_sampler is None),
sampler=pretrain_sampler,
batch_size=args.ptx_batch_size)
pretrain_dataloader = DataLoader(
pretrain_dataset, shuffle=(pretrain_sampler is None), sampler=pretrain_sampler, batch_size=args.ptx_batch_size
)
# NOTE: For small models like opt-1.3b, reward model and initial model are not required to be parallelized.
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = \
strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model
)
# configure trainer
trainer = PPOTrainer(
@@ -173,50 +171,54 @@ def main(args):
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
offload_inference_models=args.strategy != 'colossalai_gemini'
offload_inference_models=args.strategy != "colossalai_gemini",
)
trainer.fit(prompt_dataloader=prompt_dataloader,
pretrain_dataloader=pretrain_dataloader,
num_episodes=args.num_episodes,
num_collect_steps=args.num_collect_steps,
num_update_steps=args.num_update_steps)
trainer.fit(
prompt_dataloader=prompt_dataloader,
pretrain_dataloader=pretrain_dataloader,
num_episodes=args.num_episodes,
num_collect_steps=args.num_collect_steps,
num_update_steps=args.num_update_steps,
)
# save model checkpoint after fitting
strategy.save_model(actor, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(actor_optim,
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
only_rank0=False)
strategy.save_optimizer(
actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False
)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--prompt_dataset', type=str, default=None, help='path to the prompt dataset')
parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset')
parser.add_argument('--strategy',
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='colossalai_zero2',
help='strategy to use')
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama'])
parser.add_argument('--rm_path', type=str, default=None)
parser.add_argument('--rm_pretrain', type=str, default=None)
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
parser.add_argument('--num_episodes', type=int, default=10)
parser.add_argument('--num_collect_steps', type=int, default=10)
parser.add_argument('--num_update_steps', type=int, default=5)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--ptx_batch_size', type=int, default=1)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--kl_coef', type=float, default=0.1)
parser.add_argument('--ptx_coef', type=float, default=0.9)
parser.add_argument('--max_input_len', type=int, default=96)
parser.add_argument('--max_seq_len', type=int, default=128)
parser.add_argument("--prompt_dataset", type=str, default=None, help="path to the prompt dataset")
parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
parser.add_argument(
"--strategy",
choices=["ddp", "colossalai_gemini", "colossalai_zero2"],
default="colossalai_zero2",
help="strategy to use",
)
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument("--rm_path", type=str, default=None)
parser.add_argument("--rm_pretrain", type=str, default=None)
parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
parser.add_argument("--need_optim_ckpt", type=bool, default=False)
parser.add_argument("--num_episodes", type=int, default=10)
parser.add_argument("--num_collect_steps", type=int, default=10)
parser.add_argument("--num_update_steps", type=int, default=5)
parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument("--ptx_batch_size", type=int, default=1)
parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--kl_coef", type=float, default=0.1)
parser.add_argument("--ptx_coef", type=float, default=0.9)
parser.add_argument("--max_input_len", type=int, default=96)
parser.add_argument("--max_seq_len", type=int, default=128)
args = parser.parse_args()
main(args)

View File

@@ -24,24 +24,24 @@ from colossalai.nn.optimizer import HybridAdam
def train(args):
# configure strategy
if args.strategy == 'ddp':
if args.strategy == "ddp":
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
strategy = GeminiStrategy(placement_policy='cuda')
elif args.strategy == 'colossalai_zero2':
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="cuda")
elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model
with strategy.model_init_context():
if args.model == 'bloom':
if args.model == "bloom":
model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
elif args.model == 'opt':
elif args.model == "opt":
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
elif args.model == 'gpt2':
elif args.model == "gpt2":
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
elif args.model == 'llama':
elif args.model == "llama":
model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
else:
raise ValueError(f'Unsupported model "{args.model}"')
@@ -53,36 +53,36 @@ def train(args):
model.load_state_dict(state_dict)
# configure tokenizer
if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained(
'gpt2' if args.tokenizer is None else args.tokenizer)
if args.model == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained(
'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer)
"bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained(
"facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
elif args.model == "opt":
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'llama':
elif args.model == "llama":
tokenizer = LlamaTokenizer.from_pretrained(
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
tokenizer.eos_token = '<\s>'
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
)
tokenizer.eos_token = "<\s>"
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
# configure optimizer
if args.strategy.startswith('colossalai'):
if args.strategy.startswith("colossalai"):
optim = HybridAdam(model.parameters(), lr=5e-6)
else:
optim = Adam(model.parameters(), lr=5e-6)
# configure loss function
if args.loss_fn == 'log_sig':
if args.loss_fn == "log_sig":
loss_fn = LogSigLoss()
elif args.loss_fn == 'log_exp':
elif args.loss_fn == "log_exp":
loss_fn = LogExpLoss()
else:
raise ValueError(f'Unsupported loss function "{args.loss_fn}"')
@@ -94,18 +94,18 @@ def train(args):
data = load_dataset(args.dataset)
if args.test:
train_data = data['train'].select(range(20))
eval_data = data['test'].select(range(5))
train_data = data["train"].select(range(20))
eval_data = data["test"].select(range(5))
else:
train_data = data['train']
eval_data = data['test']
valid_data = data['test'].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5)))
train_data = data["train"]
eval_data = data["test"]
valid_data = data["test"].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5)))
if args.dataset == 'Dahoas/rm-static':
if args.dataset == "Dahoas/rm-static":
train_dataset = RmStaticDataset(train_data, tokenizer, args.max_len)
valid_dataset = RmStaticDataset(valid_data, tokenizer, args.max_len)
eval_dataset = RmStaticDataset(eval_data, tokenizer, args.max_len)
elif args.dataset == 'Anthropic/hh-rlhf':
elif args.dataset == "Anthropic/hh-rlhf":
train_dataset = HhRlhfDataset(train_data, tokenizer, args.max_len)
valid_dataset = HhRlhfDataset(valid_data, tokenizer, args.max_len)
eval_dataset = HhRlhfDataset(eval_data, tokenizer, args.max_len)
@@ -113,90 +113,99 @@ def train(args):
raise ValueError(f'Unsupported dataset "{args.dataset}"')
if dist.is_initialized() and dist.get_world_size() > 1:
train_sampler = DistributedSampler(train_dataset,
shuffle=True,
seed=42,
drop_last=True,
rank=dist.get_rank(),
num_replicas=dist.get_world_size())
valid_sampler = DistributedSampler(valid_dataset,
shuffle=True,
seed=42,
drop_last=True,
rank=dist.get_rank(),
num_replicas=dist.get_world_size())
eval_sampler = DistributedSampler(eval_dataset,
shuffle=True,
seed=42,
drop_last=True,
rank=dist.get_rank(),
num_replicas=dist.get_world_size())
train_sampler = DistributedSampler(
train_dataset,
shuffle=True,
seed=42,
drop_last=True,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
)
valid_sampler = DistributedSampler(
valid_dataset,
shuffle=True,
seed=42,
drop_last=True,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
)
eval_sampler = DistributedSampler(
eval_dataset,
shuffle=True,
seed=42,
drop_last=True,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
)
else:
train_sampler = None
valid_sampler = None
eval_sampler = None
train_dataloader = DataLoader(train_dataset,
shuffle=(train_sampler is None),
sampler=train_sampler,
batch_size=args.batch_size,
pin_memory=True)
train_dataloader = DataLoader(
train_dataset,
shuffle=(train_sampler is None),
sampler=train_sampler,
batch_size=args.batch_size,
pin_memory=True,
)
valid_dataloader = DataLoader(valid_dataset,
shuffle=(valid_sampler is None),
sampler=valid_sampler,
batch_size=args.batch_size,
pin_memory=True)
valid_dataloader = DataLoader(
valid_dataset,
shuffle=(valid_sampler is None),
sampler=valid_sampler,
batch_size=args.batch_size,
pin_memory=True,
)
eval_dataloader = DataLoader(eval_dataset,
shuffle=(eval_sampler is None),
sampler=eval_sampler,
batch_size=args.batch_size,
pin_memory=True)
eval_dataloader = DataLoader(
eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True
)
lr_scheduler = CosineAnnealingLR(optim, train_dataloader.__len__() // 100)
strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
model = strategy_dict['model']
optim = strategy_dict['optimizer']
lr_scheduler = strategy_dict['lr_scheduler']
trainer = RewardModelTrainer(model=model,
strategy=strategy,
optim=optim,
lr_scheduler=lr_scheduler,
loss_fn=loss_fn,
max_epochs=args.max_epochs)
model = strategy_dict["model"]
optim = strategy_dict["optimizer"]
lr_scheduler = strategy_dict["lr_scheduler"]
trainer = RewardModelTrainer(
model=model,
strategy=strategy,
optim=optim,
lr_scheduler=lr_scheduler,
loss_fn=loss_fn,
max_epochs=args.max_epochs,
)
trainer.fit(train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, eval_dataloader=eval_dataloader)
# save model checkpoint after fitting on only rank0
strategy.save_model(model, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(trainer.optimizer,
'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()),
only_rank0=False)
strategy.save_optimizer(
trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--strategy',
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='colossalai_zero2')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--model_path', type=str, default=None)
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
parser.add_argument('--dataset',
type=str,
choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'],
default='Dahoas/rm-static')
parser.add_argument('--subset', type=lambda x: None if x == 'None' else x, default=None)
parser.add_argument('--save_path', type=str, default='rm_ckpt')
parser.add_argument('--max_epochs', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--max_len', type=int, default=512)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--loss_fn', type=str, default='log_sig', choices=['log_sig', 'log_exp'])
parser.add_argument('--test', type=bool, default=False)
parser.add_argument(
"--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="colossalai_zero2"
)
parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom")
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--model_path", type=str, default=None)
parser.add_argument("--need_optim_ckpt", type=bool, default=False)
parser.add_argument(
"--dataset", type=str, choices=["Anthropic/hh-rlhf", "Dahoas/rm-static"], default="Dahoas/rm-static"
)
parser.add_argument("--subset", type=lambda x: None if x == "None" else x, default=None)
parser.add_argument("--save_path", type=str, default="rm_ckpt")
parser.add_argument("--max_epochs", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--max_len", type=int, default=512)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"])
parser.add_argument("--test", type=bool, default=False)
args = parser.parse_args()
train(args)

View File

@@ -6,18 +6,18 @@ import torch
import torch.distributed as dist
from coati.dataset import SFTDataset, SupervisedDataset
from coati.models.bloom import BLOOMActor
from coati.models.chatglm import ChatGLMActor
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from coati.models.gpt import GPTActor
from coati.models.llama import LlamaActor
from coati.models.opt import OPTActor
from coati.models.chatglm import ChatGLMActor
from coati.trainer import SFTTrainer
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from datasets import load_dataset
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, AutoModel
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from transformers.trainer import get_scheduler
@@ -28,14 +28,14 @@ from colossalai.tensor import ColoParameter
def train(args):
# configure strategy
if args.strategy == 'ddp':
if args.strategy == "ddp":
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
strategy = GeminiStrategy(placement_policy='cuda')
elif args.strategy == 'colossalai_zero2':
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
elif args.strategy == 'colossalai_zero2_cpu':
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="cuda")
elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
elif args.strategy == "colossalai_zero2_cpu":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
@@ -44,23 +44,15 @@ def train(args):
warnings.warn("Gradient checkpoint is disabled when using LoRA")
args.grad_checkpoint = False
with strategy.model_init_context():
if args.model == 'bloom':
model = BLOOMActor(pretrained=args.pretrain,
lora_rank=args.lora_rank,
checkpoint=args.grad_checkpoint)
elif args.model == 'opt':
model = OPTActor(pretrained=args.pretrain,
lora_rank=args.lora_rank,
checkpoint=args.grad_checkpoint)
elif args.model == 'gpt2':
model = GPTActor(pretrained=args.pretrain,
lora_rank=args.lora_rank,
checkpoint=args.grad_checkpoint)
elif args.model == 'llama':
model = LlamaActor(pretrained=args.pretrain,
lora_rank=args.lora_rank,
checkpoint=args.grad_checkpoint)
elif args.model == 'chatglm':
if args.model == "bloom":
model = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
elif args.model == "opt":
model = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
elif args.model == "gpt2":
model = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
elif args.model == "llama":
model = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
elif args.model == "chatglm":
model = ChatGLMActor(pretrained=args.pretrain)
else:
raise ValueError(f'Unsupported model "{args.model}"')
@@ -68,144 +60,157 @@ def train(args):
model.to(torch.float16).to(torch.cuda.current_device())
# configure tokenizer
if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained(
'gpt2' if args.tokenizer is None else args.tokenizer)
if args.model == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained(
'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer)
"bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained(
"facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
elif args.model == "opt":
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'llama':
elif args.model == "llama":
tokenizer = LlamaTokenizer.from_pretrained(
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
tokenizer.eos_token = '<\s>'
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
)
tokenizer.eos_token = "<\s>"
tokenizer.pad_token = tokenizer.unk_token
elif args.model == 'chatglm':
elif args.model == "chatglm":
tokenizer = ChatGLMTokenizer.from_pretrained(
"THUDM/chatglm-6b" if args.tokenizer is None else args.tokenizer, trust_remote_code=True)
"THUDM/chatglm-6b" if args.tokenizer is None else args.tokenizer, trust_remote_code=True
)
else:
raise ValueError(f'Unsupported model "{args.model}"')
if args.model == 'llama' and args.strategy == 'colossalai_gemini':
if args.model == "llama" and args.strategy == "colossalai_gemini":
# this is a hack to deal with the resized embedding
# to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility
for name, param in model.named_parameters():
if not isinstance(param, ColoParameter):
sub_module_name = '.'.join(name.split('.')[:-1])
weight_name = name.split('.')[-1]
sub_module_name = ".".join(name.split(".")[:-1])
weight_name = name.split(".")[-1]
sub_module = model.get_submodule(sub_module_name)
setattr(sub_module, weight_name, ColoParameter(param))
# configure optimizer
if args.strategy.startswith('colossalai'):
if args.strategy.startswith("colossalai"):
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
else:
optim = Adam(model.parameters(), lr=args.lr)
logger = get_dist_logger()
# configure dataset
if args.dataset == 'yizhongw/self_instruct':
train_data = load_dataset(args.dataset, 'super_natural_instructions', split='train')
eval_data = load_dataset(args.dataset, 'super_natural_instructions', split='test')
if args.dataset == "yizhongw/self_instruct":
train_data = load_dataset(args.dataset, "super_natural_instructions", split="train")
eval_data = load_dataset(args.dataset, "super_natural_instructions", split="test")
train_dataset = SFTDataset(train_data, tokenizer, args.max_len)
eval_dataset = SFTDataset(eval_data, tokenizer, args.max_len)
else:
train_dataset = SupervisedDataset(tokenizer=tokenizer,
data_path=args.dataset,
max_datasets_size=args.max_datasets_size,
max_length=args.max_len)
train_dataset = SupervisedDataset(
tokenizer=tokenizer,
data_path=args.dataset,
max_datasets_size=args.max_datasets_size,
max_length=args.max_len,
)
eval_dataset = None
if dist.is_initialized() and dist.get_world_size() > 1:
train_sampler = DistributedSampler(train_dataset,
shuffle=True,
seed=42,
drop_last=True,
rank=dist.get_rank(),
num_replicas=dist.get_world_size())
train_sampler = DistributedSampler(
train_dataset,
shuffle=True,
seed=42,
drop_last=True,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
)
if eval_dataset is not None:
eval_sampler = DistributedSampler(eval_dataset,
shuffle=False,
seed=42,
drop_last=False,
rank=dist.get_rank(),
num_replicas=dist.get_world_size())
eval_sampler = DistributedSampler(
eval_dataset,
shuffle=False,
seed=42,
drop_last=False,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
)
else:
train_sampler = None
eval_sampler = None
train_dataloader = DataLoader(train_dataset,
shuffle=(train_sampler is None),
sampler=train_sampler,
batch_size=args.batch_size,
pin_memory=True)
train_dataloader = DataLoader(
train_dataset,
shuffle=(train_sampler is None),
sampler=train_sampler,
batch_size=args.batch_size,
pin_memory=True,
)
if eval_dataset is not None:
eval_dataloader = DataLoader(eval_dataset,
shuffle=(eval_sampler is None),
sampler=eval_sampler,
batch_size=args.batch_size,
pin_memory=True)
eval_dataloader = DataLoader(
eval_dataset,
shuffle=(eval_sampler is None),
sampler=eval_sampler,
batch_size=args.batch_size,
pin_memory=True,
)
else:
eval_dataloader = None
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch)
lr_scheduler = get_scheduler("cosine",
optim,
num_warmup_steps=math.ceil(max_steps * 0.03),
num_training_steps=max_steps)
lr_scheduler = get_scheduler(
"cosine", optim, num_warmup_steps=math.ceil(max_steps * 0.03), num_training_steps=max_steps
)
strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
model = strategy_dict['model']
optim = strategy_dict['optimizer']
lr_scheduler = strategy_dict['lr_scheduler']
trainer = SFTTrainer(model=model,
strategy=strategy,
optim=optim,
lr_scheduler=lr_scheduler,
max_epochs=args.max_epochs,
accumulation_steps=args.accumulation_steps)
model = strategy_dict["model"]
optim = strategy_dict["optimizer"]
lr_scheduler = strategy_dict["lr_scheduler"]
trainer = SFTTrainer(
model=model,
strategy=strategy,
optim=optim,
lr_scheduler=lr_scheduler,
max_epochs=args.max_epochs,
accumulation_steps=args.accumulation_steps,
)
trainer.fit(train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
logger=logger,
use_wandb=args.use_wandb)
trainer.fit(
train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, logger=logger, use_wandb=args.use_wandb
)
# save model checkpoint after fitting on only rank0
strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(trainer.optimizer,
'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()),
only_rank0=False)
strategy.save_optimizer(
trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--strategy',
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'],
default='colossalai_zero2')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama', 'chatglm'], default='bloom')
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--dataset', type=str, default=None)
parser.add_argument('--max_datasets_size', type=int, default=None)
parser.add_argument('--save_path', type=str, default='output')
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
parser.add_argument('--max_epochs', type=int, default=3)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--max_len', type=int, default=512)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log")
parser.add_argument('--lr', type=float, default=5e-6)
parser.add_argument('--accumulation_steps', type=int, default=8)
parser.add_argument('--use_wandb', default=False, action='store_true')
parser.add_argument('--grad_checkpoint', default=False, action='store_true')
parser.add_argument(
"--strategy",
choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_zero2_cpu"],
default="colossalai_zero2",
)
parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama", "chatglm"], default="bloom")
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--dataset", type=str, default=None)
parser.add_argument("--max_datasets_size", type=int, default=None)
parser.add_argument("--save_path", type=str, default="output")
parser.add_argument("--need_optim_ckpt", type=bool, default=False)
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--max_len", type=int, default=512)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--log_interval", type=int, default=100, help="how many steps to log")
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
args = parser.parse_args()
train(args)