mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -3,7 +3,6 @@ import json
|
||||
from typing import Dict, Sequence
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
@@ -20,7 +19,8 @@ def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer, max_length: i
|
||||
padding="longest",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
) for text in strings
|
||||
)
|
||||
for text in strings
|
||||
]
|
||||
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
||||
input_ids_lens = labels_lens = [
|
||||
@@ -48,18 +48,17 @@ def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: AutoTo
|
||||
|
||||
|
||||
class EasySupervisedDataset(Dataset):
|
||||
|
||||
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 512) -> None:
|
||||
super(EasySupervisedDataset, self).__init__()
|
||||
with open(data_file, "r", encoding="UTF-8") as f:
|
||||
all_lines = f.readlines()
|
||||
#split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
|
||||
# split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
|
||||
sources, targets = [], []
|
||||
for line in all_lines:
|
||||
if "回答:" in line:
|
||||
sep_index = line.index("回答:")
|
||||
sources.append(line[:sep_index + 3])
|
||||
targets.append(line[sep_index + 3:] + tokenizer.eos_token)
|
||||
sources.append(line[: sep_index + 3])
|
||||
targets.append(line[sep_index + 3 :] + tokenizer.eos_token)
|
||||
else:
|
||||
sources.append(line)
|
||||
targets.append("" + tokenizer.eos_token)
|
||||
@@ -83,15 +82,17 @@ class EasySupervisedDataset(Dataset):
|
||||
|
||||
|
||||
class EasyPromptsDataset(Dataset):
|
||||
|
||||
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 96) -> None:
|
||||
super(EasyPromptsDataset, self).__init__()
|
||||
with open(data_file, "r", encoding="UTF-8") as f:
|
||||
all_lines = f.readlines()
|
||||
all_lines = [line if "回答:" not in line else line[:line.index("回答:") + 3] for line in all_lines]
|
||||
all_lines = [line if "回答:" not in line else line[: line.index("回答:") + 3] for line in all_lines]
|
||||
self.prompts = [
|
||||
tokenizer(line, return_tensors='pt', max_length=max_length, padding='max_length',
|
||||
truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0)
|
||||
tokenizer(line, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True)[
|
||||
"input_ids"
|
||||
]
|
||||
.to(torch.cuda.current_device())
|
||||
.squeeze(0)
|
||||
for line in tqdm(all_lines)
|
||||
]
|
||||
self.data_file = data_file
|
||||
@@ -110,7 +111,6 @@ class EasyPromptsDataset(Dataset):
|
||||
|
||||
|
||||
class EasyRewardDataset(Dataset):
|
||||
|
||||
def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None, max_length=512) -> None:
|
||||
super(EasyRewardDataset, self).__init__()
|
||||
self.chosen = []
|
||||
@@ -120,44 +120,42 @@ class EasyRewardDataset(Dataset):
|
||||
else:
|
||||
self.end_token = special_token
|
||||
print(self.end_token)
|
||||
#read all lines in the train_file to a list
|
||||
# read all lines in the train_file to a list
|
||||
with open(train_file, "r", encoding="UTF-8") as f:
|
||||
all_lines = f.readlines()
|
||||
for line in tqdm(all_lines):
|
||||
data = json.loads(line)
|
||||
prompt = "提问:" + data['prompt'] + " 回答:"
|
||||
prompt = "提问:" + data["prompt"] + " 回答:"
|
||||
|
||||
chosen = prompt + data['chosen'] + self.end_token
|
||||
chosen_token = tokenizer(chosen,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.chosen.append({
|
||||
"input_ids": chosen_token['input_ids'],
|
||||
"attention_mask": chosen_token['attention_mask']
|
||||
})
|
||||
chosen = prompt + data["chosen"] + self.end_token
|
||||
chosen_token = tokenizer(
|
||||
chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
self.chosen.append(
|
||||
{"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
|
||||
)
|
||||
|
||||
reject = prompt + data['rejected'] + self.end_token
|
||||
reject_token = tokenizer(reject,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.reject.append({
|
||||
"input_ids": reject_token['input_ids'],
|
||||
"attention_mask": reject_token['attention_mask']
|
||||
})
|
||||
reject = prompt + data["rejected"] + self.end_token
|
||||
reject_token = tokenizer(
|
||||
reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
self.reject.append(
|
||||
{"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
length = len(self.chosen)
|
||||
return length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
|
||||
"input_ids"], self.reject[idx]["attention_mask"]
|
||||
return (
|
||||
self.chosen[idx]["input_ids"],
|
||||
self.chosen[idx]["attention_mask"],
|
||||
self.reject[idx]["input_ids"],
|
||||
self.reject[idx]["attention_mask"],
|
||||
)
|
||||
|
||||
#python representation of the object and the string representation of the object
|
||||
# python representation of the object and the string representation of the object
|
||||
def __repr__(self):
|
||||
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
|
||||
|
||||
@@ -165,26 +163,25 @@ class EasyRewardDataset(Dataset):
|
||||
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
|
||||
|
||||
|
||||
'''
|
||||
"""
|
||||
Easy SFT just accept a text file which can be read line by line. However the datasets will group texts together to max_length so LLM will learn the texts meaning better.
|
||||
If individual lines are not related, just set is_group_texts to False.
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
class EasySFTDataset(Dataset):
|
||||
|
||||
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_group_texts=True) -> None:
|
||||
super().__init__()
|
||||
#read the data_file line by line
|
||||
# read the data_file line by line
|
||||
with open(data_file, "r", encoding="UTF-8") as f:
|
||||
#encode the text data line by line and put raw python list input_ids only to raw_input_ids list
|
||||
# encode the text data line by line and put raw python list input_ids only to raw_input_ids list
|
||||
raw_input_ids = []
|
||||
for line in f:
|
||||
encoded_ids = tokenizer.encode(line)
|
||||
#if the encoded_ids is longer than max_length, then split it into several parts
|
||||
# if the encoded_ids is longer than max_length, then split it into several parts
|
||||
if len(encoded_ids) > max_length:
|
||||
for i in range(0, len(encoded_ids), max_length):
|
||||
raw_input_ids.append(encoded_ids[i:i + max_length])
|
||||
raw_input_ids.append(encoded_ids[i : i + max_length])
|
||||
else:
|
||||
raw_input_ids.append(encoded_ids)
|
||||
|
||||
@@ -196,12 +193,13 @@ class EasySFTDataset(Dataset):
|
||||
if is_group_texts:
|
||||
for input_ids in raw_input_ids:
|
||||
if len(current_input_ids) + len(input_ids) > max_length:
|
||||
#pad the current_input_ids to max_length with tokenizer.pad_token_id
|
||||
# pad the current_input_ids to max_length with tokenizer.pad_token_id
|
||||
padded_length = max_length - len(current_input_ids)
|
||||
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
||||
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
|
||||
attention_mask.append(
|
||||
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
|
||||
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
|
||||
)
|
||||
current_input_ids = []
|
||||
else:
|
||||
current_input_ids.extend(input_ids)
|
||||
@@ -210,14 +208,16 @@ class EasySFTDataset(Dataset):
|
||||
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
||||
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
|
||||
attention_mask.append(
|
||||
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
|
||||
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
|
||||
)
|
||||
else:
|
||||
#just append the raw_input_ids to max_length
|
||||
# just append the raw_input_ids to max_length
|
||||
for input_ids in raw_input_ids:
|
||||
padded_length = max_length - len(input_ids)
|
||||
input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
||||
attention_mask.append(
|
||||
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
|
||||
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
|
||||
)
|
||||
grouped_input_ids.append(torch.tensor(input_ids, dtype=torch.long))
|
||||
self.input_ids = grouped_input_ids
|
||||
self.labels = copy.deepcopy(self.input_ids)
|
||||
@@ -227,14 +227,14 @@ class EasySFTDataset(Dataset):
|
||||
def __len__(self):
|
||||
return len(self.input_ids)
|
||||
|
||||
#get item from dataset
|
||||
# get item from dataset
|
||||
def __getitem__(self, idx):
|
||||
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
|
||||
|
||||
#generate the dataset description to be printed by print in python
|
||||
# generate the dataset description to be printed by print in python
|
||||
def __repr__(self):
|
||||
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
|
||||
|
||||
#generate the dataset description to be printed by print in python
|
||||
# generate the dataset description to be printed by print in python
|
||||
def __str__(self):
|
||||
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
|
||||
|
||||
@@ -4,7 +4,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from coati.models.generation import generate
|
||||
from coati.models.utils import log_probs_from_logits, masked_mean
|
||||
from coati.models.utils import log_probs_from_logits
|
||||
from peft import PeftModel
|
||||
from torch.nn.modules import Module
|
||||
from transformers import BloomConfig, BloomForCausalLM
|
||||
@@ -24,38 +24,33 @@ class Actor(Module):
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
return_action_mask: bool = True,
|
||||
**kwargs
|
||||
self, input_ids: torch.Tensor, return_action_mask: bool = True, **kwargs
|
||||
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
|
||||
sequences = generate(self.model, input_ids, **kwargs)
|
||||
attention_mask = None
|
||||
pad_token_id = kwargs.get('pad_token_id', None)
|
||||
pad_token_id = kwargs.get("pad_token_id", None)
|
||||
if pad_token_id is not None:
|
||||
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
||||
if not return_action_mask:
|
||||
return sequences, attention_mask, None
|
||||
input_len = input_ids.size(1)
|
||||
eos_token_id = kwargs.get('eos_token_id', None)
|
||||
eos_token_id = kwargs.get("eos_token_id", None)
|
||||
if eos_token_id is None:
|
||||
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
||||
else:
|
||||
# left padding may be applied, only mask action
|
||||
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
||||
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
||||
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
||||
action_mask[:, :input_len] = False
|
||||
action_mask = action_mask[:, 1:]
|
||||
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
|
||||
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len) :]
|
||||
|
||||
def forward(self,
|
||||
sequences: torch.LongTensor,
|
||||
num_actions: int,
|
||||
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Returns action log probs
|
||||
"""
|
||||
def forward(
|
||||
self, sequences: torch.LongTensor, num_actions: int, attention_mask: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""Returns action log probs"""
|
||||
output = self.model(sequences, attention_mask=attention_mask)
|
||||
logits = output['logits']
|
||||
logits = output["logits"]
|
||||
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||
return log_probs[:, -num_actions:]
|
||||
|
||||
@@ -75,11 +70,13 @@ class BLOOMActor(Actor):
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_path: str = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_path: str = None,
|
||||
) -> None:
|
||||
if pretrained is not None:
|
||||
model = BloomForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
import argparse
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset
|
||||
from coati.dataset import DataCollatorForSupervisedDataset
|
||||
from coati.models.bloom import BLOOMRM, BLOOMCritic
|
||||
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
|
||||
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
|
||||
from coati.models.opt import OPTRM, OPTActor, OPTCritic
|
||||
from coati.models.gpt import GPTRM, GPTCritic
|
||||
from coati.models.llama import LlamaCritic, LlamaRM
|
||||
from coati.models.opt import OPTRM, OPTCritic
|
||||
from coati.trainer import PPOTrainer
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||||
from easy_dataset import EasyPromptsDataset, EasySupervisedDataset
|
||||
from easy_models import BLOOMActor
|
||||
from peft import PeftModel
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
@@ -23,24 +21,24 @@ from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
def main(args):
|
||||
# configure strategy
|
||||
if args.strategy == 'ddp':
|
||||
if args.strategy == "ddp":
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = GeminiStrategy(placement_policy='cpu', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
|
||||
elif args.strategy == "colossalai_gemini":
|
||||
strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
|
||||
elif args.strategy == "colossalai_zero2":
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
if args.rm_path is not None:
|
||||
state_dict = torch.load(args.rm_path, map_location='cpu')
|
||||
state_dict = torch.load(args.rm_path, map_location="cpu")
|
||||
|
||||
# configure model
|
||||
if args.model == 'bloom':
|
||||
if args.model == "bloom":
|
||||
# initial_model = BLOOMActor(pretrained=args.pretrain)
|
||||
print('Using peft lora to load Bloom model as initial_model')
|
||||
print("Using peft lora to load Bloom model as initial_model")
|
||||
initial_model = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
|
||||
print('Using peft lora to load Bloom model as initial_model (Done)')
|
||||
print("Using peft lora to load Bloom model as initial_model (Done)")
|
||||
else:
|
||||
raise ValueError(f'Unsupported actor model "{args.model}"')
|
||||
|
||||
@@ -49,59 +47,59 @@ def main(args):
|
||||
else:
|
||||
rm_model_name = args.rm_model
|
||||
|
||||
if rm_model_name == 'gpt2':
|
||||
if rm_model_name == "gpt2":
|
||||
reward_model = GPTRM(pretrained=args.rm_pretrain)
|
||||
elif rm_model_name == 'bloom':
|
||||
elif rm_model_name == "bloom":
|
||||
print("load bloom reward model ", args.rm_pretrain)
|
||||
reward_model = BLOOMRM(pretrained=args.rm_pretrain)
|
||||
elif rm_model_name == 'opt':
|
||||
elif rm_model_name == "opt":
|
||||
reward_model = OPTRM(pretrained=args.rm_pretrain)
|
||||
elif rm_model_name == 'llama':
|
||||
elif rm_model_name == "llama":
|
||||
reward_model = LlamaRM(pretrained=args.rm_pretrain)
|
||||
else:
|
||||
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
|
||||
|
||||
if args.rm_path is not None:
|
||||
print('Loading reward model from', args.rm_path)
|
||||
print("Loading reward model from", args.rm_path)
|
||||
reward_model.load_state_dict(state_dict)
|
||||
|
||||
if args.strategy != 'colossalai_gemini':
|
||||
if args.strategy != "colossalai_gemini":
|
||||
initial_model.to(torch.float16).to(torch.cuda.current_device())
|
||||
reward_model.to(torch.float16).to(torch.cuda.current_device())
|
||||
|
||||
with strategy.model_init_context():
|
||||
if args.model == 'bloom':
|
||||
if args.model == "bloom":
|
||||
# actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||
print('Using peft lora to load Bloom model as Actor')
|
||||
print("Using peft lora to load Bloom model as Actor")
|
||||
actor = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
|
||||
print('Using peft lora to load Bloom model as Actor (Done)')
|
||||
print("Using peft lora to load Bloom model as Actor (Done)")
|
||||
else:
|
||||
raise ValueError(f'Unsupported actor model "{args.model}"')
|
||||
|
||||
if rm_model_name == 'gpt2':
|
||||
if rm_model_name == "gpt2":
|
||||
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
elif rm_model_name == 'bloom':
|
||||
elif rm_model_name == "bloom":
|
||||
print("load bloom critic ", args.rm_pretrain, " lora_rank ", args.lora_rank, " use_action_mask ", True)
|
||||
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
print("load bloom critic (Done) ")
|
||||
elif rm_model_name == 'opt':
|
||||
elif rm_model_name == "opt":
|
||||
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
elif rm_model_name == 'llama':
|
||||
elif rm_model_name == "llama":
|
||||
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
else:
|
||||
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
|
||||
|
||||
if args.rm_path is not None:
|
||||
print('Loading reward model from', args.rm_path)
|
||||
print("Loading reward model from", args.rm_path)
|
||||
critic.load_state_dict(state_dict)
|
||||
del state_dict
|
||||
|
||||
if args.strategy != 'colossalai_gemini':
|
||||
if args.strategy != "colossalai_gemini":
|
||||
critic.to(torch.float16).to(torch.cuda.current_device())
|
||||
actor.to(torch.float16).to(torch.cuda.current_device())
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith('colossalai'):
|
||||
if args.strategy.startswith("colossalai"):
|
||||
actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
|
||||
critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
|
||||
else:
|
||||
@@ -109,18 +107,18 @@ def main(args):
|
||||
critic_optim = Adam(critic.parameters(), lr=1e-7)
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
if args.model == "gpt2":
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(args.rm_pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
elif args.model == "bloom":
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.rm_pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
elif args.model == "opt":
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.rm_pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'llama':
|
||||
elif args.model == "llama":
|
||||
tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
|
||||
tokenizer.eos_token = '<\s>'
|
||||
tokenizer.eos_token = "<\s>"
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
@@ -132,26 +130,27 @@ def main(args):
|
||||
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
|
||||
else:
|
||||
prompt_sampler = None
|
||||
prompt_dataloader = DataLoader(prompt_dataset,
|
||||
shuffle=(prompt_sampler is None),
|
||||
sampler=prompt_sampler,
|
||||
batch_size=args.train_batch_size)
|
||||
prompt_dataloader = DataLoader(
|
||||
prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.train_batch_size
|
||||
)
|
||||
|
||||
pretrain_dataset = EasySupervisedDataset(args.pretrain_dataset, tokenizer)
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
|
||||
else:
|
||||
pretrain_sampler = None
|
||||
pretrain_dataloader = DataLoader(pretrain_dataset,
|
||||
shuffle=(pretrain_sampler is None),
|
||||
sampler=pretrain_sampler,
|
||||
batch_size=args.ptx_batch_size,
|
||||
collate_fn=data_collator)
|
||||
pretrain_dataloader = DataLoader(
|
||||
pretrain_dataset,
|
||||
shuffle=(pretrain_sampler is None),
|
||||
sampler=pretrain_sampler,
|
||||
batch_size=args.ptx_batch_size,
|
||||
collate_fn=data_collator,
|
||||
)
|
||||
|
||||
def tokenize_fn(texts):
|
||||
# MUST padding to max length to ensure inputs of all ranks have the same length
|
||||
# Different length may lead to hang when using gemini, as different generation steps
|
||||
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
|
||||
batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
|
||||
return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()}
|
||||
|
||||
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
|
||||
@@ -178,45 +177,46 @@ def main(args):
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
trainer.fit(prompt_dataloader=prompt_dataloader,
|
||||
pretrain_dataloader=pretrain_dataloader,
|
||||
num_episodes=args.num_episodes,
|
||||
num_update_steps=args.num_update_steps,
|
||||
num_collect_steps=args.num_collect_steps)
|
||||
trainer.fit(
|
||||
prompt_dataloader=prompt_dataloader,
|
||||
pretrain_dataloader=pretrain_dataloader,
|
||||
num_episodes=args.num_episodes,
|
||||
num_update_steps=args.num_update_steps,
|
||||
num_collect_steps=args.num_collect_steps,
|
||||
)
|
||||
|
||||
# save model checkpoint after fitting
|
||||
trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if args.need_optim_ckpt:
|
||||
strategy.save_optimizer(actor_optim,
|
||||
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
strategy.save_optimizer(
|
||||
actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--prompt_path', type=str, default=None, help='path to the prompt dataset')
|
||||
parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset')
|
||||
parser.add_argument('--strategy',
|
||||
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='ddp',
|
||||
help='strategy to use')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--sft_lora_path', type=str, default=None)
|
||||
parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--rm_path', type=str, default=None)
|
||||
parser.add_argument('--rm_pretrain', type=str, default=None)
|
||||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--num_episodes', type=int, default=10)
|
||||
parser.add_argument('--num_collect_steps', type=int, default=10)
|
||||
parser.add_argument('--num_update_steps', type=int, default=5)
|
||||
parser.add_argument('--train_batch_size', type=int, default=2)
|
||||
parser.add_argument('--ptx_batch_size', type=int, default=1)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument('--kl_coef', type=float, default=0.1)
|
||||
parser.add_argument('--ptx_coef', type=float, default=0.9)
|
||||
parser.add_argument("--prompt_path", type=str, default=None, help="path to the prompt dataset")
|
||||
parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
|
||||
parser.add_argument(
|
||||
"--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp", help="strategy to use"
|
||||
)
|
||||
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--sft_lora_path", type=str, default=None)
|
||||
parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"])
|
||||
parser.add_argument("--rm_path", type=str, default=None)
|
||||
parser.add_argument("--rm_pretrain", type=str, default=None)
|
||||
parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
|
||||
parser.add_argument("--need_optim_ckpt", type=bool, default=False)
|
||||
parser.add_argument("--num_episodes", type=int, default=10)
|
||||
parser.add_argument("--num_collect_steps", type=int, default=10)
|
||||
parser.add_argument("--num_update_steps", type=int, default=5)
|
||||
parser.add_argument("--train_batch_size", type=int, default=2)
|
||||
parser.add_argument("--ptx_batch_size", type=int, default=1)
|
||||
parser.add_argument("--experience_batch_size", type=int, default=8)
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument("--kl_coef", type=float, default=0.1)
|
||||
parser.add_argument("--ptx_coef", type=float, default=0.9)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
@@ -1,18 +1,10 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import loralib as lora
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
|
||||
from coati.models.base import RewardModel
|
||||
from coati.models.bloom import BLOOMLM
|
||||
from coati.models.gpt import GPTLM
|
||||
from coati.models.llama import LlamaLM
|
||||
from coati.models.opt import OPTLM
|
||||
from coati.trainer import SFTTrainer
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||||
from datasets import load_dataset
|
||||
from easy_dataset import EasyDataset
|
||||
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
|
||||
from torch.optim import Adam
|
||||
@@ -29,75 +21,76 @@ from colossalai.tensor import ColoParameter
|
||||
|
||||
def train(args):
|
||||
# configure strategy
|
||||
if args.strategy == 'ddp':
|
||||
if args.strategy == "ddp":
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = GeminiStrategy(placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
|
||||
elif args.strategy == "colossalai_gemini":
|
||||
strategy = GeminiStrategy(placement_policy="cuda")
|
||||
elif args.strategy == "colossalai_zero2":
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
# configure model
|
||||
with strategy.model_init_context():
|
||||
print('Warning: currently only bloom is tested, gpt2,llama and opt are not tested')
|
||||
print("Warning: currently only bloom is tested, gpt2,llama and opt are not tested")
|
||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain).to(torch.cuda.current_device())
|
||||
# if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json
|
||||
if os.path.exists(args.save_path) and os.path.exists(args.save_path + '/adapter_config.json') \
|
||||
and os.path.exists(args.save_path + '/adapter_model.bin'):
|
||||
if (
|
||||
os.path.exists(args.save_path)
|
||||
and os.path.exists(args.save_path + "/adapter_config.json")
|
||||
and os.path.exists(args.save_path + "/adapter_model.bin")
|
||||
):
|
||||
print("loading from saved peft model ", args.save_path)
|
||||
model = PeftModel.from_pretrained(model, args.save_path)
|
||||
else:
|
||||
# we'll use peft lora library to do the lora
|
||||
lora_rank = args.lora_rank if args.lora_rank > 0 else 32
|
||||
# config lora with rank of lora_rank
|
||||
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
r=lora_rank,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.1)
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM, inference_mode=False, r=lora_rank, lora_alpha=32, lora_dropout=0.1
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
model.print_trainable_parameters()
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
if args.model == "gpt2":
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
elif args.model == "bloom":
|
||||
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
elif args.model == "opt":
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'llama':
|
||||
elif args.model == "llama":
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.pretrain,
|
||||
padding_side="right",
|
||||
use_fast=False,
|
||||
)
|
||||
tokenizer.eos_token = '<\s>'
|
||||
tokenizer.eos_token = "<\s>"
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
if args.model == 'llama' and args.strategy == 'colossalai_gemini':
|
||||
if args.model == "llama" and args.strategy == "colossalai_gemini":
|
||||
# this is a hack to deal with the resized embedding
|
||||
# to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility
|
||||
for name, param in model.named_parameters():
|
||||
if not isinstance(param, ColoParameter):
|
||||
sub_module_name = '.'.join(name.split('.')[:-1])
|
||||
weight_name = name.split('.')[-1]
|
||||
sub_module_name = ".".join(name.split(".")[:-1])
|
||||
weight_name = name.split(".")[-1]
|
||||
sub_module = model.get_submodule(sub_module_name)
|
||||
setattr(sub_module, weight_name, ColoParameter(param))
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith('colossalai'):
|
||||
if args.strategy.startswith("colossalai"):
|
||||
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
|
||||
else:
|
||||
optim = Adam(model.parameters(), lr=args.lr)
|
||||
|
||||
logger = get_dist_logger()
|
||||
logger.set_level('WARNING')
|
||||
logger.set_level("WARNING")
|
||||
|
||||
# configure dataset
|
||||
law_dataset = EasyDataset(args.dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
|
||||
@@ -108,47 +101,57 @@ def train(args):
|
||||
eval_dataset = EasyDataset(args.eval_dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
|
||||
data_collator = default_collate
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
train_sampler = DistributedSampler(train_dataset,
|
||||
shuffle=True,
|
||||
seed=42,
|
||||
drop_last=True,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size())
|
||||
train_sampler = DistributedSampler(
|
||||
train_dataset,
|
||||
shuffle=True,
|
||||
seed=42,
|
||||
drop_last=True,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size(),
|
||||
)
|
||||
if eval_dataset is not None:
|
||||
eval_sampler = DistributedSampler(eval_dataset,
|
||||
shuffle=False,
|
||||
seed=42,
|
||||
drop_last=False,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size())
|
||||
eval_sampler = DistributedSampler(
|
||||
eval_dataset,
|
||||
shuffle=False,
|
||||
seed=42,
|
||||
drop_last=False,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size(),
|
||||
)
|
||||
else:
|
||||
train_sampler = None
|
||||
eval_sampler = None
|
||||
|
||||
train_dataloader = DataLoader(train_dataset,
|
||||
shuffle=(train_sampler is None),
|
||||
sampler=train_sampler,
|
||||
batch_size=args.batch_size,
|
||||
collate_fn=data_collator,
|
||||
pin_memory=True)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
shuffle=(train_sampler is None),
|
||||
sampler=train_sampler,
|
||||
batch_size=args.batch_size,
|
||||
collate_fn=data_collator,
|
||||
pin_memory=True,
|
||||
)
|
||||
if eval_dataset is not None:
|
||||
eval_dataloader = DataLoader(eval_dataset,
|
||||
shuffle=(eval_sampler is None),
|
||||
sampler=eval_sampler,
|
||||
batch_size=args.batch_size,
|
||||
collate_fn=data_collator,
|
||||
pin_memory=True)
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset,
|
||||
shuffle=(eval_sampler is None),
|
||||
sampler=eval_sampler,
|
||||
batch_size=args.batch_size,
|
||||
collate_fn=data_collator,
|
||||
pin_memory=True,
|
||||
)
|
||||
else:
|
||||
eval_dataloader = None
|
||||
|
||||
trainer = SFTTrainer(model=model,
|
||||
strategy=strategy,
|
||||
optim=optim,
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
batch_size=args.batch_size,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps)
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
strategy=strategy,
|
||||
optim=optim,
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
batch_size=args.batch_size,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
)
|
||||
|
||||
trainer.fit(logger=logger, log_interval=args.log_interval)
|
||||
|
||||
@@ -156,29 +159,27 @@ def train(args):
|
||||
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if args.need_optim_ckpt:
|
||||
strategy.save_optimizer(trainer.optimizer,
|
||||
'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
strategy.save_optimizer(
|
||||
trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--strategy',
|
||||
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='ddp')
|
||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--dataset', type=str, default=None)
|
||||
parser.add_argument('--eval_dataset', type=str, default=None)
|
||||
parser.add_argument('--save_path', type=str, default='output')
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--max_epochs', type=int, default=3)
|
||||
parser.add_argument('--batch_size', type=int, default=4)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log")
|
||||
parser.add_argument('--lr', type=float, default=5e-6)
|
||||
parser.add_argument('--accumulation_steps', type=int, default=8)
|
||||
parser.add_argument('--enable_peft_lora', action='store_true', default=False)
|
||||
parser.add_argument("--is_short_text", action='store_true', default=False)
|
||||
parser.add_argument("--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp")
|
||||
parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom")
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--dataset", type=str, default=None)
|
||||
parser.add_argument("--eval_dataset", type=str, default=None)
|
||||
parser.add_argument("--save_path", type=str, default="output")
|
||||
parser.add_argument("--need_optim_ckpt", type=bool, default=False)
|
||||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument("--log_interval", type=int, default=100, help="how many steps to log")
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--enable_peft_lora", action="store_true", default=False)
|
||||
parser.add_argument("--is_short_text", action="store_true", default=False)
|
||||
args = parser.parse_args()
|
||||
train(args)
|
||||
|
||||
@@ -6,16 +6,25 @@ from ray.job_submission import JobSubmissionClient
|
||||
def main(api_server_endpoint="http://127.0.0.1:8265"):
|
||||
client = JobSubmissionClient(api_server_endpoint)
|
||||
client.submit_job(
|
||||
entrypoint=
|
||||
"python experimental/ray/train_prompts_on_ray.py --strategy colossalai_zero2 --prompt_csv_url https://huggingface.co/datasets/fka/awesome-chatgpt-prompts/resolve/main/prompts.csv",
|
||||
entrypoint="python experimental/ray/train_prompts_on_ray.py --strategy colossalai_zero2 --prompt_csv_url https://huggingface.co/datasets/fka/awesome-chatgpt-prompts/resolve/main/prompts.csv",
|
||||
runtime_env={
|
||||
"working_dir":
|
||||
"applications/Chat",
|
||||
"working_dir": "applications/Chat",
|
||||
"pip": [
|
||||
"torch==1.13.1", "transformers>=4.20.1", "datasets", "loralib", "colossalai>=0.2.4", "langchain",
|
||||
"tokenizers", "fastapi", "sse_starlette", "wandb", "sentencepiece", "gpustat"
|
||||
]
|
||||
})
|
||||
"torch==1.13.1",
|
||||
"transformers>=4.20.1",
|
||||
"datasets",
|
||||
"loralib",
|
||||
"colossalai>=0.2.4",
|
||||
"langchain",
|
||||
"tokenizers",
|
||||
"fastapi",
|
||||
"sse_starlette",
|
||||
"wandb",
|
||||
"sentencepiece",
|
||||
"gpustat",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -26,9 +26,14 @@ from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
|
||||
class ExperienceCompositionRefs:
|
||||
|
||||
def __init__(self, sequences_attention_mask_action_mask_ref: ray.ObjectRef, action_log_probs_ref: ray.ObjectRef,
|
||||
base_action_log_probs_ref: ray.ObjectRef, value_ref: ray.ObjectRef, r_ref: ray.ObjectRef) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
sequences_attention_mask_action_mask_ref: ray.ObjectRef,
|
||||
action_log_probs_ref: ray.ObjectRef,
|
||||
base_action_log_probs_ref: ray.ObjectRef,
|
||||
value_ref: ray.ObjectRef,
|
||||
r_ref: ray.ObjectRef,
|
||||
) -> None:
|
||||
self.sequences_attention_mask_action_mask_ref = sequences_attention_mask_action_mask_ref
|
||||
self.action_log_probs_ref = action_log_probs_ref
|
||||
self.base_action_log_probs_ref = base_action_log_probs_ref
|
||||
@@ -37,14 +42,14 @@ class ExperienceCompositionRefs:
|
||||
|
||||
|
||||
class ExperienceMaker:
|
||||
|
||||
def __init__(self, kl_coef) -> None:
|
||||
self.kl_coef = kl_coef
|
||||
|
||||
@torch.no_grad()
|
||||
def make_experience(self, experiment_computation_refs: ExperienceCompositionRefs):
|
||||
sequences, attention_mask, action_mask = ray.get(
|
||||
experiment_computation_refs.sequences_attention_mask_action_mask_ref)
|
||||
experiment_computation_refs.sequences_attention_mask_action_mask_ref
|
||||
)
|
||||
action_log_probs = ray.get(experiment_computation_refs.action_log_probs_ref)
|
||||
base_action_log_probs = ray.get(experiment_computation_refs.base_action_log_probs_ref)
|
||||
r = ray.get(experiment_computation_refs.r_ref)
|
||||
@@ -58,11 +63,10 @@ class ExperienceMaker:
|
||||
|
||||
|
||||
class DistributedTorchRayActor:
|
||||
|
||||
def __init__(self, world_size, rank, local_rank, master_addr, master_port):
|
||||
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
|
||||
level=logging.INFO,
|
||||
datefmt='%Y-%m-%d %H:%M:%S')
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
self._model = None
|
||||
self._world_size = world_size
|
||||
self._rank = rank
|
||||
@@ -82,7 +86,7 @@ class DistributedTorchRayActor:
|
||||
@staticmethod
|
||||
def _get_free_port():
|
||||
with socket.socket() as sock:
|
||||
sock.bind(('', 0))
|
||||
sock.bind(("", 0))
|
||||
return sock.getsockname()[1]
|
||||
|
||||
def get_master_addr_port(self):
|
||||
@@ -90,7 +94,6 @@ class DistributedTorchRayActor:
|
||||
|
||||
|
||||
class BasePPORole(DistributedTorchRayActor):
|
||||
|
||||
def add_experience_maker(self, kl_coef: float = 0.1):
|
||||
self._experience_maker = ExperienceMaker(kl_coef)
|
||||
|
||||
@@ -99,12 +102,12 @@ class BasePPORole(DistributedTorchRayActor):
|
||||
|
||||
def _init_strategy(self, strategy: str):
|
||||
# configure strategy
|
||||
if strategy == 'ddp':
|
||||
if strategy == "ddp":
|
||||
self._strategy = DDPStrategy()
|
||||
elif strategy == 'colossalai_gemini':
|
||||
self._strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
|
||||
elif strategy == 'colossalai_zero2':
|
||||
self._strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
|
||||
elif strategy == "colossalai_gemini":
|
||||
self._strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
|
||||
elif strategy == "colossalai_zero2":
|
||||
self._strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{strategy}"')
|
||||
|
||||
@@ -124,11 +127,9 @@ class BasePPORole(DistributedTorchRayActor):
|
||||
def _load_model_from_pretrained(self, model_class: Type[LoRAModule], pretrain: str):
|
||||
raise NotImplementedError()
|
||||
|
||||
def init_model_from_pretrained(self,
|
||||
strategy: str,
|
||||
model_class: Type[LoRAModule],
|
||||
pretrain: str,
|
||||
has_optimizer=False):
|
||||
def init_model_from_pretrained(
|
||||
self, strategy: str, model_class: Type[LoRAModule], pretrain: str, has_optimizer=False
|
||||
):
|
||||
self._init_strategy(strategy)
|
||||
self._load_model_from_pretrained(model_class, pretrain)
|
||||
self._prepare_model_with_strategy(has_optimizer)
|
||||
@@ -138,7 +139,6 @@ class BasePPORole(DistributedTorchRayActor):
|
||||
|
||||
|
||||
class TrainablePPORole(BasePPORole):
|
||||
|
||||
def _load_model_from_pretrained(self, model_class, pretrain):
|
||||
with self._strategy.model_init_context():
|
||||
self._model = model_class(pretrain).to(torch.cuda.current_device())
|
||||
@@ -161,38 +161,39 @@ class TrainablePPORole(BasePPORole):
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class RayPPOActor(TrainablePPORole):
|
||||
|
||||
def set_loss_function(self, eps_clip: float):
|
||||
self._actor_loss_fn = PolicyLoss(eps_clip)
|
||||
|
||||
def load_tokenizer_from_pretrained(self, model_type: str, pretrained):
|
||||
if model_type == 'gpt2':
|
||||
if model_type == "gpt2":
|
||||
self._model_tokenizer = GPT2Tokenizer.from_pretrained(pretrained)
|
||||
self._model_tokenizer.pad_token = self._model_tokenizer.eos_token
|
||||
elif model_type == 'bloom':
|
||||
elif model_type == "bloom":
|
||||
self._model_tokenizer = BloomTokenizerFast.from_pretrained(pretrained)
|
||||
self._model_tokenizer.pad_token = self._model_tokenizer.eos_token
|
||||
elif model_type == 'opt':
|
||||
elif model_type == "opt":
|
||||
self._model_tokenizer = AutoTokenizer.from_pretrained(pretrained)
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{model_type}"')
|
||||
|
||||
# Set tokenize function for sequence generation
|
||||
def _text_input_tokenize_fn(texts):
|
||||
batch = self._model_tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True)
|
||||
batch = self._model_tokenizer(texts, return_tensors="pt", max_length=96, padding=True, truncation=True)
|
||||
return {k: v.cuda() for k, v in batch.items()}
|
||||
|
||||
self._sample_tokenize_function = _text_input_tokenize_fn
|
||||
|
||||
def setup_generate_kwargs(self, generate_kwargs: dict):
|
||||
from coati.trainer.ppo import _set_default_generate_kwargs
|
||||
|
||||
self._generate_kwargs = _set_default_generate_kwargs(self._strategy, generate_kwargs, self._model)
|
||||
self._generate_kwargs['pad_token_id'] = self._model_tokenizer.pad_token_id
|
||||
self._generate_kwargs['eos_token_id'] = self._model_tokenizer.eos_token_id
|
||||
self._generate_kwargs["pad_token_id"] = self._model_tokenizer.pad_token_id
|
||||
self._generate_kwargs["eos_token_id"] = self._model_tokenizer.eos_token_id
|
||||
|
||||
def load_csv_prompt_file_from_url_to_sampler(self, prompt_url):
|
||||
import pandas as pd
|
||||
prompts = pd.read_csv(prompt_url)['prompt']
|
||||
|
||||
prompts = pd.read_csv(prompt_url)["prompt"]
|
||||
self._sampler = self._strategy.setup_sampler(prompts)
|
||||
|
||||
def _generate(self, input_ids, **generate_kwargs):
|
||||
@@ -214,10 +215,9 @@ class RayPPOActor(TrainablePPORole):
|
||||
def _training_step(self, experience):
|
||||
num_actions = experience.action_mask.size(1)
|
||||
action_log_probs = self._model(experience.sequences, num_actions, attention_mask=experience.attention_mask)
|
||||
actor_loss = self._actor_loss_fn(action_log_probs,
|
||||
experience.action_log_probs,
|
||||
experience.advantages,
|
||||
action_mask=experience.action_mask)
|
||||
actor_loss = self._actor_loss_fn(
|
||||
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
|
||||
)
|
||||
self._strategy.backward(actor_loss, self._model, self._optimizer)
|
||||
self._strategy.optimizer_step(self._optimizer)
|
||||
self._optimizer.zero_grad()
|
||||
@@ -229,17 +229,18 @@ class RayPPOActor(TrainablePPORole):
|
||||
self._strategy.save_model(self._model, save_path, only_rank0=True)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if should_save_optimizer:
|
||||
self._strategy.save_optimizer(self._optimizer,
|
||||
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
self._strategy.save_optimizer(
|
||||
self._optimizer,
|
||||
"actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()),
|
||||
only_rank0=False,
|
||||
)
|
||||
|
||||
def generate_answer(self, prompt, max_length=30, num_return_sequences=5):
|
||||
encoded_input = self._model_tokenizer(prompt, return_tensors='pt')
|
||||
encoded_input = self._model_tokenizer(prompt, return_tensors="pt")
|
||||
input_ids = {k: v.cuda() for k, v in encoded_input.items()}
|
||||
sequence, _ = self._model.generate(**input_ids,
|
||||
max_length=max_length,
|
||||
return_action_mask=False,
|
||||
num_return_sequences=num_return_sequences)
|
||||
sequence, _ = self._model.generate(
|
||||
**input_ids, max_length=max_length, return_action_mask=False, num_return_sequences=num_return_sequences
|
||||
)
|
||||
token_list = list(sequence.data[0])
|
||||
output = " ".join([self._model_tokenizer.decode(token) for token in token_list])
|
||||
return output
|
||||
@@ -247,18 +248,16 @@ class RayPPOActor(TrainablePPORole):
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class RayPPOCritic(TrainablePPORole):
|
||||
|
||||
def set_loss_function(self, value_clip: float):
|
||||
self._critic_loss_fn = ValueLoss(value_clip)
|
||||
|
||||
def _training_step(self, experience):
|
||||
values = self._model(experience.sequences,
|
||||
action_mask=experience.action_mask,
|
||||
attention_mask=experience.attention_mask)
|
||||
critic_loss = self._critic_loss_fn(values,
|
||||
experience.values,
|
||||
experience.reward,
|
||||
action_mask=experience.action_mask)
|
||||
values = self._model(
|
||||
experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
|
||||
)
|
||||
critic_loss = self._critic_loss_fn(
|
||||
values, experience.values, experience.reward, action_mask=experience.action_mask
|
||||
)
|
||||
self._strategy.backward(critic_loss, self._model, self._optimizer)
|
||||
self._strategy.optimizer_step(self._optimizer)
|
||||
self._optimizer.zero_grad()
|
||||
@@ -272,12 +271,12 @@ class RayPPOCritic(TrainablePPORole):
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class RayPPORewardModel(BasePPORole):
|
||||
|
||||
def _load_model_from_pretrained(self, model_class, pretrain):
|
||||
with self._strategy.model_init_context():
|
||||
critic = model_class(pretrained=pretrain).to(torch.cuda.current_device())
|
||||
self._model = RewardModel(deepcopy(critic.model),
|
||||
deepcopy(critic.value_head)).to(torch.cuda.current_device())
|
||||
self._model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(
|
||||
torch.cuda.current_device()
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def calculate_r(self, sequence_attention_action_mask):
|
||||
@@ -287,7 +286,6 @@ class RayPPORewardModel(BasePPORole):
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class RayPPOInitialModel(BasePPORole):
|
||||
|
||||
def _load_model_from_pretrained(self, model_class, pretrain):
|
||||
with self._strategy.model_init_context():
|
||||
self._model = model_class(pretrain).to(torch.cuda.current_device())
|
||||
@@ -300,8 +298,8 @@ class RayPPOInitialModel(BasePPORole):
|
||||
|
||||
class PPORayActorGroup:
|
||||
"""
|
||||
A group of ray actors
|
||||
Functions start with 'async' should return list of object refs
|
||||
A group of ray actors
|
||||
Functions start with 'async' should return list of object refs
|
||||
"""
|
||||
|
||||
def __init__(self, num_nodes, num_gpus_per_node, ray_actor_type: Type[BasePPORole]) -> None:
|
||||
@@ -319,8 +317,9 @@ class PPORayActorGroup:
|
||||
pg = placement_group(bundles, strategy="STRICT_SPREAD")
|
||||
ray.get(pg.ready())
|
||||
if pg:
|
||||
master_actor = self.ray_actor_type.options(scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg, placement_group_bundle_index=0)).remote(world_size, 0, 0, None, None)
|
||||
master_actor = self.ray_actor_type.options(
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(placement_group=pg, placement_group_bundle_index=0)
|
||||
).remote(world_size, 0, 0, None, None)
|
||||
else:
|
||||
master_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, 0, 0, None, None)
|
||||
self._actor_handlers = [master_actor]
|
||||
@@ -331,16 +330,20 @@ class PPORayActorGroup:
|
||||
for rank in range(1, world_size):
|
||||
local_rank = rank % self._num_gpus_per_node
|
||||
if pg:
|
||||
worker_actor = self.ray_actor_type.options(scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg, placement_group_bundle_index=rank // self._num_gpus_per_node)).remote(
|
||||
world_size, rank, local_rank, master_addr, master_port)
|
||||
worker_actor = self.ray_actor_type.options(
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg, placement_group_bundle_index=rank // self._num_gpus_per_node
|
||||
)
|
||||
).remote(world_size, rank, local_rank, master_addr, master_port)
|
||||
else:
|
||||
worker_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, rank, local_rank,
|
||||
master_addr, master_port)
|
||||
worker_actor = self.ray_actor_type.options(num_gpus=1).remote(
|
||||
world_size, rank, local_rank, master_addr, master_port
|
||||
)
|
||||
self._actor_handlers.append(worker_actor)
|
||||
|
||||
def async_init_model_from_pretrained(self, strategy: str, model_class: Type[LoRAModule], pretrain: str,
|
||||
has_optimizer: bool):
|
||||
def async_init_model_from_pretrained(
|
||||
self, strategy: str, model_class: Type[LoRAModule], pretrain: str, has_optimizer: bool
|
||||
):
|
||||
return [
|
||||
actor.init_model_from_pretrained.remote(strategy, model_class, pretrain, has_optimizer)
|
||||
for actor in self._actor_handlers
|
||||
@@ -348,7 +351,6 @@ class PPORayActorGroup:
|
||||
|
||||
|
||||
class TrainableModelRayActorGroup(PPORayActorGroup):
|
||||
|
||||
def async_learn_on_experiences(self, experience_refs):
|
||||
num_actors = len(self._actor_handlers)
|
||||
learn_result_refs = []
|
||||
@@ -359,7 +361,6 @@ class TrainableModelRayActorGroup(PPORayActorGroup):
|
||||
|
||||
|
||||
class PPOActorRayActorGroup(TrainableModelRayActorGroup):
|
||||
|
||||
def __init__(self, num_nodes, num_gpus_per_node) -> None:
|
||||
super().__init__(num_nodes, num_gpus_per_node, RayPPOActor)
|
||||
|
||||
@@ -381,7 +382,8 @@ class PPOActorRayActorGroup(TrainableModelRayActorGroup):
|
||||
action_log_probs_refs = []
|
||||
for i in range(len(sequences_attention_mask_action_mask_refs)):
|
||||
action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_action_log_probs.remote(
|
||||
sequences_attention_mask_action_mask_refs[i])
|
||||
sequences_attention_mask_action_mask_refs[i]
|
||||
)
|
||||
action_log_probs_refs.append(action_log_probs_ref)
|
||||
return action_log_probs_refs
|
||||
|
||||
@@ -393,7 +395,6 @@ class PPOActorRayActorGroup(TrainableModelRayActorGroup):
|
||||
|
||||
|
||||
class PPOCriticRayActorGroup(TrainableModelRayActorGroup):
|
||||
|
||||
def __init__(self, num_nodes, num_gpus_per_node) -> None:
|
||||
super().__init__(num_nodes, num_gpus_per_node, RayPPOCritic)
|
||||
|
||||
@@ -402,7 +403,8 @@ class PPOCriticRayActorGroup(TrainableModelRayActorGroup):
|
||||
value_refs = []
|
||||
for i in range(len(sequences_attention_mask_action_mask_refs)):
|
||||
value_ref = self._actor_handlers[i % num_actors].calculate_value.remote(
|
||||
sequences_attention_mask_action_mask_refs[i])
|
||||
sequences_attention_mask_action_mask_refs[i]
|
||||
)
|
||||
value_refs.append(value_ref)
|
||||
return value_refs
|
||||
|
||||
@@ -411,7 +413,6 @@ class PPOCriticRayActorGroup(TrainableModelRayActorGroup):
|
||||
|
||||
|
||||
class PPOInitialRayActorGroup(PPORayActorGroup):
|
||||
|
||||
def __init__(self, num_nodes, num_gpus_per_node) -> None:
|
||||
super().__init__(num_nodes, num_gpus_per_node, RayPPOInitialModel)
|
||||
|
||||
@@ -420,13 +421,13 @@ class PPOInitialRayActorGroup(PPORayActorGroup):
|
||||
base_action_log_probs_refs = []
|
||||
for i in range(len(sequences_attention_mask_action_mask_refs)):
|
||||
base_action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_base_action_log_probs.remote(
|
||||
sequences_attention_mask_action_mask_refs[i])
|
||||
sequences_attention_mask_action_mask_refs[i]
|
||||
)
|
||||
base_action_log_probs_refs.append(base_action_log_probs_ref)
|
||||
return base_action_log_probs_refs
|
||||
|
||||
|
||||
class PPORewardRayActorGroup(PPORayActorGroup):
|
||||
|
||||
def __init__(self, num_nodes, num_gpus_per_node) -> None:
|
||||
super().__init__(num_nodes, num_gpus_per_node, RayPPORewardModel)
|
||||
|
||||
@@ -435,20 +436,21 @@ class PPORewardRayActorGroup(PPORayActorGroup):
|
||||
r_refs = []
|
||||
for i in range(len(sequences_attention_mask_action_mask_refs)):
|
||||
r_ref = self._actor_handlers[i % num_actors].calculate_r.remote(
|
||||
sequences_attention_mask_action_mask_refs[i])
|
||||
sequences_attention_mask_action_mask_refs[i]
|
||||
)
|
||||
r_refs.append(r_ref)
|
||||
return r_refs
|
||||
|
||||
|
||||
def main(args):
|
||||
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
|
||||
level=logging.INFO,
|
||||
datefmt='%Y-%m-%d %H:%M:%S')
|
||||
if args.model == 'gpt2':
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
if args.model == "gpt2":
|
||||
actor_model_class, critic_model_class = GPTActor, GPTCritic
|
||||
elif args.model == 'bloom':
|
||||
elif args.model == "bloom":
|
||||
actor_model_class, critic_model_class = BLOOMActor, BLOOMCritic
|
||||
elif args.model == 'opt':
|
||||
elif args.model == "opt":
|
||||
actor_model_class, critic_model_class = OPTActor, OPTCritic
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
@@ -462,13 +464,14 @@ def main(args):
|
||||
logging.info("Actors created")
|
||||
|
||||
# Prepare model for training
|
||||
generate_kwargs = {'max_length': 128, 'do_sample': True, 'temperature': 1.0, 'top_k': 50}
|
||||
generate_kwargs = {"max_length": 128, "do_sample": True, "temperature": 1.0, "top_k": 50}
|
||||
ray.get(
|
||||
actor_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, True) +
|
||||
critic_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, True) +
|
||||
initial_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, False) +
|
||||
reward_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, False) +
|
||||
actor_group.async_prepare_for_sequence_generation(args.model, args.pretrain, generate_kwargs))
|
||||
actor_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, True)
|
||||
+ critic_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, True)
|
||||
+ initial_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, False)
|
||||
+ reward_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, False)
|
||||
+ actor_group.async_prepare_for_sequence_generation(args.model, args.pretrain, generate_kwargs)
|
||||
)
|
||||
logging.info("Models prepared for training")
|
||||
|
||||
# Prepare models for training
|
||||
@@ -483,8 +486,12 @@ def main(args):
|
||||
# Start training
|
||||
logging.info("Training start")
|
||||
# Set all models to eval and add experience maker
|
||||
all_ray_actors = actor_group._actor_handlers + critic_group._actor_handlers + \
|
||||
initial_group._actor_handlers + reward_group._actor_handlers
|
||||
all_ray_actors = (
|
||||
actor_group._actor_handlers
|
||||
+ critic_group._actor_handlers
|
||||
+ initial_group._actor_handlers
|
||||
+ reward_group._actor_handlers
|
||||
)
|
||||
num_ray_actors = len(all_ray_actors)
|
||||
ray.get([ray_actor.eval.remote() for ray_actor in all_ray_actors])
|
||||
ray.get([ray_actor.add_experience_maker.remote() for ray_actor in all_ray_actors])
|
||||
@@ -497,18 +504,28 @@ def main(args):
|
||||
time += 1
|
||||
# Experience queueing stage
|
||||
sequences_attention_mask_action_mask_refs = actor_group.async_sample_prompts_and_make_sequence(
|
||||
experience_batch_size)
|
||||
experience_batch_size
|
||||
)
|
||||
base_action_log_probs_refs = initial_group.async_calculate_base_action_log_probs(
|
||||
sequences_attention_mask_action_mask_refs)
|
||||
sequences_attention_mask_action_mask_refs
|
||||
)
|
||||
values_refs = critic_group.async_calculate_value(sequences_attention_mask_action_mask_refs)
|
||||
r_refs = reward_group.async_calculate_r(sequences_attention_mask_action_mask_refs)
|
||||
action_log_probs_refs = actor_group.async_calculate_action_log_probs(
|
||||
sequences_attention_mask_action_mask_refs)
|
||||
experience_composition_refs.extend([
|
||||
ExperienceCompositionRefs(sequences_attention_mask_action_mask_refs[i], action_log_probs_refs[i],
|
||||
base_action_log_probs_refs[i], values_refs[i], r_refs[i])
|
||||
for i in range(len(sequences_attention_mask_action_mask_refs))
|
||||
])
|
||||
sequences_attention_mask_action_mask_refs
|
||||
)
|
||||
experience_composition_refs.extend(
|
||||
[
|
||||
ExperienceCompositionRefs(
|
||||
sequences_attention_mask_action_mask_refs[i],
|
||||
action_log_probs_refs[i],
|
||||
base_action_log_probs_refs[i],
|
||||
values_refs[i],
|
||||
r_refs[i],
|
||||
)
|
||||
for i in range(len(sequences_attention_mask_action_mask_refs))
|
||||
]
|
||||
)
|
||||
# Learning stage
|
||||
if time % update_timesteps == 0:
|
||||
experience_refs = []
|
||||
@@ -519,8 +536,9 @@ def main(args):
|
||||
experience_refs.append(selected_ray_actor.make_experience.remote(exp_composition_ref))
|
||||
# backward
|
||||
ray.get(
|
||||
actor_group.async_learn_on_experiences(experience_refs) +
|
||||
critic_group.async_learn_on_experiences(experience_refs))
|
||||
actor_group.async_learn_on_experiences(experience_refs)
|
||||
+ critic_group.async_learn_on_experiences(experience_refs)
|
||||
)
|
||||
# clear refs queue
|
||||
experience_composition_refs.clear()
|
||||
logging.info("Training finished")
|
||||
@@ -528,26 +546,24 @@ def main(args):
|
||||
actor_group.save_checkpoint(args.save_path, args.need_optim_ckpt)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--prompt_csv_url', type=str)
|
||||
parser.add_argument('--strategy',
|
||||
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='ddp')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
||||
parser.add_argument('--pretrain', type=str, default='gpt2')
|
||||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--num_episodes', type=int, default=10)
|
||||
parser.add_argument('--max_timesteps', type=int, default=10)
|
||||
parser.add_argument('--update_timesteps', type=int, default=10)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--num_actor_nodes', type=int, help='num of nodes to use to host actor model', default=1)
|
||||
parser.add_argument('--num_critic_nodes', type=int, help='num of nodes to use to host critic model', default=1)
|
||||
parser.add_argument('--num_initial_nodes', type=int, help='num of nodes to use to host initial model', default=1)
|
||||
parser.add_argument('--num_reward_nodes', type=int, help='num of nodes to use to host reward model', default=1)
|
||||
parser.add_argument('--num_gpus_per_node', type=int, help='num of gpus on a ray node', default=1)
|
||||
parser.add_argument("--prompt_csv_url", type=str)
|
||||
parser.add_argument("--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp")
|
||||
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt"])
|
||||
parser.add_argument("--pretrain", type=str, default="gpt2")
|
||||
parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts.pt")
|
||||
parser.add_argument("--need_optim_ckpt", type=bool, default=False)
|
||||
parser.add_argument("--num_episodes", type=int, default=10)
|
||||
parser.add_argument("--max_timesteps", type=int, default=10)
|
||||
parser.add_argument("--update_timesteps", type=int, default=10)
|
||||
parser.add_argument("--train_batch_size", type=int, default=8)
|
||||
parser.add_argument("--experience_batch_size", type=int, default=8)
|
||||
parser.add_argument("--num_actor_nodes", type=int, help="num of nodes to use to host actor model", default=1)
|
||||
parser.add_argument("--num_critic_nodes", type=int, help="num of nodes to use to host critic model", default=1)
|
||||
parser.add_argument("--num_initial_nodes", type=int, help="num of nodes to use to host initial model", default=1)
|
||||
parser.add_argument("--num_reward_nodes", type=int, help="num of nodes to use to host reward model", default=1)
|
||||
parser.add_argument("--num_gpus_per_node", type=int, help="num of gpus on a ray node", default=1)
|
||||
args = parser.parse_args()
|
||||
ray.init()
|
||||
main(args)
|
||||
|
||||
@@ -22,7 +22,7 @@ class HFRepoFiles:
|
||||
file_path = hf_hub_download(self.repo_id, file, local_dir=dir_path)
|
||||
|
||||
def download_all(self):
|
||||
file_path = snapshot_download(self.repo_id)
|
||||
snapshot_download(self.repo_id)
|
||||
|
||||
|
||||
def test_init(model: str, dir_path: str):
|
||||
@@ -31,19 +31,19 @@ def test_init(model: str, dir_path: str):
|
||||
actor = GPTActor(config=config)
|
||||
critic = GPTCritic(config=config)
|
||||
reward_model = GPTRM(config=config)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(dir_path)
|
||||
GPT2Tokenizer.from_pretrained(dir_path)
|
||||
elif model == "bloom":
|
||||
config = BloomConfig.from_pretrained(dir_path)
|
||||
actor = BLOOMActor(config=config)
|
||||
critic = BLOOMCritic(config=config)
|
||||
reward_model = BLOOMRM(config=config)
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(dir_path)
|
||||
BloomTokenizerFast.from_pretrained(dir_path)
|
||||
elif model == "opt":
|
||||
config = AutoConfig.from_pretrained(dir_path)
|
||||
actor = OPTActor(config=config)
|
||||
critic = OPTCritic(config=config)
|
||||
reward_model = OPTRM(config=config)
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_path)
|
||||
AutoTokenizer.from_pretrained(dir_path)
|
||||
else:
|
||||
raise NotImplementedError(f"Model {model} not implemented")
|
||||
|
||||
@@ -59,17 +59,12 @@ if __name__ == "__main__":
|
||||
exit(0)
|
||||
|
||||
repo_list = {
|
||||
"gpt2": HFRepoFiles(
|
||||
repo_id="gpt2",
|
||||
files=["config.json", "tokenizer.json", "vocab.json", "merges.txt"]
|
||||
),
|
||||
"gpt2": HFRepoFiles(repo_id="gpt2", files=["config.json", "tokenizer.json", "vocab.json", "merges.txt"]),
|
||||
"bloom": HFRepoFiles(
|
||||
repo_id="bigscience/bloom-560m",
|
||||
files=["config.json", "tokenizer.json", "tokenizer_config.json"]
|
||||
repo_id="bigscience/bloom-560m", files=["config.json", "tokenizer.json", "tokenizer_config.json"]
|
||||
),
|
||||
"opt": HFRepoFiles(
|
||||
repo_id="facebook/opt-350m",
|
||||
files=["config.json", "tokenizer_config.json", "vocab.json", "merges.txt"]
|
||||
repo_id="facebook/opt-350m", files=["config.json", "tokenizer_config.json", "vocab.json", "merges.txt"]
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -31,9 +31,11 @@ def generate_alpaca():
|
||||
def generate_sharegpt():
|
||||
# ShareGPT data requires less processing.
|
||||
conversation_dataset = []
|
||||
dataset = load_dataset("anon8231489123/ShareGPT_Vicuna_unfiltered",
|
||||
data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json",
|
||||
split="train")
|
||||
dataset = load_dataset(
|
||||
"anon8231489123/ShareGPT_Vicuna_unfiltered",
|
||||
data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json",
|
||||
split="train",
|
||||
)
|
||||
|
||||
conversations = dataset["conversations"]
|
||||
|
||||
@@ -43,23 +45,24 @@ def generate_sharegpt():
|
||||
del conv["markdown"]
|
||||
del conv["text"]
|
||||
|
||||
conversation = dict(type="conversation",
|
||||
language="Multilingual",
|
||||
dataset="ShareGPT",
|
||||
conversations=conversations[idx])
|
||||
conversation = dict(
|
||||
type="conversation", language="Multilingual", dataset="ShareGPT", conversations=conversations[idx]
|
||||
)
|
||||
conversation_dataset.append(conversation)
|
||||
|
||||
return conversation_dataset
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataset',
|
||||
type=str,
|
||||
default="All",
|
||||
choices=["Alpaca", "ShareGPT", "All"],
|
||||
help="which dataset to convert, All will combine Alpaca and ShareGPT")
|
||||
parser.add_argument('--save_path', type=str, default="dataset.json", help="path to save the converted dataset")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="All",
|
||||
choices=["Alpaca", "ShareGPT", "All"],
|
||||
help="which dataset to convert, All will combine Alpaca and ShareGPT",
|
||||
)
|
||||
parser.add_argument("--save_path", type=str, default="dataset.json", help="path to save the converted dataset")
|
||||
args = parser.parse_args()
|
||||
|
||||
conversation_dataset = []
|
||||
@@ -75,5 +78,5 @@ if __name__ == '__main__':
|
||||
for idx, sample in enumerate(conversation_dataset):
|
||||
sample["id"] = idx + 1
|
||||
|
||||
with open(args.save_path, mode='w') as f:
|
||||
with open(args.save_path, mode="w") as f:
|
||||
json.dump(conversation_dataset, f, indent=4, default=str, ensure_ascii=False)
|
||||
|
||||
@@ -6,7 +6,7 @@ random.seed(42)
|
||||
|
||||
|
||||
def sample(args):
|
||||
with open(args.dataset_path, mode='r') as f:
|
||||
with open(args.dataset_path, mode="r") as f:
|
||||
dataset_list = json.load(f)
|
||||
|
||||
sampled_dataset = [
|
||||
@@ -14,18 +14,14 @@ def sample(args):
|
||||
for idx, sample in enumerate(random.sample(dataset_list, args.sample_size))
|
||||
]
|
||||
|
||||
with open(args.save_path, mode='w') as f:
|
||||
json.dump(sampled_dataset, f, indent=4,
|
||||
default=str, ensure_ascii=False)
|
||||
with open(args.save_path, mode="w") as f:
|
||||
json.dump(sampled_dataset, f, indent=4, default=str, ensure_ascii=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataset_path', type=str, default=None,
|
||||
required=True, help="path to the pretrain dataset")
|
||||
parser.add_argument('--save_path', type=str, default='prompt.json',
|
||||
help="path to save the prompt dataset")
|
||||
parser.add_argument('--sample_size', type=int,
|
||||
default=16384, help="size of the prompt dataset")
|
||||
parser.add_argument("--dataset_path", type=str, default=None, required=True, help="path to the pretrain dataset")
|
||||
parser.add_argument("--save_path", type=str, default="prompt.json", help="path to save the prompt dataset")
|
||||
parser.add_argument("--sample_size", type=int, default=16384, help="size of the prompt dataset")
|
||||
args = parser.parse_args()
|
||||
sample(args)
|
||||
|
||||
@@ -11,13 +11,13 @@ from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, Llama
|
||||
|
||||
def eval(args):
|
||||
# configure model
|
||||
if args.model == 'gpt2':
|
||||
if args.model == "gpt2":
|
||||
actor = GPTActor(pretrained=args.pretrain)
|
||||
elif args.model == 'bloom':
|
||||
elif args.model == "bloom":
|
||||
actor = BLOOMActor(pretrained=args.pretrain)
|
||||
elif args.model == 'opt':
|
||||
elif args.model == "opt":
|
||||
actor = OPTActor(pretrained=args.pretrain)
|
||||
elif args.model == 'llama':
|
||||
elif args.model == "llama":
|
||||
actor = LlamaActor(pretrained=args.pretrain)
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
@@ -28,45 +28,38 @@ def eval(args):
|
||||
actor.load_state_dict(state_dict)
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
if args.model == "gpt2":
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
|
||||
elif args.model == "bloom":
|
||||
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
elif args.model == "opt":
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'llama':
|
||||
elif args.model == "llama":
|
||||
tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
tokenizer.eos_token = '<\s>'
|
||||
tokenizer.eos_token = "<\s>"
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
actor.eval()
|
||||
input_ids = tokenizer.encode(args.input,
|
||||
return_tensors='pt')\
|
||||
.to(torch.cuda.current_device())
|
||||
outputs = generate(actor,
|
||||
input_ids,
|
||||
max_length=args.max_length,
|
||||
do_sample=True,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
num_return_sequences=1)
|
||||
output = tokenizer.batch_decode(outputs[0],
|
||||
skip_special_tokens=True)
|
||||
input_ids = tokenizer.encode(args.input, return_tensors="pt").to(torch.cuda.current_device())
|
||||
outputs = generate(
|
||||
actor, input_ids, max_length=args.max_length, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1
|
||||
)
|
||||
output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
|
||||
print(f"[Output]: {''.join(output)}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
|
||||
# We suggest to use the pretrained model from HuggingFace, use pretrain to configure model
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--model_path', type=str, default=None)
|
||||
parser.add_argument('--input', type=str, default='Question: How are you ? Answer:')
|
||||
parser.add_argument('--max_length', type=int, default=100)
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--model_path", type=str, default=None)
|
||||
parser.add_argument("--input", type=str, default="Question: How are you ? Answer:")
|
||||
parser.add_argument("--max_length", type=int, default=100)
|
||||
args = parser.parse_args()
|
||||
eval(args)
|
||||
|
||||
@@ -5,7 +5,6 @@ from functools import partial
|
||||
|
||||
import pandas as pd
|
||||
import ray
|
||||
import torch
|
||||
from coati.quant import llama_load_quant, low_resource_init
|
||||
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
|
||||
from coati.ray.experience_maker_holder import ExperienceMakerHolder
|
||||
@@ -23,13 +22,13 @@ from transformers.modeling_utils import no_init_weights
|
||||
|
||||
def get_free_port():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('', 0))
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def get_local_ip():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||
s.connect(('8.8.8.8', 80))
|
||||
s.connect(("8.8.8.8", 80))
|
||||
return s.getsockname()[0]
|
||||
|
||||
|
||||
@@ -37,22 +36,25 @@ def main(args):
|
||||
master_addr = str(get_local_ip())
|
||||
# trainer_env_info
|
||||
trainer_port = str(get_free_port())
|
||||
env_info_trainers = [{
|
||||
'local_rank': '0',
|
||||
'rank': str(rank),
|
||||
'world_size': str(args.num_trainers),
|
||||
'master_port': trainer_port,
|
||||
'master_addr': master_addr
|
||||
} for rank in range(args.num_trainers)]
|
||||
env_info_trainers = [
|
||||
{
|
||||
"local_rank": "0",
|
||||
"rank": str(rank),
|
||||
"world_size": str(args.num_trainers),
|
||||
"master_port": trainer_port,
|
||||
"master_addr": master_addr,
|
||||
}
|
||||
for rank in range(args.num_trainers)
|
||||
]
|
||||
|
||||
# maker_env_info
|
||||
maker_port = str(get_free_port())
|
||||
env_info_maker = {
|
||||
'local_rank': '0',
|
||||
'rank': '0',
|
||||
'world_size': '1',
|
||||
'master_port': maker_port,
|
||||
'master_addr': master_addr
|
||||
"local_rank": "0",
|
||||
"rank": "0",
|
||||
"world_size": "1",
|
||||
"master_port": maker_port,
|
||||
"master_addr": master_addr,
|
||||
}
|
||||
|
||||
# configure tokenizer
|
||||
@@ -75,27 +77,33 @@ def main(args):
|
||||
eval_performance=True,
|
||||
debug=args.debug,
|
||||
update_lora_weights=not (args.lora_rank == 0),
|
||||
) for i, env_info_trainer in enumerate(env_info_trainers)
|
||||
)
|
||||
for i, env_info_trainer in enumerate(env_info_trainers)
|
||||
]
|
||||
|
||||
def model_fn():
|
||||
actor = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
|
||||
critic = get_critic_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
|
||||
reward_model = get_reward_model_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
|
||||
if args.initial_model_quant_ckpt is not None and args.model == 'llama':
|
||||
if args.initial_model_quant_ckpt is not None and args.model == "llama":
|
||||
# quantize initial model
|
||||
actor_cfg = AutoConfig.from_pretrained(args.pretrain)
|
||||
with low_resource_init(), no_init_weights():
|
||||
initial_model = get_actor_from_args(args.model, config=actor_cfg)
|
||||
initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits,
|
||||
args.quant_group_size).cuda().requires_grad_(False)
|
||||
initial_model.model = (
|
||||
llama_load_quant(
|
||||
initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
|
||||
)
|
||||
.cuda()
|
||||
.requires_grad_(False)
|
||||
)
|
||||
else:
|
||||
initial_model = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
|
||||
return actor, critic, reward_model, initial_model
|
||||
|
||||
# configure Experience Maker
|
||||
experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote(
|
||||
detached_trainer_name_list=[f'trainer{i}' for i in range(args.num_trainers)],
|
||||
detached_trainer_name_list=[f"trainer{i}" for i in range(args.num_trainers)],
|
||||
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
|
||||
model_fn=model_fn,
|
||||
env_info=env_info_maker,
|
||||
@@ -130,12 +138,11 @@ def main(args):
|
||||
dataset_size = args.experience_batch_size * 4
|
||||
|
||||
def build_dataloader():
|
||||
|
||||
def tokenize_fn(texts):
|
||||
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
|
||||
batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
|
||||
return {k: v.cuda() for k, v in batch.items()}
|
||||
|
||||
dataset = pd.read_csv(args.prompt_path)['prompt']
|
||||
dataset = pd.read_csv(args.prompt_path)["prompt"]
|
||||
dataloader = DataLoader(dataset=dataset, batch_size=dataset_size, shuffle=True, collate_fn=tokenize_fn)
|
||||
return dataloader
|
||||
|
||||
@@ -144,32 +151,31 @@ def main(args):
|
||||
ray.get(wait_tasks)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--prompt_path', type=str, default=None)
|
||||
parser.add_argument('--num_trainers', type=int, default=1)
|
||||
parser.add_argument('--trainer_strategy',
|
||||
choices=[
|
||||
'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
|
||||
'colossalai_zero2_cpu'
|
||||
],
|
||||
default='ddp')
|
||||
parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--critic_pretrain', type=str, default=None)
|
||||
parser.add_argument('--experience_steps', type=int, default=4)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--train_epochs', type=int, default=1)
|
||||
parser.add_argument('--update_steps', type=int, default=2)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument("--prompt_path", type=str, default=None)
|
||||
parser.add_argument("--num_trainers", type=int, default=1)
|
||||
parser.add_argument(
|
||||
"--trainer_strategy",
|
||||
choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
|
||||
default="ddp",
|
||||
)
|
||||
parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
|
||||
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
|
||||
parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--critic_pretrain", type=str, default=None)
|
||||
parser.add_argument("--experience_steps", type=int, default=4)
|
||||
parser.add_argument("--experience_batch_size", type=int, default=8)
|
||||
parser.add_argument("--train_epochs", type=int, default=1)
|
||||
parser.add_argument("--update_steps", type=int, default=2)
|
||||
parser.add_argument("--train_batch_size", type=int, default=8)
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
|
||||
parser.add_argument('--initial_model_quant_ckpt', type=str, default=None)
|
||||
parser.add_argument('--quant_bits', type=int, default=4)
|
||||
parser.add_argument('--quant_group_size', type=int, default=128)
|
||||
parser.add_argument('--debug', action='store_true')
|
||||
parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
|
||||
parser.add_argument("--quant_bits", type=int, default=4)
|
||||
parser.add_argument("--quant_group_size", type=int, default=128)
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
args = parser.parse_args()
|
||||
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
|
||||
main(args)
|
||||
|
||||
@@ -5,7 +5,6 @@ from functools import partial
|
||||
|
||||
import pandas as pd
|
||||
import ray
|
||||
import torch
|
||||
from coati.quant import llama_load_quant, low_resource_init
|
||||
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
|
||||
from coati.ray.experience_maker_holder import ExperienceMakerHolder
|
||||
@@ -23,13 +22,13 @@ from transformers.modeling_utils import no_init_weights
|
||||
|
||||
def get_free_port():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('', 0))
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def get_local_ip():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||
s.connect(('8.8.8.8', 80))
|
||||
s.connect(("8.8.8.8", 80))
|
||||
return s.getsockname()[0]
|
||||
|
||||
|
||||
@@ -37,23 +36,29 @@ def main(args):
|
||||
master_addr = str(get_local_ip())
|
||||
# trainer_env_info
|
||||
trainer_port = str(get_free_port())
|
||||
env_info_trainers = [{
|
||||
'local_rank': '0',
|
||||
'rank': str(rank),
|
||||
'world_size': str(args.num_trainers),
|
||||
'master_port': trainer_port,
|
||||
'master_addr': master_addr
|
||||
} for rank in range(args.num_trainers)]
|
||||
env_info_trainers = [
|
||||
{
|
||||
"local_rank": "0",
|
||||
"rank": str(rank),
|
||||
"world_size": str(args.num_trainers),
|
||||
"master_port": trainer_port,
|
||||
"master_addr": master_addr,
|
||||
}
|
||||
for rank in range(args.num_trainers)
|
||||
]
|
||||
|
||||
# maker_env_info
|
||||
maker_port = str(get_free_port())
|
||||
env_info_makers = [{
|
||||
'local_rank': '0',
|
||||
'rank': str(rank),
|
||||
'world_size': str(args.num_makers),
|
||||
'master_port': maker_port,
|
||||
'master_addr': master_addr
|
||||
} for rank in range(args.num_makers)]
|
||||
env_info_makers = [
|
||||
{
|
||||
"local_rank": "0",
|
||||
"rank": str(rank),
|
||||
"world_size": str(args.num_makers),
|
||||
"master_port": maker_port,
|
||||
"master_addr": master_addr,
|
||||
}
|
||||
for rank in range(args.num_makers)
|
||||
]
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
|
||||
@@ -63,13 +68,18 @@ def main(args):
|
||||
actor = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
|
||||
critic = get_critic_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
|
||||
reward_model = get_reward_model_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
|
||||
if args.initial_model_quant_ckpt is not None and args.model == 'llama':
|
||||
if args.initial_model_quant_ckpt is not None and args.model == "llama":
|
||||
# quantize initial model
|
||||
actor_cfg = AutoConfig.from_pretrained(args.pretrain)
|
||||
with low_resource_init(), no_init_weights():
|
||||
initial_model = get_actor_from_args(args.model, config=actor_cfg)
|
||||
initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits,
|
||||
args.quant_group_size).cuda().requires_grad_(False)
|
||||
initial_model.model = (
|
||||
llama_load_quant(
|
||||
initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
|
||||
)
|
||||
.cuda()
|
||||
.requires_grad_(False)
|
||||
)
|
||||
else:
|
||||
initial_model = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
|
||||
return actor, critic, reward_model, initial_model
|
||||
@@ -78,7 +88,7 @@ def main(args):
|
||||
experience_holder_refs = [
|
||||
ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote(
|
||||
detached_trainer_name_list=[
|
||||
f'trainer{x}'
|
||||
f"trainer{x}"
|
||||
for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False)
|
||||
],
|
||||
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
|
||||
@@ -87,8 +97,8 @@ def main(args):
|
||||
kl_coef=0.1,
|
||||
debug=args.debug,
|
||||
update_lora_weights=not (args.lora_rank == 0),
|
||||
# sync_models_from_trainers=True,
|
||||
# generation kwargs:
|
||||
# sync_models_from_trainers=True,
|
||||
# generation kwargs:
|
||||
max_length=512,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
@@ -128,12 +138,11 @@ def main(args):
|
||||
dataset_size = args.experience_batch_size * 4
|
||||
|
||||
def build_dataloader():
|
||||
|
||||
def tokenize_fn(texts):
|
||||
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
|
||||
batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
|
||||
return {k: v.cuda() for k, v in batch.items()}
|
||||
|
||||
dataset = pd.read_csv(args.prompt_path)['prompt']
|
||||
dataset = pd.read_csv(args.prompt_path)["prompt"]
|
||||
dataloader = DataLoader(dataset=dataset, batch_size=dataset_size, shuffle=True, collate_fn=tokenize_fn)
|
||||
return dataloader
|
||||
|
||||
@@ -148,39 +157,44 @@ def main(args):
|
||||
for experience_holder_ref in experience_holder_refs:
|
||||
wait_tasks.append(experience_holder_ref.workingloop.remote(build_dataloader, num_steps=args.experience_steps))
|
||||
|
||||
total_steps = args.experience_batch_size * args.experience_steps * \
|
||||
args.num_makers // (args.num_trainers * args.train_batch_size)
|
||||
total_steps = (
|
||||
args.experience_batch_size
|
||||
* args.experience_steps
|
||||
* args.num_makers
|
||||
// (args.num_trainers * args.train_batch_size)
|
||||
)
|
||||
for trainer_ref in trainer_refs:
|
||||
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
|
||||
|
||||
ray.get(wait_tasks)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--prompt_path', type=str, default=None)
|
||||
parser.add_argument('--num_makers', type=int, default=1)
|
||||
parser.add_argument('--num_trainers', type=int, default=1)
|
||||
parser.add_argument("--prompt_path", type=str, default=None)
|
||||
parser.add_argument("--num_makers", type=int, default=1)
|
||||
parser.add_argument("--num_trainers", type=int, default=1)
|
||||
parser.add_argument(
|
||||
'--trainer_strategy',
|
||||
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu', 'colossalai_zero2_cpu'],
|
||||
default='ddp')
|
||||
parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--critic_pretrain', type=str, default=None)
|
||||
parser.add_argument('--experience_steps', type=int, default=4)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--train_epochs', type=int, default=1)
|
||||
parser.add_argument('--update_steps', type=int, default=2)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
"--trainer_strategy",
|
||||
choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
|
||||
default="ddp",
|
||||
)
|
||||
parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
|
||||
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
|
||||
parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--critic_pretrain", type=str, default=None)
|
||||
parser.add_argument("--experience_steps", type=int, default=4)
|
||||
parser.add_argument("--experience_batch_size", type=int, default=8)
|
||||
parser.add_argument("--train_epochs", type=int, default=1)
|
||||
parser.add_argument("--update_steps", type=int, default=2)
|
||||
parser.add_argument("--train_batch_size", type=int, default=8)
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
|
||||
parser.add_argument('--initial_model_quant_ckpt', type=str, default=None)
|
||||
parser.add_argument('--quant_bits', type=int, default=4)
|
||||
parser.add_argument('--quant_group_size', type=int, default=128)
|
||||
parser.add_argument('--debug', action='store_true')
|
||||
parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
|
||||
parser.add_argument("--quant_bits", type=int, default=4)
|
||||
parser.add_argument("--quant_group_size", type=int, default=128)
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
pandas>=1.4.1
|
||||
sentencepiece
|
||||
colossalai==0.3.1
|
||||
colossalai==0.3.1
|
||||
|
||||
@@ -20,28 +20,28 @@ from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
def main(args):
|
||||
# configure strategy
|
||||
if args.strategy == 'ddp':
|
||||
if args.strategy == "ddp":
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
|
||||
elif args.strategy == "colossalai_gemini":
|
||||
strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
|
||||
elif args.strategy == "colossalai_zero2":
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
if args.rm_path is not None:
|
||||
warnings.warn('LoRA weights should be merged with the model weights')
|
||||
state_dict = torch.load(args.rm_path, map_location='cpu')
|
||||
warnings.warn("LoRA weights should be merged with the model weights")
|
||||
state_dict = torch.load(args.rm_path, map_location="cpu")
|
||||
|
||||
with strategy.model_init_context():
|
||||
# configure model
|
||||
if args.model == 'gpt2':
|
||||
if args.model == "gpt2":
|
||||
initial_model = GPTActor(pretrained=args.pretrain)
|
||||
elif args.model == 'bloom':
|
||||
elif args.model == "bloom":
|
||||
initial_model = BLOOMActor(pretrained=args.pretrain)
|
||||
elif args.model == 'opt':
|
||||
elif args.model == "opt":
|
||||
initial_model = OPTActor(pretrained=args.pretrain)
|
||||
elif args.model == 'llama':
|
||||
elif args.model == "llama":
|
||||
initial_model = LlamaActor(pretrained=args.pretrain)
|
||||
else:
|
||||
raise ValueError(f'Unsupported actor model "{args.model}"')
|
||||
@@ -51,13 +51,13 @@ def main(args):
|
||||
else:
|
||||
rm_model_name = args.rm_model
|
||||
|
||||
if rm_model_name == 'gpt2':
|
||||
if rm_model_name == "gpt2":
|
||||
reward_model = GPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
|
||||
elif rm_model_name == 'bloom':
|
||||
elif rm_model_name == "bloom":
|
||||
reward_model = BLOOMRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
|
||||
elif rm_model_name == 'opt':
|
||||
elif rm_model_name == "opt":
|
||||
reward_model = OPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
|
||||
elif rm_model_name == 'llama':
|
||||
elif rm_model_name == "llama":
|
||||
reward_model = LlamaRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
|
||||
else:
|
||||
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
|
||||
@@ -68,24 +68,24 @@ def main(args):
|
||||
initial_model.to(torch.float16).to(torch.cuda.current_device())
|
||||
reward_model.to(torch.float16).to(torch.cuda.current_device())
|
||||
|
||||
if args.model == 'gpt2':
|
||||
if args.model == "gpt2":
|
||||
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||
elif args.model == 'bloom':
|
||||
elif args.model == "bloom":
|
||||
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||
elif args.model == 'opt':
|
||||
elif args.model == "opt":
|
||||
actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||
elif args.model == 'llama':
|
||||
elif args.model == "llama":
|
||||
actor = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||
else:
|
||||
raise ValueError(f'Unsupported actor model "{args.model}"')
|
||||
|
||||
if rm_model_name == 'gpt2':
|
||||
if rm_model_name == "gpt2":
|
||||
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
elif rm_model_name == 'bloom':
|
||||
elif rm_model_name == "bloom":
|
||||
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
elif rm_model_name == 'opt':
|
||||
elif rm_model_name == "opt":
|
||||
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
elif rm_model_name == 'llama':
|
||||
elif rm_model_name == "llama":
|
||||
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
else:
|
||||
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
|
||||
@@ -94,12 +94,12 @@ def main(args):
|
||||
critic.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
|
||||
if args.strategy != 'colossalai_gemini':
|
||||
if args.strategy != "colossalai_gemini":
|
||||
critic.to(torch.float16).to(torch.cuda.current_device())
|
||||
actor.to(torch.float16).to(torch.cuda.current_device())
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith('colossalai'):
|
||||
if args.strategy.startswith("colossalai"):
|
||||
actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
|
||||
critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
|
||||
else:
|
||||
@@ -107,22 +107,22 @@ def main(args):
|
||||
critic_optim = Adam(critic.parameters(), lr=1e-7)
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(
|
||||
'gpt2' if args.tokenizer is None else args.tokenizer)
|
||||
if args.model == "gpt2":
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
elif args.model == "bloom":
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(
|
||||
'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer)
|
||||
"bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
|
||||
)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
|
||||
elif args.model == "opt":
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'llama':
|
||||
elif args.model == "llama":
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
|
||||
tokenizer.eos_token = '<\s>'
|
||||
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
|
||||
)
|
||||
tokenizer.eos_token = "<\s>"
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
@@ -132,27 +132,25 @@ def main(args):
|
||||
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
|
||||
else:
|
||||
prompt_sampler = None
|
||||
prompt_dataloader = DataLoader(prompt_dataset,
|
||||
shuffle=(prompt_sampler is None),
|
||||
sampler=prompt_sampler,
|
||||
batch_size=args.experience_batch_size)
|
||||
prompt_dataloader = DataLoader(
|
||||
prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.experience_batch_size
|
||||
)
|
||||
|
||||
pretrain_dataset = SupervisedDataset(tokenizer=tokenizer,
|
||||
data_path=args.pretrain_dataset,
|
||||
max_datasets_size=16384,
|
||||
max_length=args.max_input_len)
|
||||
pretrain_dataset = SupervisedDataset(
|
||||
tokenizer=tokenizer, data_path=args.pretrain_dataset, max_datasets_size=16384, max_length=args.max_input_len
|
||||
)
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
|
||||
else:
|
||||
pretrain_sampler = None
|
||||
pretrain_dataloader = DataLoader(pretrain_dataset,
|
||||
shuffle=(pretrain_sampler is None),
|
||||
sampler=pretrain_sampler,
|
||||
batch_size=args.ptx_batch_size)
|
||||
pretrain_dataloader = DataLoader(
|
||||
pretrain_dataset, shuffle=(pretrain_sampler is None), sampler=pretrain_sampler, batch_size=args.ptx_batch_size
|
||||
)
|
||||
|
||||
# NOTE: For small models like opt-1.3b, reward model and initial model are not required to be parallelized.
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = \
|
||||
strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model
|
||||
)
|
||||
|
||||
# configure trainer
|
||||
trainer = PPOTrainer(
|
||||
@@ -173,50 +171,54 @@ def main(args):
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
offload_inference_models=args.strategy != 'colossalai_gemini'
|
||||
offload_inference_models=args.strategy != "colossalai_gemini",
|
||||
)
|
||||
|
||||
trainer.fit(prompt_dataloader=prompt_dataloader,
|
||||
pretrain_dataloader=pretrain_dataloader,
|
||||
num_episodes=args.num_episodes,
|
||||
num_collect_steps=args.num_collect_steps,
|
||||
num_update_steps=args.num_update_steps)
|
||||
trainer.fit(
|
||||
prompt_dataloader=prompt_dataloader,
|
||||
pretrain_dataloader=pretrain_dataloader,
|
||||
num_episodes=args.num_episodes,
|
||||
num_collect_steps=args.num_collect_steps,
|
||||
num_update_steps=args.num_update_steps,
|
||||
)
|
||||
|
||||
# save model checkpoint after fitting
|
||||
strategy.save_model(actor, args.save_path, only_rank0=True)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if args.need_optim_ckpt:
|
||||
strategy.save_optimizer(actor_optim,
|
||||
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
strategy.save_optimizer(
|
||||
actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--prompt_dataset', type=str, default=None, help='path to the prompt dataset')
|
||||
parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset')
|
||||
parser.add_argument('--strategy',
|
||||
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='colossalai_zero2',
|
||||
help='strategy to use')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--tokenizer', type=str, default=None)
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--rm_path', type=str, default=None)
|
||||
parser.add_argument('--rm_pretrain', type=str, default=None)
|
||||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--num_episodes', type=int, default=10)
|
||||
parser.add_argument('--num_collect_steps', type=int, default=10)
|
||||
parser.add_argument('--num_update_steps', type=int, default=5)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--ptx_batch_size', type=int, default=1)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument('--kl_coef', type=float, default=0.1)
|
||||
parser.add_argument('--ptx_coef', type=float, default=0.9)
|
||||
parser.add_argument('--max_input_len', type=int, default=96)
|
||||
parser.add_argument('--max_seq_len', type=int, default=128)
|
||||
parser.add_argument("--prompt_dataset", type=str, default=None, help="path to the prompt dataset")
|
||||
parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
|
||||
parser.add_argument(
|
||||
"--strategy",
|
||||
choices=["ddp", "colossalai_gemini", "colossalai_zero2"],
|
||||
default="colossalai_zero2",
|
||||
help="strategy to use",
|
||||
)
|
||||
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
|
||||
parser.add_argument("--tokenizer", type=str, default=None)
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"])
|
||||
parser.add_argument("--rm_path", type=str, default=None)
|
||||
parser.add_argument("--rm_pretrain", type=str, default=None)
|
||||
parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
|
||||
parser.add_argument("--need_optim_ckpt", type=bool, default=False)
|
||||
parser.add_argument("--num_episodes", type=int, default=10)
|
||||
parser.add_argument("--num_collect_steps", type=int, default=10)
|
||||
parser.add_argument("--num_update_steps", type=int, default=5)
|
||||
parser.add_argument("--train_batch_size", type=int, default=8)
|
||||
parser.add_argument("--ptx_batch_size", type=int, default=1)
|
||||
parser.add_argument("--experience_batch_size", type=int, default=8)
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument("--kl_coef", type=float, default=0.1)
|
||||
parser.add_argument("--ptx_coef", type=float, default=0.9)
|
||||
parser.add_argument("--max_input_len", type=int, default=96)
|
||||
parser.add_argument("--max_seq_len", type=int, default=128)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
@@ -24,24 +24,24 @@ from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
def train(args):
|
||||
# configure strategy
|
||||
if args.strategy == 'ddp':
|
||||
if args.strategy == "ddp":
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = GeminiStrategy(placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
|
||||
elif args.strategy == "colossalai_gemini":
|
||||
strategy = GeminiStrategy(placement_policy="cuda")
|
||||
elif args.strategy == "colossalai_zero2":
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
# configure model
|
||||
with strategy.model_init_context():
|
||||
if args.model == 'bloom':
|
||||
if args.model == "bloom":
|
||||
model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||
elif args.model == 'opt':
|
||||
elif args.model == "opt":
|
||||
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||
elif args.model == 'gpt2':
|
||||
elif args.model == "gpt2":
|
||||
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||
elif args.model == 'llama':
|
||||
elif args.model == "llama":
|
||||
model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
@@ -53,36 +53,36 @@ def train(args):
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(
|
||||
'gpt2' if args.tokenizer is None else args.tokenizer)
|
||||
if args.model == "gpt2":
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
elif args.model == "bloom":
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(
|
||||
'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer)
|
||||
"bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
|
||||
)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
|
||||
elif args.model == "opt":
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'llama':
|
||||
elif args.model == "llama":
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
|
||||
tokenizer.eos_token = '<\s>'
|
||||
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
|
||||
)
|
||||
tokenizer.eos_token = "<\s>"
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith('colossalai'):
|
||||
if args.strategy.startswith("colossalai"):
|
||||
optim = HybridAdam(model.parameters(), lr=5e-6)
|
||||
else:
|
||||
optim = Adam(model.parameters(), lr=5e-6)
|
||||
|
||||
# configure loss function
|
||||
if args.loss_fn == 'log_sig':
|
||||
if args.loss_fn == "log_sig":
|
||||
loss_fn = LogSigLoss()
|
||||
elif args.loss_fn == 'log_exp':
|
||||
elif args.loss_fn == "log_exp":
|
||||
loss_fn = LogExpLoss()
|
||||
else:
|
||||
raise ValueError(f'Unsupported loss function "{args.loss_fn}"')
|
||||
@@ -94,18 +94,18 @@ def train(args):
|
||||
data = load_dataset(args.dataset)
|
||||
|
||||
if args.test:
|
||||
train_data = data['train'].select(range(20))
|
||||
eval_data = data['test'].select(range(5))
|
||||
train_data = data["train"].select(range(20))
|
||||
eval_data = data["test"].select(range(5))
|
||||
else:
|
||||
train_data = data['train']
|
||||
eval_data = data['test']
|
||||
valid_data = data['test'].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5)))
|
||||
train_data = data["train"]
|
||||
eval_data = data["test"]
|
||||
valid_data = data["test"].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5)))
|
||||
|
||||
if args.dataset == 'Dahoas/rm-static':
|
||||
if args.dataset == "Dahoas/rm-static":
|
||||
train_dataset = RmStaticDataset(train_data, tokenizer, args.max_len)
|
||||
valid_dataset = RmStaticDataset(valid_data, tokenizer, args.max_len)
|
||||
eval_dataset = RmStaticDataset(eval_data, tokenizer, args.max_len)
|
||||
elif args.dataset == 'Anthropic/hh-rlhf':
|
||||
elif args.dataset == "Anthropic/hh-rlhf":
|
||||
train_dataset = HhRlhfDataset(train_data, tokenizer, args.max_len)
|
||||
valid_dataset = HhRlhfDataset(valid_data, tokenizer, args.max_len)
|
||||
eval_dataset = HhRlhfDataset(eval_data, tokenizer, args.max_len)
|
||||
@@ -113,90 +113,99 @@ def train(args):
|
||||
raise ValueError(f'Unsupported dataset "{args.dataset}"')
|
||||
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
train_sampler = DistributedSampler(train_dataset,
|
||||
shuffle=True,
|
||||
seed=42,
|
||||
drop_last=True,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size())
|
||||
valid_sampler = DistributedSampler(valid_dataset,
|
||||
shuffle=True,
|
||||
seed=42,
|
||||
drop_last=True,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size())
|
||||
eval_sampler = DistributedSampler(eval_dataset,
|
||||
shuffle=True,
|
||||
seed=42,
|
||||
drop_last=True,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size())
|
||||
train_sampler = DistributedSampler(
|
||||
train_dataset,
|
||||
shuffle=True,
|
||||
seed=42,
|
||||
drop_last=True,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size(),
|
||||
)
|
||||
valid_sampler = DistributedSampler(
|
||||
valid_dataset,
|
||||
shuffle=True,
|
||||
seed=42,
|
||||
drop_last=True,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size(),
|
||||
)
|
||||
eval_sampler = DistributedSampler(
|
||||
eval_dataset,
|
||||
shuffle=True,
|
||||
seed=42,
|
||||
drop_last=True,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size(),
|
||||
)
|
||||
else:
|
||||
train_sampler = None
|
||||
valid_sampler = None
|
||||
eval_sampler = None
|
||||
|
||||
train_dataloader = DataLoader(train_dataset,
|
||||
shuffle=(train_sampler is None),
|
||||
sampler=train_sampler,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
shuffle=(train_sampler is None),
|
||||
sampler=train_sampler,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
valid_dataloader = DataLoader(valid_dataset,
|
||||
shuffle=(valid_sampler is None),
|
||||
sampler=valid_sampler,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True)
|
||||
valid_dataloader = DataLoader(
|
||||
valid_dataset,
|
||||
shuffle=(valid_sampler is None),
|
||||
sampler=valid_sampler,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
eval_dataloader = DataLoader(eval_dataset,
|
||||
shuffle=(eval_sampler is None),
|
||||
sampler=eval_sampler,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True)
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True
|
||||
)
|
||||
|
||||
lr_scheduler = CosineAnnealingLR(optim, train_dataloader.__len__() // 100)
|
||||
strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
|
||||
model = strategy_dict['model']
|
||||
optim = strategy_dict['optimizer']
|
||||
lr_scheduler = strategy_dict['lr_scheduler']
|
||||
trainer = RewardModelTrainer(model=model,
|
||||
strategy=strategy,
|
||||
optim=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
loss_fn=loss_fn,
|
||||
max_epochs=args.max_epochs)
|
||||
model = strategy_dict["model"]
|
||||
optim = strategy_dict["optimizer"]
|
||||
lr_scheduler = strategy_dict["lr_scheduler"]
|
||||
trainer = RewardModelTrainer(
|
||||
model=model,
|
||||
strategy=strategy,
|
||||
optim=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
loss_fn=loss_fn,
|
||||
max_epochs=args.max_epochs,
|
||||
)
|
||||
|
||||
trainer.fit(train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, eval_dataloader=eval_dataloader)
|
||||
# save model checkpoint after fitting on only rank0
|
||||
strategy.save_model(model, args.save_path, only_rank0=True)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if args.need_optim_ckpt:
|
||||
strategy.save_optimizer(trainer.optimizer,
|
||||
'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
strategy.save_optimizer(
|
||||
trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--strategy',
|
||||
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='colossalai_zero2')
|
||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
|
||||
parser.add_argument('--tokenizer', type=str, default=None)
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--model_path', type=str, default=None)
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--dataset',
|
||||
type=str,
|
||||
choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'],
|
||||
default='Dahoas/rm-static')
|
||||
parser.add_argument('--subset', type=lambda x: None if x == 'None' else x, default=None)
|
||||
parser.add_argument('--save_path', type=str, default='rm_ckpt')
|
||||
parser.add_argument('--max_epochs', type=int, default=1)
|
||||
parser.add_argument('--batch_size', type=int, default=1)
|
||||
parser.add_argument('--max_len', type=int, default=512)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument('--loss_fn', type=str, default='log_sig', choices=['log_sig', 'log_exp'])
|
||||
parser.add_argument('--test', type=bool, default=False)
|
||||
parser.add_argument(
|
||||
"--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="colossalai_zero2"
|
||||
)
|
||||
parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom")
|
||||
parser.add_argument("--tokenizer", type=str, default=None)
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--model_path", type=str, default=None)
|
||||
parser.add_argument("--need_optim_ckpt", type=bool, default=False)
|
||||
parser.add_argument(
|
||||
"--dataset", type=str, choices=["Anthropic/hh-rlhf", "Dahoas/rm-static"], default="Dahoas/rm-static"
|
||||
)
|
||||
parser.add_argument("--subset", type=lambda x: None if x == "None" else x, default=None)
|
||||
parser.add_argument("--save_path", type=str, default="rm_ckpt")
|
||||
parser.add_argument("--max_epochs", type=int, default=1)
|
||||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
parser.add_argument("--max_len", type=int, default=512)
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"])
|
||||
parser.add_argument("--test", type=bool, default=False)
|
||||
args = parser.parse_args()
|
||||
train(args)
|
||||
|
||||
@@ -6,18 +6,18 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from coati.dataset import SFTDataset, SupervisedDataset
|
||||
from coati.models.bloom import BLOOMActor
|
||||
from coati.models.chatglm import ChatGLMActor
|
||||
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
|
||||
from coati.models.gpt import GPTActor
|
||||
from coati.models.llama import LlamaActor
|
||||
from coati.models.opt import OPTActor
|
||||
from coati.models.chatglm import ChatGLMActor
|
||||
from coati.trainer import SFTTrainer
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||||
from datasets import load_dataset
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, AutoModel
|
||||
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
from transformers.trainer import get_scheduler
|
||||
|
||||
@@ -28,14 +28,14 @@ from colossalai.tensor import ColoParameter
|
||||
|
||||
def train(args):
|
||||
# configure strategy
|
||||
if args.strategy == 'ddp':
|
||||
if args.strategy == "ddp":
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = GeminiStrategy(placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2_cpu':
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
|
||||
elif args.strategy == "colossalai_gemini":
|
||||
strategy = GeminiStrategy(placement_policy="cuda")
|
||||
elif args.strategy == "colossalai_zero2":
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||
elif args.strategy == "colossalai_zero2_cpu":
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
@@ -44,23 +44,15 @@ def train(args):
|
||||
warnings.warn("Gradient checkpoint is disabled when using LoRA")
|
||||
args.grad_checkpoint = False
|
||||
with strategy.model_init_context():
|
||||
if args.model == 'bloom':
|
||||
model = BLOOMActor(pretrained=args.pretrain,
|
||||
lora_rank=args.lora_rank,
|
||||
checkpoint=args.grad_checkpoint)
|
||||
elif args.model == 'opt':
|
||||
model = OPTActor(pretrained=args.pretrain,
|
||||
lora_rank=args.lora_rank,
|
||||
checkpoint=args.grad_checkpoint)
|
||||
elif args.model == 'gpt2':
|
||||
model = GPTActor(pretrained=args.pretrain,
|
||||
lora_rank=args.lora_rank,
|
||||
checkpoint=args.grad_checkpoint)
|
||||
elif args.model == 'llama':
|
||||
model = LlamaActor(pretrained=args.pretrain,
|
||||
lora_rank=args.lora_rank,
|
||||
checkpoint=args.grad_checkpoint)
|
||||
elif args.model == 'chatglm':
|
||||
if args.model == "bloom":
|
||||
model = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
|
||||
elif args.model == "opt":
|
||||
model = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
|
||||
elif args.model == "gpt2":
|
||||
model = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
|
||||
elif args.model == "llama":
|
||||
model = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
|
||||
elif args.model == "chatglm":
|
||||
model = ChatGLMActor(pretrained=args.pretrain)
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
@@ -68,144 +60,157 @@ def train(args):
|
||||
model.to(torch.float16).to(torch.cuda.current_device())
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(
|
||||
'gpt2' if args.tokenizer is None else args.tokenizer)
|
||||
if args.model == "gpt2":
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
elif args.model == "bloom":
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(
|
||||
'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer)
|
||||
"bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
|
||||
)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
|
||||
elif args.model == "opt":
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'llama':
|
||||
elif args.model == "llama":
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
|
||||
tokenizer.eos_token = '<\s>'
|
||||
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
|
||||
)
|
||||
tokenizer.eos_token = "<\s>"
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
elif args.model == 'chatglm':
|
||||
elif args.model == "chatglm":
|
||||
tokenizer = ChatGLMTokenizer.from_pretrained(
|
||||
"THUDM/chatglm-6b" if args.tokenizer is None else args.tokenizer, trust_remote_code=True)
|
||||
"THUDM/chatglm-6b" if args.tokenizer is None else args.tokenizer, trust_remote_code=True
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
if args.model == 'llama' and args.strategy == 'colossalai_gemini':
|
||||
if args.model == "llama" and args.strategy == "colossalai_gemini":
|
||||
# this is a hack to deal with the resized embedding
|
||||
# to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility
|
||||
for name, param in model.named_parameters():
|
||||
if not isinstance(param, ColoParameter):
|
||||
sub_module_name = '.'.join(name.split('.')[:-1])
|
||||
weight_name = name.split('.')[-1]
|
||||
sub_module_name = ".".join(name.split(".")[:-1])
|
||||
weight_name = name.split(".")[-1]
|
||||
sub_module = model.get_submodule(sub_module_name)
|
||||
setattr(sub_module, weight_name, ColoParameter(param))
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith('colossalai'):
|
||||
if args.strategy.startswith("colossalai"):
|
||||
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
|
||||
else:
|
||||
optim = Adam(model.parameters(), lr=args.lr)
|
||||
logger = get_dist_logger()
|
||||
|
||||
# configure dataset
|
||||
if args.dataset == 'yizhongw/self_instruct':
|
||||
train_data = load_dataset(args.dataset, 'super_natural_instructions', split='train')
|
||||
eval_data = load_dataset(args.dataset, 'super_natural_instructions', split='test')
|
||||
if args.dataset == "yizhongw/self_instruct":
|
||||
train_data = load_dataset(args.dataset, "super_natural_instructions", split="train")
|
||||
eval_data = load_dataset(args.dataset, "super_natural_instructions", split="test")
|
||||
|
||||
train_dataset = SFTDataset(train_data, tokenizer, args.max_len)
|
||||
eval_dataset = SFTDataset(eval_data, tokenizer, args.max_len)
|
||||
|
||||
else:
|
||||
train_dataset = SupervisedDataset(tokenizer=tokenizer,
|
||||
data_path=args.dataset,
|
||||
max_datasets_size=args.max_datasets_size,
|
||||
max_length=args.max_len)
|
||||
train_dataset = SupervisedDataset(
|
||||
tokenizer=tokenizer,
|
||||
data_path=args.dataset,
|
||||
max_datasets_size=args.max_datasets_size,
|
||||
max_length=args.max_len,
|
||||
)
|
||||
eval_dataset = None
|
||||
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
train_sampler = DistributedSampler(train_dataset,
|
||||
shuffle=True,
|
||||
seed=42,
|
||||
drop_last=True,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size())
|
||||
train_sampler = DistributedSampler(
|
||||
train_dataset,
|
||||
shuffle=True,
|
||||
seed=42,
|
||||
drop_last=True,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size(),
|
||||
)
|
||||
if eval_dataset is not None:
|
||||
eval_sampler = DistributedSampler(eval_dataset,
|
||||
shuffle=False,
|
||||
seed=42,
|
||||
drop_last=False,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size())
|
||||
eval_sampler = DistributedSampler(
|
||||
eval_dataset,
|
||||
shuffle=False,
|
||||
seed=42,
|
||||
drop_last=False,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size(),
|
||||
)
|
||||
else:
|
||||
train_sampler = None
|
||||
eval_sampler = None
|
||||
|
||||
train_dataloader = DataLoader(train_dataset,
|
||||
shuffle=(train_sampler is None),
|
||||
sampler=train_sampler,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
shuffle=(train_sampler is None),
|
||||
sampler=train_sampler,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
)
|
||||
if eval_dataset is not None:
|
||||
eval_dataloader = DataLoader(eval_dataset,
|
||||
shuffle=(eval_sampler is None),
|
||||
sampler=eval_sampler,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True)
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset,
|
||||
shuffle=(eval_sampler is None),
|
||||
sampler=eval_sampler,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
)
|
||||
else:
|
||||
eval_dataloader = None
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch)
|
||||
lr_scheduler = get_scheduler("cosine",
|
||||
optim,
|
||||
num_warmup_steps=math.ceil(max_steps * 0.03),
|
||||
num_training_steps=max_steps)
|
||||
lr_scheduler = get_scheduler(
|
||||
"cosine", optim, num_warmup_steps=math.ceil(max_steps * 0.03), num_training_steps=max_steps
|
||||
)
|
||||
strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
|
||||
model = strategy_dict['model']
|
||||
optim = strategy_dict['optimizer']
|
||||
lr_scheduler = strategy_dict['lr_scheduler']
|
||||
trainer = SFTTrainer(model=model,
|
||||
strategy=strategy,
|
||||
optim=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps)
|
||||
model = strategy_dict["model"]
|
||||
optim = strategy_dict["optimizer"]
|
||||
lr_scheduler = strategy_dict["lr_scheduler"]
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
strategy=strategy,
|
||||
optim=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
)
|
||||
|
||||
trainer.fit(train_dataloader=train_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
logger=logger,
|
||||
use_wandb=args.use_wandb)
|
||||
trainer.fit(
|
||||
train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, logger=logger, use_wandb=args.use_wandb
|
||||
)
|
||||
|
||||
# save model checkpoint after fitting on only rank0
|
||||
strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if args.need_optim_ckpt:
|
||||
strategy.save_optimizer(trainer.optimizer,
|
||||
'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
strategy.save_optimizer(
|
||||
trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--strategy',
|
||||
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'],
|
||||
default='colossalai_zero2')
|
||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama', 'chatglm'], default='bloom')
|
||||
parser.add_argument('--tokenizer', type=str, default=None)
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--dataset', type=str, default=None)
|
||||
parser.add_argument('--max_datasets_size', type=int, default=None)
|
||||
parser.add_argument('--save_path', type=str, default='output')
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--max_epochs', type=int, default=3)
|
||||
parser.add_argument('--batch_size', type=int, default=4)
|
||||
parser.add_argument('--max_len', type=int, default=512)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log")
|
||||
parser.add_argument('--lr', type=float, default=5e-6)
|
||||
parser.add_argument('--accumulation_steps', type=int, default=8)
|
||||
parser.add_argument('--use_wandb', default=False, action='store_true')
|
||||
parser.add_argument('--grad_checkpoint', default=False, action='store_true')
|
||||
parser.add_argument(
|
||||
"--strategy",
|
||||
choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_zero2_cpu"],
|
||||
default="colossalai_zero2",
|
||||
)
|
||||
parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama", "chatglm"], default="bloom")
|
||||
parser.add_argument("--tokenizer", type=str, default=None)
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--dataset", type=str, default=None)
|
||||
parser.add_argument("--max_datasets_size", type=int, default=None)
|
||||
parser.add_argument("--save_path", type=str, default="output")
|
||||
parser.add_argument("--need_optim_ckpt", type=bool, default=False)
|
||||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--max_len", type=int, default=512)
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument("--log_interval", type=int, default=100, help="how many steps to log")
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
train(args)
|
||||
|
||||
Reference in New Issue
Block a user