[autoparallel] test compatibility for gemini and auto parallel (#2700)

This commit is contained in:
YuliangLiu0306
2023-02-15 09:43:29 +08:00
committed by GitHub
parent d701ef81b1
commit 7fa6be49d2
3 changed files with 212 additions and 4 deletions

View File

@@ -377,8 +377,9 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o
# TODO: build a ColoParamter class to manager the distributed parameters
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
param.data = shape_consistency_manager.apply_for_autoparallel_runtime(
param.data, param.sharding_spec, target_sharding_spec).detach().clone()
param = torch.nn.Parameter(
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
target_sharding_spec).detach().clone())
setattr(target_module, name, param)
comm_actions = node.best_strategy.communication_actions
@@ -432,8 +433,9 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o
# TODO: build a ColoParamter class to manager the distributed parameters
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
target.data = shape_consistency_manager.apply_for_autoparallel_runtime(
target.data, target.sharding_spec, target_sharding_spec).detach().clone()
target = torch.nn.Parameter(
shape_consistency_manager.apply_for_autoparallel_runtime(target.data, target.sharding_spec,
target_sharding_spec).detach().clone())
assert hasattr(target_module, atoms[-1])
setattr(target_module, atoms[-1], target)