From a992da9f0f7adc9c185e3bfb07d937fee3272a03 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Mon, 14 Jul 2025 16:25:03 +0800 Subject: [PATCH 1/5] fix code evaluation --- .../coati/distributed/producer.py | 4 +- .../reward/code_reward/testing_util.py | 19 ++++++--- .../distributed/reward/code_reward/utils.py | 40 ++++++++++++------- .../coati/distributed/reward/reward_fn.py | 17 ++++++-- applications/ColossalChat/rl_example.py | 19 +++++++-- .../ColossalChat/start_code_verifier.py | 35 ++++++++++++++++ 6 files changed, 104 insertions(+), 30 deletions(-) create mode 100644 applications/ColossalChat/start_code_verifier.py diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 2a3746391..11fb5d3aa 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -83,7 +83,7 @@ class BaseProducer: reward_model_kwargs = { k: v for k, v in grpo_config.items() - if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"] + if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length", "code_verifier_api_url"] } self.response_format_tags = grpo_config.get("response_format_tags", None) if producer_idx == 0: @@ -250,7 +250,7 @@ class BaseProducer: for m in range(eval_outputs["input_ids"].size(0)) for n in range(eval_outputs["input_ids"].size(1)) ] - eval_statistics_tensor[0] += len([res for res in eval_results if res["ans_valid"] == 1]) + eval_statistics_tensor[0] += sum([max(0, res["ans_valid"]) for res in eval_results]) eval_statistics_tensor[1] += len(eval_results) allreduce(eval_statistics_tensor, op=ReduceOp.SUM, group_name="producer_group") to_log_msg[f"eval/{eval_task_name}"] = ( diff --git a/applications/ColossalChat/coati/distributed/reward/code_reward/testing_util.py b/applications/ColossalChat/coati/distributed/reward/code_reward/testing_util.py index c6d6d8fad..12973337e 100644 --- a/applications/ColossalChat/coati/distributed/reward/code_reward/testing_util.py +++ b/applications/ColossalChat/coati/distributed/reward/code_reward/testing_util.py @@ -89,7 +89,7 @@ def clean_traceback(error_traceback): return error_traceback -def run_test(in_outs, test=None, debug=False, timeout=15): +def run_test(in_outs, test=None, debug=False, timeout=15, run_all_tests=False): """ if test(generated_code) is not None it'll try to run the code. otherwise it'll just return an input and output pair. @@ -180,8 +180,8 @@ def run_test(in_outs, test=None, debug=False, timeout=15): tmp_test = new_test sol += tmp_test - if debug: - print(f"sol = {sol}") + # if debug: + # print(f"sol = {sol}") method_name = "code" signal.alarm(timeout) try: @@ -202,8 +202,7 @@ def run_test(in_outs, test=None, debug=False, timeout=15): } signal.alarm(0) if debug: - print(f"get method = {datetime.now().time()}") - + print(f"get method {method_name} = {datetime.now().time()}") try: method = getattr(tmp, method_name) # get_attr second arg must be str except Exception: @@ -329,6 +328,9 @@ def run_test(in_outs, test=None, debug=False, timeout=15): error_traceback = traceback.format_exc() print(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}") results.append(-1) + signal.alarm(0) + if run_all_tests: + continue return results, { "error": repr(e), "traceback": clean_traceback(error_traceback), @@ -519,6 +521,10 @@ def run_test(in_outs, test=None, debug=False, timeout=15): results.append(tmp_result) if tmp_result is not True: + if debug: + print("final result:", results) + if run_all_tests: + continue return results, { "output": raw_true_output_copy, "expected": raw_outputs, @@ -539,7 +545,8 @@ def run_test(in_outs, test=None, debug=False, timeout=15): ) print(f"results = {results}") - + if debug: + print("final results", results) return results, {} diff --git a/applications/ColossalChat/coati/distributed/reward/code_reward/utils.py b/applications/ColossalChat/coati/distributed/reward/code_reward/utils.py index e19e9e387..dbb3b2149 100644 --- a/applications/ColossalChat/coati/distributed/reward/code_reward/utils.py +++ b/applications/ColossalChat/coati/distributed/reward/code_reward/utils.py @@ -16,27 +16,24 @@ # limitations under the License. import multiprocessing -import os -import sys import traceback from typing import Optional +import requests + from .testing_util import run_test def _temp_run(sample, generation, debug, result, metadata_list, timeout): - with open(os.devnull, "w") as devnull: - sys.stdout = devnull - sys.stderr = devnull - try: - res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout) - result.append(res) - metadata_list.append(metadata) - except Exception: - # print(e) # some tracebacks are extremely long. - traceback.print_exc(10) - result.append([-1 for i in range(len(sample["inputs"]))]) - metadata_list.append({}) + try: + res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout) + result.append(res) + metadata_list.append(metadata) + except Exception: + # print(e) # some tracebacks are extremely long. + traceback.print_exc(10) + result.append([-1 for i in range(len(sample["inputs"]))]) + metadata_list.append({}) def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=True): @@ -49,7 +46,7 @@ def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=Tru metadata_list = manager.list() p = multiprocessing.Process(target=_temp_run, args=(in_outs, generation, debug, result, metadata_list, timeout)) p.start() - p.join(timeout=timeout + 1) + p.join(timeout=600) # Global timeout of 10 minutes that's for all test cases combined if p.is_alive(): p.kill() # p.terminate() @@ -59,3 +56,16 @@ def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=Tru if debug: print("global timeout") return result[0], metadata_list + + +def check_correctness_code_api( + in_outs: Optional[dict], generation, timeout=10, debug=True, url="http://localhost:8000/check_correctness" +): + payload = {"in_outs": in_outs, "generation": generation, "timeout": timeout, "debug": debug} + response = requests.post(url, json=payload) + if response.status_code == 200: + results = response.json() + return results["result"], results["metadata"] + else: + print(f"Error: {response.status_code} - {response.text}") + return [-1 for i in range(len(in_outs["inputs"]))], {} diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index 5c8810fdf..f7a2fb89c 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -24,7 +24,7 @@ import torch from latex2sympy2_extended import NormalizationConfig from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify -from .code_reward.utils import check_correctness as check_correctness_code +from .code_reward.utils import check_correctness_code_api as check_correctness_code from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure CANNOT_PARSE_GT_ANSWER = -1 @@ -223,6 +223,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): def code_reward_fn(input_ids, test_cases, response_idx, **kwargs): + url = kwargs.get("url", "http://localhost:8000/check_correctness") tokenizer = kwargs["tokenizer"] eval_mode = kwargs.get("eval_mode", False) soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False) @@ -255,6 +256,9 @@ def code_reward_fn(input_ids, test_cases, response_idx, **kwargs): if format_valid: format_acc += 1 + res = [] + metadata = [] + try: try: if not isinstance(test_cases, dict): @@ -264,15 +268,18 @@ def code_reward_fn(input_ids, test_cases, response_idx, **kwargs): raise e # Complete check on all in-out pairs first. If there is no failure, per-sample test can be skipped. try: - res, metadata = check_correctness_code(in_outs=test_cases, generation=solution, timeout=10, debug=True) + res, metadata = check_correctness_code( + in_outs=test_cases, generation=solution, timeout=10, debug=False, url=url + ) metadata = dict(enumerate(metadata))[0] - success = all(map(lambda x: x is True, res)) + success = all(map(lambda x: x == 1, res)) if success: ans_acc += 1 if eval_mode or format_valid: reward += acc_score if not eval_mode: reward = reward + length_reward + except Exception: pass @@ -288,7 +295,9 @@ def code_reward_fn(input_ids, test_cases, response_idx, **kwargs): return { "prompt": prompt, "prediction": decoded_final_answer, - "gold": test_cases["outputs"], + "test_cases": test_cases, + "test_results": res, + "test_metadata": metadata, "parsed": solution, "format_valid": format_acc.item(), "ans_valid": ans_acc.item(), diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index d7b7c2a5d..08814f9f1 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -12,6 +12,9 @@ DEFAUT_SYSTEM_PROMPT = { "code": "You are a helpful assistant.", } +# bypass the proxy for local addresses +os.environ["no_proxy"] = "127.0.0.1,localhost" + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") @@ -138,6 +141,13 @@ if __name__ == "__main__": choices=["think_answer_tags", "boxed", "code"], help="Reward type for GRPO.", ) + parser.add_argument( + "-cv", + "--code-verifier-api-url", + type=str, + default=None, + help="API URL for code verifier. If not provided, the code verifier will be disabled.", + ) parser.add_argument( "-ei", "--eval-interval", @@ -165,6 +175,7 @@ if __name__ == "__main__": parser.add_argument( "--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process." ) + args = parser.parse_args() if args.train_minibatch_size is None: @@ -188,7 +199,7 @@ if __name__ == "__main__": namespace="ray-example", runtime_env={ "env_vars": { - # "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray + # "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray "TOKENIZERS_PARALLELISM": "false" }, }, @@ -201,7 +212,7 @@ if __name__ == "__main__": _temp_dir=args.ray_dir, runtime_env={ "env_vars": { - # "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray + # "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray "TOKENIZERS_PARALLELISM": "false" }, }, @@ -321,7 +332,9 @@ if __name__ == "__main__": } else: raise ValueError(f"Unsupported algorithm: {args.algo}") - + if args.reward_type == "code": + assert args.code_verifier_api_url is not None, "Please provide a code verifier API URL for code reward type." + grpo_config.update({"code_verifier_api_url": args.code_verifier_api_url}) if args.system_prompt is None: # Default system prompt args.system_prompt = DEFAUT_SYSTEM_PROMPT[args.reward_type] diff --git a/applications/ColossalChat/start_code_verifier.py b/applications/ColossalChat/start_code_verifier.py new file mode 100644 index 000000000..d1924f610 --- /dev/null +++ b/applications/ColossalChat/start_code_verifier.py @@ -0,0 +1,35 @@ +from typing import List, Optional + +from coati.distributed.reward.code_reward.utils import check_correctness # Assuming utils.py is in the same directory +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel + +app = FastAPI() + + +class CheckCorrectnessRequest(BaseModel): + in_outs: Optional[dict] + generation: str + timeout: int = 10 + debug: bool = True + eval_mode: bool = False + + +class CheckCorrectnessResponse(BaseModel): + result: List[int] + metadata: List[dict] + + +@app.post("/check_correctness", response_model=CheckCorrectnessResponse) +def check_correctness_api(request: CheckCorrectnessRequest): + try: + result, metadata = check_correctness( + in_outs=request.in_outs, + generation=request.generation, + timeout=request.timeout, + debug=request.debug, + eval_mode=request.eval_mode, + ) + return CheckCorrectnessResponse(result=result, metadata=metadata) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) From d8504752083f53dbfab03c82acd470f5df74f8fa Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Mon, 14 Jul 2025 18:23:39 +0800 Subject: [PATCH 2/5] fix style --- .../coati/distributed/reward/code_reward/testing_util.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/reward/code_reward/testing_util.py b/applications/ColossalChat/coati/distributed/reward/code_reward/testing_util.py index 12973337e..19be101f6 100644 --- a/applications/ColossalChat/coati/distributed/reward/code_reward/testing_util.py +++ b/applications/ColossalChat/coati/distributed/reward/code_reward/testing_util.py @@ -180,8 +180,6 @@ def run_test(in_outs, test=None, debug=False, timeout=15, run_all_tests=False): tmp_test = new_test sol += tmp_test - # if debug: - # print(f"sol = {sol}") method_name = "code" signal.alarm(timeout) try: From 4cf5ce20bfad4a5a27b6fbac6cae5194b9916e66 Mon Sep 17 00:00:00 2001 From: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Date: Thu, 17 Jul 2025 15:05:10 +0800 Subject: [PATCH 3/5] add entropy (#6363) --- .../coati/distributed/grpo_consumer.py | 39 ++++++++++++++++++- .../ColossalChat/coati/distributed/utils.py | 10 +++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index f8ce1afde..754f78097 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -6,7 +6,7 @@ import torch import wandb from coati.distributed.consumer import BaseConsumer from coati.distributed.loss import PolicyLoss -from coati.distributed.utils import memory_efficient_logprob +from coati.distributed.utils import entropy_from_logits, memory_efficient_logprob from coati.trainer.utils import all_reduce_mean, all_reduce_sum from transformers import AutoModelForCausalLM, AutoTokenizer @@ -75,6 +75,7 @@ class GRPOConsumer(BaseConsumer): self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6)) self.accum_loss = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device) + self.accum_entropy = torch.zeros(1, device=self.device) self.accum_advantages = torch.zeros(1, device=self.device) self.raw_train_batch_reward = [] self.raw_train_batch_format_acc = [] @@ -244,6 +245,7 @@ class GRPOConsumer(BaseConsumer): else self.booster.no_sync(self.policy_model, self.optimizer) ) with ctx: + mini_batch_entropies = [] for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size): input_ids_forward_micro_batch = data["input_ids"][ forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size @@ -310,9 +312,11 @@ class GRPOConsumer(BaseConsumer): data_policy_forward["reference_action_log_probs"] = reference_action_log_probs kl = [] + policy_model_logits = torch.empty_like(input_ids_forward_micro_batch, device=self.device) def _criterion(outputs, inputs): action_logits = outputs.logits + policy_model_logits.copy_(action_logits) action_log_probs = memory_efficient_logprob( action_logits / self.generate_config["temperature"], inputs["input_ids"], @@ -359,6 +363,20 @@ class GRPOConsumer(BaseConsumer): kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data mean_kl.append(kl) mean_loss.append(all_reduce_mean(loss, self.plugin).data) + mini_batch_entropies.append( + all_reduce_mean( + ( + ( + ( + entropy_from_logits(policy_model_logits[:, -num_action:]) + * action_mask_forward_micro_batch + ).sum(-1) + ) + / action_mask_forward_micro_batch.sum(-1) + ).detach(), + self.plugin, + ) + ) else: policy_model_logits = self.policy_model( input_ids=input_ids_forward_micro_batch, @@ -412,6 +430,20 @@ class GRPOConsumer(BaseConsumer): kl = all_reduce_mean(kl.mean(), self.plugin) mean_kl.append(kl.data) mean_loss.append(loss.data) + mini_batch_entropies.append( + all_reduce_mean( + ( + ( + ( + entropy_from_logits(policy_model_logits[:, -num_action:]) + * action_mask_forward_micro_batch + ).sum(-1) + ) + / action_mask_forward_micro_batch.sum(-1) + ).detach(), + self.plugin, + ) + ) if not self.plugin.pp_size > 1 or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() @@ -423,7 +455,9 @@ class GRPOConsumer(BaseConsumer): ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin) advantages = all_reduce_mean(advantages.mean(), self.plugin) response_length = all_reduce_mean(response_length.mean(), self.plugin) + entropy = torch.cat(mini_batch_entropies, dim=0).mean() self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) + self.accum_entropy.add_(entropy.data) if self.policy_loss_fn.beta > 0: self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) self.accum_advantages.add_(advantages.data) @@ -464,6 +498,7 @@ class GRPOConsumer(BaseConsumer): f"Response Length: {raw_batch_response_len_mean:.4f}", f"Sample_utilization: {sample_utilization:.4f}", f"Overlength samples ratio: {overlength_samples_ratio:.4f}", + f"Entropy: {self.accum_entropy.item() / self.accum_count:.4f}", ] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else []) print("\n".join(to_log_msg)) metrics = { @@ -475,6 +510,7 @@ class GRPOConsumer(BaseConsumer): "train/advantages": self.accum_advantages.item() / self.accum_count, "train/learning_rate": self.lr_scheduler.get_last_lr()[0], "train/sample_utilization": sample_utilization, + "train/entropy": self.accum_entropy.item() / self.accum_count, "train/overlength_samples_ratio": overlength_samples_ratio, "rollout/temperature": data["temperature"].cpu().numpy()[0][0], } @@ -484,6 +520,7 @@ class GRPOConsumer(BaseConsumer): self.wandb_run.log(metrics) self.accum_loss.zero_() self.accum_kl.zero_() + self.accum_entropy.zero_() self.accum_advantages.zero_() self.accum_count = 0 return loss_scalar diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index d46243114..466914cc0 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -110,6 +110,16 @@ def memory_efficient_logprob( return action_log_probs +def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor: + """ + Calculate entropy + Reference: https://github.com/volcengine/verl/blob/96b730bbed80292a439f0c0057d3920ab8b28d52/verl/utils/torch_functional.py#L145 + """ + p = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(p * logits, dim=-1) + return entropy + + def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: """ Compute the masked mean of a tensor along a specified dimension. From 57e92104a23e2e9085c71d70e14a1d288018926f Mon Sep 17 00:00:00 2001 From: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Date: Tue, 22 Jul 2025 10:02:02 +0800 Subject: [PATCH 4/5] hotfix entropy calculation (#6364) --- .../coati/distributed/grpo_consumer.py | 32 ++++++++----------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 754f78097..a3f1a1cbb 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -250,6 +250,9 @@ class GRPOConsumer(BaseConsumer): input_ids_forward_micro_batch = data["input_ids"][ forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size ] + old_action_log_probs_micro_batch = old_action_log_probs[ + forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size + ] attention_mask_forward_micro_batch = data["attention_mask"][ forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size ] @@ -306,17 +309,22 @@ class GRPOConsumer(BaseConsumer): "action_mask": action_mask_forward_micro_batch, "advantages": advantages_forward_micro_batch, "loss_mask": loss_mask_forward_micro_batch, + "old_action_log_probs": old_action_log_probs_micro_batch, "source": self.rank, } if reference_action_log_probs is not None: data_policy_forward["reference_action_log_probs"] = reference_action_log_probs kl = [] - policy_model_logits = torch.empty_like(input_ids_forward_micro_batch, device=self.device) def _criterion(outputs, inputs): action_logits = outputs.logits - policy_model_logits.copy_(action_logits) + mini_batch_entropies.append( + ( + ((entropy_from_logits(action_logits[:, -num_action:]) * inputs["action_mask"]).sum(-1)) + / inputs["action_mask"].sum(-1) + ).detach() + ) action_log_probs = memory_efficient_logprob( action_logits / self.generate_config["temperature"], inputs["input_ids"], @@ -339,7 +347,7 @@ class GRPOConsumer(BaseConsumer): loss, _ = self.policy_loss_fn( action_log_probs, - action_log_probs, + inputs["old_action_log_probs"], inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1), per_token_kl, inputs["action_mask"], @@ -363,20 +371,6 @@ class GRPOConsumer(BaseConsumer): kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data mean_kl.append(kl) mean_loss.append(all_reduce_mean(loss, self.plugin).data) - mini_batch_entropies.append( - all_reduce_mean( - ( - ( - ( - entropy_from_logits(policy_model_logits[:, -num_action:]) - * action_mask_forward_micro_batch - ).sum(-1) - ) - / action_mask_forward_micro_batch.sum(-1) - ).detach(), - self.plugin, - ) - ) else: policy_model_logits = self.policy_model( input_ids=input_ids_forward_micro_batch, @@ -415,7 +409,7 @@ class GRPOConsumer(BaseConsumer): loss, _ = self.policy_loss_fn( action_log_probs, - old_action_log_probs, + old_action_log_probs_micro_batch, advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), per_token_kl, action_mask_forward_micro_batch, @@ -455,7 +449,7 @@ class GRPOConsumer(BaseConsumer): ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin) advantages = all_reduce_mean(advantages.mean(), self.plugin) response_length = all_reduce_mean(response_length.mean(), self.plugin) - entropy = torch.cat(mini_batch_entropies, dim=0).mean() + entropy = all_reduce_mean(torch.cat(mini_batch_entropies, dim=0).mean(), self.plugin) self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) self.accum_entropy.add_(entropy.data) if self.policy_loss_fn.beta > 0: From cd32236e53d565079fbbbeccf7bc7ef6ec7eafa1 Mon Sep 17 00:00:00 2001 From: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Date: Tue, 29 Jul 2025 16:56:52 +0800 Subject: [PATCH 5/5] [Fix] Add L2 Regularization (#6372) * fix no L2 regularization error * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- applications/ColossalChat/coati/distributed/consumer.py | 2 +- .../ColossalChat/coati/distributed/grpo_consumer.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index e360392e7..ba7d882c9 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -365,7 +365,7 @@ class SimpleConsumer(BaseConsumer): self.model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.model.train() self.model.gradient_checkpointing_enable() - self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3) + self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3, weight_decay=0.01) self.accum_loss = torch.zeros(1, device=self.device) def setup(self): diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index a3f1a1cbb..424d46098 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -72,7 +72,11 @@ class GRPOConsumer(BaseConsumer): self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.policy_model.train() self.policy_model.gradient_checkpointing_enable() - self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6)) + self.optimizer = HybridAdam( + self.policy_model.parameters(), + lr=grpo_config.get("lr", 1e-6), + weight_decay=grpo_config.get("weight_decay", 0.01), + ) self.accum_loss = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device) self.accum_entropy = torch.zeros(1, device=self.device)