mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 01:24:04 +00:00
Support evaluation during training
This commit is contained in:
@@ -34,7 +34,7 @@ def launch_distributed(
|
||||
inference_microbatch_size: int,
|
||||
train_batch_size: int,
|
||||
train_minibatch_size: int,
|
||||
dataset_config: Dict[str, Any],
|
||||
train_dataset_config: Dict[str, Any],
|
||||
dataloaders_config: Dict[str, Any],
|
||||
inference_model_config: Dict[str, Any],
|
||||
generate_config: Dict[str, Any],
|
||||
@@ -50,6 +50,9 @@ def launch_distributed(
|
||||
project_name: Optional[str] = None,
|
||||
save_interval: int = 100,
|
||||
save_dir: str = "./model",
|
||||
eval_dataset_config: Optional[Dict[str, Any]] = None,
|
||||
eval_interval: int = 100,
|
||||
eval_save_dir: Optional[str] = None,
|
||||
):
|
||||
|
||||
if core_algo not in ALGO_MAP:
|
||||
@@ -60,9 +63,9 @@ def launch_distributed(
|
||||
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
|
||||
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
|
||||
|
||||
dataset_path = dataset_config["path"]
|
||||
dataset_path = train_dataset_config["path"]
|
||||
num_samples = get_jsonl_size_fast(dataset_path)
|
||||
global_inference_batch_size = inference_batch_size * num_producers
|
||||
global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer
|
||||
num_update_per_episode = num_samples // global_inference_batch_size
|
||||
num_recv_per_update = inference_batch_size // inference_microbatch_size
|
||||
|
||||
@@ -74,7 +77,7 @@ def launch_distributed(
|
||||
num_consumer_procs=num_consumer_procs,
|
||||
num_episodes=num_episodes,
|
||||
batch_size=inference_batch_size,
|
||||
dataset_config=dataset_config,
|
||||
train_dataset_config=train_dataset_config,
|
||||
dataloaders_config=dataloaders_config,
|
||||
model_config=inference_model_config,
|
||||
generate_config=generate_config,
|
||||
@@ -83,6 +86,10 @@ def launch_distributed(
|
||||
backend=inference_backend,
|
||||
num_generations=num_generations,
|
||||
consumer_plugin_config=plugin_config,
|
||||
eval_dataset_config=eval_dataset_config,
|
||||
eval_interval=eval_interval,
|
||||
evaluation_function_type=grpo_config["reward_fn_type"],
|
||||
eval_save_dir=eval_save_dir,
|
||||
)
|
||||
procs.append(producer)
|
||||
generate_config_consumer = copy.deepcopy(generate_config)
|
||||
@@ -111,6 +118,7 @@ def launch_distributed(
|
||||
project_name=project_name,
|
||||
save_interval=save_interval,
|
||||
save_dir=save_dir,
|
||||
eval_interval=eval_interval,
|
||||
)
|
||||
procs.append(consumer)
|
||||
ray.get([p.setup.remote() for p in procs])
|
||||
|
Reference in New Issue
Block a user