[hotfix] remove potiential circle import (#1307)

* make it faster

* [hotfix] remove circle import
This commit is contained in:
Jiarui Fang 2022-07-14 13:44:26 +08:00 committed by GitHub
parent 6f2f9eb214
commit 4165eabb1e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 16 additions and 101 deletions

View File

@ -7,7 +7,7 @@ from torch.nn.modules.loss import _Loss
from colossalai.logging import get_dist_logger
from torch import Tensor
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
from colossalai.gemini.ophooks import register_ophooks_recursively, BaseOpHook
from colossalai.engine.schedule import BaseSchedule, NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule
from typing import Optional, Type
from colossalai.engine.gradient_handler import BaseGradientHandler

View File

@ -12,7 +12,6 @@ from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import switch_virtual_pipeline_parallel_rank
from colossalai.utils.cuda import get_current_device
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from ._base_schedule import BaseSchedule
@ -157,6 +156,7 @@ class PipelineSchedule(BaseSchedule):
return self._move_to_device(mciro_batch_data)
def pre_processing(self, engine):
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
# TODO: remove this after testing new zero with pipeline parallelism
model = engine.model
if isinstance(model, NaiveAMPModel):
@ -482,6 +482,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
self.num_model_chunks = num_model_chunks
def pre_processing(self, engine):
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
if isinstance(engine.model, ShardedModelV2):
self.dtype = torch.half
elif isinstance(engine.model[0], NaiveAMPModel):

View File

@ -3,7 +3,7 @@ import pickle
from pathlib import Path
from colossalai.context.parallel_mode import ParallelMode
import torch
from colossalai.engine.ophooks import BaseOpHook
from colossalai.gemini.ophooks import BaseOpHook
from colossalai.registry import OPHOOKS
from colossalai.logging import get_dist_logger
from colossalai.core import global_context as gpc

View File

@ -22,7 +22,7 @@ from colossalai.logging import get_dist_logger
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
from colossalai.engine import Engine
from colossalai.engine.ophooks import BaseOpHook
from colossalai.gemini.ophooks import BaseOpHook
from colossalai.utils import (get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param)
from colossalai.utils.moe import sync_moe_model_param

View File

@ -2,7 +2,7 @@ from pathlib import Path
from typing import Union
from colossalai.engine import Engine
from torch.utils.tensorboard import SummaryWriter
from colossalai.engine.ophooks import MemTracerOpHook
from colossalai.gemini.ophooks import MemTracerOpHook
from colossalai.utils.profiler.legacy.prof_utils import BaseProfiler

View File

@ -5,7 +5,7 @@ import torch
from enum import Enum
from typing import List
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.engine.ophooks import BaseOpHook
from colossalai.gemini.ophooks import BaseOpHook
from colossalai.engine import Engine
from colossalai.utils.profiler.extention import ProfilerExtension

View File

@ -8,9 +8,9 @@ import torch.distributed as dist
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.ophooks import register_ophooks_recursively
from colossalai.gemini.ophooks import register_ophooks_recursively
from colossalai.zero.utils import ZeroHook
from colossalai.engine.paramhooks import BaseParamHookMgr
from colossalai.gemini.paramhooks import BaseParamHookMgr
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device, disposable
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollector

View File

@ -8,7 +8,7 @@ from colossalai.registry import OPHOOKS
from colossalai.utils import get_current_device
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.engine.ophooks import BaseOpHook
from colossalai.gemini.ophooks import BaseOpHook
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.gemini.memory_tracer import MemStatsCollector

View File

@ -1,11 +1,11 @@
colossalai.engine.ophooks
colossalai.gemini.ophooks
=========================
.. automodule:: colossalai.engine.ophooks
.. automodule:: colossalai.gemini.ophooks
:members:
.. toctree::
:maxdepth: 2
colossalai.engine.ophooks.zero_hook
colossalai.gemini.ophooks.zero_hook

View File

@ -1,5 +1,5 @@
colossalai.engine.ophooks.zero\_hook
colossalai.gemini.ophooks.zero\_hook
====================================
.. automodule:: colossalai.engine.ophooks.zero_hook
.. automodule:: colossalai.gemini.ophooks.zero_hook
:members:

View File

@ -8,5 +8,5 @@ colossalai.engine
:maxdepth: 2
colossalai.engine.gradient_handler
colossalai.engine.ophooks
colossalai.gemini.ophooks
colossalai.engine.schedule

View File

@ -1,86 +0,0 @@
import pytest
from colossalai.engine.paramhooks import BaseParamHookMgr
from torch import nn
import torch
import torch.nn.functional as F
import copy
class SubNet(nn.Module):
def __init__(self, out_features) -> None:
super().__init__()
self.bias = nn.Parameter(torch.zeros(out_features))
def forward(self, x, weight):
return F.linear(x, weight, self.bias)
class Net(nn.Module):
def __init__(self, checkpoint=False) -> None:
super().__init__()
self.fc1 = nn.Linear(5, 5)
self.sub_fc = SubNet(5)
self.fc2 = nn.Linear(5, 1)
def forward(self, x):
x = self.fc1(x)
x = self.sub_fc(x, self.fc1.weight)
x = self.fc1(x)
x = self.fc2(x)
return x
def net_data():
return (torch.randn(2, 5, dtype=torch.float, device='cuda'),)
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
if loose:
return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
return torch.allclose(tensor_a, tensor_b)
def test_base_param_hook():
torch.manual_seed(0)
model = Net(checkpoint=True).cuda()
model.train()
inputs = net_data()
def run_model(model, inputs, use_param_hook = False):
if use_param_hook:
class HooKWrapper:
def __init__(self) -> None:
self.hook_triggered_times = 0
def wrapper_func(self):
def hook(param, grad) -> torch.Tensor or None:
self.hook_triggered_times += 1
return grad
return hook
hookwrapper = HooKWrapper()
param_list = [p for p in model.parameters()]
hook_mgr = BaseParamHookMgr(param_list)
hook_mgr.register_backward_hooks(hookwrapper.wrapper_func())
model.zero_grad(set_to_none=True)
with torch.cuda.amp.autocast():
y = model(*inputs)
loss = y.sum()
loss.backward()
if use_param_hook:
hook_mgr.remove_hooks()
return hookwrapper.hook_triggered_times
model_copy = copy.deepcopy(model)
run_model(model, inputs, False)
ret2 = run_model(model_copy, inputs, True)
# Make sure param hook has only be fired once in case of parameter sharing
assert ret2 == len(list(model.parameters()))
for p, p_copy in zip(model.parameters(), model_copy.parameters()):
assert allclose(p.grad, p_copy.grad), f"{p.grad} vs {p_copy.grad}"
if __name__ == '__main__':
test_base_param_hook()