mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-11-28 03:46:58 +00:00
[chat] refactor trainer (#3648)
* [chat] ppo trainer remove useless args * [chat] update examples * [chat] update benchmark * [chat] update examples * [chat] fix sft training with wandb * [chat] polish docstr
This commit is contained in:
@@ -140,8 +140,7 @@ def main(args):
|
||||
ptx_coef=0,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
tokenizer=preprocess_batch,
|
||||
offload_inference_models=args.offload_inference_models,
|
||||
max_length=512,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
@@ -179,10 +178,11 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--num_episodes', type=int, default=3)
|
||||
parser.add_argument('--max_timesteps', type=int, default=8)
|
||||
parser.add_argument('--update_timesteps', type=int, default=8)
|
||||
parser.add_argument('--max_epochs', type=int, default=3)
|
||||
parser.add_argument('--max_epochs', type=int, default=1)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0)
|
||||
parser.add_argument('--cuda_mem_frac', type=float, default=1.0)
|
||||
parser.add_argument('--offload_inference_models', action='store_true', default=False)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user