[fix] revert reward update and evaluation (#6295)

* Revert "rewrite reward fn"

This reverts commit d06042b434.

* Revert "upgrade reward math verification"

This reverts commit a6085ff676.

* Revert "fix bug"

This reverts commit 01640ebd65.

* Revert "reuse comm-group"

This reverts commit bd61918dcf.

* Revert "Support evaluation during training"

This reverts commit 57a88395fe.
This commit is contained in:
YeAnbang 2025-05-07 10:56:47 +08:00 committed by GitHub
parent 17928ad84f
commit eb6b5dd62e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 82 additions and 307 deletions

1
.gitignore vendored
View File

@ -165,4 +165,3 @@ applications/ColossalChat/logs
applications/ColossalChat/tests/logs applications/ColossalChat/tests/logs
applications/ColossalChat/wandb applications/ColossalChat/wandb
applications/ColossalChat/model applications/ColossalChat/model
applications/ColossalChat/eval

View File

@ -36,7 +36,6 @@ class BaseConsumer:
minibatch_size: int = 1, minibatch_size: int = 1,
save_interval: int = 100, save_interval: int = 100,
save_dir: str = "./model", save_dir: str = "./model",
eval_interval: int = -1,
): ):
self.num_producers = num_producers self.num_producers = num_producers
self.num_episodes = num_episodes self.num_episodes = num_episodes
@ -52,7 +51,6 @@ class BaseConsumer:
self.save_dir = save_dir self.save_dir = save_dir
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size" assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
self.num_microbatches = batch_size // minibatch_size self.num_microbatches = batch_size // minibatch_size
self.eval_interval = eval_interval
self.model_config = model_config self.model_config = model_config
self.plugin_config = plugin_config self.plugin_config = plugin_config
@ -95,6 +93,7 @@ class BaseConsumer:
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model") cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
self.buffer = [] self.buffer = []
self.recv_cnt = 0 self.recv_cnt = 0
def state_dict(self) -> Dict[str, torch.Tensor]: def state_dict(self) -> Dict[str, torch.Tensor]:
@ -111,27 +110,6 @@ class BaseConsumer:
with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar: with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar:
for step in pbar: for step in pbar:
i = 0 i = 0
if self.eval_interval > 0 and step % self.eval_interval == 0:
eval_statistics = None
eval_global_step = None
for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}")
local_eval_result = ray_broadcast_tensor_dict(
None, src=0, device=self.device, group_name=f"sync_data_{r}"
)
assert "consumer_global_step" in local_eval_result
eval_global_step = local_eval_result.pop("consumer_global_step").item()
if eval_statistics is None:
eval_statistics = local_eval_result
else:
eval_statistics = {
k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics
}
eval_statistics = {"eval/" + k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
if dist.get_rank() == 0:
if hasattr(self, "wandb_run"):
self.wandb_run.log(eval_statistics, step=eval_global_step)
print(f"Eval statistics: {eval_statistics}")
for _ in range(self.num_recv_per_update): for _ in range(self.num_recv_per_update):
# receive data from producers # receive data from producers
for r in range(self.num_producers): for r in range(self.num_producers):
@ -217,7 +195,6 @@ class SimpleConsumer(BaseConsumer):
minibatch_size=1, minibatch_size=1,
save_interval: int = 100, save_interval: int = 100,
save_dir="./model", save_dir="./model",
eval_interval: int = -1,
): ):
super().__init__( super().__init__(
num_producers, num_producers,
@ -232,9 +209,6 @@ class SimpleConsumer(BaseConsumer):
model_config, model_config,
plugin_config, plugin_config,
minibatch_size, minibatch_size,
save_interval,
save_dir,
eval_interval,
) )
path = model_config.pop("path") path = model_config.pop("path")
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)

View File

