mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[fx] tested the complete workflow for auto-parallel (#1336)
* [fx] tested the complete workflow for auto-parallel * polish code * polish code * polish code
This commit is contained in:
@@ -175,7 +175,7 @@ class LazyInitContext():
|
||||
self._unpatch_nn_init_funcs()
|
||||
self._unpatch_torch_tensor_funcs()
|
||||
|
||||
def lazy_init_parameters(self, model: torch.nn.Module, device='cpu', call_back: Callable = None):
|
||||
def lazy_init_parameters(self, model: torch.nn.Module, device='cpu'):
|
||||
"""
|
||||
Initialize the weights of the meta-tensor model.
|
||||
|
||||
@@ -205,6 +205,7 @@ class LazyInitContext():
|
||||
# get sharding spec
|
||||
dist_spec = getattr(tensor, 'dist_spec', None)
|
||||
pg = getattr(tensor, 'pg', None)
|
||||
comp_spec = getattr(tensor, 'comp_spec', None)
|
||||
|
||||
# convert the tensor from meta to materialized one
|
||||
if tensor.is_meta:
|
||||
@@ -224,14 +225,15 @@ class LazyInitContext():
|
||||
else:
|
||||
tensor = ColoTensor.from_torch_tensor(tensor)
|
||||
|
||||
# apply sharding
|
||||
if dist_spec:
|
||||
tensor = tensor.redistribute(dist_spec=dist_spec, pg=pg)
|
||||
|
||||
# override the original tensor
|
||||
with torch.no_grad():
|
||||
setattr(module, name, tensor)
|
||||
|
||||
# apply sharding
|
||||
if dist_spec:
|
||||
tensor.process_group = pg
|
||||
tensor.set_tensor_spec(dist_spec, comp_spec)
|
||||
|
||||
_init_recursively(model)
|
||||
|
||||
return model
|
||||
|
Reference in New Issue
Block a user