mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[shardformer] add util functions for shardformer tests/fix sync_shared_param (#4366)
* add util functions for shardformer tests & rewrite gpt2 test * fix shared_params & embedding/merging * fix precision
This commit is contained in:
committed by
Hongxin Liu
parent
5c6f183192
commit
b1feeced8e
@@ -37,7 +37,8 @@ class HybridParallelModule(ModelWrapper):
|
||||
self.shared_param_process_groups = []
|
||||
for shared_param in self.shared_params:
|
||||
if len(shared_param) > 0:
|
||||
self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))
|
||||
self.shared_param_process_groups.append(
|
||||
self.stage_manager.init_process_group_by_stages(list(shared_param.keys())))
|
||||
if precision == 'fp16':
|
||||
module = module.half().cuda()
|
||||
elif precision == 'bf16':
|
||||
|
Reference in New Issue
Block a user