mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[autoparallel] test compatibility for gemini and auto parallel (#2700)
This commit is contained in:
@@ -377,8 +377,9 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o
|
||||
# TODO: build a ColoParamter class to manager the distributed parameters
|
||||
# we could use .data here, because all the operations just happen before the real training
|
||||
# loop, so we don't need to track these operations in the autograd graph.
|
||||
param.data = shape_consistency_manager.apply_for_autoparallel_runtime(
|
||||
param.data, param.sharding_spec, target_sharding_spec).detach().clone()
|
||||
param = torch.nn.Parameter(
|
||||
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
|
||||
target_sharding_spec).detach().clone())
|
||||
|
||||
setattr(target_module, name, param)
|
||||
comm_actions = node.best_strategy.communication_actions
|
||||
@@ -432,8 +433,9 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o
|
||||
# TODO: build a ColoParamter class to manager the distributed parameters
|
||||
# we could use .data here, because all the operations just happen before the real training
|
||||
# loop, so we don't need to track these operations in the autograd graph.
|
||||
target.data = shape_consistency_manager.apply_for_autoparallel_runtime(
|
||||
target.data, target.sharding_spec, target_sharding_spec).detach().clone()
|
||||
target = torch.nn.Parameter(
|
||||
shape_consistency_manager.apply_for_autoparallel_runtime(target.data, target.sharding_spec,
|
||||
target_sharding_spec).detach().clone())
|
||||
|
||||
assert hasattr(target_module, atoms[-1])
|
||||
setattr(target_module, atoms[-1], target)
|
||||
|
Reference in New Issue
Block a user