diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py index 35b3deced..1a6b04d43 100755 --- a/applications/ColossalChat/coati/dataset/loader.py +++ b/applications/ColossalChat/coati/dataset/loader.py @@ -357,7 +357,9 @@ def apply_chat_template_and_mask( ignore_idx: int = -100, ) -> Dict[str, torch.Tensor]: - system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a math problem that involves reasoning.\n" + system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the tags, i.e., 123 .\n\n" + + system_element = { "role": "system", diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 84a69979f..a99198d7e 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -1,6 +1,6 @@ from contextlib import nullcontext from typing import Any, Dict, Optional - +import os import ray import ray.util.collective as cc import torch @@ -33,6 +33,8 @@ class BaseConsumer: model_config: Dict[str, Any], plugin_config: Dict[str, Any], microbatch_size: int = 1, + save_interval: int = 100, + save_dir: str = "./model" ): self.num_producers = num_producers self.num_episodes = num_episodes @@ -44,6 +46,8 @@ class BaseConsumer: self.num_recv_per_update = num_recv_per_update self.batch_size = batch_size self.microbatch_size = microbatch_size + self.save_interval = save_interval + self.save_dir = save_dir assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size" self.num_microbatches = batch_size // microbatch_size @@ -116,6 +120,14 @@ class BaseConsumer: pbar.set_postfix({"loss": loss}) i += 1 assert len(self.buffer) == 0 + if (step + 1) % self.save_interval == 0: + if self.rank == 0: + print(f"Start saving policy model at step {step + 1}.") + save_path = os.path.join(self.save_dir, f"modeling-step-{step + 1}") + self.booster.save_model(self.policy_model, save_path, shard=True) + if self.rank == 0: + print(f"Saved model checkpoint at step {step + 1} in folder {save_path}") + if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1: print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") state_dict = self.state_dict() diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 49240d8da..ead0c86e0 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -32,7 +32,7 @@ class GRPOConsumer(BaseConsumer): plugin_config, microbatch_size=1, num_generations=4, - use_wandb=False, + use_wandb=True, ): super().__init__( num_producers, @@ -79,7 +79,7 @@ class GRPOConsumer(BaseConsumer): self.policy_loss_fn = PolicyLoss() self.global_step = 0 - if self.rank == 0: + if use_wandb and self.rank == 0: self.wandb_run = wandb.init(project="GRPO-Test", sync_tensorboard=True) def setup(self):