mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[autoparallel] integrate_gpt_related_tests (#2134)
* [autoparallel] integrate_gpt_related_tests * polish code * polish code * add GPT2Model into runtime test
This commit is contained in:
@@ -230,7 +230,12 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||
new_slice_items = []
|
||||
|
||||
for slice_item in getitem_index:
|
||||
if slice_item is None:
|
||||
new_slice_items.append(None)
|
||||
continue
|
||||
|
||||
new_start, new_stop, new_step = slice_item.start, slice_item.stop, slice_item.step
|
||||
|
||||
if slice_item.start in node_pairs:
|
||||
new_start = node_pairs[slice_item.start]
|
||||
elif slice_item.stop in node_pairs:
|
||||
@@ -355,7 +360,10 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||
for node in nodes:
|
||||
if node.op == 'call_module':
|
||||
target_module = node.graph.owning_module.get_submodule(node.target)
|
||||
|
||||
# TODO: we need to do more actions to take care of the shared parameters.
|
||||
if hasattr(target_module, 'processed') and target_module.processed:
|
||||
continue
|
||||
setattr(target_module, 'processed', True)
|
||||
for name, param in target_module.named_parameters():
|
||||
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
|
||||
# apply the sharding spec of parameters
|
||||
@@ -404,7 +412,9 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||
target_module = root
|
||||
target = getattr(root, atoms[0])
|
||||
else:
|
||||
target_module = root.get_submodule(atoms[-2])
|
||||
target_module = root
|
||||
for atom in atoms[:-1]:
|
||||
target_module = getattr(target_module, atom)
|
||||
target = getattr(target_module, atoms[-1])
|
||||
|
||||
target_sharding_spec = node.sharding_spec
|
||||
|
Reference in New Issue
Block a user