diff --git a/.gitignore b/.gitignore index d48035a5b..0503f3c95 100644 --- a/.gitignore +++ b/.gitignore @@ -166,3 +166,4 @@ applications/ColossalChat/tests/logs applications/ColossalChat/wandb applications/ColossalChat/model applications/ColossalChat/eval +applications/ColossalChat/rollouts diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 8de8b774e..70e2201fe 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -120,12 +120,16 @@ class GRPOConsumer(BaseConsumer): "either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config." ) # Initialize verifiable reward. - response_format_tags = { - "think_start": {"text": "", "num_occur": 1}, - "think_end": {"text": "", "num_occur": 1}, - "answer_start": {"text": "", "num_occur": 1}, - "answer_end": {"text": "", "num_occur": 1}, - } + response_format_tags = ( + { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, + } + if grpo_config.get("reward_fn_type") == "think_answer_tags" + else None + ) reward_model_kwargs = { k: v for k, v in grpo_config.items() diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index e6367d2d1..6eeb5d379 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -55,6 +55,8 @@ def launch_distributed( eval_interval: int = 100, eval_save_dir: Optional[str] = None, eval_generation_config: Optional[Dict[str, Any]] = None, + log_rollout_interval: int = 20, + rollout_log_file: str = "./rollout_log.jsonl", ): if core_algo not in ALGO_MAP: raise NotImplementedError(f"{core_algo} is not supported yet.") @@ -98,6 +100,8 @@ def launch_distributed( project_name=project_name, run_name=run_name, wandb_group_name=wandb_group_name, + log_rollout_interval=log_rollout_interval, + rollout_log_file=rollout_log_file, ) procs.append(producer) generate_config_consumer = copy.deepcopy(generate_config) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 0d91f43f1..75879278c 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -1,4 +1,5 @@ import copy +import json import os from typing import Any, Dict, Optional @@ -49,7 +50,8 @@ class BaseProducer: project_name: str = None, run_name: str = None, wandb_group_name: str = None, - wandb_log_rollout_interval: int = 20, + log_rollout_interval: int = 20, + rollout_log_file: str = "./rollout_log.jsonl", ): self.producer_idx = producer_idx self.num_producers = num_producers @@ -70,9 +72,16 @@ class BaseProducer: self.eval_save_dir = eval_save_dir self.consumer_global_step = 0 self.eval_mode = False - self.wandb_rollout_data = [] - self.wandb_log_rollout_interval = wandb_log_rollout_interval + self.log_rollout_interval = log_rollout_interval self.latest_rollout_log_step = -1 + if producer_idx == 0: + if os.path.exists(rollout_log_file): + raise ValueError( + f"Rollout log file {rollout_log_file} already exists. Please delete it or change the name." + ) + else: + os.makedirs(os.path.dirname(rollout_log_file), exist_ok=True) + self.rollout_log_file = open(rollout_log_file, "w", encoding="utf8") if self.producer_idx == 0: self.wandb_run = wandb.init( project=project_name, @@ -320,6 +329,8 @@ class SimpleProducer(BaseProducer): project_name: str = None, run_name: str = None, wandb_group_name: str = None, + log_rollout_interval: int = 20, + rollout_log_file: str = "./rollout_log.jsonl", ): super().__init__( producer_idx, @@ -342,6 +353,8 @@ class SimpleProducer(BaseProducer): project_name=project_name, run_name=run_name, wandb_group_name=wandb_group_name, + log_rollout_interval=log_rollout_interval, + rollout_log_file=rollout_log_file, ) self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations) self.eval_generation_config = copy.deepcopy(self.model.generate_config) @@ -353,26 +366,31 @@ class SimpleProducer(BaseProducer): def rollout(self, input_ids, attention_mask, **kwargs): rollouts = self.model.generate(input_ids, attention_mask, **kwargs) if self.producer_idx == 0 and not self.eval_mode: - wandb_rollout_data = self.wandb_rollout_data + [ - [ - str(self.consumer_global_step), - str(self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)), - ] - ] if ( - self.consumer_global_step - self.latest_rollout_log_step >= self.wandb_log_rollout_interval + self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval or self.latest_rollout_log_step == -1 ): - self.wandb_rollout_data = wandb_rollout_data - self.latest_rollout_log_step = self.consumer_global_step - self.wandb_run.log( - { - "rollout/rollout_examples": wandb.Table( - columns=["train_step", "rollout_examples"], data=wandb_rollout_data + new_record = ( + json.dumps( + { + "train_step": self.consumer_global_step, + "rollout": self.tokenizer.batch_decode( + rollouts["input_ids"][:, 0], skip_special_tokens=True + ), + } ) - } - ) + + "\n" + ) + self.rollout_log_file.write(new_record) + self.rollout_log_file.flush() + self.latest_rollout_log_step = self.consumer_global_step return rollouts + def __del__(self): + if self.producer_idx == 0: + self.wandb_run.finish() + if hasattr(self, "rollout_log_file"): + self.rollout_log_file.close() + def load_state_dict(self, state_dict): self.model.load_state_dict(state_dict) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 071912ddf..98c139f14 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -118,6 +118,9 @@ if __name__ == "__main__": parser.add_argument( "-esd", "--eval-save-dir", type=str, default="./eval", help="Directory for saving evaluation results." ) + parser.add_argument( + "-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings." + ) args = parser.parse_args() if args.train_minibatch_size is None: @@ -269,4 +272,6 @@ if __name__ == "__main__": eval_interval=args.eval_interval, eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")), eval_generation_config=eval_generation_config, + log_rollout_interval=20, + rollout_log_file=os.path.join(args.rollout_save_dir, args.project.replace(" ", "_") + ".jsonl"), )