[hotfix] remove potiential circle import (#1307)

* make it faster

* [hotfix] remove circle import
This commit is contained in:
Jiarui Fang 2022-07-14 13:44:26 +08:00 committed by GitHub
parent 6f2f9eb214
commit 4165eabb1e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 16 additions and 101 deletions

View File

@ -7,7 +7,7 @@ from torch.nn.modules.loss import _Loss
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from torch import Tensor from torch import Tensor
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook from colossalai.gemini.ophooks import register_ophooks_recursively, BaseOpHook
from colossalai.engine.schedule import BaseSchedule, NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule from colossalai.engine.schedule import BaseSchedule, NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule
from typing import Optional, Type from typing import Optional, Type
from colossalai.engine.gradient_handler import BaseGradientHandler from colossalai.engine.gradient_handler import BaseGradientHandler

View File

@ -12,7 +12,6 @@ from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import switch_virtual_pipeline_parallel_rank from colossalai.utils import switch_virtual_pipeline_parallel_rank
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from ._base_schedule import BaseSchedule from ._base_schedule import BaseSchedule
@ -157,6 +156,7 @@ class PipelineSchedule(BaseSchedule):
return self._move_to_device(mciro_batch_data) return self._move_to_device(mciro_batch_data)
def pre_processing(self, engine): def pre_processing(self, engine):
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
# TODO: remove this after testing new zero with pipeline parallelism # TODO: remove this after testing new zero with pipeline parallelism
model = engine.model model = engine.model
if isinstance(model, NaiveAMPModel): if isinstance(model, NaiveAMPModel):
@ -482,6 +482,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
self.num_model_chunks = num_model_chunks self.num_model_chunks = num_model_chunks
def pre_processing(self, engine): def pre_processing(self, engine):
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
if isinstance(engine.model, ShardedModelV2): if isinstance(engine.model, ShardedModelV2):
self.dtype = torch.half self.dtype = torch.half
elif isinstance(engine.model[0], NaiveAMPModel): elif isinstance(engine.model[0], NaiveAMPModel):

View File

@ -3,7 +3,7 @@ import pickle
from pathlib import Path from pathlib import Path
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
import torch import torch
from colossalai.engine.ophooks import BaseOpHook from colossalai.gemini.ophooks import BaseOpHook
from colossalai.registry import OPHOOKS from colossalai.registry import OPHOOKS
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc

View File

@ -22,7 +22,7 @@ from colossalai.logging import get_dist_logger
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
from colossalai.engine import Engine from colossalai.engine import Engine
from colossalai.engine.ophooks import BaseOpHook from colossalai.gemini.ophooks import BaseOpHook
from colossalai.utils import (get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param) from colossalai.utils import (get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param)
from colossalai.utils.moe import sync_moe_model_param from colossalai.utils.moe import sync_moe_model_param

View File

@ -2,7 +2,7 @@ from pathlib import Path
from typing import Union from typing import Union
from colossalai.engine import Engine from colossalai.engine import Engine
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from colossalai.engine.ophooks import MemTracerOpHook from colossalai.gemini.ophooks import MemTracerOpHook
from colossalai.utils.profiler.legacy.prof_utils import BaseProfiler from colossalai.utils.profiler.legacy.prof_utils import BaseProfiler

View File

@ -5,7 +5,7 @@ import torch
from enum import Enum from enum import Enum
from typing import List from typing import List
from colossalai.gemini.stateful_tensor import StatefulTensor from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.engine.ophooks import BaseOpHook from colossalai.gemini.ophooks import BaseOpHook
from colossalai.engine import Engine from colossalai.engine import Engine
from colossalai.utils.profiler.extention import ProfilerExtension from colossalai.utils.profiler.extention import ProfilerExtension

View File

@ -8,9 +8,9 @@ import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.engine.ophooks import register_ophooks_recursively from colossalai.gemini.ophooks import register_ophooks_recursively
from colossalai.zero.utils import ZeroHook from colossalai.zero.utils import ZeroHook
from colossalai.engine.paramhooks import BaseParamHookMgr from colossalai.gemini.paramhooks import BaseParamHookMgr
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device, disposable from colossalai.utils import get_current_device, disposable
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollector from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollector

View File

@ -8,7 +8,7 @@ from colossalai.registry import OPHOOKS
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.engine.ophooks import BaseOpHook from colossalai.gemini.ophooks import BaseOpHook
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.gemini.memory_tracer import MemStatsCollector from colossalai.gemini.memory_tracer import MemStatsCollector

View File

@ -1,11 +1,11 @@
colossalai.engine.ophooks colossalai.gemini.ophooks
========================= =========================
.. automodule:: colossalai.engine.ophooks .. automodule:: colossalai.gemini.ophooks
:members: :members:
.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 2
colossalai.engine.ophooks.zero_hook colossalai.gemini.ophooks.zero_hook

View File

@ -1,5 +1,5 @@
colossalai.engine.ophooks.zero\_hook colossalai.gemini.ophooks.zero\_hook
==================================== ====================================
.. automodule:: colossalai.engine.ophooks.zero_hook .. automodule:: colossalai.gemini.ophooks.zero_hook
:members: :members:

View File

@ -8,5 +8,5 @@ colossalai.engine
:maxdepth: 2 :maxdepth: 2
colossalai.engine.gradient_handler colossalai.engine.gradient_handler
colossalai.engine.ophooks colossalai.gemini.ophooks
colossalai.engine.schedule colossalai.engine.schedule

View File

@ -1,86 +0,0 @@
import pytest
from colossalai.engine.paramhooks import BaseParamHookMgr
from torch import nn
import torch
import torch.nn.functional as F
import copy
class SubNet(nn.Module):
def __init__(self, out_features) -> None:
super().__init__()
self.bias = nn.Parameter(torch.zeros(out_features))
def forward(self, x, weight):
return F.linear(x, weight, self.bias)
class Net(nn.Module):
def __init__(self, checkpoint=False) -> None:
super().__init__()
self.fc1 = nn.Linear(5, 5)
self.sub_fc = SubNet(5)
self.fc2 = nn.Linear(5, 1)
def forward(self, x):
x = self.fc1(x)
x = self.sub_fc(x, self.fc1.weight)
x = self.fc1(x)
x = self.fc2(x)
return x
def net_data():
return (torch.randn(2, 5, dtype=torch.float, device='cuda'),)
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
if loose:
return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
return torch.allclose(tensor_a, tensor_b)
def test_base_param_hook():
torch.manual_seed(0)
model = Net(checkpoint=True).cuda()
model.train()
inputs = net_data()
def run_model(model, inputs, use_param_hook = False):
if use_param_hook:
class HooKWrapper:
def __init__(self) -> None:
self.hook_triggered_times = 0
def wrapper_func(self):
def hook(param, grad) -> torch.Tensor or None:
self.hook_triggered_times += 1
return grad
return hook
hookwrapper = HooKWrapper()
param_list = [p for p in model.parameters()]
hook_mgr = BaseParamHookMgr(param_list)
hook_mgr.register_backward_hooks(hookwrapper.wrapper_func())
model.zero_grad(set_to_none=True)
with torch.cuda.amp.autocast():
y = model(*inputs)
loss = y.sum()
loss.backward()
if use_param_hook:
hook_mgr.remove_hooks()
return hookwrapper.hook_triggered_times
model_copy = copy.deepcopy(model)
run_model(model, inputs, False)
ret2 = run_model(model_copy, inputs, True)
# Make sure param hook has only be fired once in case of parameter sharing
assert ret2 == len(list(model.parameters()))
for p, p_copy in zip(model.parameters(), model_copy.parameters()):
assert allclose(p.grad, p_copy.grad), f"{p.grad} vs {p_copy.grad}"
if __name__ == '__main__':
test_base_param_hook()