diff --git a/.gitignore b/.gitignore
index d48035a5b..0503f3c95 100644
--- a/.gitignore
+++ b/.gitignore
@@ -166,3 +166,4 @@ applications/ColossalChat/tests/logs
applications/ColossalChat/wandb
applications/ColossalChat/model
applications/ColossalChat/eval
+applications/ColossalChat/rollouts
diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py
index 531cf363c..8a2221b92 100644
--- a/applications/ColossalChat/coati/distributed/consumer.py
+++ b/applications/ColossalChat/coati/distributed/consumer.py
@@ -113,7 +113,6 @@ class BaseConsumer:
) as pbar:
for step in pbar:
i = 0
- allow_sync_model = True
for _ in range(self.num_recv_per_update):
# receive data from producers
for r in range(self.num_producers):
@@ -140,7 +139,6 @@ class BaseConsumer:
loss = self.step(i, pbar, **batch)
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
if loss is not None:
- allow_sync_model = True
pbar.set_postfix({"loss": loss})
i += 1
if self.lr_scheduler is not None:
@@ -154,31 +152,29 @@ class BaseConsumer:
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
- if allow_sync_model:
- if self.pp_size > 1:
- print(
- f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
+ if self.pp_size > 1:
+ print(
+ f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
+ )
+ else:
+ print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
+ torch.cuda.empty_cache()
+ state_dict = self.state_dict()
+ if self.pp_size > 1:
+ if self.tp_rank == 0 and self.dp_rank == 0:
+ ray_broadcast_tensor_dict(
+ state_dict,
+ src=self.num_producers,
+ device=self.device,
+ group_name=f"sync_model_{self.pp_rank}",
)
- else:
- print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
- torch.cuda.empty_cache()
- state_dict = self.state_dict()
- if self.pp_size > 1:
- if self.tp_rank == 0 and self.dp_rank == 0:
- ray_broadcast_tensor_dict(
- state_dict,
- src=self.num_producers,
- device=self.device,
- group_name=f"sync_model_{self.pp_rank}",
- )
- else:
- if self.rank == 0:
- ray_broadcast_tensor_dict(
- state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
- )
- del state_dict
- torch.cuda.empty_cache()
- allow_sync_model = True
+ else:
+ if self.rank == 0:
+ ray_broadcast_tensor_dict(
+ state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
+ )
+ del state_dict
+ torch.cuda.empty_cache()
@ray.remote
diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py
index f73f2d96f..bc2bd45d1 100644
--- a/applications/ColossalChat/coati/distributed/grpo_consumer.py
+++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py
@@ -120,14 +120,20 @@ class GRPOConsumer(BaseConsumer):
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
)
# Initialize verifiable reward.
- response_format_tags = {
- "think_start": {"text": "", "num_occur": 1},
- "think_end": {"text": "", "num_occur": 1},
- "answer_start": {"text": "", "num_occur": 1},
- "answer_end": {"text": "", "num_occur": 1},
- }
+ response_format_tags = (
+ {
+ "think_start": {"text": "", "num_occur": 1},
+ "think_end": {"text": "", "num_occur": 1},
+ "answer_start": {"text": "", "num_occur": 1},
+ "answer_end": {"text": "", "num_occur": 1},
+ }
+ if grpo_config.get("reward_fn_type") == "think_answer_tags"
+ else None
+ )
reward_model_kwargs = {
- k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
+ k: v
+ for k, v in grpo_config.items()
+ if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"]
}
self.reward_model = VerifiableReward(
reward_fns=[
diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py
index e6367d2d1..6eeb5d379 100644
--- a/applications/ColossalChat/coati/distributed/launch.py
+++ b/applications/ColossalChat/coati/distributed/launch.py
@@ -55,6 +55,8 @@ def launch_distributed(
eval_interval: int = 100,
eval_save_dir: Optional[str] = None,
eval_generation_config: Optional[Dict[str, Any]] = None,
+ log_rollout_interval: int = 20,
+ rollout_log_file: str = "./rollout_log.jsonl",
):
if core_algo not in ALGO_MAP:
raise NotImplementedError(f"{core_algo} is not supported yet.")
@@ -98,6 +100,8 @@ def launch_distributed(
project_name=project_name,
run_name=run_name,
wandb_group_name=wandb_group_name,
+ log_rollout_interval=log_rollout_interval,
+ rollout_log_file=rollout_log_file,
)
procs.append(producer)
generate_config_consumer = copy.deepcopy(generate_config)
diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py
index d796bff48..75879278c 100644
--- a/applications/ColossalChat/coati/distributed/producer.py
+++ b/applications/ColossalChat/coati/distributed/producer.py
@@ -1,4 +1,5 @@
import copy
+import json
import os
from typing import Any, Dict, Optional
@@ -49,6 +50,8 @@ class BaseProducer:
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
+ log_rollout_interval: int = 20,
+ rollout_log_file: str = "./rollout_log.jsonl",
):
self.producer_idx = producer_idx
self.num_producers = num_producers
@@ -58,7 +61,7 @@ class BaseProducer:
self.microbatch_size = microbatch_size
assert batch_size % microbatch_size == 0
self.num_microbatches = batch_size // microbatch_size
- self.lastest_eval_step = -1
+ self.latest_eval_step = -1
self.train_dataset_config = train_dataset_config
self.model_config = model_config
@@ -68,6 +71,17 @@ class BaseProducer:
self.eval_interval = eval_interval
self.eval_save_dir = eval_save_dir
self.consumer_global_step = 0
+ self.eval_mode = False
+ self.log_rollout_interval = log_rollout_interval
+ self.latest_rollout_log_step = -1
+ if producer_idx == 0:
+ if os.path.exists(rollout_log_file):
+ raise ValueError(
+ f"Rollout log file {rollout_log_file} already exists. Please delete it or change the name."
+ )
+ else:
+ os.makedirs(os.path.dirname(rollout_log_file), exist_ok=True)
+ self.rollout_log_file = open(rollout_log_file, "w", encoding="utf8")
if self.producer_idx == 0:
self.wandb_run = wandb.init(
project=project_name,
@@ -77,7 +91,7 @@ class BaseProducer:
group=wandb_group_name,
)
- if os.path.exists(self.eval_save_dir):
+ if os.path.exists(self.eval_save_dir) and self.eval_interval > 0:
raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.")
# init tokenizer
@@ -180,14 +194,15 @@ class BaseProducer:
break
if self.eval_interval > 0 and self.eval_dataset_config is not None:
if (
- self.consumer_global_step % self.eval_interval == 0
- and self.consumer_global_step > self.lastest_eval_step
- ):
+ self.consumer_global_step - self.latest_eval_step >= self.eval_interval
+ and self.consumer_global_step > self.latest_eval_step
+ ) or self.latest_eval_step == -1:
to_log_msg = {}
+ self.eval_mode = True
for eval_task_name in self.eval_dataloaders:
if self.producer_idx == 0:
print(
- f"[P{self.producer_idx}] Evaluate consumer step {self.consumer_global_step} on task {eval_task_name}"
+ f"[P{self.producer_idx}] Evaluate model at training step {self.consumer_global_step} on task {eval_task_name}"
)
eval_results = []
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)
@@ -220,14 +235,15 @@ class BaseProducer:
safe_append_to_jsonl_file(
os.path.join(
self.eval_save_dir,
- f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl",
+ f"{eval_task_name}_training_step_{self.consumer_global_step}.jsonl",
),
eval_results,
)
if self.producer_idx == 0:
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
- self.lastest_eval_step = self.consumer_global_step
+ self.eval_mode = False
+ self.latest_eval_step = self.consumer_global_step
outputs = self.rollout(**batch)
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
@@ -256,6 +272,8 @@ class BaseProducer:
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
)
+ if "consumer_global_step" in state_dict:
+ self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict)
else:
print(
@@ -311,6 +329,8 @@ class SimpleProducer(BaseProducer):
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
+ log_rollout_interval: int = 20,
+ rollout_log_file: str = "./rollout_log.jsonl",
):
super().__init__(
producer_idx,
@@ -333,6 +353,8 @@ class SimpleProducer(BaseProducer):
project_name=project_name,
run_name=run_name,
wandb_group_name=wandb_group_name,
+ log_rollout_interval=log_rollout_interval,
+ rollout_log_file=rollout_log_file,
)
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
@@ -343,10 +365,32 @@ class SimpleProducer(BaseProducer):
@torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs):
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
- # if self.producer_idx == 1:
- # print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
-
+ if self.producer_idx == 0 and not self.eval_mode:
+ if (
+ self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval
+ or self.latest_rollout_log_step == -1
+ ):
+ new_record = (
+ json.dumps(
+ {
+ "train_step": self.consumer_global_step,
+ "rollout": self.tokenizer.batch_decode(
+ rollouts["input_ids"][:, 0], skip_special_tokens=True
+ ),
+ }
+ )
+ + "\n"
+ )
+ self.rollout_log_file.write(new_record)
+ self.rollout_log_file.flush()
+ self.latest_rollout_log_step = self.consumer_global_step
return rollouts
+ def __del__(self):
+ if self.producer_idx == 0:
+ self.wandb_run.finish()
+ if hasattr(self, "rollout_log_file"):
+ self.rollout_log_file.close()
+
def load_state_dict(self, state_dict):
self.model.load_state_dict(state_dict)
diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py
index 14d340dc4..a4042ae97 100644
--- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py
+++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py
@@ -1,7 +1,77 @@
import torch
+from latex2sympy2_extended import NormalizationConfig
+from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
+CANNOT_PARSE_GT_ANSWER = -1
+CANNOT_PARSE_PREDICTION = -2
+SUCCESS = 1
+MATCHING_FAIL = 0
+
+
+def verify_math_representation(completion, gt_answer):
+ """
+ Verify if the completion is a valid math representation of the gt_answer.
+ """
+ if not completion.startswith("\\boxed{"):
+ completion = "\\boxed{" + completion + "}"
+ if not gt_answer.startswith("\\boxed{"):
+ gt_answer = "\\boxed{" + gt_answer + "}"
+ target = (
+ ExprExtractionConfig(),
+ LatexExtractionConfig(
+ normalization_config=NormalizationConfig(
+ nits=False,
+ malformed_operators=False,
+ basic_latex=True,
+ boxed="all",
+ units=True,
+ ),
+ boxed_match_priority=0,
+ ),
+ )
+ if not isinstance(gt_answer, str) or len(gt_answer) == 0:
+ raise ValueError("gt_answer should be a string, please verify your training data.")
+ if not isinstance(completion, str) or len(completion) == 0:
+ return MATCHING_FAIL
+ try:
+ parsed_gt_answer = parse(gt_answer, extraction_config=target)
+ if len(parsed_gt_answer) == 0:
+ return CANNOT_PARSE_GT_ANSWER
+ parsed_completion = parse(completion, extraction_config=target)
+ if len(parsed_completion) == 0:
+ return CANNOT_PARSE_PREDICTION
+ if verify(parsed_gt_answer, parsed_completion):
+ return SUCCESS
+ else:
+ return MATCHING_FAIL
+ except Exception:
+ return MATCHING_FAIL
+
+
+def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward):
+ math_verify_result = verify_math_representation(decoded_final_answer, gt_answer)
+ exact_match_result = (
+ SUCCESS
+ if decoded_final_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", "")
+ == gt_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", "")
+ else MATCHING_FAIL
+ )
+ if math_verify_result == SUCCESS:
+ ans_acc += 1
+ reward += acc_score
+ elif exact_match_result == SUCCESS:
+ # sometimes for answers that's not a (valid) math expression, math_verify will fail
+ ans_acc += 1
+ if math_verify_result == CANNOT_PARSE_PREDICTION:
+ reward += (
+ acc_score / 2
+ ) # not a valid latex math representation, but the answer is correct, receive half of the score
+ else:
+ reward += acc_score
+ return reward, ans_acc
+
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
tokenizer = kwargs["tokenizer"]
@@ -14,15 +84,18 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
s, e = response_idx[0], response_idx[1]
length_reward = 0.0
- if soft_over_length_punishment:
- max_length = kwargs.get("max_length", 1024 * 4)
- cache_length = kwargs.get("cache_length", 512)
- res_length = e.item() - s.item() + 1
- if max_length - cache_length < res_length < max_length:
- length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
+ res_length = e.item() - s.item() + 1
+ if not eval_mode:
+ max_new_tokens = kwargs["max_new_tokens"]
+ else:
+ max_new_tokens = -1 # for eval mode, we don't need to check the length
+ if not eval_mode and soft_over_length_punishment:
+ cache_length = kwargs["cache_length"]
+ if max_new_tokens - cache_length < res_length < max_new_tokens:
+ length_reward = ((max_new_tokens - cache_length) - res_length) / cache_length * acc_score
if gt_answer is None:
- return reward
+ raise ValueError("no gt_answer is provided, please check your training dataset.")
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
@@ -35,15 +108,15 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
format_acc += 1
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
- if (
- format_valid
- and final_answer is not None
- and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower()
- ):
- ans_acc += 1
- reward += acc_score
+ if final_answer is not None:
+ if eval_mode or format_valid:
+ reward, ans_acc = verify_model_answer(final_answer, gt_answer, ans_acc, acc_score, reward)
+ if not eval_mode:
+ reward = reward + length_reward
- reward = reward + length_reward
+ # Check if the sequence is over length
+ if not eval_mode and res_length >= max_new_tokens:
+ reward *= 0.0
if not eval_mode:
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
@@ -56,6 +129,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
"parsed": final_answer,
"format_valid": format_acc.item(),
"ans_valid": ans_acc.item(),
+ "response_length": res_length,
+ "reward": reward.item(),
}
@@ -71,31 +146,45 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
s, e = response_idx[0], response_idx[1]
length_reward = 0.0
- if soft_over_length_punishment:
- max_length = kwargs.get("max_length", 1024 * 4)
- cache_length = kwargs.get("cache_length", 512)
- res_length = e.item() - s.item() + 1
- if max_length - cache_length < res_length < max_length:
- length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
+ res_length = e.item() - s.item() + 1
+ if not eval_mode:
+ max_new_tokens = kwargs["max_new_tokens"]
+ else:
+ max_new_tokens = -1 # for eval mode, we don't need to check the length
+ if not eval_mode and soft_over_length_punishment:
+ cache_length = kwargs["cache_length"]
+ if max_new_tokens - cache_length < res_length < max_new_tokens:
+ length_reward = ((max_new_tokens - cache_length) - res_length) / cache_length * acc_score
if gt_answer is None:
- return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
+ raise ValueError("no gt_answer is provided, please check your training dataset.")
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
+
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
final_answer = extract_boxed_solution(decoded_final_answer)
format_valid = final_answer is not None
+ if "tags" in kwargs and kwargs["tags"]:
+ tags = kwargs["tags"]
+ format_valid = format_valid and all(
+ [decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags]
+ )
# Check format accuracy
if format_valid:
format_acc += 1
reward += format_score
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
- if format_valid and final_answer is not None and gt_answer.strip().lower() == final_answer.strip().lower():
- ans_acc += 1
- reward += acc_score
+ if final_answer is not None:
+ if eval_mode or format_valid:
+ reward, ans_acc = verify_model_answer(final_answer, gt_answer, ans_acc, acc_score, reward)
+ if not eval_mode:
+ reward = reward + length_reward
+
+ # Check if the sequence is over length
+ if not eval_mode and res_length >= max_new_tokens:
+ reward *= 0.0
- reward = reward + length_reward
if not eval_mode:
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
else:
@@ -107,4 +196,6 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
"parsed": final_answer,
"format_valid": format_acc.item(),
"ans_valid": ans_acc.item(),
+ "response_length": res_length,
+ "reward": reward.item(),
}
diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py
index 9f89eb546..a97c2b210 100644
--- a/applications/ColossalChat/rl_example.py
+++ b/applications/ColossalChat/rl_example.py
@@ -109,7 +109,7 @@ if __name__ == "__main__":
"--eval-interval",
type=int,
default=100,
- help="Interval for evaluation. Evaluate every ei consumer steps.",
+ help="Interval for evaluation. Evaluate every ei training steps.",
)
# Logging/Checkpointing parameters
@@ -118,6 +118,9 @@ if __name__ == "__main__":
parser.add_argument(
"-esd", "--eval-save-dir", type=str, default="./eval", help="Directory for saving evaluation results."
)
+ parser.add_argument(
+ "-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings."
+ )
args = parser.parse_args()
if args.train_minibatch_size is None:
@@ -198,6 +201,8 @@ if __name__ == "__main__":
"beta": args.kl_coeff, # KL penalty coefficient
"loss_variation": "sample_level",
"reward_fn_type": args.reward_type,
+ "max_length": args.max_new_tokens + args.max_prompt_tokens,
+ "max_new_tokens": args.max_new_tokens,
}
elif args.algo == "DAPO":
# DAPO variant settings
@@ -213,6 +218,7 @@ if __name__ == "__main__":
"loss_variation": "token_level",
"soft_over_length_punishment": True,
"max_length": args.max_new_tokens + args.max_prompt_tokens,
+ "max_new_tokens": args.max_new_tokens,
"cache_length": min(1024, int(args.max_new_tokens / 4)),
"filter_truncated_response": True,
"reward_fn_type": args.reward_type,
@@ -266,4 +272,6 @@ if __name__ == "__main__":
eval_interval=args.eval_interval,
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
eval_generation_config=eval_generation_config,
+ log_rollout_interval=20,
+ rollout_log_file=os.path.join(args.rollout_save_dir, args.project.replace(" ", "_") + ".jsonl"),
)