mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 01:48:07 +00:00
[profiler] add MemProfiler (#356)
* add memory trainer hook * fix bug * add memory trainer hook * fix import bug * fix import bug * add trainer hook * fix #370 git log bug * modify `to_tensorboard` function to support better output * remove useless output * change the name of `MemProfiler` * complete memory profiler * replace error with warning * finish trainer hook * modify interface of MemProfiler * modify `__init__.py` in profiler * remove unnecessary pass statement * add usage to doc string * add usage to trainer hook * new location to store temp data file
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from asyncio.log import logger
|
||||
from typing import List
|
||||
from torch.nn import Module
|
||||
from torch.nn.modules.loss import _Loss
|
||||
@@ -9,9 +10,9 @@ from torch.optim import Optimizer
|
||||
from colossalai.logging import get_dist_logger
|
||||
from torch import Tensor
|
||||
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
|
||||
from typing import Optional
|
||||
from typing import Optional, Type
|
||||
from colossalai.engine.gradient_handler import BaseGradientHandler
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
class Engine:
|
||||
"""Basic engine class for training and evaluation. It runs a specific process method
|
||||
@@ -64,6 +65,11 @@ class Engine:
|
||||
self._ophook_list = ophook_list
|
||||
register_ophooks_recursively(self._model, self._ophook_list)
|
||||
|
||||
@property
|
||||
def ophooks(self):
|
||||
"""show current activated ophooks"""
|
||||
return self._ophook_list
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""Model attached to the engine"""
|
||||
@@ -79,6 +85,21 @@ class Engine:
|
||||
"""Criterion attached to the engine"""
|
||||
return self._criterion
|
||||
|
||||
def add_hook(self, ophook: Type[BaseOpHook]) -> None:
|
||||
"""add necessary hook"""
|
||||
# whether this hook exist
|
||||
for h in self._ophook_list:
|
||||
if type(h) == type(ophook):
|
||||
logger = get_dist_logger()
|
||||
logger.warning(f"duplicate hooks, at least two instance of {type(ophook)}")
|
||||
self._ophook_list.append(ophook)
|
||||
register_ophooks_recursively(self._model, self._ophook_list)
|
||||
|
||||
def remove_hook(self, ophook: Type[BaseOpHook]) -> None:
|
||||
"""remove hook"""
|
||||
logger = get_dist_logger()
|
||||
logger.warning(f"removing hooks is currently not supported")
|
||||
|
||||
def zero_grad(self):
|
||||
"""Set the gradient of parameters to zero
|
||||
"""
|
||||
@@ -150,4 +171,4 @@ class Engine:
|
||||
"""Sets the model to evaluation mode.
|
||||
"""
|
||||
self.training = False
|
||||
self._model.eval()
|
||||
self._model.eval()
|
@@ -1,12 +1,15 @@
|
||||
import json
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
import torch
|
||||
from colossalai.engine.ophooks import BaseOpHook
|
||||
from colossalai.registry import OPHOOKS
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
from typing import Union
|
||||
from colossalai.utils.memory_tracer import AsyncMemoryMonitor
|
||||
|
||||
import os
|
||||
import math
|
||||
|
||||
|
||||
@@ -103,12 +106,14 @@ class MemTracerOpHook(BaseOpHook):
|
||||
if self.valid_iter != 0 and self.valid_iter % self.refreshrate == 0:
|
||||
# output file info
|
||||
self._logger.info(f"dump a memory statistics as pickle to {self._data_prefix}-{self._rank}.pkl")
|
||||
self.save_results()
|
||||
home_dir = Path.home()
|
||||
with open (home_dir.joinpath(f".cache/colossal/mem-{self._rank}.pkl"), "wb") as f:
|
||||
pickle.dump(self.async_mem_monitor.state_dict, f)
|
||||
self._count += 1
|
||||
self._logger.debug(f"data file has been refreshed {self._count} times")
|
||||
# finish a iteration
|
||||
self._curiter += 1
|
||||
|
||||
def save_results(self):
|
||||
datafile = f"{self._data_prefix}-{self._rank}.pkl"
|
||||
self.async_mem_monitor.save(datafile)
|
||||
def save_results(self, data_file: Union[str, Path]):
|
||||
with open(data_file, "w") as f:
|
||||
f.write(json.dumps(self.async_mem_monitor.state_dict))
|
Reference in New Issue
Block a user