[autoparallel] add numerical test for node strategies (#1760)

* [autoparallel] add numerical test for node strategies

* polish code

* polish code
This commit is contained in:
YuliangLiu0306
2022-10-27 10:42:54 +08:00
committed by GitHub
parent 25952b67d7
commit b4cc59b61e
10 changed files with 283 additions and 60 deletions

View File

@@ -77,6 +77,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
if target_sharding_spec.dim_partition_dict != {}:
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
setattr(param, 'sharding_spec', origin_sharding_spec)
# TODO: build a ColoParamter class to manager the distributed parameters
param_sharded = torch.nn.Parameter(
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
target_sharding_spec).detach().clone())