@ -40,7 +40,6 @@ class GRPOConsumer(BaseConsumer):
project_name=None, project_name=None,
save_interval: int = 100, save_interval: int = 100,
save_dir="./model", save_dir="./model",
eval_interval: int = -1,
): ):
print(f"Using GRPO config: {grpo_config}") print(f"Using GRPO config: {grpo_config}")
if grpo_config.get("loss_variation", "sample_level") == "token_level": if grpo_config.get("loss_variation", "sample_level") == "token_level":
@ -73,7 +72,6 @@ class GRPOConsumer(BaseConsumer):
minibatch_size, minibatch_size,
save_interval=save_interval, save_interval=save_interval,
save_dir=save_dir, save_dir=save_dir,
eval_interval=eval_interval,
) )
path = model_config.pop("path") path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@ -530,5 +528,4 @@ class GRPOConsumer(BaseConsumer):
self.policy_model._force_wait_all_gather() self.policy_model._force_wait_all_gather()
model = self.policy_model.unwrap() model = self.policy_model.unwrap()
state_dict = model.state_dict() state_dict = model.state_dict()
state_dict["consumer_global_step"] = torch.tensor([self.global_step], device=self.device)
return state_dict return state_dict

View File

@ -205,8 +205,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
generate_config = generate_config.copy() generate_config = generate_config.copy()
generate_config.update(self.FORCE_GENERATE_CONFIG) generate_config.update(self.FORCE_GENERATE_CONFIG)
generate_config.update({"n": num_generations}) generate_config.update({"n": num_generations})
self.generate_config = generate_config self.generate_config = SamplingParams(**generate_config)
self.sample_params = SamplingParams(**generate_config)
self.model_config = model_config self.model_config = model_config
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.num_generations = num_generations self.num_generations = num_generations
@ -220,9 +219,8 @@ class VLLMInferenceBackend(BaseInferenceBackend):
micro_batch_input_ids_no_padding = [ micro_batch_input_ids_no_padding = [
micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size) micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
] ]
sample_params = kwargs.get("sample_params", self.sample_params)
outputs = self.llm.generate( outputs = self.llm.generate(
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=sample_params, use_tqdm=False prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
) )
out_tokens = [] out_tokens = []
out_len = [] out_len = []
@ -268,11 +266,11 @@ class VLLMInferenceBackend(BaseInferenceBackend):
"response_idx": response_idx, "response_idx": response_idx,
} }
data = {k: v.view(micro_batch_size, -1, v.size(-1)) for k, v in data.items()} data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
if "gt_answer" in kwargs: if "gt_answer" in kwargs:
# repeat gt_answer for each prompt. # repeat gt_answer for each prompt.
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(data["input_ids"].size(1), dim=1) data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1)
data = {k: v.to(get_current_device()) for k, v in data.items()} data = {k: v.to(get_current_device()) for k, v in data.items()}
return data return data

View File

