mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 20:54:55 +00:00
[utils] fixed lazy init context (#1867)
This commit is contained in:
parent
50c4cb0167
commit
e6ec99d389
@ -1,23 +1,24 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import types
|
||||||
|
from typing import Callable, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from colossalai.tensor import ColoParameter, ColoTensor
|
|
||||||
|
|
||||||
import types
|
from colossalai.tensor import ColoParameter, ColoTensor
|
||||||
import inspect
|
|
||||||
from typing import List, Callable
|
|
||||||
from colossalai.utils.model.utils import substitute_init_recursively
|
from colossalai.utils.model.utils import substitute_init_recursively
|
||||||
|
|
||||||
|
|
||||||
class LazyInitContext():
|
class LazyInitContext():
|
||||||
"""
|
"""
|
||||||
A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor
|
A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor
|
||||||
initialization functions for lazy initialization
|
initialization functions for lazy initialization
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
This API is only experimental and subject to future changes.
|
This API is only experimental and subject to future changes.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
with LazyInitContext() as ctx:
|
with LazyInitContext() as ctx:
|
||||||
@ -30,19 +31,20 @@ class LazyInitContext():
|
|||||||
# initialize weights
|
# initialize weights
|
||||||
ctx.lazy_init_parameters(model)
|
ctx.lazy_init_parameters(model)
|
||||||
|
|
||||||
# make sure the weight is not a meta tensor
|
# make sure the weight is not a meta tensor
|
||||||
# and initialized correctly
|
# and initialized correctly
|
||||||
assert not model.weight.is_meta and torch.all(model.weight == 0)
|
assert not model.weight.is_meta and torch.all(model.weight == 0)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
to_meta (bool): optional, whether to initialize the model with meta tensors, default is False.
|
to_meta (bool): optional, whether to initialize the model with meta tensors, default is True. This
|
||||||
extra_torch_tensor_func (List[str]): extra torch tensor functions related
|
argument exists for now because some corner cases such as self.weight = torch.zeros(...) cannot be captured yet.
|
||||||
|
extra_torch_tensor_func (List[str]): extra torch tensor functions related
|
||||||
to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default.
|
to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tensor_set_value_func = ['zero_', 'fill_']
|
tensor_set_value_func = ['zero_', 'fill_']
|
||||||
|
|
||||||
def __init__(self, to_meta: bool = False, extra_torch_tensor_func: List[str] = None):
|
def __init__(self, to_meta: bool = True, extra_torch_tensor_func: List[str] = None):
|
||||||
# TODO: hijack the torch constructor functions as well
|
# TODO: hijack the torch constructor functions as well
|
||||||
self._to_meta = to_meta
|
self._to_meta = to_meta
|
||||||
self._intercepted_nn_init_func_cache = {}
|
self._intercepted_nn_init_func_cache = {}
|
||||||
@ -212,18 +214,19 @@ class LazyInitContext():
|
|||||||
materialized_tensor = torch.empty_like(tensor, device=device)
|
materialized_tensor = torch.empty_like(tensor, device=device)
|
||||||
# if this tensor is a meta tensor, it must have an init function
|
# if this tensor is a meta tensor, it must have an init function
|
||||||
assert tensor in self._intercepted_nn_init_func_cache
|
assert tensor in self._intercepted_nn_init_func_cache
|
||||||
tensor = materialized_tensor
|
else:
|
||||||
|
materialized_tensor = tensor
|
||||||
|
|
||||||
# apply init function
|
# apply init function
|
||||||
if tensor in self._intercepted_nn_init_func_cache:
|
if tensor in self._intercepted_nn_init_func_cache:
|
||||||
init_func, args, kwargs = self._intercepted_nn_init_func_cache[tensor][-1]
|
init_func, args, kwargs = self._intercepted_nn_init_func_cache[tensor][-1]
|
||||||
init_func(tensor, *args, **kwargs)
|
init_func(materialized_tensor, *args, **kwargs)
|
||||||
|
|
||||||
# convert it to ColoTensor or ColoParameter
|
# convert it to ColoTensor or ColoParameter
|
||||||
if is_param:
|
if is_param:
|
||||||
tensor = ColoParameter.from_torch_tensor(tensor, requires_grad=tensor.requires_grad)
|
tensor = ColoParameter.from_torch_tensor(materialized_tensor, requires_grad=tensor.requires_grad)
|
||||||
else:
|
else:
|
||||||
tensor = ColoTensor.from_torch_tensor(tensor)
|
tensor = ColoTensor.from_torch_tensor(materialized_tensor)
|
||||||
|
|
||||||
# override the original tensor
|
# override the original tensor
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -1,16 +1,18 @@
|
|||||||
import colossalai
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import pytest
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
import torch.distributed as dist
|
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import colossalai
|
||||||
from colossalai.fx import ColoTracer
|
from colossalai.fx import ColoTracer
|
||||||
from colossalai.utils.model.lazy_init_context import LazyInitContext
|
|
||||||
from colossalai.fx.passes.shard_1d_pass import transformer_mlp_pass
|
from colossalai.fx.passes.shard_1d_pass import transformer_mlp_pass
|
||||||
from colossalai.utils import free_port
|
|
||||||
from colossalai.tensor import ProcessGroup
|
from colossalai.tensor import ProcessGroup
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from colossalai.utils.model.lazy_init_context import LazyInitContext
|
||||||
|
|
||||||
|
|
||||||
class MLP(torch.nn.Module):
|
class MLP(torch.nn.Module):
|
||||||
@ -35,6 +37,9 @@ def run_workflow(world_size):
|
|||||||
with LazyInitContext() as ctx:
|
with LazyInitContext() as ctx:
|
||||||
model = MLP(16)
|
model = MLP(16)
|
||||||
|
|
||||||
|
for param in model.parameters():
|
||||||
|
assert param.is_meta
|
||||||
|
|
||||||
# tracing
|
# tracing
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
graph = tracer.trace(model)
|
graph = tracer.trace(model)
|
||||||
@ -46,6 +51,8 @@ def run_workflow(world_size):
|
|||||||
|
|
||||||
# materialization and sharding
|
# materialization and sharding
|
||||||
ctx.lazy_init_parameters(annotated_gm)
|
ctx.lazy_init_parameters(annotated_gm)
|
||||||
|
for param in model.parameters():
|
||||||
|
assert not param.is_meta
|
||||||
|
|
||||||
# # check sharding
|
# # check sharding
|
||||||
assert list(model.linear1.weight.shape) == [16 // world_size, 16]
|
assert list(model.linear1.weight.shape) == [16 // world_size, 16]
|
||||||
@ -57,7 +64,7 @@ def run_workflow(world_size):
|
|||||||
data = torch.rand(4, 16)
|
data = torch.rand(4, 16)
|
||||||
non_fx_out = model(data)
|
non_fx_out = model(data)
|
||||||
fx_out = annotated_gm(data)
|
fx_out = annotated_gm(data)
|
||||||
assert torch.equal(non_fx_out, fx_out)
|
assert torch.equal(non_fx_out, fx_out), f'{non_fx_out} vs {fx_out}'
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
@ -74,4 +81,4 @@ def test_complete_workflow(world_size):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_complete_workflow(2)
|
test_complete_workflow(1)
|
||||||
|
Loading…
Reference in New Issue
Block a user