Files
ColossalAI/colossalai/engine/ophooks/_base_ophook.py
Jiarui Fang 569357fea0 add pytorch hooks (#179)
* add pytorch hooks
fix #175

* remove licenses in src code

* add gpu memory tracer

* replacing print with logger in ophooks.
2022-01-25 22:20:54 +08:00

30 lines
677 B
Python

from abc import ABC, abstractmethod
import torch
class BaseOpHook(ABC):
"""This class allows users to add customized operations
before and after the execution of a PyTorch submodule"""
def __init__(self):
pass
@abstractmethod
def pre_fwd_exec(self, module: torch.nn.Module, *args):
pass
@abstractmethod
def post_fwd_exec(self, module: torch.nn.Module, *args):
pass
@abstractmethod
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
pass
@abstractmethod
def post_bwd_exec(self, module: torch.nn.Module, input):
pass
@abstractmethod
def post_iter(self):
pass