[test] fixed gemini plugin test (#3411)

* [test] fixed gemini plugin test

* polish code

* polish code
This commit is contained in:
Frank Lee
2023-04-03 17:12:22 +08:00
committed by GitHub
parent 30412866e0
commit 638a07a7f9
7 changed files with 124 additions and 131 deletions

View File

@@ -1,10 +1,11 @@
from typing import Optional, Set
from functools import partial
from typing import Optional, Set
import torch
import torch.nn as nn
from colossalai.nn.parallel.data_parallel import _cast_float
from colossalai.gemini.tensor_utils import free_storage
from colossalai.nn.parallel.data_parallel import _cast_float
from .region_manager import RegionManager
from .util import GlobalRuntimeInfo
@@ -20,10 +21,7 @@ class BaseOffloadModule:
is_sync (bool): synchronous mode or not.
"""
def __init__(self,
model: nn.Module,
region_manager: RegionManager,
is_sync=True):
def __init__(self, model: nn.Module, region_manager: RegionManager, is_sync=True):
self.model = model
self.region_manager = region_manager
@@ -69,8 +67,8 @@ class BaseOffloadModule:
for p in self.model.parameters():
p.grad = None
GlobalRuntimeInfo.fwd_prefetch_event_map.clear()
GlobalRuntimeInfo.bwd_prefetch_event_map.clear()
GlobalRuntimeInfo().fwd_prefetch_event_map.clear()
GlobalRuntimeInfo().bwd_prefetch_event_map.clear()
def grad_handle(self, p, grad):
empty_grad = torch.empty_like(grad)
@@ -82,7 +80,7 @@ class BaseOffloadModule:
self.overflow_counter += region.has_inf_or_nan
master_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.grad_offload_stream):
GlobalRuntimeInfo.d2h_stream.wait_stream(master_stream)
GlobalRuntimeInfo().d2h_stream.wait_stream(master_stream)
region.move_grad_to_cpu()
return empty_grad