@ -34,7 +34,7 @@ def launch_distributed(
inference_microbatch_size: int, inference_microbatch_size: int,
train_batch_size: int, train_batch_size: int,
train_minibatch_size: int, train_minibatch_size: int,
train_dataset_config: Dict[str, Any], dataset_config: Dict[str, Any],
dataloaders_config: Dict[str, Any], dataloaders_config: Dict[str, Any],
inference_model_config: Dict[str, Any], inference_model_config: Dict[str, Any],
generate_config: Dict[str, Any], generate_config: Dict[str, Any],
@ -50,9 +50,6 @@ def launch_distributed(
project_name: Optional[str] = None, project_name: Optional[str] = None,
save_interval: int = 100, save_interval: int = 100,
save_dir: str = "./model", save_dir: str = "./model",
eval_dataset_config: Optional[Dict[str, Any]] = None,
eval_interval: int = 100,
eval_save_dir: Optional[str] = None,
): ):
if core_algo not in ALGO_MAP: if core_algo not in ALGO_MAP:
@ -63,9 +60,9 @@ def launch_distributed(
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config) train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0 assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
dataset_path = train_dataset_config["path"] dataset_path = dataset_config["path"]
num_samples = get_jsonl_size_fast(dataset_path) num_samples = get_jsonl_size_fast(dataset_path)
global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer global_inference_batch_size = inference_batch_size * num_producers
num_update_per_episode = num_samples // global_inference_batch_size num_update_per_episode = num_samples // global_inference_batch_size
num_recv_per_update = inference_batch_size // inference_microbatch_size num_recv_per_update = inference_batch_size // inference_microbatch_size
@ -77,7 +74,7 @@ def launch_distributed(
num_consumer_procs=num_consumer_procs, num_consumer_procs=num_consumer_procs,
num_episodes=num_episodes, num_episodes=num_episodes,
batch_size=inference_batch_size, batch_size=inference_batch_size,
train_dataset_config=train_dataset_config, dataset_config=dataset_config,
dataloaders_config=dataloaders_config, dataloaders_config=dataloaders_config,
model_config=inference_model_config, model_config=inference_model_config,
generate_config=generate_config, generate_config=generate_config,
@ -86,10 +83,6 @@ def launch_distributed(
backend=inference_backend, backend=inference_backend,
num_generations=num_generations, num_generations=num_generations,
consumer_plugin_config=plugin_config, consumer_plugin_config=plugin_config,
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval * num_recv_per_update,
evaluation_function_type=grpo_config["reward_fn_type"],
eval_save_dir=eval_save_dir,
) )
procs.append(producer) procs.append(producer)
generate_config_consumer = copy.deepcopy(generate_config) generate_config_consumer = copy.deepcopy(generate_config)
@ -118,7 +111,6 @@ def launch_distributed(
project_name=project_name, project_name=project_name,
save_interval=save_interval, save_interval=save_interval,
save_dir=save_dir, save_dir=save_dir,
eval_interval=eval_interval,
) )
procs.append(consumer) procs.append(consumer)
ray.get([p.setup.remote() for p in procs]) ray.get([p.setup.remote() for p in procs])

View File

