[autoparallel] add split handler (#2032)

* [autoparallel] add split handler

* add numerical test and runtime passes
This commit is contained in:
YuliangLiu0306
2022-11-29 11:03:51 +08:00
committed by GitHub
parent 28aa9a4294
commit 0dbcd4a6f5
9 changed files with 500 additions and 22 deletions

View File

@@ -100,8 +100,24 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# skip the placeholder node added in _solution_annotation pass
if not hasattr(node, 'sharding_spec'):
continue
output_dim_partition_dict = node.sharding_spec.dim_partition_dict
device_mesh = node.sharding_spec.device_mesh
def _process_sharding_spec(sharding_spec):
if isinstance(sharding_spec, ShardingSpec):
dim_partition_dict = sharding_spec.dim_partition_dict
device_mesh = sharding_spec.device_mesh
return dim_partition_dict, device_mesh
if sharding_spec is None:
return None, None
assert isinstance(sharding_spec,
(tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None'
device_mesh = sharding_spec[0].device_mesh
dim_partition_dict = []
for element in sharding_spec:
dim_partition_dict.append(_process_sharding_spec(element))
return dim_partition_dict, sharding_spec
output_dim_partition_dict, device_mesh = _process_sharding_spec(node.sharding_spec)
new_args = []
if node.op == 'call_method':