[autoparallel] integrate_gpt_related_tests (#2134)

* [autoparallel] integrate_gpt_related_tests

* polish code

* polish code

* add GPT2Model into runtime test
This commit is contained in:
YuliangLiu0306
2022-12-23 12:36:59 +08:00
committed by GitHub
parent 59e343328d
commit 550f8f8905
5 changed files with 217 additions and 207 deletions

View File

@@ -230,7 +230,12 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
new_slice_items = []
for slice_item in getitem_index:
if slice_item is None:
new_slice_items.append(None)
continue
new_start, new_stop, new_step = slice_item.start, slice_item.stop, slice_item.step
if slice_item.start in node_pairs:
new_start = node_pairs[slice_item.start]
elif slice_item.stop in node_pairs:
@@ -355,7 +360,10 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
for node in nodes:
if node.op == 'call_module':
target_module = node.graph.owning_module.get_submodule(node.target)
# TODO: we need to do more actions to take care of the shared parameters.
if hasattr(target_module, 'processed') and target_module.processed:
continue
setattr(target_module, 'processed', True)
for name, param in target_module.named_parameters():
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
# apply the sharding spec of parameters
@@ -404,7 +412,9 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
target_module = root
target = getattr(root, atoms[0])
else:
target_module = root.get_submodule(atoms[-2])
target_module = root
for atom in atoms[:-1]:
target_module = getattr(target_module, atom)
target = getattr(target_module, atoms[-1])
target_sharding_spec = node.sharding_spec