mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[autoparallel] runtime_backward_apply (#1720)
This commit is contained in:
@@ -498,3 +498,11 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
for comm_spec in comm_action_sequence:
|
||||
comm_spec.covert_spec_to_action(tensor_with_sharding_spec)
|
||||
tensor_with_sharding_spec.sharding_spec = target_spec
|
||||
return tensor_with_sharding_spec
|
||||
|
||||
def apply_for_autoparallel_runtime(self, tensor, source_spec, target_spec):
|
||||
_, comm_action_sequence, _ = self.shape_consistency(source_spec, target_spec)
|
||||
for comm_spec in comm_action_sequence:
|
||||
comm_spec.covert_spec_to_action(tensor)
|
||||
tensor.sharding_spec = target_spec
|
||||
return tensor
|
Reference in New Issue
Block a user