@ -1,13 +1,9 @@
import copy
import os
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import ray import ray
import ray.util.collective as cc import ray.util.collective as cc
import torch import torch
import tqdm
from coati.dataset.loader import RawConversationDataset from coati.dataset.loader import RawConversationDataset
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -15,12 +11,7 @@ from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict from .comm import ray_broadcast_tensor_dict
from .inference_backend import BACKEND_MAP from .inference_backend import BACKEND_MAP
from .utils import pre_send, safe_write_jsonl from .utils import pre_send
try:
from vllm import SamplingParams
except ImportError:
LLM = None
class BaseProducer: class BaseProducer:
@ -31,7 +22,7 @@ class BaseProducer:
num_consumer_procs: int, num_consumer_procs: int,
num_episodes: int, num_episodes: int,
batch_size: int, batch_size: int,
train_dataset_config: Dict[str, Any], dataset_config: Dict[str, Any],
dataloaders_config: Dict[str, Any], dataloaders_config: Dict[str, Any],
model_config: Dict[str, Any], model_config: Dict[str, Any],
generate_config: Dict[str, Any], generate_config: Dict[str, Any],
@ -39,10 +30,6 @@ class BaseProducer:
microbatch_size: int = 1, microbatch_size: int = 1,
backend: str = "transformers", backend: str = "transformers",
consumer_plugin_config: Dict[str, Any] = None, consumer_plugin_config: Dict[str, Any] = None,
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
evaluation_function_type="think_answer_tags",
eval_save_dir: str = "./eval",
): ):
self.producer_idx = producer_idx self.producer_idx = producer_idx
self.num_producers = num_producers self.num_producers = num_producers
@ -53,17 +40,10 @@ class BaseProducer:
assert batch_size % microbatch_size == 0 assert batch_size % microbatch_size == 0
self.num_microbatches = batch_size // microbatch_size self.num_microbatches = batch_size // microbatch_size
self.train_dataset_config = train_dataset_config self.dataset_config = dataset_config
self.model_config = model_config self.model_config = model_config
self.generate_config = generate_config self.generate_config = generate_config
self.tokenizer_config = tokenizer_config self.tokenizer_config = tokenizer_config
self.consumer_plugin_config = consumer_plugin_config
self.eval_interval = eval_interval
self.eval_save_dir = eval_save_dir
self.consumer_global_step = 0
if os.path.exists(self.eval_save_dir):
raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.")
# init tokenizer # init tokenizer
if tokenizer_config is None: if tokenizer_config is None:
@ -75,13 +55,13 @@ class BaseProducer:
self.tokenizer.padding_side = "left" self.tokenizer.padding_side = "left"
# init dataloader # init dataloader
train_dataset_path = train_dataset_config.pop("path") dataset_path = dataset_config.pop("path")
self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config) self.dataset = RawConversationDataset(self.tokenizer, dataset_path, **dataset_config)
self.train_dataloader = DataLoader( self.dataloader = DataLoader(
self.train_dataset, self.dataset,
batch_size=microbatch_size, batch_size=microbatch_size,
sampler=DistributedSampler( sampler=DistributedSampler(
self.train_dataset, self.dataset,
num_replicas=num_producers, num_replicas=num_producers,
rank=producer_idx, rank=producer_idx,
shuffle=True, shuffle=True,
@ -91,36 +71,6 @@ class BaseProducer:
num_workers=4, num_workers=4,
drop_last=True, drop_last=True,
) )
self.eval_dataset_config = eval_dataset_config
if self.eval_dataset_config is not None:
self.eval_dataloaders = {}
for eval_task_name in self.eval_dataset_config:
eval_dataset_path = eval_dataset_config[eval_task_name].pop("path")
eval_dataset = RawConversationDataset(
self.tokenizer, eval_dataset_path, **eval_dataset_config[eval_task_name]
)
print(f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}")
self.eval_dataloaders[eval_task_name] = DataLoader(
eval_dataset,
batch_size=microbatch_size,
sampler=DistributedSampler(
eval_dataset,
num_replicas=num_producers,
rank=producer_idx,
shuffle=False,
drop_last=False,
seed=42,
),
)
if evaluation_function_type == "think_answer_tags":
self.evaluation_function = math_reward_fn
elif evaluation_function_type == "boxed":
self.evaluation_function = boxed_math_reward_fn
else:
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
else:
raise ValueError("eval_dataset_config is not defined")
self.device = get_current_device() self.device = get_current_device()
# init backend # init backend
@ -129,7 +79,7 @@ class BaseProducer:
else: else:
raise ValueError(f"Unexpected backend {backend}") raise ValueError(f"Unexpected backend {backend}")
self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size self.consumer_pp_size = consumer_plugin_config["pp_size"] # consumer pp size
def setup(self) -> None: def setup(self) -> None:
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}") cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
@ -146,67 +96,29 @@ class BaseProducer:
raise NotImplementedError raise NotImplementedError
def loop(self) -> None: def loop(self) -> None:
num_update_per_episode = len(self.train_dataloader) // self.num_microbatches num_update_per_episode = len(self.dataloader) // self.num_microbatches
num_valid_microbatches = num_update_per_episode * self.num_microbatches num_valid_microbatches = num_update_per_episode * self.num_microbatches
print( print(
f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.train_dataloader)}" f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.dataloader)}"
) )
for episode in range(self.num_episodes): for episode in range(self.num_episodes):
self.train_dataloader.sampler.set_epoch(episode) self.dataloader.sampler.set_epoch(episode)
for i, batch in enumerate(self.train_dataloader): for i, batch in enumerate(self.dataloader):
if i >= num_valid_microbatches: if i >= num_valid_microbatches:
break break
if self.eval_interval > 0 and self.eval_dataset_config is not None:
if i % self.eval_interval == 0:
eval_statistics = {}
for eval_task_name in self.eval_dataloaders:
print(
f"[P{self.producer_idx}] Evaluate episode {episode} step {i} on task {eval_task_name}"
)
eval_results = []
eval_statistics[eval_task_name] = torch.zeros(2, device=self.device)
for eval_batch in tqdm.tqdm(
self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0
):
eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params)
eval_results = eval_results + [
self.evaluation_function(
eval_outputs["input_ids"][m][n],
eval_outputs["gt_answer"][m][n],
eval_outputs["response_idx"][m][n],
tokenizer=self.tokenizer,
eval_mode=True,
)
for m in range(eval_outputs["input_ids"].size(0))
for n in range(eval_outputs["input_ids"].size(1))
]
eval_statistics[eval_task_name][0] += len(
[res for res in eval_results if res["ans_valid"] == 1]
)
eval_statistics[eval_task_name][1] += len(eval_results)
# save eval results
result_file_name = os.path.join(
self.eval_save_dir,
f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl",
)
# delete the file if it exists
safe_write_jsonl(result_file_name, eval_results)
print(f"[P{self.producer_idx}] Send eval statistics episode {episode} step {i}")
eval_statistics["consumer_global_step"] = torch.tensor(
[self.consumer_global_step], device=self.device
)
ray_broadcast_tensor_dict(
eval_statistics,
src=0,
device=self.device,
group_name=f"sync_data_{self.producer_idx}",
)
outputs = self.rollout(**batch) outputs = self.rollout(**batch)
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs["temperature"] = torch.tensor( outputs["temperature"] = torch.tensor(
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0) [
(
self.model.generate_config["temperature"]
if isinstance(self.model.generate_config.temperature, dict)
else self.model.generate_config.temperature
)
]
* outputs["input_ids"].size(0)
).to(outputs["input_ids"].device) ).to(outputs["input_ids"].device)
outputs = pre_send(outputs) outputs = pre_send(outputs)
ray_broadcast_tensor_dict( ray_broadcast_tensor_dict(
@ -238,8 +150,6 @@ class BaseProducer:
state_dict = ray_broadcast_tensor_dict( state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name="sync_model" None, self.num_producers, device=self.device, group_name="sync_model"
) )
if "consumer_global_step" in state_dict:
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict) self.load_state_dict(state_dict)
del state_dict del state_dict
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -249,12 +159,15 @@ class BaseProducer:
self.model.llm.wake_up() self.model.llm.wake_up()
# linear annealing for 1 episode, temperature from initial to 0.9 # linear annealing for 1 episode, temperature from initial to 0.9
if episode <= 0: if episode <= 0:
ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader) ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[ if isinstance(self.model.generate_config.temperature, dict):
"temperature" self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
] + ratio * 0.9 "temperature"
if hasattr(self.model, "sample_params"): ] + ratio * 0.9
self.model.sample_params.temperature = self.model.generate_config["temperature"] else:
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[
"temperature"
] + ratio * 0.9
@ray.remote @ray.remote
@ -266,7 +179,7 @@ class SimpleProducer(BaseProducer):
num_consumer_procs, num_consumer_procs,
num_episodes, num_episodes,
batch_size, batch_size,
train_dataset_config, dataset_config,
dataloaders_config, dataloaders_config,
model_config, model_config,
generate_config, generate_config,
@ -275,10 +188,6 @@ class SimpleProducer(BaseProducer):
backend="transformers", backend="transformers",
num_generations: int = 8, num_generations: int = 8,
consumer_plugin_config=None, consumer_plugin_config=None,
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
evaluation_function_type="think_answer_tags",
eval_save_dir: str = "./eval",
): ):
super().__init__( super().__init__(
producer_idx, producer_idx,
@ -286,7 +195,7 @@ class SimpleProducer(BaseProducer):
num_consumer_procs, num_consumer_procs,
num_episodes, num_episodes,
batch_size, batch_size,
train_dataset_config, dataset_config,
dataloaders_config, dataloaders_config,
model_config, model_config,
generate_config, generate_config,
@ -294,15 +203,8 @@ class SimpleProducer(BaseProducer):
microbatch_size, microbatch_size,
backend, backend,
consumer_plugin_config, consumer_plugin_config,
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval,
evaluation_function_type=evaluation_function_type,
eval_save_dir=eval_save_dir,
) )
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations) self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
@torch.no_grad() @torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs): def rollout(self, input_ids, attention_mask, **kwargs):

