[chat] refactor trainer (#3648)

* [chat] ppo trainer remove useless args

* [chat] update examples

* [chat] update benchmark

* [chat] update examples

* [chat] fix sft training with wandb

* [chat] polish docstr
This commit is contained in:
Hongxin Liu
2023-04-26 18:11:49 +08:00
committed by GitHub
parent f8288315d9
commit 2a951955ad
12 changed files with 72 additions and 536 deletions

View File

@@ -140,8 +140,7 @@ def main(args):
ptx_coef=0,
max_epochs=args.max_epochs,
train_batch_size=args.train_batch_size,
experience_batch_size=args.experience_batch_size,
tokenizer=preprocess_batch,
offload_inference_models=args.offload_inference_models,
max_length=512,
do_sample=True,
temperature=1.0,
@@ -179,10 +178,11 @@ if __name__ == '__main__':
parser.add_argument('--num_episodes', type=int, default=3)
parser.add_argument('--max_timesteps', type=int, default=8)
parser.add_argument('--update_timesteps', type=int, default=8)
parser.add_argument('--max_epochs', type=int, default=3)
parser.add_argument('--max_epochs', type=int, default=1)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0)
parser.add_argument('--cuda_mem_frac', type=float, default=1.0)
parser.add_argument('--offload_inference_models', action='store_true', default=False)
args = parser.parse_args()
main(args)