mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
reconstruct chat trainer and fix training script (#3588)
Co-authored-by: Yuanchen Xu <yuanchen.xu00@gmail.com>
This commit is contained in:
@@ -114,8 +114,10 @@ def main(args):
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
callbacks=callbacks)
|
||||
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 64), device=torch.cuda.current_device())
|
||||
trainer.fit(random_prompts,
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 1, 64), device=torch.cuda.current_device())
|
||||
random_attention_mask = torch.randint(1, (1000, 1, 64), device=torch.cuda.current_device()).to(torch.bool)
|
||||
random_pretrain = [{'input_ids':random_prompts[i], 'labels':random_prompts[i], 'attention_mask':random_attention_mask[i]} for i in range(1000)]
|
||||
trainer.fit(random_prompts, random_pretrain,
|
||||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
@@ -136,7 +138,7 @@ if __name__ == '__main__':
|
||||
default='naive')
|
||||
parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt', 'roberta'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_dummy.pt')
|
||||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_dummy')
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--num_episodes', type=int, default=50)
|
||||
parser.add_argument('--max_timesteps', type=int, default=10)
|
||||
|
@@ -3,6 +3,7 @@ from random import randint
|
||||
|
||||
import loralib as lora
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.dataset import HhRlhfDataset, RmStaticDataset
|
||||
from coati.models import LogExpLoss, LogSigLoss
|
||||
from coati.models.base import RewardModel
|
||||
@@ -17,6 +18,8 @@ from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrat
|
||||
from coati.utils import prepare_llama_tokenizer_and_embedding
|
||||
from datasets import load_dataset
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer, LlamaTokenizer, RobertaTokenizer
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
@@ -120,13 +123,38 @@ def train(args):
|
||||
else:
|
||||
raise ValueError(f'Unsupported dataset "{args.dataset}"')
|
||||
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True, rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size())
|
||||
valid_sampler = DistributedSampler(valid_dataset, shuffle=True, seed=42, drop_last=True, rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size())
|
||||
eval_sampler = DistributedSampler(eval_dataset, shuffle=True, seed=42, drop_last=True, rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size())
|
||||
else:
|
||||
train_sampler = None
|
||||
valid_sampler = None
|
||||
eval_sampler = None
|
||||
|
||||
train_dataloader = DataLoader(train_dataset,
|
||||
shuffle=(train_sampler is None),
|
||||
sampler=train_sampler,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True)
|
||||
|
||||
valid_dataloader = DataLoader(valid_dataset, shuffle=(valid_sampler is None),
|
||||
sampler=valid_sampler,
|
||||
batch_size=args.batch_size, pin_memory=True)
|
||||
|
||||
eval_dataloader = DataLoader(eval_dataset, shuffle=(eval_sampler is None),
|
||||
sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True)
|
||||
|
||||
trainer = RewardModelTrainer(model=model,
|
||||
strategy=strategy,
|
||||
optim=optim,
|
||||
loss_fn=loss_fn,
|
||||
train_dataset=train_dataset,
|
||||
valid_dataset=valid_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
train_dataloader=train_dataloader,
|
||||
valid_dataloader=valid_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
batch_size=args.batch_size,
|
||||
max_epochs=args.max_epochs)
|
||||
|
||||
|
Reference in New Issue
Block a user