diff --git a/colossalai/engine/_base_engine.py b/colossalai/engine/_base_engine.py index 074b9d0cc..146a29669 100644 --- a/colossalai/engine/_base_engine.py +++ b/colossalai/engine/_base_engine.py @@ -7,7 +7,7 @@ from torch.nn.modules.loss import _Loss from colossalai.logging import get_dist_logger from torch import Tensor -from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook +from colossalai.gemini.ophooks import register_ophooks_recursively, BaseOpHook from colossalai.engine.schedule import BaseSchedule, NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule from typing import Optional, Type from colossalai.engine.gradient_handler import BaseGradientHandler diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py index 6e865ae8f..97571fa02 100644 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -12,7 +12,6 @@ from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.utils import switch_virtual_pipeline_parallel_rank from colossalai.utils.cuda import get_current_device -from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 from ._base_schedule import BaseSchedule @@ -157,6 +156,7 @@ class PipelineSchedule(BaseSchedule): return self._move_to_device(mciro_batch_data) def pre_processing(self, engine): + from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 # TODO: remove this after testing new zero with pipeline parallelism model = engine.model if isinstance(model, NaiveAMPModel): @@ -482,6 +482,7 @@ class InterleavedPipelineSchedule(PipelineSchedule): self.num_model_chunks = num_model_chunks def pre_processing(self, engine): + from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 if isinstance(engine.model, ShardedModelV2): self.dtype = torch.half elif isinstance(engine.model[0], NaiveAMPModel): diff --git a/colossalai/engine/ophooks/__init__.py b/colossalai/gemini/ophooks/__init__.py similarity index 100% rename from colossalai/engine/ophooks/__init__.py rename to colossalai/gemini/ophooks/__init__.py diff --git a/colossalai/engine/ophooks/_memtracer_ophook.py b/colossalai/gemini/ophooks/_memtracer_ophook.py similarity index 98% rename from colossalai/engine/ophooks/_memtracer_ophook.py rename to colossalai/gemini/ophooks/_memtracer_ophook.py index 4f16edfab..71831f1aa 100644 --- a/colossalai/engine/ophooks/_memtracer_ophook.py +++ b/colossalai/gemini/ophooks/_memtracer_ophook.py @@ -3,7 +3,7 @@ import pickle from pathlib import Path from colossalai.context.parallel_mode import ParallelMode import torch -from colossalai.engine.ophooks import BaseOpHook +from colossalai.gemini.ophooks import BaseOpHook from colossalai.registry import OPHOOKS from colossalai.logging import get_dist_logger from colossalai.core import global_context as gpc diff --git a/colossalai/engine/ophooks/_shard_grad_ophook.py b/colossalai/gemini/ophooks/_shard_grad_ophook.py similarity index 100% rename from colossalai/engine/ophooks/_shard_grad_ophook.py rename to colossalai/gemini/ophooks/_shard_grad_ophook.py diff --git a/colossalai/engine/ophooks/_shard_param_ophook.py b/colossalai/gemini/ophooks/_shard_param_ophook.py similarity index 100% rename from colossalai/engine/ophooks/_shard_param_ophook.py rename to colossalai/gemini/ophooks/_shard_param_ophook.py diff --git a/colossalai/engine/ophooks/utils.py b/colossalai/gemini/ophooks/utils.py similarity index 100% rename from colossalai/engine/ophooks/utils.py rename to colossalai/gemini/ophooks/utils.py diff --git a/colossalai/engine/paramhooks/__init__.py b/colossalai/gemini/paramhooks/__init__.py similarity index 100% rename from colossalai/engine/paramhooks/__init__.py rename to colossalai/gemini/paramhooks/__init__.py diff --git a/colossalai/engine/paramhooks/_param_hookmgr.py b/colossalai/gemini/paramhooks/_param_hookmgr.py similarity index 100% rename from colossalai/engine/paramhooks/_param_hookmgr.py rename to colossalai/gemini/paramhooks/_param_hookmgr.py diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 086efaac3..e907efdde 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -22,7 +22,7 @@ from colossalai.logging import get_dist_logger from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape from colossalai.engine import Engine -from colossalai.engine.ophooks import BaseOpHook +from colossalai.gemini.ophooks import BaseOpHook from colossalai.utils import (get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param) from colossalai.utils.moe import sync_moe_model_param diff --git a/colossalai/utils/profiler/legacy/mem_profiler.py b/colossalai/utils/profiler/legacy/mem_profiler.py index c4d7ca2ef..f80f6ecf5 100644 --- a/colossalai/utils/profiler/legacy/mem_profiler.py +++ b/colossalai/utils/profiler/legacy/mem_profiler.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Union from colossalai.engine import Engine from torch.utils.tensorboard import SummaryWriter -from colossalai.engine.ophooks import MemTracerOpHook +from colossalai.gemini.ophooks import MemTracerOpHook from colossalai.utils.profiler.legacy.prof_utils import BaseProfiler diff --git a/colossalai/utils/profiler/stateful_tensor_mem_extention.py b/colossalai/utils/profiler/stateful_tensor_mem_extention.py index 749823553..127055c8c 100644 --- a/colossalai/utils/profiler/stateful_tensor_mem_extention.py +++ b/colossalai/utils/profiler/stateful_tensor_mem_extention.py @@ -5,7 +5,7 @@ import torch from enum import Enum from typing import List from colossalai.gemini.stateful_tensor import StatefulTensor -from colossalai.engine.ophooks import BaseOpHook +from colossalai.gemini.ophooks import BaseOpHook from colossalai.engine import Engine from colossalai.utils.profiler.extention import ProfilerExtension diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 9940ea5e5..a0214f609 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -8,9 +8,9 @@ import torch.distributed as dist import torch.nn as nn from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.engine.ophooks import register_ophooks_recursively +from colossalai.gemini.ophooks import register_ophooks_recursively from colossalai.zero.utils import ZeroHook -from colossalai.engine.paramhooks import BaseParamHookMgr +from colossalai.gemini.paramhooks import BaseParamHookMgr from colossalai.logging import get_dist_logger from colossalai.utils import get_current_device, disposable from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollector diff --git a/colossalai/zero/utils/zero_hook.py b/colossalai/zero/utils/zero_hook.py index e29266021..189d1ad2d 100644 --- a/colossalai/zero/utils/zero_hook.py +++ b/colossalai/zero/utils/zero_hook.py @@ -8,7 +8,7 @@ from colossalai.registry import OPHOOKS from colossalai.utils import get_current_device from colossalai.zero.shard_utils import BaseShardStrategy -from colossalai.engine.ophooks import BaseOpHook +from colossalai.gemini.ophooks import BaseOpHook from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr from colossalai.gemini.memory_tracer import MemStatsCollector diff --git a/docs/colossalai/colossalai.engine.ophooks.rst b/docs/colossalai/colossalai.engine.ophooks.rst index f4b8a8396..0173aa2a4 100644 --- a/docs/colossalai/colossalai.engine.ophooks.rst +++ b/docs/colossalai/colossalai.engine.ophooks.rst @@ -1,11 +1,11 @@ -colossalai.engine.ophooks +colossalai.gemini.ophooks ========================= -.. automodule:: colossalai.engine.ophooks +.. automodule:: colossalai.gemini.ophooks :members: .. toctree:: :maxdepth: 2 - colossalai.engine.ophooks.zero_hook + colossalai.gemini.ophooks.zero_hook diff --git a/docs/colossalai/colossalai.engine.ophooks.zero_hook.rst b/docs/colossalai/colossalai.engine.ophooks.zero_hook.rst index 270d1839c..f7868dd3a 100644 --- a/docs/colossalai/colossalai.engine.ophooks.zero_hook.rst +++ b/docs/colossalai/colossalai.engine.ophooks.zero_hook.rst @@ -1,5 +1,5 @@ -colossalai.engine.ophooks.zero\_hook +colossalai.gemini.ophooks.zero\_hook ==================================== -.. automodule:: colossalai.engine.ophooks.zero_hook +.. automodule:: colossalai.gemini.ophooks.zero_hook :members: diff --git a/docs/colossalai/colossalai.engine.rst b/docs/colossalai/colossalai.engine.rst index 00028968a..740cb0334 100644 --- a/docs/colossalai/colossalai.engine.rst +++ b/docs/colossalai/colossalai.engine.rst @@ -8,5 +8,5 @@ colossalai.engine :maxdepth: 2 colossalai.engine.gradient_handler - colossalai.engine.ophooks + colossalai.gemini.ophooks colossalai.engine.schedule diff --git a/tests/test_engine/test_param_hook.py b/tests/test_engine/test_param_hook.py deleted file mode 100644 index 54639157f..000000000 --- a/tests/test_engine/test_param_hook.py +++ /dev/null @@ -1,86 +0,0 @@ -import pytest -from colossalai.engine.paramhooks import BaseParamHookMgr -from torch import nn -import torch -import torch.nn.functional as F -import copy - -class SubNet(nn.Module): - def __init__(self, out_features) -> None: - super().__init__() - self.bias = nn.Parameter(torch.zeros(out_features)) - - def forward(self, x, weight): - return F.linear(x, weight, self.bias) - - -class Net(nn.Module): - def __init__(self, checkpoint=False) -> None: - super().__init__() - self.fc1 = nn.Linear(5, 5) - self.sub_fc = SubNet(5) - self.fc2 = nn.Linear(5, 1) - - def forward(self, x): - x = self.fc1(x) - x = self.sub_fc(x, self.fc1.weight) - x = self.fc1(x) - x = self.fc2(x) - return x - -def net_data(): - return (torch.randn(2, 5, dtype=torch.float, device='cuda'),) - -def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: - if loose: - return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3) - return torch.allclose(tensor_a, tensor_b) - - -def test_base_param_hook(): - torch.manual_seed(0) - model = Net(checkpoint=True).cuda() - model.train() - inputs = net_data() - - def run_model(model, inputs, use_param_hook = False): - if use_param_hook: - class HooKWrapper: - def __init__(self) -> None: - self.hook_triggered_times = 0 - - def wrapper_func(self): - def hook(param, grad) -> torch.Tensor or None: - self.hook_triggered_times += 1 - return grad - return hook - - hookwrapper = HooKWrapper() - param_list = [p for p in model.parameters()] - hook_mgr = BaseParamHookMgr(param_list) - hook_mgr.register_backward_hooks(hookwrapper.wrapper_func()) - - model.zero_grad(set_to_none=True) - - with torch.cuda.amp.autocast(): - y = model(*inputs) - loss = y.sum() - loss.backward() - - if use_param_hook: - hook_mgr.remove_hooks() - return hookwrapper.hook_triggered_times - - model_copy = copy.deepcopy(model) - - run_model(model, inputs, False) - ret2 = run_model(model_copy, inputs, True) - - # Make sure param hook has only be fired once in case of parameter sharing - assert ret2 == len(list(model.parameters())) - - for p, p_copy in zip(model.parameters(), model_copy.parameters()): - assert allclose(p.grad, p_copy.grad), f"{p.grad} vs {p_copy.grad}" - -if __name__ == '__main__': - test_base_param_hook()