[fix] revert reward update and evaluation (#6295)

* Revert "rewrite reward fn"

This reverts commit d06042b434.

* Revert "upgrade reward math verification"

This reverts commit a6085ff676.

* Revert "fix bug"

This reverts commit 01640ebd65.

* Revert "reuse comm-group"

This reverts commit bd61918dcf.

* Revert "Support evaluation during training"

This reverts commit 57a88395fe.
This commit is contained in:
YeAnbang
2025-05-07 10:56:47 +08:00
committed by GitHub
parent 17928ad84f
commit eb6b5dd62e
9 changed files with 82 additions and 307 deletions

View File

@@ -34,7 +34,7 @@ def launch_distributed(
inference_microbatch_size: int,
train_batch_size: int,
train_minibatch_size: int,
train_dataset_config: Dict[str, Any],
dataset_config: Dict[str, Any],
dataloaders_config: Dict[str, Any],
inference_model_config: Dict[str, Any],
generate_config: Dict[str, Any],
@@ -50,9 +50,6 @@ def launch_distributed(
project_name: Optional[str] = None,
save_interval: int = 100,
save_dir: str = "./model",
eval_dataset_config: Optional[Dict[str, Any]] = None,
eval_interval: int = 100,
eval_save_dir: Optional[str] = None,
):
if core_algo not in ALGO_MAP:
@@ -63,9 +60,9 @@ def launch_distributed(
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
dataset_path = train_dataset_config["path"]
dataset_path = dataset_config["path"]
num_samples = get_jsonl_size_fast(dataset_path)
global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer
global_inference_batch_size = inference_batch_size * num_producers
num_update_per_episode = num_samples // global_inference_batch_size
num_recv_per_update = inference_batch_size // inference_microbatch_size
@@ -77,7 +74,7 @@ def launch_distributed(
num_consumer_procs=num_consumer_procs,
num_episodes=num_episodes,
batch_size=inference_batch_size,
train_dataset_config=train_dataset_config,
dataset_config=dataset_config,
dataloaders_config=dataloaders_config,
model_config=inference_model_config,
generate_config=generate_config,
@@ -86,10 +83,6 @@ def launch_distributed(
backend=inference_backend,
num_generations=num_generations,
consumer_plugin_config=plugin_config,
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval * num_recv_per_update,
evaluation_function_type=grpo_config["reward_fn_type"],
eval_save_dir=eval_save_dir,
)
procs.append(producer)
generate_config_consumer = copy.deepcopy(generate_config)
@@ -118,7 +111,6 @@ def launch_distributed(
project_name=project_name,
save_interval=save_interval,
save_dir=save_dir,
eval_interval=eval_interval,
)
procs.append(consumer)
ray.get([p.setup.remote() for p in procs])