diff --git a/applications/Chat/examples/train_prompts.py b/applications/Chat/examples/train_prompts.py index ff3cf78dc..6643796d7 100644 --- a/applications/Chat/examples/train_prompts.py +++ b/applications/Chat/examples/train_prompts.py @@ -1,19 +1,20 @@ import argparse + import pandas as pd import torch import torch.distributed as dist -from coati.models.bloom import BLOOMActor, BLOOMRM, BLOOMCritic -from coati.models.gpt import GPTActor, GPTRM, GPTCritic -from coati.models.opt import OPTActor, OPTRM, OPTCritic -from coati.models.llama import LlamaActor, LlamaRM, LlamaCritic +from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset +from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic +from coati.models.gpt import GPTRM, GPTActor, GPTCritic +from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM +from coati.models.opt import OPTRM, OPTActor, OPTCritic from coati.trainer import PPOTrainer from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from coati.utils import prepare_llama_tokenizer_and_embedding from torch.optim import Adam from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, GPT2Tokenizer -from coati.dataset import SupervisedDataset, DataCollatorForSupervisedDataset, PromptDataset -from coati.utils import prepare_llama_tokenizer_and_embedding +from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer from colossalai.nn.optimizer import HybridAdam @@ -45,12 +46,12 @@ def main(args): initial_model = LlamaActor(pretrained=args.pretrain) else: raise ValueError(f'Unsupported actor model "{args.model}"') - + if args.rm_model == None: rm_model_name = args.model else: rm_model_name = args.rm_model - + if rm_model_name == 'gpt2': reward_model = GPTRM(pretrained=args.rm_pretrain) elif rm_model_name == 'bloom': @@ -61,15 +62,14 @@ def main(args): reward_model = LlamaRM(pretrained=args.rm_pretrain) else: raise ValueError(f'Unsupported reward model "{rm_model_name}"') - - + if args.rm_path is not None: reward_model.load_state_dict(state_dict) - + if args.strategy != 'colossalai_gemini': - initial_model.to(torch.float16).to(torch.cuda.current_device()) - reward_model.to(torch.float16).to(torch.cuda.current_device()) - + initial_model.to(torch.float16).to(torch.cuda.current_device()) + reward_model.to(torch.float16).to(torch.cuda.current_device()) + with strategy.model_init_context(): if args.model == 'gpt2': actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank) @@ -81,7 +81,7 @@ def main(args): actor = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank) else: raise ValueError(f'Unsupported actor model "{args.model}"') - + if rm_model_name == 'gpt2': critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) elif rm_model_name == 'bloom': @@ -92,11 +92,11 @@ def main(args): critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) else: raise ValueError(f'Unsupported reward model "{rm_model_name}"') - + if args.rm_path is not None: critic.load_state_dict(state_dict) del state_dict - + if args.strategy != 'colossalai_gemini': critic.to(torch.float16).to(torch.cuda.current_device()) actor.to(torch.float16).to(torch.cuda.current_device()) @@ -121,32 +121,38 @@ def main(args): tokenizer.eos_token = '<\s>' else: raise ValueError(f'Unsupported model "{args.model}"') - + if args.model == 'llama': tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, actor) else: tokenizer.pad_token = tokenizer.eos_token - + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) - + prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_path, max_datasets_size=16384) if dist.is_initialized() and dist.get_world_size() > 1: prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True) - prompt_dataloader = DataLoader(prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.train_batch_size) - + prompt_dataloader = DataLoader(prompt_dataset, + shuffle=(prompt_sampler is None), + sampler=prompt_sampler, + batch_size=args.train_batch_size) + pretrain_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=args.pretrain_dataset, max_datasets_size=16384) if dist.is_initialized() and dist.get_world_size() > 1: pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True) - pretrain_dataloader = DataLoader(pretrain_dataset, shuffle=(pretrain_sampler is None), sampler=pretrain_sampler, batch_size=args.ptx_batch_size, collate_fn=data_collator) - + pretrain_dataloader = DataLoader(pretrain_dataset, + shuffle=(pretrain_sampler is None), + sampler=pretrain_sampler, + batch_size=args.ptx_batch_size, + collate_fn=data_collator) + def tokenize_fn(texts): # MUST padding to max length to ensure inputs of all ranks have the same length # Different length may lead to hang when using gemini, as different generation steps batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()} - - (actor, actor_optim), (critic, critic_optim) = strategy.prepare( - (actor, actor_optim), (critic, critic_optim)) + + (actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim)) # configure trainer trainer = PPOTrainer( @@ -192,7 +198,8 @@ if __name__ == '__main__': parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset') parser.add_argument('--strategy', choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], - default='naive', help='strategy to use') + default='naive', + help='strategy to use') parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama'])