mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 01:48:07 +00:00
[zero] update zero context init with the updated test utils (#327)
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
from re import S
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
import torch
|
||||
from . import BaseOpHook
|
||||
@@ -7,7 +6,7 @@ from colossalai.registry import OPHOOKS
|
||||
from colossalai.logging import get_dist_logger
|
||||
from time import sleep, time
|
||||
import pickle
|
||||
from typing import Union, Optional
|
||||
from typing import Optional
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
|
||||
@@ -19,12 +18,13 @@ def get_cuda_memory_used(device: Optional[torch.device]) -> int:
|
||||
"""
|
||||
ret: int = torch.cuda.memory_allocated(device)
|
||||
# get the peak memory to report correct data, so reset the counter for the next call
|
||||
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
|
||||
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
return ret
|
||||
|
||||
|
||||
class AsyncMemoryMonitor:
|
||||
|
||||
def __init__(self, power=10):
|
||||
"""
|
||||
An Async Mem Monitor runing during computing.
|
||||
@@ -81,7 +81,7 @@ class AsyncMemoryMonitor:
|
||||
def save(self, filename):
|
||||
with open(filename, "wb") as f:
|
||||
pickle.dump(self.state_dict(), f)
|
||||
|
||||
|
||||
def clear(self):
|
||||
self.mem_stats.clear()
|
||||
self.time_stamps.clear()
|
||||
@@ -92,7 +92,7 @@ class MemTracerOpHook(BaseOpHook):
|
||||
'''
|
||||
Collect GPU memory usage information
|
||||
|
||||
Args:
|
||||
Args:
|
||||
warmup (int): This parameter indicates how many iterations to truncate
|
||||
before profiling, e.g. set to 5 and the data will start from 6-th iteration
|
||||
refreshrate (int): This parameter decides the frequency of write file.
|
||||
@@ -106,6 +106,7 @@ class MemTracerOpHook(BaseOpHook):
|
||||
_data_prefix (string): the prefix of the stats data file
|
||||
_rank (int): the rank of current node
|
||||
'''
|
||||
|
||||
def __init__(self, warmup: int = 50, refreshrate: int = 10, data_prefix: str = "memstats"):
|
||||
super().__init__()
|
||||
self.async_mem_monitor = AsyncMemoryMonitor()
|
||||
@@ -128,7 +129,7 @@ class MemTracerOpHook(BaseOpHook):
|
||||
@property
|
||||
def refreshrate(self) -> int:
|
||||
return self._refreshrate
|
||||
|
||||
|
||||
@property
|
||||
def warmup(self) -> int:
|
||||
return self._warmup
|
||||
@@ -178,8 +179,7 @@ class MemTracerOpHook(BaseOpHook):
|
||||
# every `refreshrate` times, refresh the file
|
||||
if self.valid_iter != 0 and self.valid_iter % self.refreshrate == 0:
|
||||
# output file info
|
||||
self._logger.info(
|
||||
f'dump a memory statistics as pickle to {self._dataprefix}-{self._rank}.pkl')
|
||||
self._logger.info(f'dump a memory statistics as pickle to {self._dataprefix}-{self._rank}.pkl')
|
||||
self.save_results()
|
||||
self._count += 1
|
||||
self._logger.debug(f'data file has been refreshed {self._count} times')
|
||||
|
Reference in New Issue
Block a user