reconstruct chat trainer and fix training script (#3588)

Co-authored-by: Yuanchen Xu <yuanchen.xu00@gmail.com>
This commit is contained in:
Yuanchen
2023-04-18 16:44:03 +08:00
committed by GitHub
parent dac127d0ee
commit 1ec0d386a9
8 changed files with 163 additions and 137 deletions

View File

@@ -114,8 +114,10 @@ def main(args):
eos_token_id=tokenizer.eos_token_id,
callbacks=callbacks)
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 64), device=torch.cuda.current_device())
trainer.fit(random_prompts,
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 1, 64), device=torch.cuda.current_device())
random_attention_mask = torch.randint(1, (1000, 1, 64), device=torch.cuda.current_device()).to(torch.bool)
random_pretrain = [{'input_ids':random_prompts[i], 'labels':random_prompts[i], 'attention_mask':random_attention_mask[i]} for i in range(1000)]
trainer.fit(random_prompts, random_pretrain,
num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps)
@@ -136,7 +138,7 @@ if __name__ == '__main__':
default='naive')
parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt', 'roberta'])
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--save_path', type=str, default='actor_checkpoint_dummy.pt')
parser.add_argument('--save_path', type=str, default='actor_checkpoint_dummy')
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
parser.add_argument('--num_episodes', type=int, default=50)
parser.add_argument('--max_timesteps', type=int, default=10)

View File

@@ -3,6 +3,7 @@ from random import randint
import loralib as lora
import torch
import torch.distributed as dist
from coati.dataset import HhRlhfDataset, RmStaticDataset
from coati.models import LogExpLoss, LogSigLoss
from coati.models.base import RewardModel
@@ -17,6 +18,8 @@ from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrat
from coati.utils import prepare_llama_tokenizer_and_embedding
from datasets import load_dataset
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer, LlamaTokenizer, RobertaTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
@@ -120,13 +123,38 @@ def train(args):
else:
raise ValueError(f'Unsupported dataset "{args.dataset}"')
if dist.is_initialized() and dist.get_world_size() > 1:
train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True, rank=dist.get_rank(),
num_replicas=dist.get_world_size())
valid_sampler = DistributedSampler(valid_dataset, shuffle=True, seed=42, drop_last=True, rank=dist.get_rank(),
num_replicas=dist.get_world_size())
eval_sampler = DistributedSampler(eval_dataset, shuffle=True, seed=42, drop_last=True, rank=dist.get_rank(),
num_replicas=dist.get_world_size())
else:
train_sampler = None
valid_sampler = None
eval_sampler = None
train_dataloader = DataLoader(train_dataset,
shuffle=(train_sampler is None),
sampler=train_sampler,
batch_size=args.batch_size,
pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, shuffle=(valid_sampler is None),
sampler=valid_sampler,
batch_size=args.batch_size, pin_memory=True)
eval_dataloader = DataLoader(eval_dataset, shuffle=(eval_sampler is None),
sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True)
trainer = RewardModelTrainer(model=model,
strategy=strategy,
optim=optim,
loss_fn=loss_fn,
train_dataset=train_dataset,
valid_dataset=valid_dataset,
eval_dataset=eval_dataset,
train_dataloader=train_dataloader,
valid_dataloader=valid_dataloader,
eval_dataloader=eval_dataloader,
batch_size=args.batch_size,
max_epochs=args.max_epochs)