mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[autoparallel] shard param and buffer as expected (#1753)
* [autoparallel] shard param and buffer as expected * fix unit test issue
This commit is contained in:
@@ -511,13 +511,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
'''
|
||||
_, comm_action_sequence, _ = self.shape_consistency(tensor_with_sharding_spec.sharding_spec, target_spec)
|
||||
for comm_spec in comm_action_sequence:
|
||||
comm_spec.covert_spec_to_action(tensor_with_sharding_spec)
|
||||
tensor_with_sharding_spec = comm_spec.covert_spec_to_action(tensor_with_sharding_spec)
|
||||
tensor_with_sharding_spec.sharding_spec = target_spec
|
||||
return tensor_with_sharding_spec
|
||||
|
||||
def apply_for_autoparallel_runtime(self, tensor, source_spec, target_spec):
|
||||
_, comm_action_sequence, _ = self.shape_consistency(source_spec, target_spec)
|
||||
for comm_spec in comm_action_sequence:
|
||||
comm_spec.covert_spec_to_action(tensor)
|
||||
tensor = comm_spec.covert_spec_to_action(tensor)
|
||||
tensor.sharding_spec = target_spec
|
||||
return tensor
|
||||
|
Reference in New Issue
Block a user