mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[test] fixed gemini plugin test (#3411)
* [test] fixed gemini plugin test * polish code * polish code
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
from typing import Optional, Set
|
||||
from functools import partial
|
||||
from typing import Optional, Set
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.nn.parallel.data_parallel import _cast_float
|
||||
from colossalai.gemini.tensor_utils import free_storage
|
||||
from colossalai.nn.parallel.data_parallel import _cast_float
|
||||
|
||||
from .region_manager import RegionManager
|
||||
from .util import GlobalRuntimeInfo
|
||||
@@ -20,10 +21,7 @@ class BaseOffloadModule:
|
||||
is_sync (bool): synchronous mode or not.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: nn.Module,
|
||||
region_manager: RegionManager,
|
||||
is_sync=True):
|
||||
def __init__(self, model: nn.Module, region_manager: RegionManager, is_sync=True):
|
||||
|
||||
self.model = model
|
||||
self.region_manager = region_manager
|
||||
@@ -69,8 +67,8 @@ class BaseOffloadModule:
|
||||
for p in self.model.parameters():
|
||||
p.grad = None
|
||||
|
||||
GlobalRuntimeInfo.fwd_prefetch_event_map.clear()
|
||||
GlobalRuntimeInfo.bwd_prefetch_event_map.clear()
|
||||
GlobalRuntimeInfo().fwd_prefetch_event_map.clear()
|
||||
GlobalRuntimeInfo().bwd_prefetch_event_map.clear()
|
||||
|
||||
def grad_handle(self, p, grad):
|
||||
empty_grad = torch.empty_like(grad)
|
||||
@@ -82,7 +80,7 @@ class BaseOffloadModule:
|
||||
self.overflow_counter += region.has_inf_or_nan
|
||||
master_stream = torch.cuda.current_stream()
|
||||
with torch.cuda.stream(self.grad_offload_stream):
|
||||
GlobalRuntimeInfo.d2h_stream.wait_stream(master_stream)
|
||||
GlobalRuntimeInfo().d2h_stream.wait_stream(master_stream)
|
||||
region.move_grad_to_cpu()
|
||||
return empty_grad
|
||||
|
||||
|
Reference in New Issue
Block a user