mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 13:59:08 +00:00
[chatgpt] startegy add prepare method (#2766)
* [chatgpt] startegy add prepare method * [chatgpt] refactor examples * [chatgpt] refactor strategy.prepare * [chatgpt] support save/load checkpoint * [chatgpt] fix unwrap actor * [chatgpt] fix unwrap actor
This commit is contained in:
@@ -68,6 +68,9 @@ def main(args):
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
||||
|
||||
# configure trainer
|
||||
trainer = PPOTrainer(
|
||||
strategy,
|
||||
|
@@ -68,6 +68,9 @@ def main(args):
|
||||
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True)
|
||||
return {k: v.cuda() for k, v in batch.items()}
|
||||
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
||||
|
||||
# configure trainer
|
||||
trainer = PPOTrainer(
|
||||
strategy,
|
||||
|
Reference in New Issue
Block a user