diff --git a/.gitignore b/.gitignore index 0503f3c95..94d952295 100644 --- a/.gitignore +++ b/.gitignore @@ -167,3 +167,9 @@ applications/ColossalChat/wandb applications/ColossalChat/model applications/ColossalChat/eval applications/ColossalChat/rollouts +applications/ColossalChat/*.txt +applications/ColossalChat/*.db +applications/ColossalChat/stdin +applications/ColossalChat/*.zip +applications/ColossalChat/*.prof +applications/ColossalChat/*.png diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py index 43cf78383..16fd385ba 100755 --- a/applications/ColossalChat/coati/dataset/loader.py +++ b/applications/ColossalChat/coati/dataset/loader.py @@ -367,9 +367,9 @@ def apply_chat_template_and_mask( } # Format for RL. - gt_answer = None - if "messages" in chat and "gt_answer" in chat: - gt_answer = chat["gt_answer"] + if "messages" in chat: + gt_answer = chat.get("gt_answer", None) + test_cases = chat.get("test_cases", None) chat = [chat["messages"]] tokens = [] @@ -402,12 +402,14 @@ def apply_chat_template_and_mask( labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx if gt_answer is not None: - gt_answer = tokenizer.encode( - gt_answer, padding="max_length", truncation=True, max_length=128, return_tensors="pt" - ) - gt_answer = gt_answer.squeeze(1) return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer} - + elif test_cases is not None: + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + "test_cases": test_cases, + } return { "input_ids": input_ids, "attention_mask": attention_mask, @@ -440,3 +442,20 @@ class RawConversationDataset(Dataset): tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length, self.system_prompt) self.tokenized_texts[index] = dict(tokens) return self.tokenized_texts[index] + + +def collate_fn_grpo(batch): + input_ids = [item["input_ids"] for item in batch] + attention_mask = [item["attention_mask"] for item in batch] + labels = [item["labels"] for item in batch] + # Assume input_ids, attention_mask, labels are already of the same length, + # otherwise use pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) + input_ids = torch.stack(input_ids) + attention_mask = torch.stack(attention_mask) + labels = torch.stack(labels) + ret = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} + if "test_cases" in batch[0]: + ret["test_cases"] = [item["test_cases"] for item in batch] + if "gt_answer" in batch[0]: + ret["gt_answer"] = [item["gt_answer"] for item in batch] + return ret diff --git a/applications/ColossalChat/coati/distributed/README.md b/applications/ColossalChat/coati/distributed/README.md index e68dfb812..74e30bc46 100644 --- a/applications/ColossalChat/coati/distributed/README.md +++ b/applications/ColossalChat/coati/distributed/README.md @@ -86,6 +86,50 @@ Then write IP node map to /etc/hosts Set Ascend Multi-Node Config + +This repository implements a distributed Reinforcement Learning (RL) training framework designed to fine-tune large language models using algorithms such as **GRPO** and **DAPO**. It supports multi-node and multi-GPU setups, scalable rollout generation, and policy optimization using libraries like VLLM. Currently, we support two Reinforcement Learning with Verifiable Reward (RLVR) tasks: solving math problems and code generation. + +**Please note that we are still under intensive development, stay tuned.** + +--- + +## 🚀 Features + +* **Distributed Training with Ray**: Scalable to multiple machines and GPUs. +* **Support for GRPO and DAPO**: Choose your preferred policy optimization algorithm. +* **Model Backends**: Support `vllm` as inference backends. +* **Rollout and Policy Decoupling**: Efficient generation and consumption of data through parallel inferencer-trainer architecture. +* **Evaluation Integration**: Easily plug in task-specific eval datasets. +* **Checkpoints and Logging**: Configurable intervals and directories. + +--- + +## 🛠 Installation + +### Prepare Develop Environment + +Install Colossalai & ColossalChat +```bash +git clone https://github.com/hpcaitech/ColossalAI.git +git checkout grpo-latest +BUILD_EXT=1 pip install -e . + +cd ./applications/ColossalChat +pip install -e . +``` + +Install vllm +```bash +pip install vllm==0.7.3 +``` + +Install Ray. +```bash +pip install ray +``` + +Install Other Dependencies + ```bash export ATB_LLM_HCCL_ENABLE=1 export ATB_LLM_COMM_BACKEND="hccl" @@ -242,6 +286,7 @@ In addition to the two default training settings we provided--- original `GRPO` train_tensor_parallelism_size) ``` + --- ## 🧪 Example: single machine 8-GPU Zero2 Strategy @@ -292,6 +337,54 @@ plugin_config={ }, # for pp, tp ``` +```bash +# Hint1: replace /models/Qwen/Qwen2.5-7B to your model path +# replace /datasets/train-alignment.jsonl to your dataset path +python rl_example.py + -m /path/to/Qwen2.5-Math-7B/ \ + -d /path/to/train_data.jsonl \ + --master_address '10.0.0.3' + -t 16 \ + -i 16 \ + -p GRPO-Train-Align-Debug \ + -g 2 \ + -ibs 1 \ + -tbs 2 \ + -tMbs 1 \ + -tmbs 2 \ + -imbs 1 \ + -s "Please reason step by step, and put your final answer within \\boxed{}." \ + -tMbs 8 \ + -p GRPO-Train-Align-Debug \ +``` + +## 🧪 Example: multi-machine TP+PP Strategy + +### Create ray cluster on multi-machine +For example, now we have 4 nodes and their IPs are 10.0.0.3, 10.0.0.4, 10.0.0.5, 10.0.0.6. +We use 10.0.0.3 as master node. First we start a ray cluster on 10.0.0.3: +```bash +ray start --head --node-ip-address=10.0.0.3 +``` + +Then, for each slave node (10.0.0.4/10.0.0.5/10.0.0.6), we add to the ray cluser by following code: +```bash +ray start --address='10.0.0.3:6379' +``` + +Modify plugin_config in ./applications/ColossalChat/rl_example.py +```python +plugin_config={ + "tp_size": 4, + "pp_size": 2, + "microbatch_size": max( + 1, args.train_microbatch_size // 2 + ), # microbatch size should be set to train_microbatch_size // pp_size + "zero_stage": 1, + "max_norm": 1.0, + }, # for pp, tp +``` + ```bash # Hint1: replace /models/Qwen/Qwen2.5-7B to your model path # replace /datasets/train-alignment.jsonl to your dataset path diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 4ca23e911..a48f87224 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -8,7 +8,6 @@ from coati.distributed.consumer import BaseConsumer from coati.distributed.loss import PolicyLoss from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward -from coati.distributed.utils import calc_action_log_probs from coati.trainer.utils import all_reduce_mean, all_reduce_sum from transformers import AutoModelForCausalLM, AutoTokenizer @@ -145,7 +144,10 @@ class GRPOConsumer(BaseConsumer): def setup(self): super().setup() if (not self.plugin.pp_size > 1 and self.rank == 0) or ( - self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 + self.plugin.pp_size > 1 + and self.booster.plugin.stage_manager.is_last_stage() + and self.tp_rank == 0 + and self.dp_rank == 0 ): self.wandb_run = wandb.init( project=self.project_name, @@ -237,7 +239,6 @@ class GRPOConsumer(BaseConsumer): effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin) effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin) - total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin) self.effective_sample_count += effective_samples.item() pbar.set_postfix( { @@ -295,12 +296,11 @@ class GRPOConsumer(BaseConsumer): ) if self.booster.plugin.stage_manager.is_last_stage(): - reference_model_logits = reference_model_outputs["outputs"]["logits"] - reference_action_log_probs = calc_action_log_probs( - reference_model_logits / self.generate_config["temperature"], + reference_action_log_probs = memory_efficient_logprob( + reference_model_outputs["outputs"]["logits"] / self.generate_config["temperature"], input_ids_forward_micro_batch, num_action, - self.plugin.shard_config, + shard_config=self.plugin.shard_config, ) else: # Dummy reference logprobs for data iterator. @@ -323,11 +323,11 @@ class GRPOConsumer(BaseConsumer): def _criterion(outputs, inputs): action_logits = outputs.logits - action_log_probs = calc_action_log_probs( + action_log_probs = memory_efficient_logprob( action_logits / self.generate_config["temperature"], inputs["input_ids"], num_action, - self.plugin.shard_config, + shard_config=self.plugin.shard_config, ) if "reference_action_log_probs" in inputs: per_token_kl = ( @@ -370,16 +370,15 @@ class GRPOConsumer(BaseConsumer): mean_kl.append(kl) mean_loss.append(all_reduce_mean(loss, self.plugin).data) else: - policy_model_logits = self.policy_model( input_ids=input_ids_forward_micro_batch, attention_mask=attention_mask_forward_micro_batch, ).logits - action_log_probs = calc_action_log_probs( + action_log_probs = memory_efficient_logprob( policy_model_logits / self.generate_config["temperature"], input_ids_forward_micro_batch, num_action, - self.plugin.shard_config, + shard_config=self.plugin.shard_config, ) if self.policy_loss_fn.beta > 0: @@ -388,11 +387,11 @@ class GRPOConsumer(BaseConsumer): input_ids=input_ids_forward_micro_batch, attention_mask=attention_mask_forward_micro_batch, ).logits - reference_action_log_probs = calc_action_log_probs( + reference_action_log_probs = memory_efficient_logprob( reference_model_logits / self.generate_config["temperature"], input_ids_forward_micro_batch, num_action, - self.plugin.shard_config, + shard_config=self.plugin.shard_config, ) per_token_kl = ( torch.exp(reference_action_log_probs - action_log_probs) @@ -424,7 +423,10 @@ class GRPOConsumer(BaseConsumer): mean_kl.append(kl.data) mean_loss.append(loss.data) if not self.plugin.pp_size > 1 or ( - self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 + self.plugin.pp_size > 1 + and self.booster.plugin.stage_manager.is_last_stage() + and self.tp_rank == 0 + and self.dp_rank == 0 ): reward = all_reduce_mean(reward.mean(), self.plugin) format_acc = all_reduce_mean(format_acc.mean(), self.plugin) diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 7988802a3..056c57b1d 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -74,9 +74,8 @@ class TransformersInferenceBackend(BaseInferenceBackend): micro_batch_size = input_ids.size(0) input_ids = input_ids.to(get_current_device()) attention_mask = attention_mask.to(get_current_device()) - gt_answer = None - if "gt_answer" in kwargs: - gt_answer = kwargs.pop("gt_answer") + gt_answer = kwargs.pop("gt_answer", None) + test_cases = kwargs.pop("test_cases", None) if self.num_generations > 1: input_ids = input_ids.repeat_interleave(self.num_generations, dim=0) attention_mask = attention_mask.repeat_interleave(self.num_generations, dim=0) @@ -116,8 +115,9 @@ class TransformersInferenceBackend(BaseInferenceBackend): data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()} if gt_answer is not None: - # repeat gt_answer for each prompt. - data["gt_answer"] = gt_answer.repeat_interleave(self.num_generations, dim=1) + data["gt_answer"] = gt_answer + if test_cases is not None: + data["test_cases"] = test_cases data = {k: v.to(get_current_device()) for k, v in data.items()} return data @@ -270,11 +270,11 @@ class VLLMInferenceBackend(BaseInferenceBackend): } data = {k: v.view(micro_batch_size, -1, v.size(-1)) for k, v in data.items()} - - if "gt_answer" in kwargs: - # repeat gt_answer for each prompt. - data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(data["input_ids"].size(1), dim=1) data = {k: v.to(get_current_device()) for k, v in data.items()} + if "gt_answer" in kwargs: + data["gt_answer"] = kwargs["gt_answer"] + if "test_cases" in kwargs: + data["test_cases"] = kwargs["test_cases"] return data def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index ef6ef5104..cb8ccb172 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -16,7 +16,7 @@ def get_jsonl_size_fast(path: str) -> int: with open(path) as f: lines = f.readlines() lines = [line for line in lines if line.strip()] - return len(lines) - 1 + return len(lines) def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int: @@ -36,7 +36,6 @@ def launch_distributed( train_batch_size: int, train_minibatch_size: int, train_dataset_config: Dict[str, Any], - dataloaders_config: Dict[str, Any], inference_model_config: Dict[str, Any], generate_config: Dict[str, Any], train_model_config: Dict[str, Any], @@ -121,7 +120,6 @@ def launch_distributed( num_episodes=num_episodes, batch_size=inference_batch_size, train_dataset_config=train_dataset_config, - dataloaders_config=dataloaders_config, model_config=inference_model_config, generate_config=generate_config, tokenizer_config=tokenizer_config, diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 47bdc5dd5..9fdaed058 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -8,8 +8,8 @@ import ray.util.collective as cc import torch import tqdm import wandb -from coati.dataset.loader import RawConversationDataset -from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn +from coati.dataset.loader import RawConversationDataset, collate_fn_grpo +from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn from ray.util.collective import allreduce from ray.util.collective.types import ReduceOp from torch.utils.data import DataLoader, DistributedSampler @@ -34,7 +34,6 @@ class BaseProducer: num_episodes: int, batch_size: int, train_dataset_config: Dict[str, Any], - dataloaders_config: Dict[str, Any], model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer_config: Optional[Dict[str, Any]] = None, @@ -73,6 +72,13 @@ class BaseProducer: self.eval_mode = False self.log_rollout_interval = log_rollout_interval self.latest_rollout_log_step = -1 + self.grpo_config = grpo_config + reward_model_kwargs = { + k: v + for k, v in grpo_config.items() + if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"] + } + self.response_format_tags = grpo_config.get("response_format_tags", None) if producer_idx == 0: if os.path.exists(rollout_log_file): raise ValueError( @@ -118,7 +124,16 @@ class BaseProducer: ), num_workers=4, drop_last=True, + collate_fn=collate_fn_grpo, ) + if grpo_config["reward_fn_type"] == "think_answer_tags": + self.evaluation_function = math_reward_fn + elif grpo_config["reward_fn_type"] == "boxed": + self.evaluation_function = boxed_math_reward_fn + elif grpo_config["reward_fn_type"] == "code": + self.evaluation_function = code_reward_fn + else: + raise ValueError(f"Unknown evaluation function type {grpo_config['reward_fn_type']}") self.eval_dataset_config = eval_dataset_config if self.eval_dataset_config is not None: @@ -140,6 +155,7 @@ class BaseProducer: drop_last=False, seed=42, ), + collate_fn=collate_fn_grpo, ) if evaluation_function_type == "think_answer_tags": self.evaluation_function = math_reward_fn @@ -212,7 +228,13 @@ class BaseProducer: eval_results = eval_results + [ self.evaluation_function( eval_outputs["input_ids"][m][n], - eval_outputs["gt_answer"][m][n], + eval_outputs[ + ( + "test_cases" + if self.grpo_config["reward_fn_type"] == "code" + else "gt_answer" + ) + ][m], eval_outputs["response_idx"][m][n], tokenizer=self.tokenizer, eval_mode=True, @@ -311,7 +333,6 @@ class SimpleProducer(BaseProducer): num_episodes, batch_size, train_dataset_config, - dataloaders_config, model_config, generate_config, tokenizer_config=None, @@ -338,7 +359,6 @@ class SimpleProducer(BaseProducer): num_episodes, batch_size, train_dataset_config, - dataloaders_config, model_config, generate_config, tokenizer_config, diff --git a/applications/ColossalChat/coati/distributed/reward/code_reward/testing_util.py b/applications/ColossalChat/coati/distributed/reward/code_reward/testing_util.py new file mode 100644 index 000000000..c6d6d8fad --- /dev/null +++ b/applications/ColossalChat/coati/distributed/reward/code_reward/testing_util.py @@ -0,0 +1,669 @@ +# Code from the verl Project (https://github.com/agentica-project/rllm), +# which itself is adapted from Prime (https://github.com/PRIME-RL/PRIME) +# +# Copyright 2024 ByteDance Group + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import ast +import faulthandler +import json +import platform + +# to run the solution files we're using a timing based approach +import signal +import sys +import traceback + +# used for debugging to time steps +from datetime import datetime +from enum import Enum + +# for capturing the stdout +from io import StringIO + +# used for testing the code that reads from input +from unittest.mock import mock_open, patch + +import numpy as np +from pyext import RuntimeModule + + +def truncatefn(s, length=300): + assert isinstance(s, str) + if len(s) <= length: + return s + + return s[: length // 2] + "...(truncated) ..." + s[-length // 2 :] + + +class CODE_TYPE(Enum): + call_based = 0 + standard_input = 1 + + +# used to capture stdout as a list +# from https://stackoverflow.com/a/16571630/6416660 +# alternative use redirect_stdout() from contextlib +class Capturing(list): + def __enter__(self): + self._stdout = sys.stdout + sys.stdout = self._stringio = StringIO() + # Make closing the StringIO a no-op + self._stringio.close = lambda x: 1 + return self + + def __exit__(self, *args): + self.append(self._stringio.getvalue()) + del self._stringio # free up some memory + sys.stdout = self._stdout + + +def only_int_check(val): + return isinstance(val, int) + + +def string_int_check(val): + return isinstance(val, str) and val.isdigit() + + +def combined_int_check(val): + return only_int_check(val) or string_int_check(val) + + +def clean_traceback(error_traceback): + file_start = error_traceback.find('File ""') + # print(file_start) + error_traceback = "Traceback (most recent call last):\n " + error_traceback[file_start:] + return error_traceback + + +def run_test(in_outs, test=None, debug=False, timeout=15): + """ + if test(generated_code) is not None it'll try to run the code. + otherwise it'll just return an input and output pair. + """ + # Disable functionalities that can make destructive changes to the test. + reliability_guard() + + if debug: + print(f"start = {datetime.now().time()}") + + if in_outs: + if in_outs.get("fn_name") is None: + which_type = CODE_TYPE.standard_input # Standard input + method_name = None + else: + which_type = CODE_TYPE.call_based # Call-based + method_name = in_outs["fn_name"] + + if debug: + print(f"loaded input_output = {datetime.now().time()}") + + if test is None: + raise AssertionError("should not happen: test code is none") + elif test is not None: + results = [] + sol = "from string import *\nfrom re import *\nfrom datetime import *\nfrom collections import *\nfrom heapq import *\nfrom bisect import *\nfrom copy import *\nfrom math import *\nfrom random import *\nfrom statistics import *\nfrom itertools import *\nfrom functools import *\nfrom operator import *\nfrom io import *\nfrom sys import *\nfrom json import *\nfrom builtins import *\nfrom typing import *\nimport string\nimport re\nimport datetime\nimport collections\nimport heapq\nimport bisect\nimport copy\nimport math\nimport random\nimport statistics\nimport itertools\nimport functools\nimport operator\nimport io\nimport sys\nimport json\nsys.setrecursionlimit(6*10**5)\n" # noqa: E501 + if debug: + print(f"loading test code = {datetime.now().time()}") + + if which_type == CODE_TYPE.call_based: + sol += test + if debug: + print(f"sol = {sol}") + signal.alarm(timeout) + try: + tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol) + tmp = tmp_sol if "class Solution" not in test else tmp_sol.Solution() + signal.alarm(0) + except Exception as e: + signal.alarm(0) + error_traceback = traceback.format_exc() + if debug: + print(f"type 0 compilation error = {e}") + results.append(-2) + return results, { + "error": repr(e), + # "error_code": -1, + # "error_message": "Compilation Error", + "traceback": clean_traceback(error_traceback), + } + signal.alarm(0) + + elif which_type == CODE_TYPE.standard_input: + # sol + # if code has if __name__ == "__main__": then remove it + try: + astree = ast.parse(test) + last_block = astree.body[-1] + if isinstance(last_block, ast.If): + condition = last_block.test + if ast.unparse(condition).strip() == "__name__ == '__main__'": + test = ast.unparse(astree.body[:-1]) + "\n" + ast.unparse(last_block.body) + except Exception: + pass + + tmp_test = test.split("\n") + + new_test = [] + for x in tmp_test: + if (not x.startswith("from ")) and (not x.startswith("import ")): + new_test.append("\t" + x + "\n") + else: + new_test.append(x + "\n") + tmp_test = new_test + + new_test = "" + started = False + for i in tmp_test: + if i.startswith("\t") and not started: + new_test += "stdin = sys.stdin\nstdout = sys.stdout\n" + new_test += "def code():\n" + new_test += i + started = True + elif started and ((i.startswith("from ")) or (i.startswith("import "))): + new_test += "\t" + i + else: + new_test += i + tmp_test = new_test + + sol += tmp_test + if debug: + print(f"sol = {sol}") + method_name = "code" + signal.alarm(timeout) + try: + tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol) + tmp = tmp_sol + signal.alarm(0) + except Exception as e: + signal.alarm(0) + error_traceback = traceback.format_exc() + if debug: + print(f"type 1 compilation error = {e}") + results.append(-2) + return results, { + "error": repr(e), + # "error_code": -1, + # "error_message": "Compilation Error", + "traceback": clean_traceback(error_traceback), + } + signal.alarm(0) + if debug: + print(f"get method = {datetime.now().time()}") + + try: + method = getattr(tmp, method_name) # get_attr second arg must be str + except Exception: + signal.alarm(0) + error_traceback = traceback.format_exc() + error_info = sys.exc_info() + print(f"unable to get function error = {error_info}") + results.append(-2) + return results, { + "error": repr(error_info), + # "error_code": -1, + # "error_message": "Unable to extract code", + "traceback": clean_traceback(error_traceback), + } + + for index, inputs in enumerate(in_outs["inputs"]): + raw_inputs = inputs + raw_outputs = in_outs["outputs"][index] + if which_type == CODE_TYPE.call_based: + inputs = [json.loads(line) for line in inputs.split("\n")] + in_outs["outputs"][index] = json.loads(in_outs["outputs"][index]) + + truncate_line_size = 300 // (raw_inputs.count("\n") + 1) + raw_inputs = "\n".join( + [truncatefn(line, truncate_line_size) for line in raw_inputs.strip().split("\n")] + ) + raw_outputs = truncatefn(raw_outputs, 200) + else: + raw_inputs = truncatefn(raw_inputs) + raw_outputs = truncatefn(raw_outputs, 200) + # JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list) + try: + if isinstance(inputs[0], dict): + inputs = [{int(k): v for k, v in inputs[0].items()}] + except Exception: + pass + try: + if isinstance(in_outs["outputs"][index], dict): + in_outs["outputs"][index] = [{int(k): v for k, v in in_outs["outputs"][index].items()}] + except Exception: + pass + try: + if isinstance(in_outs["outputs"][index][0], dict): + in_outs["outputs"][index] = [{int(k): v for k, v in in_outs["outputs"][index][0].items()}] + except Exception: + pass + + if debug: + print( + f"time: {datetime.now().time()} testing index = {index} inputs = {inputs}, {type(inputs)}. type = {which_type}" + ) + if which_type == CODE_TYPE.call_based: # Call-based + signal.alarm(timeout) + faulthandler.enable() + try: + output = method(*inputs) + raw_true_output = output + + raw_true_output_copy = json.dumps(output) + raw_true_output_copy = truncatefn(raw_true_output_copy, 200) + + # ground truth sequences are not tuples + if isinstance(output, tuple): + output = list(output) + + tmp_result = output == in_outs["outputs"][index] + if isinstance(in_outs["outputs"][index], list) and in_outs["outputs"][index]: + tmp_result = tmp_result or (output == in_outs["outputs"][index][0]) + + # ground truth sequences are not tuples + try: + if isinstance(output[0], tuple): + tmp_result = tmp_result or ([list(x) for x in output] == in_outs["outputs"][index][0]) + except Exception: + pass + results.append(tmp_result) + if tmp_result is not True: + return results, { + "output": raw_true_output_copy, + "expected": raw_outputs, + "inputs": raw_inputs, + # "error_code": -2, + "error_message": "Wrong Answer", + } + # reset the alarm + signal.alarm(0) + except Exception as e: + signal.alarm(0) + error_traceback = traceback.format_exc() + faulthandler.disable() + if debug: + print(f"Standard input runtime error or time limit exceeded error = {e}") + results.append(-1) + return results, { + "error": repr(e), + "traceback": clean_traceback(error_traceback), + } + faulthandler.disable() + signal.alarm(0) + if debug: + print( + f"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + ) + elif which_type == CODE_TYPE.standard_input: # Standard input + faulthandler.enable() + passed = False + + if isinstance(inputs, list): + inputs = "\n".join(inputs) + if isinstance(in_outs["outputs"][index], list): + in_outs["outputs"][index] = "\n".join(in_outs["outputs"][index]) + + signal.alarm(timeout) + with Capturing() as output: + try: + call_method(method, inputs) + # reset the alarm + signal.alarm(0) + passed = True + except Exception as e: + # runtime error or took too long + signal.alarm(0) + error_traceback = traceback.format_exc() + print(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}") + results.append(-1) + return results, { + "error": repr(e), + "traceback": clean_traceback(error_traceback), + } + signal.alarm(0) + raw_true_output = output[0] + raw_true_output_copy = truncatefn(raw_true_output, 200) + output = raw_true_output.splitlines() + if not passed: + if debug: + nl = "\n" + if not isinstance(inputs, list): + print( + f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl, ' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + ) + else: + print( + f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + ) + continue + + if passed and debug: + print(f"==> output = {output}, test outputs = {in_outs['outputs'][index]}") + + if custom_compare_(output, in_outs["outputs"][index]): + tmp_result = True + results.append(tmp_result) + continue + + # ground truth sequences are expressed as lists not tuples + if isinstance(output, tuple): + output = list(output) + + tmp_result = False + try: + tmp_result = output == [in_outs["outputs"][index]] + if isinstance(in_outs["outputs"][index], list): + tmp_result = tmp_result or (output == in_outs["outputs"][index]) + if isinstance(output[0], str): + tmp_result = tmp_result or ([e.strip() for e in output] == in_outs["outputs"][index]) + except Exception as e: + if debug: + print(f"Failed check1 exception = {e}") + + if tmp_result is True: + results.append(tmp_result) + continue + + # try one more time without \n + if isinstance(in_outs["outputs"][index], list): + for tmp_index, i in enumerate(in_outs["outputs"][index]): + in_outs["outputs"][index][tmp_index] = i.split("\n") + in_outs["outputs"][index][tmp_index] = [ + x.strip() for x in in_outs["outputs"][index][tmp_index] if x + ] + else: + in_outs["outputs"][index] = in_outs["outputs"][index].split("\n") + in_outs["outputs"][index] = list(filter(len, in_outs["outputs"][index])) + in_outs["outputs"][index] = list(map(lambda x: x.strip(), in_outs["outputs"][index])) + + try: + tmp_result = output == [in_outs["outputs"][index]] + if isinstance(in_outs["outputs"][index], list): + tmp_result = tmp_result or (output == in_outs["outputs"][index]) + except Exception as e: + if debug: + print(f"Failed check2 exception = {e}") + + if tmp_result is True: + results.append(tmp_result) + continue + + # try by converting the output into a split up list too + if isinstance(output, list): + output = list(filter(len, output)) + + if debug: + nl = "\n" + if not isinstance(inputs, list): + print( + f"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl, ' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]} {tmp_result=}" + ) + else: + print( + f"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]} {tmp_result=}" + ) + + if debug: + print(f"{tmp_result=} @a") + + try: + tmp_result = output == [in_outs["outputs"][index]] + if isinstance(in_outs["outputs"][index], list): + tmp_result = tmp_result or (output == in_outs["outputs"][index]) + except Exception as e: + if debug: + print(f"Failed check3 exception = {e}") + + if debug: + print(f"{tmp_result=} @b") + + try: + all_ints = all( + combined_int_check(e1) and combined_int_check(e2) + for e1, e2 in zip(output, in_outs["outputs"][index]) + ) + if not all_ints: + if debug: + print( + [ + combined_int_check(e1) and combined_int_check(e2) + for e1, e2 in zip(output, in_outs["outputs"][index]) + ] + ) + output_float = [float(e) for e in output] + gt_float = [float(e) for e in in_outs["outputs"][index]] + tmp_result = tmp_result or ( + (len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float) + ) + except Exception: + pass + + if debug: + print(f"{tmp_result=} @c") + + try: + if isinstance(output[0], list): + all_ints = all( + combined_int_check(e1) and combined_int_check(e2) + for e1, e2 in zip(output[0], in_outs["outputs"][index]) + ) + if not all_ints: + output_float = [float(e) for e in output[0]] + gt_float = [float(e) for e in in_outs["outputs"][index][0]] + tmp_result = tmp_result or ( + (len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float) + ) + except Exception: + pass + + if tmp_result is True: + results.append(tmp_result) + continue + + if debug: + print(f"{tmp_result=} @d") + # try by converting the stuff into split up list + if isinstance(in_outs["outputs"][index], list): + for tmp_index, i in enumerate(in_outs["outputs"][index]): + in_outs["outputs"][index][tmp_index] = set(i.split()) + else: + in_outs["outputs"][index] = set(in_outs["outputs"][index].split()) + + if debug: + print(f"{tmp_result=} @e") + + try: + tmp_result = output == in_outs["outputs"][index] + except Exception as e: + if debug: + print(f"Failed check4 exception = {e}") + continue + + if tmp_result is True: + results.append(tmp_result) + continue + + if debug: + print(f"{tmp_result=} @f") + + # try by converting the output into a split up list too + if isinstance(output, list): + for tmp_index, i in enumerate(output): + output[tmp_index] = i.split() + output = list(filter(len, output)) + for tmp_index, i in enumerate(output): + output[tmp_index] = set(i) + else: + output = output.split() + output = list(filter(len, output)) + output = set(output) + + if debug: + print(f"{tmp_result=} @g") + + if tmp_result is True and debug: + print("PASSED") + + results.append(tmp_result) + if tmp_result is not True: + return results, { + "output": raw_true_output_copy, + "expected": raw_outputs, + "inputs": raw_inputs, + # "error_code": -2, + "error_message": "Wrong Answer", + } + + if debug: + nl = "\n" + if not isinstance(inputs, list): + print( + f"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl, ' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + ) + else: + print( + f"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + ) + + print(f"results = {results}") + + return results, {} + + +def custom_compare_(output, ground_truth): + if isinstance(output, list): + output_1 = "\n".join(output) + if stripped_string_compare(output_1, ground_truth): + return True + + if isinstance(output, list): + output_2 = [o.lstrip().rstrip() for o in output] + output_2 = "\n".join(output_2) + if stripped_string_compare(output_2, ground_truth): + return True + + return False + + +def stripped_string_compare(s1, s2): + s1 = s1.lstrip().rstrip() + s2 = s2.lstrip().rstrip() + return s1 == s2 + + +def call_method(method, inputs): + if isinstance(inputs, list): + inputs = "\n".join(inputs) + + inputs_line_iterator = iter(inputs.split("\n")) + + # sys.setrecursionlimit(10000) + + # @patch('builtins.input', side_effect=inputs.split("\n")) + @patch("builtins.open", mock_open(read_data=inputs)) + @patch("sys.stdin", StringIO(inputs)) + @patch("sys.stdin.readline", lambda *args: next(inputs_line_iterator)) + @patch("sys.stdin.readlines", lambda *args: inputs.split("\n")) + @patch("sys.stdin.read", lambda *args: inputs) + # @patch('sys.stdout.write', print) + def _inner_call_method(_method): + try: + return _method() + except SystemExit: + pass + finally: + pass + + return _inner_call_method(method) + + +def reliability_guard(maximum_memory_bytes=None): + """ + This disables various destructive functions and prevents the generated code + from interfering with the test (e.g. fork bomb, killing other processes, + removing filesystem files, etc.) + WARNING + This function is NOT a security sandbox. Untrusted code, including, model- + generated code, should not be blindly executed outside of one. See the + Codex paper for more information about OpenAI's code sandbox, and proceed + with caution. + """ + + if maximum_memory_bytes is not None: + import resource + + resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) + resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) + if platform.uname().system != "Darwin": + resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) + + faulthandler.disable() + + import builtins + + builtins.exit = None + builtins.quit = None + + import os + + os.environ["OMP_NUM_THREADS"] = "1" + + os.kill = None + os.system = None # 防止干扰repl评测 + os.putenv = None + os.remove = None + os.removedirs = None + os.rmdir = None + os.fchdir = None + os.setuid = None + os.fork = None + os.forkpty = None + os.killpg = None + os.rename = None + os.renames = None + os.truncate = None + os.replace = None + os.unlink = None + os.fchmod = None + os.fchown = None + os.chmod = None + os.chown = None + os.chroot = None + os.lchflags = None + os.lchmod = None + os.lchown = None + os.getcwd = None + os.chdir = None + + import shutil + + shutil.rmtree = None + shutil.move = None + shutil.chown = None + + import subprocess + + subprocess.Popen = None # type: ignore + + __builtins__["help"] = None + + import sys + + sys.modules["ipdb"] = None + sys.modules["joblib"] = None + sys.modules["resource"] = None + sys.modules["psutil"] = None + sys.modules["tkinter"] = None diff --git a/applications/ColossalChat/coati/distributed/reward/code_reward/utils.py b/applications/ColossalChat/coati/distributed/reward/code_reward/utils.py new file mode 100644 index 000000000..e19e9e387 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/reward/code_reward/utils.py @@ -0,0 +1,61 @@ +# Code from the verl Project (https://github.com/agentica-project/rllm), +# which itself is adapted from Prime (https://github.com/PRIME-RL/PRIME) +# +# Copyright 2024 ByteDance Group + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing +import os +import sys +import traceback +from typing import Optional + +from .testing_util import run_test + + +def _temp_run(sample, generation, debug, result, metadata_list, timeout): + with open(os.devnull, "w") as devnull: + sys.stdout = devnull + sys.stderr = devnull + try: + res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout) + result.append(res) + metadata_list.append(metadata) + except Exception: + # print(e) # some tracebacks are extremely long. + traceback.print_exc(10) + result.append([-1 for i in range(len(sample["inputs"]))]) + metadata_list.append({}) + + +def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=True): + """Check correctness of code generation with a global timeout. + The global timeout is to catch some extreme/rare cases not handled by the timeouts + inside `run_test`""" + + manager = multiprocessing.Manager() + result = manager.list() + metadata_list = manager.list() + p = multiprocessing.Process(target=_temp_run, args=(in_outs, generation, debug, result, metadata_list, timeout)) + p.start() + p.join(timeout=timeout + 1) + if p.is_alive(): + p.kill() + # p.terminate() + if not result: + # consider that all tests failed + result = [[-1 for i in range(len(in_outs["inputs"]))]] + if debug: + print("global timeout") + return result[0], metadata_list diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index a4042ae97..5c8810fdf 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -1,7 +1,30 @@ +# Copyright 2024 ByteDance Group + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Some functions in this file are adapted from the verl project +under the Apache License 2.0: +https://github.com/volcengine/verl +""" + + +import json + import torch from latex2sympy2_extended import NormalizationConfig from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify +from .code_reward.utils import check_correctness as check_correctness_code from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure CANNOT_PARSE_GT_ANSWER = -1 @@ -98,15 +121,11 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): raise ValueError("no gt_answer is provided, please check your training dataset.") decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) - gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True) + final_answer, processed_str = extract_solution(decoded_final_answer) format_valid = validate_response_structure(processed_str, kwargs["tags"]) - # Check format accuracy - if format_valid: - format_acc += 1 - # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid if final_answer is not None: if eval_mode or format_valid: @@ -114,6 +133,10 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): if not eval_mode: reward = reward + length_reward + # Check format accuracy + if format_valid: + format_acc += 1 + # Check if the sequence is over length if not eval_mode and res_length >= max_new_tokens: reward *= 0.0 @@ -138,7 +161,6 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): tokenizer = kwargs["tokenizer"] eval_mode = kwargs.get("eval_mode", False) soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False) - format_score = 0.0 acc_score = 10.0 reward = torch.tensor(0.0) format_acc = torch.tensor(0.0) @@ -161,7 +183,6 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) - gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True) final_answer = extract_boxed_solution(decoded_final_answer) format_valid = final_answer is not None if "tags" in kwargs and kwargs["tags"]: @@ -169,10 +190,6 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): format_valid = format_valid and all( [decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags] ) - # Check format accuracy - if format_valid: - format_acc += 1 - reward += format_score # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid if final_answer is not None: @@ -181,6 +198,10 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): if not eval_mode: reward = reward + length_reward + # Check format accuracy + if format_valid: + format_acc += 1 + # Check if the sequence is over length if not eval_mode and res_length >= max_new_tokens: reward *= 0.0 @@ -199,3 +220,78 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): "response_length": res_length, "reward": reward.item(), } + + +def code_reward_fn(input_ids, test_cases, response_idx, **kwargs): + tokenizer = kwargs["tokenizer"] + eval_mode = kwargs.get("eval_mode", False) + soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False) + acc_score = 10.0 + reward = torch.tensor(0.0) + format_acc = torch.tensor(0.0) + ans_acc = torch.tensor(0.0) + s, e = response_idx[0], response_idx[1] + + length_reward = 0.0 + res_length = e.item() - s.item() + 1 + if not eval_mode: + max_new_tokens = kwargs["max_new_tokens"] + else: + max_new_tokens = -1 # for eval mode, we don't need to check the length + if not eval_mode and soft_over_length_punishment: + cache_length = kwargs["cache_length"] + if max_new_tokens - cache_length < res_length < max_new_tokens: + length_reward = ((max_new_tokens - cache_length) - res_length) / cache_length * acc_score + + # try to get code solution from completion. if the completion is pure code, this will not take effect. + decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) + + solution = decoded_final_answer.split("```python")[-1].split("```")[0] + format_valid = False + if "```python" in decoded_final_answer: + format_valid = solution is not None + + # Check format accuracy + if format_valid: + format_acc += 1 + + try: + try: + if not isinstance(test_cases, dict): + test_cases = json.loads(test_cases) + except Exception as e: + print(f"Error {e}: Cannot parse test cases.") + raise e + # Complete check on all in-out pairs first. If there is no failure, per-sample test can be skipped. + try: + res, metadata = check_correctness_code(in_outs=test_cases, generation=solution, timeout=10, debug=True) + metadata = dict(enumerate(metadata))[0] + success = all(map(lambda x: x is True, res)) + if success: + ans_acc += 1 + if eval_mode or format_valid: + reward += acc_score + if not eval_mode: + reward = reward + length_reward + except Exception: + pass + + # Check if the sequence is over length + if not eval_mode and res_length >= max_new_tokens: + reward *= 0.0 + except Exception: + pass + if not eval_mode: + return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) + else: + prompt = tokenizer.decode(input_ids[:s], skip_special_tokens=True) + return { + "prompt": prompt, + "prediction": decoded_final_answer, + "gold": test_cases["outputs"], + "parsed": solution, + "format_valid": format_acc.item(), + "ans_valid": ans_acc.item(), + "response_length": res_length, + "reward": reward.item(), + } diff --git a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py index ba83f7787..2f0b3778c 100644 --- a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py +++ b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py @@ -2,6 +2,7 @@ Function-based reward verification module. """ +import inspect from typing import Any, Dict, List import torch @@ -15,7 +16,8 @@ class VerifiableReward: def __call__( self, input_ids: torch.LongTensor, - gt_answer: List[torch.Tensor] = None, + gt_answer: List[str] = None, + test_cases: List[str] = None, response_idx: List[torch.Tensor] = None, ) -> torch.Tensor: # Get batch size @@ -26,18 +28,44 @@ class VerifiableReward: # Loop through reward functions for reward_fn in self.reward_fns: # Apply the reward function to the entire batch at once - reward_batch = torch.stack( - [ - reward_fn( - input_ids[i], - gt_answer=gt_answer[i], - response_idx=response_idx[i], - **self.kwargs, - ) - for i in range(bs) - ], - dim=0, - ) + if "gt_answer" in inspect.getfullargspec(reward_fn).args: + reward_batch = torch.stack( + [ + reward_fn( + input_ids[i], + gt_answer=gt_answer[i], + response_idx=response_idx[i], + **self.kwargs, + ) + for i in range(bs) + ], + dim=0, + ) + elif "test_cases" in inspect.getfullargspec(reward_fn).args: + reward_batch = torch.stack( + [ + reward_fn( + input_ids[i], + test_cases=test_cases[i], + response_idx=response_idx[i], + **self.kwargs, + ) + for i in range(bs) + ], + dim=0, + ) + else: + reward_batch = torch.stack( + [ + reward_fn( + input_ids[i], + response_idx=response_idx[i], + **self.kwargs, + ) + for i in range(bs) + ], + dim=0, + ) rewards += reward_batch return rewards diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index a40ebbcfb..d46243114 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -71,31 +71,43 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T return per_label_logps.squeeze(-1) -def calc_action_log_probs( +def memory_efficient_logprob( logits: torch.Tensor, - sequences: torch.LongTensor, - num_actions: int, - shard_config, + inputs: torch.Tensor, + num_action: int, + chunk_size: int = 2048, + shard_config: Any = None, vocab_size: int = None, ) -> torch.Tensor: - """Calculate action log probs. - + """ + Calculate action log probs in a memory-efficient way by processing in chunks. Args: logits (torch.Tensor): Output tensor of Actor.forward.logits. - sequences (torch.LongTensor): Input sequences. - num_actions (int): Number of actions. - shard_config - vocab_size - - + inputs (torch.LongTensor): Input sequences. + num_action (int): Number of actions. + chunk_size (int, optional): Size of each chunk to process. Default is 2048. + shard_config: Shard configuration for distributed computation. + vocab_size (int, optional): Vocabulary size. Default is None. Returns: torch.Tensor: Action log probs. """ - # labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] - # logits: torch.Tensor, # [B, S, Vocab_size] - log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype) - log_probs = log_probs.squeeze(-1) - return log_probs[:, -num_actions:] + action_log_probs = torch.zeros((logits.size(0), num_action), device=logits.device, dtype=logits.dtype) + context_length = logits.size(1) - num_action + for i in range(action_log_probs.size(0)): + # loop over each sample in the micro-batch + for start in range(context_length, logits.size(1), chunk_size): + end = min(start + chunk_size, logits.size(1)) + # calculate log probs in chunks to save memory + log_probs = dist_log_prob( + inputs[i : i + 1, start - 1 : end], + logits[i : i + 1, start - 1 : end], + shard_config, + vocab_size, + logits.dtype, + ) # [1, chunk_size, 1] + log_probs = log_probs.squeeze(-1) + action_log_probs[i, start - context_length : end - context_length] += log_probs[0] + return action_log_probs def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 0b7bde6b0..9f6e895e4 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -6,6 +6,12 @@ import ray import torch from coati.distributed.launch import launch_distributed +DEFAUT_SYSTEM_PROMPT = { + "think_answer_tags": "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the tags, i.e., 123 .\n\n", + "boxed": "Please reason step by step, and put your final answer within \\boxed{}.", + "code": "You are a helpful assistant.", +} + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") @@ -101,7 +107,7 @@ if __name__ == "__main__": "--reward-type", type=str, default="think_answer_tags", - choices=["think_answer_tags", "boxed"], + choices=["think_answer_tags", "boxed", "code"], help="Reward type for GRPO.", ) parser.add_argument( @@ -167,10 +173,29 @@ if __name__ == "__main__": if args.master_address is None: # Default settings: Using single machine - ray.init(address="local", namespace="ray-example") + ray.init( + address="local", + namespace="ray-example", + runtime_env={ + "env_vars": { + # "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray + "TOKENIZERS_PARALLELISM": "false" + }, + }, + ) else: # For ray distributed multi-machine training, Please change _node_ip_address to your IP address of your master node - ray.init(_node_ip_address=args.master_address, namespace="ray-example", _temp_dir=args.ray_dir) + ray.init( + _node_ip_address=args.master_address, + namespace="ray-example", + _temp_dir=args.ray_dir, + runtime_env={ + "env_vars": { + # "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray + "TOKENIZERS_PARALLELISM": "false" + }, + }, + ) if args.top_k is None: if args.backend == "transformers": @@ -178,6 +203,8 @@ if __name__ == "__main__": elif args.backend == "vllm": args.top_k = -1 + os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock + inference_model_config = dict(path=args.model) train_model_config = dict( path=args.model, use_flash_attention_2=False, use_cache=False, attn_implementation="eager" @@ -276,6 +303,10 @@ if __name__ == "__main__": else: raise ValueError(f"Unsupported algorithm: {args.algo}") + if args.system_prompt is None: + # Default system prompt + args.system_prompt = DEFAUT_SYSTEM_PROMPT[args.reward_type] + launch_distributed( num_producers=args.num_inferencer, num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.producer_tensor_parallel_size), @@ -290,7 +321,6 @@ if __name__ == "__main__": "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt, }, - dataloaders_config={}, inference_model_config=inference_model_config, generate_config=generate_config, num_generations=args.num_generations, diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 823527df6..d2798a23e 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -302,6 +302,17 @@ class Qwen2Policy(Policy): ) if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: + if not self.shard_config.enable_tensor_parallelism: + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size, + "self_attn.num_heads": self.model.config.num_attention_heads, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["self_attn.num_key_value_heads"] = ( + self.model.config.num_key_value_heads + ) + policy[Qwen2DecoderLayer] = ModulePolicyDescription(attribute_replacement=decoder_attribute_replacement) + self.append_or_create_method_replacement( description={ "forward": get_qwen2_flash_attention_npu_forward(self.shard_config, sp_mode, sp_size, sp_group),