mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-05 03:26:48 +00:00
add save
This commit is contained in:
parent
0f566cc2d4
commit
0cc0c843ed
@ -357,7 +357,9 @@ def apply_chat_template_and_mask(
|
||||
ignore_idx: int = -100,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
|
||||
system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning.\n"
|
||||
system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n\n"
|
||||
|
||||
|
||||
|
||||
system_element = {
|
||||
"role": "system",
|
||||
|
@ -1,6 +1,6 @@
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import os
|
||||
import ray
|
||||
import ray.util.collective as cc
|
||||
import torch
|
||||
@ -33,6 +33,8 @@ class BaseConsumer:
|
||||
model_config: Dict[str, Any],
|
||||
plugin_config: Dict[str, Any],
|
||||
microbatch_size: int = 1,
|
||||
save_interval: int = 100,
|
||||
save_dir: str = "./model"
|
||||
):
|
||||
self.num_producers = num_producers
|
||||
self.num_episodes = num_episodes
|
||||
@ -44,6 +46,8 @@ class BaseConsumer:
|
||||
self.num_recv_per_update = num_recv_per_update
|
||||
self.batch_size = batch_size
|
||||
self.microbatch_size = microbatch_size
|
||||
self.save_interval = save_interval
|
||||
self.save_dir = save_dir
|
||||
assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size"
|
||||
self.num_microbatches = batch_size // microbatch_size
|
||||
|
||||
@ -116,6 +120,14 @@ class BaseConsumer:
|
||||
pbar.set_postfix({"loss": loss})
|
||||
i += 1
|
||||
assert len(self.buffer) == 0
|
||||
if (step + 1) % self.save_interval == 0:
|
||||
if self.rank == 0:
|
||||
print(f"Start saving policy model at step {step + 1}.")
|
||||
save_path = os.path.join(self.save_dir, f"modeling-step-{step + 1}")
|
||||
self.booster.save_model(self.policy_model, save_path, shard=True)
|
||||
if self.rank == 0:
|
||||
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
|
||||
|
||||
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
|
||||
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
|
||||
state_dict = self.state_dict()
|
||||
|
@ -32,7 +32,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
plugin_config,
|
||||
microbatch_size=1,
|
||||
num_generations=4,
|
||||
use_wandb=False,
|
||||
use_wandb=True,
|
||||
):
|
||||
super().__init__(
|
||||
num_producers,
|
||||
@ -79,7 +79,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
self.policy_loss_fn = PolicyLoss()
|
||||
self.global_step = 0
|
||||
if self.rank == 0:
|
||||
if use_wandb and self.rank == 0:
|
||||
self.wandb_run = wandb.init(project="GRPO-Test", sync_tensorboard=True)
|
||||
|
||||
def setup(self):
|
||||
|
Loading…
Reference in New Issue
Block a user