[shardformer] add util functions for shardformer tests/fix sync_shared_param (#4366)

* add util functions for shardformer tests & rewrite gpt2 test

* fix shared_params & embedding/merging

* fix precision
This commit is contained in:
Baizhou Zhang
2023-08-03 17:50:15 +08:00
committed by Hongxin Liu
parent 5c6f183192
commit b1feeced8e
4 changed files with 189 additions and 113 deletions

View File

@@ -37,7 +37,8 @@ class HybridParallelModule(ModelWrapper):
self.shared_param_process_groups = []
for shared_param in self.shared_params:
if len(shared_param) > 0:
self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))
self.shared_param_process_groups.append(
self.stage_manager.init_process_group_by_stages(list(shared_param.keys())))
if precision == 'fp16':
module = module.half().cuda()
elif precision == 'bf16':