mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-20 09:01:06 +00:00
[feat] Sync shard model (#6289)
* [feat] support hybrid parallel model sync * update consumer and producer * update files * update producer * remove print * update --------- Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
@@ -57,7 +57,7 @@ def launch_distributed(
|
||||
else:
|
||||
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
|
||||
|
||||
train_dp_size = get_dp_size_fast(num_producers, plugin_config)
|
||||
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
|
||||
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
|
||||
|
||||
dataset_path = dataset_config["path"]
|
||||
@@ -82,6 +82,7 @@ def launch_distributed(
|
||||
microbatch_size=inference_microbatch_size,
|
||||
backend=inference_backend,
|
||||
num_generations=num_generations,
|
||||
consumer_plugin_config=plugin_config,
|
||||
)
|
||||
procs.append(producer)
|
||||
generate_config_consumer = copy.deepcopy(generate_config)
|
||||
|
Reference in New Issue
Block a user