mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 04:02:17 +00:00
[chatgpt] Add saving ckpt callback for PPO (#2880)
* add checkpoint callback for chatgpt * add save ckpt callbacks for ppo --------- Co-authored-by: Fazzie-Maqianli <55798671+Fazziekey@users.noreply.github.com>
This commit is contained in:
parent
e588703454
commit
287d60499e
@ -1,4 +1,5 @@
|
|||||||
from .base import Callback
|
from .base import Callback
|
||||||
from .performance_evaluator import PerformanceEvaluator
|
from .performance_evaluator import PerformanceEvaluator
|
||||||
|
from .save_checkpoint import SaveCheckpoint
|
||||||
|
|
||||||
__all__ = ['Callback', 'PerformanceEvaluator']
|
__all__ = ['Callback', 'PerformanceEvaluator', 'SaveCheckpoint']
|
||||||
|
@ -0,0 +1,75 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import torch.distributed as dist
|
||||||
|
from chatgpt.trainer.strategies import ColossalAIStrategy, Strategy
|
||||||
|
from chatgpt.trainer.utils import is_rank_0
|
||||||
|
from torch import nn
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
from .base import Callback
|
||||||
|
|
||||||
|
|
||||||
|
class SaveCheckpoint(Callback):
|
||||||
|
"""
|
||||||
|
The callback for saving checkpoint for chatgpt.
|
||||||
|
|
||||||
|
Only support saving actor and critic model.
|
||||||
|
A typical architecture of the saved checkpoint would be:
|
||||||
|
- checkpoint
|
||||||
|
- episode_x
|
||||||
|
- actor.pt
|
||||||
|
- actor-optim-rank-0.pt
|
||||||
|
- actor-optim-rank-1.pt
|
||||||
|
- critic.pt
|
||||||
|
- critic-optim-rank-0.pt
|
||||||
|
- critic-optim-rank-1.pt
|
||||||
|
- ...
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path(str): the base path you want to save checkpoint, the checkpoint would be saved at `path/checkpoint`
|
||||||
|
interval(int): the interval episode of saving checkpoint
|
||||||
|
strategy(Strategy): the strategy used to train
|
||||||
|
actor(nn.Module): the actor model
|
||||||
|
critic(nn.Module): the critic model
|
||||||
|
actor_optim(Optimizer): the optimizer of actor
|
||||||
|
critic_optim(Optimizer): the optimizer of critic
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
path: str,
|
||||||
|
interval: int,
|
||||||
|
strategy: Strategy,
|
||||||
|
actor: nn.Module = None,
|
||||||
|
critic: nn.Module = None,
|
||||||
|
actor_optim: Optimizer = None,
|
||||||
|
critic_optim: Optimizer = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.path = os.path.join(path, 'checkpoint')
|
||||||
|
self.interval = interval
|
||||||
|
self.strategy = strategy
|
||||||
|
self.model_dict = {'actor': [actor, actor_optim], 'critic': [critic, critic_optim]}
|
||||||
|
|
||||||
|
def on_episode_end(self, episode: int) -> None:
|
||||||
|
if (episode + 1) % self.interval != 0:
|
||||||
|
return
|
||||||
|
base_path = os.path.join(self.path, f'episode_{episode}')
|
||||||
|
if not os.path.exists(base_path):
|
||||||
|
os.makedirs(base_path)
|
||||||
|
|
||||||
|
for model in self.model_dict.keys():
|
||||||
|
|
||||||
|
# save model
|
||||||
|
if self.model_dict[model][0] is None:
|
||||||
|
# saving only optimizer states is meaningless, so it would be skipped
|
||||||
|
continue
|
||||||
|
model_path = os.path.join(base_path, f'{model}.pt')
|
||||||
|
self.strategy.save_model(model=self.model_dict[model][0], path=model_path, only_rank0=True)
|
||||||
|
|
||||||
|
# save optimizer
|
||||||
|
if self.model_dict[model][1] is None:
|
||||||
|
continue
|
||||||
|
only_rank0 = not isinstance(self.strategy, ColossalAIStrategy)
|
||||||
|
rank = 0 if is_rank_0() else dist.get_rank()
|
||||||
|
optim_path = os.path.join(base_path, f'{model}-optim-rank-{rank}.pt')
|
||||||
|
self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0)
|
@ -4,6 +4,7 @@ from copy import deepcopy
|
|||||||
import torch
|
import torch
|
||||||
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
|
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
|
||||||
from chatgpt.trainer import PPOTrainer
|
from chatgpt.trainer import PPOTrainer
|
||||||
|
from chatgpt.trainer.callbacks import SaveCheckpoint
|
||||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from transformers import AutoTokenizer, BloomTokenizerFast
|
from transformers import AutoTokenizer, BloomTokenizerFast
|
||||||
@ -71,26 +72,38 @@ def main(args):
|
|||||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
||||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
||||||
|
|
||||||
|
callbacks = []
|
||||||
|
if args.save_ckpt_path:
|
||||||
|
ckpt_callback = SaveCheckpoint(
|
||||||
|
args.save_ckpt_path,
|
||||||
|
args.save_ckpt_interval,
|
||||||
|
strategy,
|
||||||
|
actor,
|
||||||
|
critic,
|
||||||
|
actor_optim,
|
||||||
|
critic_optim,
|
||||||
|
)
|
||||||
|
callbacks.append(ckpt_callback)
|
||||||
|
|
||||||
# configure trainer
|
# configure trainer
|
||||||
trainer = PPOTrainer(
|
|
||||||
strategy,
|
trainer = PPOTrainer(strategy,
|
||||||
actor,
|
actor,
|
||||||
critic,
|
critic,
|
||||||
reward_model,
|
reward_model,
|
||||||
initial_model,
|
initial_model,
|
||||||
actor_optim,
|
actor_optim,
|
||||||
critic_optim,
|
critic_optim,
|
||||||
max_epochs=args.max_epochs,
|
max_epochs=args.max_epochs,
|
||||||
train_batch_size=args.train_batch_size,
|
train_batch_size=args.train_batch_size,
|
||||||
experience_batch_size=args.experience_batch_size,
|
tokenizer=preprocess_batch,
|
||||||
tokenizer=preprocess_batch,
|
max_length=128,
|
||||||
max_length=128,
|
do_sample=True,
|
||||||
do_sample=True,
|
temperature=1.0,
|
||||||
temperature=1.0,
|
top_k=50,
|
||||||
top_k=50,
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
eos_token_id=tokenizer.eos_token_id,
|
callbacks=callbacks)
|
||||||
)
|
|
||||||
|
|
||||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 64), device=torch.cuda.current_device())
|
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 64), device=torch.cuda.current_device())
|
||||||
trainer.fit(random_prompts,
|
trainer.fit(random_prompts,
|
||||||
@ -120,5 +133,10 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||||
|
parser.add_argument('--save_ckpt_path',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="path to save checkpoint, None means not to save")
|
||||||
|
parser.add_argument('--save_ckpt_interval', type=int, default=1, help="the interval of episode to save checkpoint")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
Loading…
Reference in New Issue
Block a user