diff --git a/applications/ChatGPT/chatgpt/trainer/rm.py b/applications/ChatGPT/chatgpt/trainer/rm.py index b76ae5373..f6639edcb 100644 --- a/applications/ChatGPT/chatgpt/trainer/rm.py +++ b/applications/ChatGPT/chatgpt/trainer/rm.py @@ -3,10 +3,13 @@ from abc import ABC import loralib as lora from chatgpt.dataset import RewardDataset from chatgpt.nn import PairWiseLoss -from torch.optim import Adam +from torch.optim import Adam, Optimizer from torch.utils.data import DataLoader from tqdm import tqdm +from .strategies import Strategy +from .utils import is_rank_0 + class RewardModelTrainer(ABC): """ @@ -14,32 +17,41 @@ class RewardModelTrainer(ABC): Args: model (torch.nn.Module): the model to train + strategy (Strategy): the strategy to use for training + optim(Optimizer): the optimizer to use for training train_dataset (RewardDataset): the dataset to use for training eval_dataset (RewardDataset): the dataset to use for evaluation batch_size (int, defaults to 1): the batch size while training - num_epochs (int, defaults to 2): the number of epochs to train + max_epochs (int, defaults to 2): the number of epochs to train optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer """ - def __init__(self, - model, - train_dataset: RewardDataset, - eval_dataset: RewardDataset, - batch_size: int = 1, - num_epochs: int = 2, - optim_kwargs: dict = {'lr': 1e-4}) -> None: + def __init__( + self, + model, + strategy: Strategy, + optim: Optimizer, + train_dataset: RewardDataset, + eval_dataset: RewardDataset, + batch_size: int = 1, + max_epochs: int = 2, + ) -> None: super().__init__() - self.model = model + self.strategy = strategy + self.epochs = max_epochs self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size) self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size) + + self.model = strategy.setup_model(model) self.loss_fn = PairWiseLoss() - self.optimizer = Adam(self.model.parameters(), **optim_kwargs) - self.epochs = num_epochs + self.optimizer = strategy.setup_optimizer(optim, self.model) def fit(self, use_lora): - epoch_bar = tqdm(range(self.epochs), desc='Train epoch') + epoch_bar = tqdm(range(self.epochs), desc='Train epoch', disable=not is_rank_0()) for epoch in range(self.epochs): - step_bar = tqdm(range(self.train_dataloader.__len__()), desc='Train step of epoch %d' % epoch) + step_bar = tqdm(range(self.train_dataloader.__len__()), + desc='Train step of epoch %d' % epoch, + disable=not is_rank_0()) # train if use_lora > 0: print("Using Lora") @@ -54,8 +66,8 @@ class RewardModelTrainer(ABC): chosen_reward = self.model(chosen_ids, attention_mask=c_mask) reject_reward = self.model(reject_ids, attention_mask=r_mask) loss = self.loss_fn(chosen_reward, reject_reward) - loss.backward() - self.optimizer.step() + self.strategy.backward(loss, self.model, self.optimizer) + self.strategy.optimizer_step(self.optimizer) self.optimizer.zero_grad() step_bar.update() step_bar.set_postfix({'loss': loss.item()}) diff --git a/applications/ChatGPT/examples/train_dummy.sh b/applications/ChatGPT/examples/train_dummy.sh index 559d338ee..595da573e 100755 --- a/applications/ChatGPT/examples/train_dummy.sh +++ b/applications/ChatGPT/examples/train_dummy.sh @@ -13,6 +13,6 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" } -set_n_least_used_CUDA_VISIBLE_DEVICES 1 +set_n_least_used_CUDA_VISIBLE_DEVICES 2 -python train_dummy.py --model bloom --pretrain '/data2/users/lczht/bloom-560m' --lora_rank 16 +torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy colossalai_zero2 diff --git a/applications/ChatGPT/examples/train_prompts.sh b/applications/ChatGPT/examples/train_prompts.sh index 0b82d3f1c..db73ac8e8 100755 --- a/applications/ChatGPT/examples/train_prompts.sh +++ b/applications/ChatGPT/examples/train_prompts.sh @@ -13,6 +13,6 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" } -set_n_least_used_CUDA_VISIBLE_DEVICES 1 +set_n_least_used_CUDA_VISIBLE_DEVICES 2 -python train_prompts.py prompts.csv --pretrain '/data2/users/lczht/bloom-560m' --lora_rank 16 +torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2 diff --git a/applications/ChatGPT/examples/train_reward_model.py b/applications/ChatGPT/examples/train_reward_model.py index fd78a2ac6..47688325e 100644 --- a/applications/ChatGPT/examples/train_reward_model.py +++ b/applications/ChatGPT/examples/train_reward_model.py @@ -5,33 +5,55 @@ import torch from chatgpt.dataset import RewardDataset from chatgpt.nn import BLOOMRM from chatgpt.trainer import RewardModelTrainer +from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from datasets import load_dataset +from torch.optim import Adam from transformers import BloomTokenizerFast +from colossalai.nn.optimizer import HybridAdam + def train(args): + # configure strategy + if args.strategy == 'naive': + strategy = NaiveStrategy() + elif args.strategy == 'ddp': + strategy = DDPStrategy() + elif args.strategy == 'colossalai_gemini': + strategy = ColossalAIStrategy(stage=3, placement_policy='cuda') + elif args.strategy == 'colossalai_zero2': + strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + else: + raise ValueError(f'Unsupported strategy "{args.strategy}"') + + # configure model tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) tokenizer.pad_token = tokenizer.eos_token - model = BLOOMRM(pretrained=args.pretrain) - - model.cuda() - + model = BLOOMRM(pretrained=args.pretrain).cuda() max_len = 1024 + # configure optimizer + if args.strategy.startswith('colossalai'): + optim = HybridAdam(model.parameters(), lr=5e-5) + else: + optim = Adam(model.parameters(), lr=5e-5) + # prepare for data and dataset data = load_dataset(args.dataset) - train_data = data["train"] - eval_data = data['test'] + train_data = data["train"].select(range(100)) + eval_data = data['test'].select(range(5)) train_dataset = RewardDataset(train_data, tokenizer, max_len) eval_dataset = RewardDataset(eval_data, tokenizer, max_len) # batch_size here is expected to be C(k,2), k means # response of each prompt # be limited with the format of dataset 'Dahoas/rm-static', we'd better use batch_size as 1 trainer = RewardModelTrainer(model=model, + strategy=strategy, + optim=optim, train_dataset=train_dataset, eval_dataset=eval_dataset, batch_size=args.batch_size, - num_epochs=args.max_epochs) + max_epochs=args.max_epochs) trainer.fit(use_lora=args.lora_rank) @@ -43,6 +65,9 @@ def train(args): if __name__ == '__main__': parser = argparse.ArgumentParser() + parser.add_argument('--strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive') parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--dataset', type=str, default='Dahoas/rm-static') parser.add_argument('--save_path', type=str, default='rm_ckpt.pth') diff --git a/applications/ChatGPT/examples/train_rm.sh b/applications/ChatGPT/examples/train_rm.sh index bf46d7e43..ed91deee2 100755 --- a/applications/ChatGPT/examples/train_rm.sh +++ b/applications/ChatGPT/examples/train_rm.sh @@ -13,6 +13,6 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" } -set_n_least_used_CUDA_VISIBLE_DEVICES 1 +set_n_least_used_CUDA_VISIBLE_DEVICES 2 -python train_reward_model.py --pretrain '/data2/users/lczht/bloom-560m' --lora_rank 16 +torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain '/data2/users/lczht/bloom-560m' --strategy colossalai_zero2