[autoparallel] shard param and buffer as expected (#1753)

* [autoparallel] shard param and buffer as expected

* fix unit test issue
This commit is contained in:
YuliangLiu0306
2022-10-21 15:45:13 +08:00
committed by GitHub
parent cdb7d5e7d2
commit 980ed21723
6 changed files with 129 additions and 106 deletions

View File

@@ -511,13 +511,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
'''
_, comm_action_sequence, _ = self.shape_consistency(tensor_with_sharding_spec.sharding_spec, target_spec)
for comm_spec in comm_action_sequence:
comm_spec.covert_spec_to_action(tensor_with_sharding_spec)
tensor_with_sharding_spec = comm_spec.covert_spec_to_action(tensor_with_sharding_spec)
tensor_with_sharding_spec.sharding_spec = target_spec
return tensor_with_sharding_spec
def apply_for_autoparallel_runtime(self, tensor, source_spec, target_spec):
_, comm_action_sequence, _ = self.shape_consistency(source_spec, target_spec)
for comm_spec in comm_action_sequence:
comm_spec.covert_spec_to_action(tensor)
tensor = comm_spec.covert_spec_to_action(tensor)
tensor.sharding_spec = target_spec
return tensor