[chatgpt] startegy add prepare method (#2766)

* [chatgpt] startegy add prepare method

* [chatgpt] refactor examples

* [chatgpt] refactor strategy.prepare

* [chatgpt] support save/load checkpoint

* [chatgpt] fix unwrap actor

* [chatgpt] fix unwrap actor
This commit is contained in:
ver217
2023-02-17 11:27:27 +08:00
committed by GitHub
parent a2b43e393d
commit 4ee311c026
9 changed files with 164 additions and 15 deletions

View File

@@ -68,6 +68,9 @@ def main(args):
else:
raise ValueError(f'Unsupported model "{args.model}"')
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
# configure trainer
trainer = PPOTrainer(
strategy,

View File

@@ -68,6 +68,9 @@ def main(args):
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True)
return {k: v.cuda() for k, v in batch.items()}
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
# configure trainer
trainer = PPOTrainer(
strategy,