[chatgpt] optimize generation kwargs (#2717)

* [chatgpt] ppo trainer use default generate args

* [chatgpt] example remove generation preparing fn

* [chatgpt] benchmark remove generation preparing fn

* [chatgpt] fix ci
This commit is contained in:
ver217
2023-02-15 13:59:58 +08:00
committed by GitHub
parent 21d6a48f4d
commit 9c0943ecdb
7 changed files with 48 additions and 52 deletions

View File

@@ -5,7 +5,6 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from chatgpt.nn import GPTActor, GPTCritic, RewardModel
from chatgpt.nn.generation_utils import gpt_prepare_inputs_fn, update_model_kwargs_fn
from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.callbacks import PerformanceEvaluator
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
@@ -151,8 +150,6 @@ def main(args):
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
prepare_inputs_fn=gpt_prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn,
callbacks=[performance_evaluator])
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())

View File

@@ -5,7 +5,6 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from chatgpt.nn import OPTActor, OPTCritic, RewardModel
from chatgpt.nn.generation_utils import opt_prepare_inputs_fn, update_model_kwargs_fn
from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.callbacks import PerformanceEvaluator
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
@@ -144,8 +143,6 @@ def main(args):
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
prepare_inputs_fn=opt_prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn,
callbacks=[performance_evaluator])
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())