mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-27 19:36:13 +00:00
[hotfix] remove potiential circle import (#1307)
* make it faster * [hotfix] remove circle import
This commit is contained in:
parent
6f2f9eb214
commit
4165eabb1e
@ -7,7 +7,7 @@ from torch.nn.modules.loss import _Loss
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from torch import Tensor
|
||||
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
|
||||
from colossalai.gemini.ophooks import register_ophooks_recursively, BaseOpHook
|
||||
from colossalai.engine.schedule import BaseSchedule, NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule
|
||||
from typing import Optional, Type
|
||||
from colossalai.engine.gradient_handler import BaseGradientHandler
|
||||
|
@ -12,7 +12,6 @@ from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import switch_virtual_pipeline_parallel_rank
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||
|
||||
from ._base_schedule import BaseSchedule
|
||||
|
||||
@ -157,6 +156,7 @@ class PipelineSchedule(BaseSchedule):
|
||||
return self._move_to_device(mciro_batch_data)
|
||||
|
||||
def pre_processing(self, engine):
|
||||
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||
# TODO: remove this after testing new zero with pipeline parallelism
|
||||
model = engine.model
|
||||
if isinstance(model, NaiveAMPModel):
|
||||
@ -482,6 +482,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
self.num_model_chunks = num_model_chunks
|
||||
|
||||
def pre_processing(self, engine):
|
||||
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||
if isinstance(engine.model, ShardedModelV2):
|
||||
self.dtype = torch.half
|
||||
elif isinstance(engine.model[0], NaiveAMPModel):
|
||||
|
@ -3,7 +3,7 @@ import pickle
|
||||
from pathlib import Path
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
import torch
|
||||
from colossalai.engine.ophooks import BaseOpHook
|
||||
from colossalai.gemini.ophooks import BaseOpHook
|
||||
from colossalai.registry import OPHOOKS
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.core import global_context as gpc
|
@ -22,7 +22,7 @@ from colossalai.logging import get_dist_logger
|
||||
|
||||
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.engine.ophooks import BaseOpHook
|
||||
from colossalai.gemini.ophooks import BaseOpHook
|
||||
|
||||
from colossalai.utils import (get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param)
|
||||
from colossalai.utils.moe import sync_moe_model_param
|
||||
|
@ -2,7 +2,7 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
from colossalai.engine import Engine
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from colossalai.engine.ophooks import MemTracerOpHook
|
||||
from colossalai.gemini.ophooks import MemTracerOpHook
|
||||
from colossalai.utils.profiler.legacy.prof_utils import BaseProfiler
|
||||
|
||||
|
||||
|
@ -5,7 +5,7 @@ import torch
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||
from colossalai.engine.ophooks import BaseOpHook
|
||||
from colossalai.gemini.ophooks import BaseOpHook
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.utils.profiler.extention import ProfilerExtension
|
||||
|
||||
|
@ -8,9 +8,9 @@ import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine.ophooks import register_ophooks_recursively
|
||||
from colossalai.gemini.ophooks import register_ophooks_recursively
|
||||
from colossalai.zero.utils import ZeroHook
|
||||
from colossalai.engine.paramhooks import BaseParamHookMgr
|
||||
from colossalai.gemini.paramhooks import BaseParamHookMgr
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device, disposable
|
||||
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollector
|
||||
|
@ -8,7 +8,7 @@ from colossalai.registry import OPHOOKS
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.engine.ophooks import BaseOpHook
|
||||
from colossalai.gemini.ophooks import BaseOpHook
|
||||
|
||||
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
|
||||
from colossalai.gemini.memory_tracer import MemStatsCollector
|
||||
|
@ -1,11 +1,11 @@
|
||||
colossalai.engine.ophooks
|
||||
colossalai.gemini.ophooks
|
||||
=========================
|
||||
|
||||
.. automodule:: colossalai.engine.ophooks
|
||||
.. automodule:: colossalai.gemini.ophooks
|
||||
:members:
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
colossalai.engine.ophooks.zero_hook
|
||||
colossalai.gemini.ophooks.zero_hook
|
||||
|
@ -1,5 +1,5 @@
|
||||
colossalai.engine.ophooks.zero\_hook
|
||||
colossalai.gemini.ophooks.zero\_hook
|
||||
====================================
|
||||
|
||||
.. automodule:: colossalai.engine.ophooks.zero_hook
|
||||
.. automodule:: colossalai.gemini.ophooks.zero_hook
|
||||
:members:
|
||||
|
@ -8,5 +8,5 @@ colossalai.engine
|
||||
:maxdepth: 2
|
||||
|
||||
colossalai.engine.gradient_handler
|
||||
colossalai.engine.ophooks
|
||||
colossalai.gemini.ophooks
|
||||
colossalai.engine.schedule
|
||||
|
@ -1,86 +0,0 @@
|
||||
import pytest
|
||||
from colossalai.engine.paramhooks import BaseParamHookMgr
|
||||
from torch import nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import copy
|
||||
|
||||
class SubNet(nn.Module):
|
||||
def __init__(self, out_features) -> None:
|
||||
super().__init__()
|
||||
self.bias = nn.Parameter(torch.zeros(out_features))
|
||||
|
||||
def forward(self, x, weight):
|
||||
return F.linear(x, weight, self.bias)
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, checkpoint=False) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(5, 5)
|
||||
self.sub_fc = SubNet(5)
|
||||
self.fc2 = nn.Linear(5, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.sub_fc(x, self.fc1.weight)
|
||||
x = self.fc1(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
def net_data():
|
||||
return (torch.randn(2, 5, dtype=torch.float, device='cuda'),)
|
||||
|
||||
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
|
||||
if loose:
|
||||
return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
|
||||
return torch.allclose(tensor_a, tensor_b)
|
||||
|
||||
|
||||
def test_base_param_hook():
|
||||
torch.manual_seed(0)
|
||||
model = Net(checkpoint=True).cuda()
|
||||
model.train()
|
||||
inputs = net_data()
|
||||
|
||||
def run_model(model, inputs, use_param_hook = False):
|
||||
if use_param_hook:
|
||||
class HooKWrapper:
|
||||
def __init__(self) -> None:
|
||||
self.hook_triggered_times = 0
|
||||
|
||||
def wrapper_func(self):
|
||||
def hook(param, grad) -> torch.Tensor or None:
|
||||
self.hook_triggered_times += 1
|
||||
return grad
|
||||
return hook
|
||||
|
||||
hookwrapper = HooKWrapper()
|
||||
param_list = [p for p in model.parameters()]
|
||||
hook_mgr = BaseParamHookMgr(param_list)
|
||||
hook_mgr.register_backward_hooks(hookwrapper.wrapper_func())
|
||||
|
||||
model.zero_grad(set_to_none=True)
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
y = model(*inputs)
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
|
||||
if use_param_hook:
|
||||
hook_mgr.remove_hooks()
|
||||
return hookwrapper.hook_triggered_times
|
||||
|
||||
model_copy = copy.deepcopy(model)
|
||||
|
||||
run_model(model, inputs, False)
|
||||
ret2 = run_model(model_copy, inputs, True)
|
||||
|
||||
# Make sure param hook has only be fired once in case of parameter sharing
|
||||
assert ret2 == len(list(model.parameters()))
|
||||
|
||||
for p, p_copy in zip(model.parameters(), model_copy.parameters()):
|
||||
assert allclose(p.grad, p_copy.grad), f"{p.grad} vs {p_copy.grad}"
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_base_param_hook()
|
Loading…
Reference in New Issue
Block a user