[fx] tested the complete workflow for auto-parallel (#1336)

* [fx] tested the complete workflow for auto-parallel

* polish code

* polish code

* polish code
This commit is contained in:
Frank Lee
2022-07-20 10:45:17 +08:00
committed by GitHub
parent 4631fef8a0
commit 2cc1175c76
4 changed files with 187 additions and 106 deletions

View File

@@ -175,7 +175,7 @@ class LazyInitContext():
self._unpatch_nn_init_funcs()
self._unpatch_torch_tensor_funcs()
def lazy_init_parameters(self, model: torch.nn.Module, device='cpu', call_back: Callable = None):
def lazy_init_parameters(self, model: torch.nn.Module, device='cpu'):
"""
Initialize the weights of the meta-tensor model.
@@ -205,6 +205,7 @@ class LazyInitContext():
# get sharding spec
dist_spec = getattr(tensor, 'dist_spec', None)
pg = getattr(tensor, 'pg', None)
comp_spec = getattr(tensor, 'comp_spec', None)
# convert the tensor from meta to materialized one
if tensor.is_meta:
@@ -224,14 +225,15 @@ class LazyInitContext():
else:
tensor = ColoTensor.from_torch_tensor(tensor)
# apply sharding
if dist_spec:
tensor = tensor.redistribute(dist_spec=dist_spec, pg=pg)
# override the original tensor
with torch.no_grad():
setattr(module, name, tensor)
# apply sharding
if dist_spec:
tensor.process_group = pg
tensor.set_tensor_spec(dist_spec, comp_spec)
_init_recursively(model)
return model