diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py index 1a6b04d43..4518fd71f 100755 --- a/applications/ColossalChat/coati/dataset/loader.py +++ b/applications/ColossalChat/coati/dataset/loader.py @@ -359,8 +359,6 @@ def apply_chat_template_and_mask( system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the tags, i.e., 123 .\n\n" - - system_element = { "role": "system", "content": system_prompt, diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index a99198d7e..1e85cccb3 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -1,6 +1,7 @@ +import os from contextlib import nullcontext from typing import Any, Dict, Optional -import os + import ray import ray.util.collective as cc import torch @@ -34,7 +35,7 @@ class BaseConsumer: plugin_config: Dict[str, Any], microbatch_size: int = 1, save_interval: int = 100, - save_dir: str = "./model" + save_dir: str = "./model", ): self.num_producers = num_producers self.num_episodes = num_episodes diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index ead0c86e0..15f7e340e 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -79,7 +79,7 @@ class GRPOConsumer(BaseConsumer): self.policy_loss_fn = PolicyLoss() self.global_step = 0 - if use_wandb and self.rank == 0: + if use_wandb and self.rank == 0: self.wandb_run = wandb.init(project="GRPO-Test", sync_tensorboard=True) def setup(self):