[feat] Sync shard model (#6289)

* [feat] support hybrid parallel model sync

* update consumer and producer

* update files

* update producer

* remove print

* update

---------

Co-authored-by: duanjunwen <935724073@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
Tong Li
2025-04-30 14:47:01 +08:00
committed by GitHub
parent 14f237ce7e
commit 5fd4bcb9d8
4 changed files with 66 additions and 20 deletions

View File

@@ -57,7 +57,7 @@ def launch_distributed(
else:
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
train_dp_size = get_dp_size_fast(num_producers, plugin_config)
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
dataset_path = dataset_config["path"]
@@ -82,6 +82,7 @@ def launch_distributed(
microbatch_size=inference_microbatch_size,
backend=inference_backend,
num_generations=num_generations,
consumer_plugin_config=plugin_config,
)
procs.append(producer)
generate_config_consumer = copy.deepcopy(generate_config)