View File

@ -1,74 +1,10 @@
import torch import torch
from latex2sympy2_extended import NormalizationConfig
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
CANNOT_PARSE_GT_ANSWER = -1
CANNOT_PARSE_PREDICTION = -2
SUCCESS = 1
MATCHING_FAIL = 0
def verify_math_representation(completion, gt_answer):
"""
Verify if the completion is a valid math representation of the gt_answer.
"""
target = (
ExprExtractionConfig(),
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
boxed="all",
units=True,
),
boxed_match_priority=0,
),
)
if not isinstance(gt_answer, str) or len(gt_answer) == 0:
raise ValueError("gt_answer should be a string, please verify your training data.")
if not isinstance(completion, str) or len(completion) == 0:
return MATCHING_FAIL
try:
parsed_gt_answer = parse(gt_answer, extraction_config=target)
if len(parsed_gt_answer) == 0:
return CANNOT_PARSE_GT_ANSWER
parsed_completion = parse(completion, extraction_config=target)
if len(parsed_completion) == 0:
return CANNOT_PARSE_PREDICTION
if verify(parsed_gt_answer, parsed_completion):
return SUCCESS
else:
return MATCHING_FAIL
except Exception:
return MATCHING_FAIL
def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward):
math_verify_result = verify_math_representation(decoded_final_answer, gt_answer)
if math_verify_result == SUCCESS:
ans_acc += 1
reward += acc_score
elif math_verify_result == CANNOT_PARSE_GT_ANSWER or math_verify_result == CANNOT_PARSE_PREDICTION:
if decoded_final_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(
",", ""
) == gt_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", ""):
ans_acc += 1
if math_verify_result == CANNOT_PARSE_GT_ANSWER:
# plain text answer cannot be parsed, but is correct
reward += acc_score
else:
reward += (
acc_score / 2
) # not a valid latex math representation, but the answer is correct, receive half of the score
return reward, ans_acc
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
tokenizer = kwargs["tokenizer"] tokenizer = kwargs["tokenizer"]
eval_mode = kwargs.get("eval_mode", False)
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False) soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
acc_score = 10.0 acc_score = 10.0
reward = torch.tensor(0.0) reward = torch.tensor(0.0)
@ -98,28 +34,46 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
format_acc += 1 format_acc += 1
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
if format_valid and final_answer is not None: if (
reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward) format_valid
and final_answer is not None
and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower()
):
ans_acc += 1
reward += acc_score
reward = reward + length_reward reward = reward + length_reward
if not eval_mode: return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
def gsm8k_reward_fn(input_ids, **kwargs):
gt_answer = kwargs["gt_answer"]
tokenizer = kwargs["tokenizer"]
s, e = kwargs["response_start"], kwargs["response_end"]
reward = torch.tensor(0.0).to(input_ids.device)
if gt_answer is None:
return reward
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
final_answer, processed_str = extract_solution(decoded_final_answer)
is_valid = True
try:
int(final_answer.strip())
except Exception:
is_valid = False
format_valid = validate_response_structure(processed_str, kwargs["tags"])
if not is_valid or not format_valid:
return reward
else: else:
prompt = tokenizer.decode(input_ids[:s], skip_special_tokens=True) reward += 1.0
return { if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
"prompt": prompt, reward = reward + 9.0
"prediction": decoded_final_answer, return reward
"gold": gt_answer,
"parsed": final_answer,
"format_valid": format_acc.item(),
"ans_valid": ans_acc.item(),
}
def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
tokenizer = kwargs["tokenizer"] tokenizer = kwargs["tokenizer"]
eval_mode = kwargs.get("eval_mode", False)
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False) soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
format_score = 0.0 format_score = 0.0
acc_score = 10.0 acc_score = 10.0
@ -137,7 +91,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
if gt_answer is None: if gt_answer is None:
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) return reward
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True) gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
@ -149,19 +103,10 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
reward += format_score reward += format_score
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
if format_valid and final_answer is not None: if format_valid and final_answer is not None and gt_answer.strip().lower() == final_answer.strip().lower():
reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward) ans_acc += 1
reward += acc_score
reward = reward + length_reward reward = reward + length_reward
if not eval_mode:
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
else:
prompt = tokenizer.decode(input_ids[:s], skip_special_tokens=True)
return {
"prompt": prompt,
"prediction": decoded_final_answer,
"gold": gt_answer,
"parsed": final_answer,
"format_valid": format_acc.item(),
"ans_valid": ans_acc.item(),
}

