mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[autoparallel] add split handler (#2032)
* [autoparallel] add split handler * add numerical test and runtime passes
This commit is contained in:
@@ -100,8 +100,24 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||
# skip the placeholder node added in _solution_annotation pass
|
||||
if not hasattr(node, 'sharding_spec'):
|
||||
continue
|
||||
output_dim_partition_dict = node.sharding_spec.dim_partition_dict
|
||||
device_mesh = node.sharding_spec.device_mesh
|
||||
|
||||
def _process_sharding_spec(sharding_spec):
|
||||
if isinstance(sharding_spec, ShardingSpec):
|
||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||
device_mesh = sharding_spec.device_mesh
|
||||
return dim_partition_dict, device_mesh
|
||||
if sharding_spec is None:
|
||||
return None, None
|
||||
assert isinstance(sharding_spec,
|
||||
(tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None'
|
||||
|
||||
device_mesh = sharding_spec[0].device_mesh
|
||||
dim_partition_dict = []
|
||||
for element in sharding_spec:
|
||||
dim_partition_dict.append(_process_sharding_spec(element))
|
||||
return dim_partition_dict, sharding_spec
|
||||
|
||||
output_dim_partition_dict, device_mesh = _process_sharding_spec(node.sharding_spec)
|
||||
new_args = []
|
||||
|
||||
if node.op == 'call_method':
|
||||
|
Reference in New Issue
Block a user