mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-11-02 23:14:35 +00:00
[fix] revert reward update and evaluation (#6295)
* Revert "rewrite reward fn" This reverts commitd06042b434. * Revert "upgrade reward math verification" This reverts commita6085ff676. * Revert "fix bug" This reverts commit01640ebd65. * Revert "reuse comm-group" This reverts commitbd61918dcf. * Revert "Support evaluation during training" This reverts commit57a88395fe.
This commit is contained in:
@@ -34,7 +34,7 @@ def launch_distributed(
|
||||
inference_microbatch_size: int,
|
||||
train_batch_size: int,
|
||||
train_minibatch_size: int,
|
||||
train_dataset_config: Dict[str, Any],
|
||||
dataset_config: Dict[str, Any],
|
||||
dataloaders_config: Dict[str, Any],
|
||||
inference_model_config: Dict[str, Any],
|
||||
generate_config: Dict[str, Any],
|
||||
@@ -50,9 +50,6 @@ def launch_distributed(
|
||||
project_name: Optional[str] = None,
|
||||
save_interval: int = 100,
|
||||
save_dir: str = "./model",
|
||||
eval_dataset_config: Optional[Dict[str, Any]] = None,
|
||||
eval_interval: int = 100,
|
||||
eval_save_dir: Optional[str] = None,
|
||||
):
|
||||
|
||||
if core_algo not in ALGO_MAP:
|
||||
@@ -63,9 +60,9 @@ def launch_distributed(
|
||||
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
|
||||
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
|
||||
|
||||
dataset_path = train_dataset_config["path"]
|
||||
dataset_path = dataset_config["path"]
|
||||
num_samples = get_jsonl_size_fast(dataset_path)
|
||||
global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer
|
||||
global_inference_batch_size = inference_batch_size * num_producers
|
||||
num_update_per_episode = num_samples // global_inference_batch_size
|
||||
num_recv_per_update = inference_batch_size // inference_microbatch_size
|
||||
|
||||
@@ -77,7 +74,7 @@ def launch_distributed(
|
||||
num_consumer_procs=num_consumer_procs,
|
||||
num_episodes=num_episodes,
|
||||
batch_size=inference_batch_size,
|
||||
train_dataset_config=train_dataset_config,
|
||||
dataset_config=dataset_config,
|
||||
dataloaders_config=dataloaders_config,
|
||||
model_config=inference_model_config,
|
||||
generate_config=generate_config,
|
||||
@@ -86,10 +83,6 @@ def launch_distributed(
|
||||
backend=inference_backend,
|
||||
num_generations=num_generations,
|
||||
consumer_plugin_config=plugin_config,
|
||||
eval_dataset_config=eval_dataset_config,
|
||||
eval_interval=eval_interval * num_recv_per_update,
|
||||
evaluation_function_type=grpo_config["reward_fn_type"],
|
||||
eval_save_dir=eval_save_dir,
|
||||
)
|
||||
procs.append(producer)
|
||||
generate_config_consumer = copy.deepcopy(generate_config)
|
||||
@@ -118,7 +111,6 @@ def launch_distributed(
|
||||
project_name=project_name,
|
||||
save_interval=save_interval,
|
||||
save_dir=save_dir,
|
||||
eval_interval=eval_interval,
|
||||
)
|
||||
procs.append(consumer)
|
||||
ray.get([p.setup.remote() for p in procs])
|
||||
|
||||
Reference in New Issue
Block a user