From 62f4e2eb0760ac8bfe28834b061dbc2bda93ade9 Mon Sep 17 00:00:00 2001 From: YY Lin Date: Thu, 6 Apr 2023 11:54:52 +0800 Subject: [PATCH] [Chat]Add Peft support & fix the ptx bug (#3433) * Update ppo.py Fix the bug of fetching wrong batch data * Add peft model support in SFT and Prompts training In stage-1 and stage-3, the peft model supports are added. So the trained artifacts will be only a small lora additions instead of the whole bunch of files. * Delete test_prompts.txt * Delete test_pretrained.txt * Move the peft stuffs to a community folder. * Move the demo sft to community * delete dirty files * Add instructions to install peft using source * Remove Chinese comments * remove the Chinese comments --- applications/Chat/coati/trainer/ppo.py | 7 +- .../Chat/examples/community/EasyPeftModel.md | 24 ++ .../Chat/examples/community/easy_dataset.py | 242 ++++++++++++++++++ .../Chat/examples/community/easy_models.py | 97 +++++++ .../examples/community/train_peft_prompts.py | 227 ++++++++++++++++ .../Chat/examples/community/train_peft_sft.py | 187 ++++++++++++++ 6 files changed, 781 insertions(+), 3 deletions(-) create mode 100644 applications/Chat/examples/community/EasyPeftModel.md create mode 100644 applications/Chat/examples/community/easy_dataset.py create mode 100644 applications/Chat/examples/community/easy_models.py create mode 100644 applications/Chat/examples/community/train_peft_prompts.py create mode 100644 applications/Chat/examples/community/train_peft_sft.py diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py index 5c7c71d20..2b0cfcc16 100644 --- a/applications/Chat/coati/trainer/ppo.py +++ b/applications/Chat/coati/trainer/ppo.py @@ -92,9 +92,10 @@ class PPOTrainer(Trainer): # ptx loss if self.ptx_coef != 0: - ptx = next(iter(self.pretrain_dataloader))['input_ids'].to(torch.cuda.current_device()) - label = next(iter(self.pretrain_dataloader))['labels'].to(torch.cuda.current_device())[:, 1:] - attention_mask = next(iter(self.pretrain_dataloader))['attention_mask'].to(torch.cuda.current_device()) + batch = next(iter(self.pretrain_dataloader)) + ptx = batch['input_ids'].to(torch.cuda.current_device()) + label = batch['labels'].to(torch.cuda.current_device())[:, 1:] + attention_mask = batch['attention_mask'].to(torch.cuda.current_device()) ptx_log_probs = self.actor.get_base_model()(ptx, attention_mask=attention_mask)['logits'][..., :-1, :] ptx_loss = self.ptx_loss_fn(ptx_log_probs.view(-1, ptx_log_probs.size(-1)), label.view(-1)) actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef) diff --git a/applications/Chat/examples/community/EasyPeftModel.md b/applications/Chat/examples/community/EasyPeftModel.md new file mode 100644 index 000000000..16c4af76b --- /dev/null +++ b/applications/Chat/examples/community/EasyPeftModel.md @@ -0,0 +1,24 @@ +# Add Peft support for SFT and Prompts model training + +The orginal implementation just adopts the loralib and merges the layers into the final model. The huggingface peft is a better lora model implementation and can be easily training and distributed. + +Since reward model is relative small, I just keep it as original one. I suggest train full model to get the proper reward/critic model. + +# Prelimenary installation +Since the current pypi peft package(0.2) has some bugs, please install the peft package using source. +``` +git clone https://github.com/huggingface/peft +cd peft +pip install . +``` + +# Usage +For SFT training, just call train_peft_sft.py + +Its arguments are almost identical to train_sft.py instead adding a new eval_dataset if you have a eval_dataset file. The data file is just a plain datafile, please check the format in the easy_dataset.py. + +For stage-3 rlhf training, call train_peft_prompts.py. +Its arguments are almost idential to train_prompts.py. The only difference is that I use text files to indicate the prompt and pretrained data file. The models are included in easy_models.py. Currently only bloom models are tested, but technically gpt2/opt/llama should be supported. + +# Dataformat +Please refer the formats in test_sft.txt, test_prompts.txt, test_pretrained.txt. \ No newline at end of file diff --git a/applications/Chat/examples/community/easy_dataset.py b/applications/Chat/examples/community/easy_dataset.py new file mode 100644 index 000000000..15dd9a3cc --- /dev/null +++ b/applications/Chat/examples/community/easy_dataset.py @@ -0,0 +1,242 @@ +import copy +from typing import Dict, Sequence +from datasets import load_dataset +from torch.utils.data import Dataset +from transformers import AutoTokenizer +import torch +from tqdm import tqdm +import json + +from tqdm import tqdm +import json + +IGNORE_INDEX = -100 + + +def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer,max_length :int = 512) -> Dict: + """Tokenize a list of strings.""" + tokenized_list = [ + tokenizer( + text, + return_tensors="pt", + padding="longest", + max_length=max_length, + truncation=True, + ) for text in strings + ] + input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] + input_ids_lens = labels_lens = [ + tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list + ] + return dict( + input_ids=input_ids, + labels=labels, + input_ids_lens=input_ids_lens, + labels_lens=labels_lens, + ) + + +def preprocess( + sources: Sequence[str], + targets: Sequence[str], + tokenizer: AutoTokenizer, + max_length :int = 512 +) -> Dict: + """Preprocess the data by tokenizing.""" + examples = [s + t for s, t in zip(sources, targets)] + examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer,max_length) for strings in (examples, sources)] + input_ids = examples_tokenized["input_ids"] + labels = copy.deepcopy(input_ids) + for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): + label[:source_len] = IGNORE_INDEX + return dict(input_ids=input_ids, labels=labels) + + +class EasySupervisedDataset(Dataset): + def __init__(self, data_file :str, tokenizer :AutoTokenizer,max_length :int = 512) -> None: + super(EasySupervisedDataset,self).__init__() + with open(data_file,"r",encoding="UTF-8") as f: + all_lines = f.readlines() + #split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:" + sources,targets = [],[] + for line in all_lines: + if "回答:" in line: + sep_index = line.index("回答:") + sources.append(line[:sep_index+3]) + targets.append(line[sep_index+3:]+tokenizer.eos_token) + else: + sources.append(line) + targets.append(""+tokenizer.eos_token) + data_dict = preprocess(sources, targets, tokenizer,max_length) + + self.input_ids = data_dict["input_ids"] + self.labels = data_dict["labels"] + self.data_file = data_file + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + return dict(input_ids=self.input_ids[i], labels=self.labels[i]) + + def __repr__(self): + return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})" + + def __str__(self): + return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})" + +class EasyPromptsDataset(Dataset): + def __init__(self,data_file :str, tokenizer :AutoTokenizer, max_length :int = 96) -> None: + super(EasyPromptsDataset,self).__init__() + with open(data_file,"r",encoding="UTF-8") as f: + all_lines = f.readlines() + all_lines = [line if "回答:" not in line else line[:line.index("回答:")+3] for line in all_lines] + self.prompts = [ + tokenizer(line, + return_tensors='pt', + max_length=max_length, + padding='max_length', + truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0) + for line in tqdm(all_lines) + ] + self.data_file = data_file + def __len__(self): + return len(self.prompts) + + def __getitem__(self, idx): + return self.prompts[idx] + + def __repr__(self): + return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})" + + def __str__(self): + return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})" + + +class EasyRewardDataset(Dataset): + def __init__(self,train_file :str,tokenizer :AutoTokenizer, special_token = None,max_length = 512) -> None: + super(EasyRewardDataset,self).__init__() + self.chosen = [] + self.reject = [] + if special_token is None: + self.end_token = tokenizer.eos_token + else: + self.end_token = special_token + print(self.end_token) + #read all lines in the train_file to a list + with open(train_file,"r",encoding="UTF-8") as f: + all_lines = f.readlines() + for line in tqdm(all_lines): + data = json.loads(line) + prompt = "提问:"+data['prompt']+" 回答:" + + chosen = prompt + data['chosen'] + self.end_token + chosen_token = tokenizer(chosen, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt") + self.chosen.append({ + "input_ids": chosen_token['input_ids'], + "attention_mask": chosen_token['attention_mask'] + }) + + reject = prompt + data['rejected'] + self.end_token + reject_token = tokenizer(reject, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt") + self.reject.append({ + "input_ids": reject_token['input_ids'], + "attention_mask": reject_token['attention_mask'] + }) + + def __len__(self): + length = len(self.chosen) + return length + + def __getitem__(self, idx): + return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][ + "input_ids"], self.reject[idx]["attention_mask"] + + #python representation of the object and the string representation of the object + def __repr__(self): + return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})" + + def __str__(self): + return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})" + +''' +Easy SFT just accept a text file which can be read line by line. However the datasest will group texts together to max_length so LLM will learn the texts meaning better. +If individual lines are not related, just set is_group_texts to False. +''' +class EasySFTDataset(Dataset): + + def __init__(self,data_file :str,tokenizer :AutoTokenizer,max_length = 512,is_group_texts = True) -> None: + super().__init__() + #read the data_file line by line + with open(data_file,"r",encoding="UTF-8") as f: + #encode the text data line by line and put raw python list input_ids only to raw_input_ids list + raw_input_ids = [] + for line in f: + encoded_ids = tokenizer.encode(line) + #if the encoded_ids is longer than max_length, then split it into several parts + if len(encoded_ids) > max_length: + for i in range(0,len(encoded_ids),max_length): + raw_input_ids.append(encoded_ids[i:i+max_length]) + else: + raw_input_ids.append(encoded_ids) + + grouped_inpup_ids = [] + current_input_ids = [] + attention_mask = [] + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + if is_group_texts: + for input_ids in raw_input_ids: + if len(current_input_ids) + len(input_ids) > max_length: + #pad the current_input_ids to max_length with tokenizer.pad_token_id + padded_length = max_length - len(current_input_ids) + current_input_ids.extend([tokenizer.pad_token_id] * padded_length) + grouped_inpup_ids.append(torch.tensor(current_input_ids,dtype=torch.long)) + attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long)) + current_input_ids = [] + else: + current_input_ids.extend(input_ids) + if len(current_input_ids) > 0: + padded_length = max_length - len(current_input_ids) + current_input_ids.extend([tokenizer.pad_token_id] * padded_length) + grouped_inpup_ids.append(torch.tensor(current_input_ids,dtype=torch.long)) + attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long)) + else: + #just append the raw_input_ids to max_length + for input_ids in raw_input_ids: + padded_length = max_length - len(input_ids) + input_ids.extend([tokenizer.pad_token_id] * padded_length) + attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long)) + grouped_inpup_ids.append(torch.tensor(input_ids,dtype=torch.long)) + self.input_ids = grouped_inpup_ids + self.labels = copy.deepcopy(self.input_ids) + self.file_name = data_file + self.attention_mask = attention_mask + + def __len__(self): + return len(self.input_ids) + + #get item from dataset + def __getitem__(self,idx): + return dict(input_ids=self.input_ids[idx],labels=self.labels[idx],attention_mask=self.attention_mask[idx]) + + #generate the dataset description to be printed by print in python + def __repr__(self): + return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})" + + #generate the dataset description to be printed by print in python + def __str__(self): + return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})" + + + + + \ No newline at end of file diff --git a/applications/Chat/examples/community/easy_models.py b/applications/Chat/examples/community/easy_models.py new file mode 100644 index 000000000..080fc1802 --- /dev/null +++ b/applications/Chat/examples/community/easy_models.py @@ -0,0 +1,97 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules import Module + +from coati.models.generation import generate +from coati.models.utils import log_probs_from_logits,masked_mean +from transformers import BloomConfig,BloomForCausalLM +from peft import PeftModel + +class Actor(Module): + """ + Actor model base class. + + Args: + model (nn.Module): Actor Model. + """ + + def __init__(self, model: nn.Module) -> None: + super().__init__() + self.model = model + + @torch.no_grad() + def generate( + self, + input_ids: torch.Tensor, + return_action_mask: bool = True, + **kwargs + ) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]: + sequences = generate(self.model, input_ids, **kwargs) + attention_mask = None + pad_token_id = kwargs.get('pad_token_id', None) + if pad_token_id is not None: + attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) + if not return_action_mask: + return sequences, attention_mask, None + input_len = input_ids.size(1) + eos_token_id = kwargs.get('eos_token_id', None) + if eos_token_id is None: + action_mask = torch.ones_like(sequences, dtype=torch.bool) + else: + # left padding may be applied, only mask action + action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0 + action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input + action_mask[:, :input_len] = False + action_mask = action_mask[:, 1:] + return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):] + + def forward(self, + sequences: torch.LongTensor, + num_actions: int, + attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Returns action log probs + """ + output = self.model(sequences, attention_mask=attention_mask) + logits = output['logits'] + log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) + return log_probs[:, -num_actions:] + + def get_base_model(self): + return self.model + + +class BLOOMActor(Actor): + """ + BLOOM Actor model. + + Args: + pretrained (str): Pretrained model name or path. + config (BloomConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: str = None, + config: Optional[BloomConfig] = None, + checkpoint: bool = False, + lora_path: str = None) -> None: + if pretrained is not None: + model = BloomForCausalLM.from_pretrained(pretrained) + elif config is not None: + model = BloomForCausalLM(config) + else: + model = BloomForCausalLM(BloomConfig()) + if lora_path is not None: + model = PeftModel.from_pretrained(model,lora_path) + if checkpoint: + model.gradient_checkpointing_enable() + super().__init__(model) + + def print_trainable_parameters(self): + self.get_base_model().print_trainable_parameters() + diff --git a/applications/Chat/examples/community/train_peft_prompts.py b/applications/Chat/examples/community/train_peft_prompts.py new file mode 100644 index 000000000..b9394c9e4 --- /dev/null +++ b/applications/Chat/examples/community/train_peft_prompts.py @@ -0,0 +1,227 @@ +import argparse + +import pandas as pd +import torch +import torch.distributed as dist +from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset +from coati.models.bloom import BLOOMRM, BLOOMCritic +from easy_models import BLOOMActor +from coati.models.gpt import GPTRM, GPTActor, GPTCritic +from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM +from coati.models.opt import OPTRM, OPTActor, OPTCritic +from coati.trainer import PPOTrainer +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from coati.utils import prepare_llama_tokenizer_and_embedding +from torch.optim import Adam +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer + +from colossalai.nn.optimizer import HybridAdam +from peft import PeftModel +from easy_dataset import EasyPromptsDataset,EasySupervisedDataset + +def main(args): + # configure strategy + if args.strategy == 'naive': + strategy = NaiveStrategy() + elif args.strategy == 'ddp': + strategy = DDPStrategy() + elif args.strategy == 'colossalai_gemini': + strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5) + elif args.strategy == 'colossalai_zero2': + strategy = ColossalAIStrategy(stage=2, placement_policy='cpu') + else: + raise ValueError(f'Unsupported strategy "{args.strategy}"') + + if args.rm_path is not None: + state_dict = torch.load(args.rm_path, map_location='cpu') + + # configure model + if args.model == 'bloom': + # initial_model = BLOOMActor(pretrained=args.pretrain) + print('Using peft lora to load Bloom model as inital_model') + initial_model = BLOOMActor(pretrained=args.pretrain,lora_path=args.sft_lora_path) + print('Using peft lora to load Bloom model as initial_model (Done)') + else: + raise ValueError(f'Unsupported actor model "{args.model}"') + + if args.rm_model == None: + rm_model_name = args.model + else: + rm_model_name = args.rm_model + + if rm_model_name == 'gpt2': + reward_model = GPTRM(pretrained=args.rm_pretrain) + elif rm_model_name == 'bloom': + print("load bloom reward model ",args.rm_pretrain) + reward_model = BLOOMRM(pretrained=args.rm_pretrain) + elif rm_model_name == 'opt': + reward_model = OPTRM(pretrained=args.rm_pretrain) + elif rm_model_name == 'llama': + reward_model = LlamaRM(pretrained=args.rm_pretrain) + else: + raise ValueError(f'Unsupported reward model "{rm_model_name}"') + + if args.rm_path is not None: + print('Loading reward model from', args.rm_path) + reward_model.load_state_dict(state_dict) + + if args.strategy != 'colossalai_gemini': + initial_model.to(torch.float16).to(torch.cuda.current_device()) + reward_model.to(torch.float16).to(torch.cuda.current_device()) + + with strategy.model_init_context(): + if args.model == 'bloom': + # actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank) + print('Using peft lora to load Bloom model as Actor') + actor = BLOOMActor(pretrained=args.pretrain,lora_path=args.sft_lora_path) + print('Using peft lora to load Bloom model as Actor (Done)') + else: + raise ValueError(f'Unsupported actor model "{args.model}"') + + if rm_model_name == 'gpt2': + critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + elif rm_model_name == 'bloom': + print("load bloom critic ",args.rm_pretrain," lora_rank ",args.lora_rank," use_action_mask ",True) + critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + print("load bloom critic (Done) ") + elif rm_model_name == 'opt': + critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + elif rm_model_name == 'llama': + critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + else: + raise ValueError(f'Unsupported reward model "{rm_model_name}"') + + if args.rm_path is not None: + print('Loading reward model from', args.rm_path) + critic.load_state_dict(state_dict) + del state_dict + + if args.strategy != 'colossalai_gemini': + critic.to(torch.float16).to(torch.cuda.current_device()) + actor.to(torch.float16).to(torch.cuda.current_device()) + + # configure optimizer + if args.strategy.startswith('colossalai'): + actor_optim = HybridAdam(actor.parameters(), lr=1e-7) + critic_optim = HybridAdam(critic.parameters(), lr=1e-7) + else: + actor_optim = Adam(actor.parameters(), lr=1e-7) + critic_optim = Adam(critic.parameters(), lr=1e-7) + + # configure tokenizer + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained(args.rm_pretrain) + elif args.model == 'bloom': + tokenizer = BloomTokenizerFast.from_pretrained(args.rm_pretrain) + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained(args.rm_pretrain) + elif args.model == 'llama': + tokenizer = LlamaTokenizer.from_pretrained(args.pretrain) + tokenizer.eos_token = '<\s>' + else: + raise ValueError(f'Unsupported model "{args.model}"') + + if args.model == 'llama': + tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, actor) + else: + tokenizer.pad_token = tokenizer.eos_token + + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + + prompt_dataset = EasyPromptsDataset(args.prompt_path,tokenizer) + if dist.is_initialized() and dist.get_world_size() > 1: + prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True) + else: + prompt_sampler = None + prompt_dataloader = DataLoader(prompt_dataset, + shuffle=(prompt_sampler is None), + sampler=prompt_sampler, + batch_size=args.train_batch_size) + + pretrain_dataset = EasySupervisedDataset(args.pretrain_dataset, tokenizer) + if dist.is_initialized() and dist.get_world_size() > 1: + pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True) + else: + pretrain_sampler = None + pretrain_dataloader = DataLoader(pretrain_dataset, + shuffle=(pretrain_sampler is None), + sampler=pretrain_sampler, + batch_size=args.ptx_batch_size, + collate_fn=data_collator) + + def tokenize_fn(texts): + # MUST padding to max length to ensure inputs of all ranks have the same length + # Different length may lead to hang when using gemini, as different generation steps + batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) + return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()} + + (actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim)) + + # configure trainer + trainer = PPOTrainer( + strategy, + actor, + critic, + reward_model, + initial_model, + actor_optim, + critic_optim, + kl_coef=args.kl_coef, + ptx_coef=args.ptx_coef, + max_epochs=args.max_epochs, + train_batch_size=args.train_batch_size, + experience_batch_size=args.experience_batch_size, + tokenizer=tokenize_fn, + max_length=512, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + + trainer.fit(prompt_dataloader=prompt_dataloader, + pretrain_dataloader=pretrain_dataloader, + num_episodes=args.num_episodes, + max_timesteps=args.max_timesteps, + update_timesteps=args.update_timesteps) + + # save model checkpoint after fitting + trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer) + # save optimizer checkpoint on all ranks + if args.need_optim_ckpt: + strategy.save_optimizer(actor_optim, + 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--prompt_path', type=str, default=None, help='path to the prompt dataset') + parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset') + parser.add_argument('--strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive', + help='strategy to use') + parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--sft_lora_path', type=str, default=None) + parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama']) + parser.add_argument('--rm_path', type=str, default=None) + parser.add_argument('--rm_pretrain', type=str, default=None) + parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts') + parser.add_argument('--need_optim_ckpt', type=bool, default=False) + parser.add_argument('--num_episodes', type=int, default=10) + parser.add_argument('--max_timesteps', type=int, default=10) + parser.add_argument('--update_timesteps', type=int, default=10) + parser.add_argument('--max_epochs', type=int, default=5) + parser.add_argument('--train_batch_size', type=int, default=2) + parser.add_argument('--ptx_batch_size', type=int, default=1) + parser.add_argument('--experience_batch_size', type=int, default=8) + parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument('--kl_coef', type=float, default=0.1) + parser.add_argument('--ptx_coef', type=float, default=0.9) + args = parser.parse_args() + main(args) diff --git a/applications/Chat/examples/community/train_peft_sft.py b/applications/Chat/examples/community/train_peft_sft.py new file mode 100644 index 000000000..65d901261 --- /dev/null +++ b/applications/Chat/examples/community/train_peft_sft.py @@ -0,0 +1,187 @@ +import argparse +import os + +import loralib as lora +import torch +import torch.distributed as dist +from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset +from coati.models.base import RewardModel +from coati.models.bloom import BLOOMLM +from coati.models.gpt import GPTLM +from coati.models.llama import LlamaLM +from coati.models.opt import OPTLM +from coati.trainer import SFTTrainer +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from coati.utils import prepare_llama_tokenizer_and_embedding +from datasets import load_dataset +from torch.optim import Adam +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from transformers import AutoTokenizer, BloomTokenizerFast,AutoModelForCausalLM +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer + +from colossalai.logging import get_dist_logger +from colossalai.nn.optimizer import HybridAdam +from colossalai.tensor import ColoParameter + +from torch.utils.data.dataloader import default_collate +from peft import LoraConfig, TaskType,get_peft_model,PeftModel +from easy_dataset import EasyDataset + +def train(args): + # configure strategy + if args.strategy == 'naive': + strategy = NaiveStrategy() + elif args.strategy == 'ddp': + strategy = DDPStrategy() + elif args.strategy == 'colossalai_gemini': + strategy = ColossalAIStrategy(stage=3, placement_policy='cuda') + elif args.strategy == 'colossalai_zero2': + strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + else: + raise ValueError(f'Unsupported strategy "{args.strategy}"') + + # configure model + with strategy.model_init_context(): + print('Warning: currently only bloom is tested, gpt2,llama and opt are not tested') + model = AutoModelForCausalLM.from_pretrained(args.pretrain).to(torch.cuda.current_device()) + #if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json + if os.path.exists(args.save_path) and os.path.exists(args.save_path+'/adapter_config.json') \ + and os.path.exists(args.save_path+'/adapter_model.bin'): + print("loading from saved peft model ",args.save_path) + model = PeftModel.from_pretrained(model, args.save_path) + else: + #we'll use peft lora library to do the lora + lora_rank = args.lora_rank if args.lora_rank > 0 else 32 + #config lora with rank of lora_rank + lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=lora_rank, lora_alpha=32, lora_dropout=0.1) + model = get_peft_model(model, lora_config) + model.print_trainable_parameters() + + + # configure tokenizer + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'bloom': + tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + elif args.model == 'llama': + tokenizer = AutoTokenizer.from_pretrained( + args.pretrain, + padding_side="right", + use_fast=False, + ) + tokenizer.eos_token = '<\s>' + else: + raise ValueError(f'Unsupported model "{args.model}"') + tokenizer.pad_token = tokenizer.eos_token + if args.model == 'llama': + tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model) + + if args.strategy == 'colossalai_gemini': + # this is a hack to deal with the resized embedding + # to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatiblity + for name, param in model.named_parameters(): + if not isinstance(param, ColoParameter): + sub_module_name = '.'.join(name.split('.')[:-1]) + weight_name = name.split('.')[-1] + sub_module = model.get_submodule(sub_module_name) + setattr(sub_module, weight_name, ColoParameter(param)) + else: + tokenizer.pad_token = tokenizer.eos_token + + # configure optimizer + if args.strategy.startswith('colossalai'): + optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0) + else: + optim = Adam(model.parameters(), lr=args.lr) + + logger = get_dist_logger() + logger.set_level('WARNING') + + # configure dataset + law_dataset = EasyDataset(args.dataset,tokenizer=tokenizer,is_group_texts=not args.is_short_text) + train_dataset = law_dataset + print(train_dataset) + eval_dataset = None + if args.eval_dataset is not None: + eval_dataset = EasyDataset(args.eval_dataset,tokenizer=tokenizer,is_group_texts=not args.is_short_text) + data_collator = default_collate + if dist.is_initialized() and dist.get_world_size() > 1: + train_sampler = DistributedSampler(train_dataset, + shuffle=True, + seed=42, + drop_last=True, + rank=dist.get_rank(), + num_replicas=dist.get_world_size()) + if eval_dataset is not None: + eval_sampler = DistributedSampler(eval_dataset, + shuffle=False, + seed=42, + drop_last=False, + rank=dist.get_rank(), + num_replicas=dist.get_world_size()) + else: + train_sampler = None + eval_sampler = None + + train_dataloader = DataLoader(train_dataset, + shuffle=(train_sampler is None), + sampler=train_sampler, + batch_size=args.batch_size, + collate_fn=data_collator, + pin_memory=True) + if eval_dataset is not None: + eval_dataloader = DataLoader(eval_dataset, + shuffle=(eval_sampler is None), + sampler=eval_sampler, + batch_size=args.batch_size, + collate_fn=data_collator, + pin_memory=True) + else: + eval_dataloader = None + + trainer = SFTTrainer(model=model, + strategy=strategy, + optim=optim, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + batch_size=args.batch_size, + max_epochs=args.max_epochs, + accimulation_steps=args.accimulation_steps) + + trainer.fit(logger=logger, log_interval=args.log_interval) + + # save model checkpoint after fitting on only rank0 + trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer) + # save optimizer checkpoint on all ranks + if args.need_optim_ckpt: + strategy.save_optimizer(trainer.optimizer, + 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive') + parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom') + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--dataset', type=str, default=None) + parser.add_argument('--eval_dataset', type=str, default=None) + parser.add_argument('--save_path', type=str, default='output') + parser.add_argument('--need_optim_ckpt', type=bool, default=False) + parser.add_argument('--max_epochs', type=int, default=3) + parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log") + parser.add_argument('--lr', type=float, default=5e-6) + parser.add_argument('--accimulation_steps', type=int, default=8) + parser.add_argument('--enable_peft_lora',action='store_true', default=False) + parser.add_argument("--is_short_text",action='store_true', default=False) + args = parser.parse_args() + train(args)