[Gemini] add bert for MemtracerWrapper unintests (#1982)

This commit is contained in:
Jiarui Fang
2022-11-18 14:58:28 +08:00
committed by GitHub
parent e481489aa6
commit 3712ac7f90
4 changed files with 32 additions and 11 deletions

View File

@@ -28,6 +28,9 @@ class _Wrapper():
def show_mem_stats(self):
self._ophook_list[0].show_mem_stats()
def named_buffers(self):
return self._model.named_buffers()
def MemtracerWrapper(model):
ophook_list = [MemTracerOpHook()]

View File

@@ -7,6 +7,7 @@ from colossalai.gemini.ophooks import BaseOpHook
class MemTracerOpHook(BaseOpHook):
"""
TODO() what if parameters are sharded by multiple submodules.
register buff on its father node
"""
def __init__(self):