diff --git a/.gitignore b/.gitignore
index d48035a5b..0503f3c95 100644
--- a/.gitignore
+++ b/.gitignore
@@ -166,3 +166,4 @@ applications/ColossalChat/tests/logs
applications/ColossalChat/wandb
applications/ColossalChat/model
applications/ColossalChat/eval
+applications/ColossalChat/rollouts
diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py
index 8de8b774e..70e2201fe 100644
--- a/applications/ColossalChat/coati/distributed/grpo_consumer.py
+++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py
@@ -120,12 +120,16 @@ class GRPOConsumer(BaseConsumer):
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
)
# Initialize verifiable reward.
- response_format_tags = {
- "think_start": {"text": "", "num_occur": 1},
- "think_end": {"text": "", "num_occur": 1},
- "answer_start": {"text": "", "num_occur": 1},
- "answer_end": {"text": "", "num_occur": 1},
- }
+ response_format_tags = (
+ {
+ "think_start": {"text": "", "num_occur": 1},
+ "think_end": {"text": "", "num_occur": 1},
+ "answer_start": {"text": "", "num_occur": 1},
+ "answer_end": {"text": "", "num_occur": 1},
+ }
+ if grpo_config.get("reward_fn_type") == "think_answer_tags"
+ else None
+ )
reward_model_kwargs = {
k: v
for k, v in grpo_config.items()
diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py
index e6367d2d1..6eeb5d379 100644
--- a/applications/ColossalChat/coati/distributed/launch.py
+++ b/applications/ColossalChat/coati/distributed/launch.py
@@ -55,6 +55,8 @@ def launch_distributed(
eval_interval: int = 100,
eval_save_dir: Optional[str] = None,
eval_generation_config: Optional[Dict[str, Any]] = None,
+ log_rollout_interval: int = 20,
+ rollout_log_file: str = "./rollout_log.jsonl",
):
if core_algo not in ALGO_MAP:
raise NotImplementedError(f"{core_algo} is not supported yet.")
@@ -98,6 +100,8 @@ def launch_distributed(
project_name=project_name,
run_name=run_name,
wandb_group_name=wandb_group_name,
+ log_rollout_interval=log_rollout_interval,
+ rollout_log_file=rollout_log_file,
)
procs.append(producer)
generate_config_consumer = copy.deepcopy(generate_config)
diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py
index 0d91f43f1..75879278c 100644
--- a/applications/ColossalChat/coati/distributed/producer.py
+++ b/applications/ColossalChat/coati/distributed/producer.py
@@ -1,4 +1,5 @@
import copy
+import json
import os
from typing import Any, Dict, Optional
@@ -49,7 +50,8 @@ class BaseProducer:
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
- wandb_log_rollout_interval: int = 20,
+ log_rollout_interval: int = 20,
+ rollout_log_file: str = "./rollout_log.jsonl",
):
self.producer_idx = producer_idx
self.num_producers = num_producers
@@ -70,9 +72,16 @@ class BaseProducer:
self.eval_save_dir = eval_save_dir
self.consumer_global_step = 0
self.eval_mode = False
- self.wandb_rollout_data = []
- self.wandb_log_rollout_interval = wandb_log_rollout_interval
+ self.log_rollout_interval = log_rollout_interval
self.latest_rollout_log_step = -1
+ if producer_idx == 0:
+ if os.path.exists(rollout_log_file):
+ raise ValueError(
+ f"Rollout log file {rollout_log_file} already exists. Please delete it or change the name."
+ )
+ else:
+ os.makedirs(os.path.dirname(rollout_log_file), exist_ok=True)
+ self.rollout_log_file = open(rollout_log_file, "w", encoding="utf8")
if self.producer_idx == 0:
self.wandb_run = wandb.init(
project=project_name,
@@ -320,6 +329,8 @@ class SimpleProducer(BaseProducer):
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
+ log_rollout_interval: int = 20,
+ rollout_log_file: str = "./rollout_log.jsonl",
):
super().__init__(
producer_idx,
@@ -342,6 +353,8 @@ class SimpleProducer(BaseProducer):
project_name=project_name,
run_name=run_name,
wandb_group_name=wandb_group_name,
+ log_rollout_interval=log_rollout_interval,
+ rollout_log_file=rollout_log_file,
)
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
@@ -353,26 +366,31 @@ class SimpleProducer(BaseProducer):
def rollout(self, input_ids, attention_mask, **kwargs):
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
if self.producer_idx == 0 and not self.eval_mode:
- wandb_rollout_data = self.wandb_rollout_data + [
- [
- str(self.consumer_global_step),
- str(self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)),
- ]
- ]
if (
- self.consumer_global_step - self.latest_rollout_log_step >= self.wandb_log_rollout_interval
+ self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval
or self.latest_rollout_log_step == -1
):
- self.wandb_rollout_data = wandb_rollout_data
- self.latest_rollout_log_step = self.consumer_global_step
- self.wandb_run.log(
- {
- "rollout/rollout_examples": wandb.Table(
- columns=["train_step", "rollout_examples"], data=wandb_rollout_data
+ new_record = (
+ json.dumps(
+ {
+ "train_step": self.consumer_global_step,
+ "rollout": self.tokenizer.batch_decode(
+ rollouts["input_ids"][:, 0], skip_special_tokens=True
+ ),
+ }
)
- }
- )
+ + "\n"
+ )
+ self.rollout_log_file.write(new_record)
+ self.rollout_log_file.flush()
+ self.latest_rollout_log_step = self.consumer_global_step
return rollouts
+ def __del__(self):
+ if self.producer_idx == 0:
+ self.wandb_run.finish()
+ if hasattr(self, "rollout_log_file"):
+ self.rollout_log_file.close()
+
def load_state_dict(self, state_dict):
self.model.load_state_dict(state_dict)
diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py
index 071912ddf..98c139f14 100644
--- a/applications/ColossalChat/rl_example.py
+++ b/applications/ColossalChat/rl_example.py
@@ -118,6 +118,9 @@ if __name__ == "__main__":
parser.add_argument(
"-esd", "--eval-save-dir", type=str, default="./eval", help="Directory for saving evaluation results."
)
+ parser.add_argument(
+ "-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings."
+ )
args = parser.parse_args()
if args.train_minibatch_size is None:
@@ -269,4 +272,6 @@ if __name__ == "__main__":
eval_interval=args.eval_interval,
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
eval_generation_config=eval_generation_config,
+ log_rollout_interval=20,
+ rollout_log_file=os.path.join(args.rollout_save_dir, args.project.replace(" ", "_") + ".jsonl"),
)