[autoparallel] runtime_backward_apply (#1720)

This commit is contained in:
YuliangLiu0306
2022-10-18 10:44:58 +08:00
committed by GitHub
parent 393f594051
commit 51b89d2202
2 changed files with 56 additions and 27 deletions

View File

@@ -498,3 +498,11 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
for comm_spec in comm_action_sequence:
comm_spec.covert_spec_to_action(tensor_with_sharding_spec)
tensor_with_sharding_spec.sharding_spec = target_spec
return tensor_with_sharding_spec
def apply_for_autoparallel_runtime(self, tensor, source_spec, target_spec):
_, comm_action_sequence, _ = self.shape_consistency(source_spec, target_spec)
for comm_spec in comm_action_sequence:
comm_spec.covert_spec_to_action(tensor)
tensor.sharding_spec = target_spec
return tensor