View File

@ -1,10 +1,7 @@
import json
import os
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, List from typing import Any, Dict, List
import torch import torch
from filelock import FileLock
from colossalai.shardformer.layer.loss import dist_log_prob from colossalai.shardformer.layer.loss import dist_log_prob
@ -155,13 +152,3 @@ def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.
""" """
tensor = tensor * mask tensor = tensor * mask
return tensor.sum(dim=dim) return tensor.sum(dim=dim)
def safe_write_jsonl(file_path, data):
with FileLock(file_path + ".lock"):
# Ensure file exists
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "a", encoding="utf8") as f:
for entry in data:
json_line = json.dumps(entry, ensure_ascii=False)
f.write(json_line + "\n")

View File

@ -1,5 +1,4 @@
import argparse import argparse
import json
import os import os
import ray import ray
@ -9,16 +8,7 @@ from coati.distributed.launch import launch_distributed
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
parser.add_argument("-d", "--dataset", type=str, default="data_train.jsonl") parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
parser.add_argument(
"-ed",
"--eval-dataset",
type=str,
default='{"eval task name":"data_eval.jsonl"}',
help="Evaluation dataset for each task, please use json format to specify the dataset for each task. \
For example: {'task1':'data_eval_task1.jsonl', 'task2':'data_eval_task2.jsonl'}, the jsonl file should be in the same format as the training dataset. \
The key is the task name, and the value is the path to the jsonl file",
)
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.") parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.") parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.")
@ -104,14 +94,11 @@ if __name__ == "__main__":
choices=["think_answer_tags", "boxed"], choices=["think_answer_tags", "boxed"],
help="Reward type for GRPO.", help="Reward type for GRPO.",
) )
parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.")
# Logging/Checkpointing parameters # Logging/Checkpointing parameters
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.") parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
parser.add_argument("-sd", "--save-dir", type=str, default="./model", help="Directory for saving checkpoints.") parser.add_argument("-sd", "--save-dir", type=str, default="./model", help="Directory for saving checkpoints.")
parser.add_argument(
"-esd", "--eval-save-dir", type=str, default="./eval", help="Directory for saving evaluation results."
)
args = parser.parse_args() args = parser.parse_args()
if args.train_minibatch_size is None: if args.train_minibatch_size is None:
@ -221,7 +208,7 @@ if __name__ == "__main__":
inference_microbatch_size=args.inference_microbatch_size, inference_microbatch_size=args.inference_microbatch_size,
train_batch_size=args.train_batch_size, train_batch_size=args.train_batch_size,
train_minibatch_size=args.train_minibatch_size, train_minibatch_size=args.train_minibatch_size,
train_dataset_config={ dataset_config={
"path": args.dataset, "path": args.dataset,
"max_length": args.max_prompt_tokens, "max_length": args.max_prompt_tokens,
"system_prompt": args.system_prompt, "system_prompt": args.system_prompt,
@ -251,10 +238,4 @@ if __name__ == "__main__":
project_name=args.project, project_name=args.project,
save_interval=args.save_interval, save_interval=args.save_interval,
save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")), save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")),
eval_dataset_config={
k: {"path": v, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt}
for k, v in json.loads(args.eval_dataset).items()
},
eval_interval=args.eval_interval,
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
) )