mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 05:49:55 +00:00
[chatgpt] optimize generation kwargs (#2717)
* [chatgpt] ppo trainer use default generate args * [chatgpt] example remove generation preparing fn * [chatgpt] benchmark remove generation preparing fn * [chatgpt] fix ci
This commit is contained in:
@@ -5,7 +5,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from chatgpt.nn import GPTActor, GPTCritic, RewardModel
|
||||
from chatgpt.nn.generation_utils import gpt_prepare_inputs_fn, update_model_kwargs_fn
|
||||
from chatgpt.trainer import PPOTrainer
|
||||
from chatgpt.trainer.callbacks import PerformanceEvaluator
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
|
||||
@@ -151,8 +150,6 @@ def main(args):
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
prepare_inputs_fn=gpt_prepare_inputs_fn,
|
||||
update_model_kwargs_fn=update_model_kwargs_fn,
|
||||
callbacks=[performance_evaluator])
|
||||
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())
|
||||
|
@@ -5,7 +5,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from chatgpt.nn import OPTActor, OPTCritic, RewardModel
|
||||
from chatgpt.nn.generation_utils import opt_prepare_inputs_fn, update_model_kwargs_fn
|
||||
from chatgpt.trainer import PPOTrainer
|
||||
from chatgpt.trainer.callbacks import PerformanceEvaluator
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
|
||||
@@ -144,8 +143,6 @@ def main(args):
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
prepare_inputs_fn=opt_prepare_inputs_fn,
|
||||
update_model_kwargs_fn=update_model_kwargs_fn,
|
||||
callbacks=[performance_evaluator])
|
||||
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())
|
||||
|
Reference in New Issue
Block a user