Support evaluation during training

This commit is contained in:
YeAnbang
2025-04-30 18:13:40 +08:00
parent 5fd4bcb9d8
commit 57a88395fe
9 changed files with 234 additions and 65 deletions

View File

@@ -34,7 +34,7 @@ def launch_distributed(
inference_microbatch_size: int,
train_batch_size: int,
train_minibatch_size: int,
dataset_config: Dict[str, Any],
train_dataset_config: Dict[str, Any],
dataloaders_config: Dict[str, Any],
inference_model_config: Dict[str, Any],
generate_config: Dict[str, Any],
@@ -50,6 +50,9 @@ def launch_distributed(
project_name: Optional[str] = None,
save_interval: int = 100,
save_dir: str = "./model",
eval_dataset_config: Optional[Dict[str, Any]] = None,
eval_interval: int = 100,
eval_save_dir: Optional[str] = None,
):
if core_algo not in ALGO_MAP:
@@ -60,9 +63,9 @@ def launch_distributed(
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
dataset_path = dataset_config["path"]
dataset_path = train_dataset_config["path"]
num_samples = get_jsonl_size_fast(dataset_path)
global_inference_batch_size = inference_batch_size * num_producers
global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer
num_update_per_episode = num_samples // global_inference_batch_size
num_recv_per_update = inference_batch_size // inference_microbatch_size
@@ -74,7 +77,7 @@ def launch_distributed(
num_consumer_procs=num_consumer_procs,
num_episodes=num_episodes,
batch_size=inference_batch_size,
dataset_config=dataset_config,
train_dataset_config=train_dataset_config,
dataloaders_config=dataloaders_config,
model_config=inference_model_config,
generate_config=generate_config,
@@ -83,6 +86,10 @@ def launch_distributed(
backend=inference_backend,
num_generations=num_generations,
consumer_plugin_config=plugin_config,
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval,
evaluation_function_type=grpo_config["reward_fn_type"],
eval_save_dir=eval_save_dir,
)
procs.append(producer)
generate_config_consumer = copy.deepcopy(generate_config)
@@ -111,6 +118,7 @@ def launch_distributed(
project_name=project_name,
save_interval=save_interval,
save_dir=save_dir,
eval_interval=eval_interval,
)
procs.append(consumer)
ray.get([p.setup.remote() for p in procs])