mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-04 11:06:25 +00:00
[hotfix] remove potiential circle import (#1307)
* make it faster * [hotfix] remove circle import
This commit is contained in:
parent
6f2f9eb214
commit
4165eabb1e
@ -7,7 +7,7 @@ from torch.nn.modules.loss import _Loss
|
|||||||
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
|
from colossalai.gemini.ophooks import register_ophooks_recursively, BaseOpHook
|
||||||
from colossalai.engine.schedule import BaseSchedule, NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule
|
from colossalai.engine.schedule import BaseSchedule, NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule
|
||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
from colossalai.engine.gradient_handler import BaseGradientHandler
|
from colossalai.engine.gradient_handler import BaseGradientHandler
|
||||||
|
@ -12,7 +12,6 @@ from colossalai.core import global_context as gpc
|
|||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils import switch_virtual_pipeline_parallel_rank
|
from colossalai.utils import switch_virtual_pipeline_parallel_rank
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
|
||||||
|
|
||||||
from ._base_schedule import BaseSchedule
|
from ._base_schedule import BaseSchedule
|
||||||
|
|
||||||
@ -157,6 +156,7 @@ class PipelineSchedule(BaseSchedule):
|
|||||||
return self._move_to_device(mciro_batch_data)
|
return self._move_to_device(mciro_batch_data)
|
||||||
|
|
||||||
def pre_processing(self, engine):
|
def pre_processing(self, engine):
|
||||||
|
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||||
# TODO: remove this after testing new zero with pipeline parallelism
|
# TODO: remove this after testing new zero with pipeline parallelism
|
||||||
model = engine.model
|
model = engine.model
|
||||||
if isinstance(model, NaiveAMPModel):
|
if isinstance(model, NaiveAMPModel):
|
||||||
@ -482,6 +482,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||||||
self.num_model_chunks = num_model_chunks
|
self.num_model_chunks = num_model_chunks
|
||||||
|
|
||||||
def pre_processing(self, engine):
|
def pre_processing(self, engine):
|
||||||
|
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||||
if isinstance(engine.model, ShardedModelV2):
|
if isinstance(engine.model, ShardedModelV2):
|
||||||
self.dtype = torch.half
|
self.dtype = torch.half
|
||||||
elif isinstance(engine.model[0], NaiveAMPModel):
|
elif isinstance(engine.model[0], NaiveAMPModel):
|
||||||
|
@ -3,7 +3,7 @@ import pickle
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
import torch
|
import torch
|
||||||
from colossalai.engine.ophooks import BaseOpHook
|
from colossalai.gemini.ophooks import BaseOpHook
|
||||||
from colossalai.registry import OPHOOKS
|
from colossalai.registry import OPHOOKS
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
@ -22,7 +22,7 @@ from colossalai.logging import get_dist_logger
|
|||||||
|
|
||||||
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
|
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
|
||||||
from colossalai.engine import Engine
|
from colossalai.engine import Engine
|
||||||
from colossalai.engine.ophooks import BaseOpHook
|
from colossalai.gemini.ophooks import BaseOpHook
|
||||||
|
|
||||||
from colossalai.utils import (get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param)
|
from colossalai.utils import (get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param)
|
||||||
from colossalai.utils.moe import sync_moe_model_param
|
from colossalai.utils.moe import sync_moe_model_param
|
||||||
|
@ -2,7 +2,7 @@ from pathlib import Path
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
from colossalai.engine import Engine
|
from colossalai.engine import Engine
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from colossalai.engine.ophooks import MemTracerOpHook
|
from colossalai.gemini.ophooks import MemTracerOpHook
|
||||||
from colossalai.utils.profiler.legacy.prof_utils import BaseProfiler
|
from colossalai.utils.profiler.legacy.prof_utils import BaseProfiler
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ import torch
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List
|
from typing import List
|
||||||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||||
from colossalai.engine.ophooks import BaseOpHook
|
from colossalai.gemini.ophooks import BaseOpHook
|
||||||
from colossalai.engine import Engine
|
from colossalai.engine import Engine
|
||||||
from colossalai.utils.profiler.extention import ProfilerExtension
|
from colossalai.utils.profiler.extention import ProfilerExtension
|
||||||
|
|
||||||
|
@ -8,9 +8,9 @@ import torch.distributed as dist
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine.ophooks import register_ophooks_recursively
|
from colossalai.gemini.ophooks import register_ophooks_recursively
|
||||||
from colossalai.zero.utils import ZeroHook
|
from colossalai.zero.utils import ZeroHook
|
||||||
from colossalai.engine.paramhooks import BaseParamHookMgr
|
from colossalai.gemini.paramhooks import BaseParamHookMgr
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils import get_current_device, disposable
|
from colossalai.utils import get_current_device, disposable
|
||||||
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollector
|
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollector
|
||||||
|
@ -8,7 +8,7 @@ from colossalai.registry import OPHOOKS
|
|||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.engine.ophooks import BaseOpHook
|
from colossalai.gemini.ophooks import BaseOpHook
|
||||||
|
|
||||||
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
|
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
|
||||||
from colossalai.gemini.memory_tracer import MemStatsCollector
|
from colossalai.gemini.memory_tracer import MemStatsCollector
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
colossalai.engine.ophooks
|
colossalai.gemini.ophooks
|
||||||
=========================
|
=========================
|
||||||
|
|
||||||
.. automodule:: colossalai.engine.ophooks
|
.. automodule:: colossalai.gemini.ophooks
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
|
|
||||||
colossalai.engine.ophooks.zero_hook
|
colossalai.gemini.ophooks.zero_hook
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
colossalai.engine.ophooks.zero\_hook
|
colossalai.gemini.ophooks.zero\_hook
|
||||||
====================================
|
====================================
|
||||||
|
|
||||||
.. automodule:: colossalai.engine.ophooks.zero_hook
|
.. automodule:: colossalai.gemini.ophooks.zero_hook
|
||||||
:members:
|
:members:
|
||||||
|
@ -8,5 +8,5 @@ colossalai.engine
|
|||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
|
|
||||||
colossalai.engine.gradient_handler
|
colossalai.engine.gradient_handler
|
||||||
colossalai.engine.ophooks
|
colossalai.gemini.ophooks
|
||||||
colossalai.engine.schedule
|
colossalai.engine.schedule
|
||||||
|
@ -1,86 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from colossalai.engine.paramhooks import BaseParamHookMgr
|
|
||||||
from torch import nn
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import copy
|
|
||||||
|
|
||||||
class SubNet(nn.Module):
|
|
||||||
def __init__(self, out_features) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.bias = nn.Parameter(torch.zeros(out_features))
|
|
||||||
|
|
||||||
def forward(self, x, weight):
|
|
||||||
return F.linear(x, weight, self.bias)
|
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Module):
|
|
||||||
def __init__(self, checkpoint=False) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.fc1 = nn.Linear(5, 5)
|
|
||||||
self.sub_fc = SubNet(5)
|
|
||||||
self.fc2 = nn.Linear(5, 1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.fc1(x)
|
|
||||||
x = self.sub_fc(x, self.fc1.weight)
|
|
||||||
x = self.fc1(x)
|
|
||||||
x = self.fc2(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def net_data():
|
|
||||||
return (torch.randn(2, 5, dtype=torch.float, device='cuda'),)
|
|
||||||
|
|
||||||
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
|
|
||||||
if loose:
|
|
||||||
return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
|
|
||||||
return torch.allclose(tensor_a, tensor_b)
|
|
||||||
|
|
||||||
|
|
||||||
def test_base_param_hook():
|
|
||||||
torch.manual_seed(0)
|
|
||||||
model = Net(checkpoint=True).cuda()
|
|
||||||
model.train()
|
|
||||||
inputs = net_data()
|
|
||||||
|
|
||||||
def run_model(model, inputs, use_param_hook = False):
|
|
||||||
if use_param_hook:
|
|
||||||
class HooKWrapper:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.hook_triggered_times = 0
|
|
||||||
|
|
||||||
def wrapper_func(self):
|
|
||||||
def hook(param, grad) -> torch.Tensor or None:
|
|
||||||
self.hook_triggered_times += 1
|
|
||||||
return grad
|
|
||||||
return hook
|
|
||||||
|
|
||||||
hookwrapper = HooKWrapper()
|
|
||||||
param_list = [p for p in model.parameters()]
|
|
||||||
hook_mgr = BaseParamHookMgr(param_list)
|
|
||||||
hook_mgr.register_backward_hooks(hookwrapper.wrapper_func())
|
|
||||||
|
|
||||||
model.zero_grad(set_to_none=True)
|
|
||||||
|
|
||||||
with torch.cuda.amp.autocast():
|
|
||||||
y = model(*inputs)
|
|
||||||
loss = y.sum()
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
if use_param_hook:
|
|
||||||
hook_mgr.remove_hooks()
|
|
||||||
return hookwrapper.hook_triggered_times
|
|
||||||
|
|
||||||
model_copy = copy.deepcopy(model)
|
|
||||||
|
|
||||||
run_model(model, inputs, False)
|
|
||||||
ret2 = run_model(model_copy, inputs, True)
|
|
||||||
|
|
||||||
# Make sure param hook has only be fired once in case of parameter sharing
|
|
||||||
assert ret2 == len(list(model.parameters()))
|
|
||||||
|
|
||||||
for p, p_copy in zip(model.parameters(), model_copy.parameters()):
|
|
||||||
assert allclose(p.grad, p_copy.grad), f"{p.grad} vs {p_copy.grad}"
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_base_param_hook()
|
|
Loading…
Reference in New Issue
Block a user