[zero] refactor memstats collector (#706)

* refactor memstats collector

* fix disposable

* polish code
This commit is contained in:
ver217
2022-04-11 10:46:08 +08:00
committed by GitHub
parent 3fc8a204dc
commit ab8c6b4a0e
8 changed files with 44 additions and 114 deletions

View File

@@ -4,8 +4,8 @@ import os
import random
import socket
from pathlib import Path
from typing import List, Union
from typing import Callable, List, Union
import functools
import torch
from torch._six import inf
from torch.nn.parameter import Parameter
@@ -112,6 +112,7 @@ def conditional_context(context_manager, enable=True):
class model_branch_context(object):
def __enter__(self):
self.env_status = env.save()
@@ -131,7 +132,7 @@ def _calc_l2_norm(grads):
colossal_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads],
False # no per-parameter norm
False # no per-parameter norm
)
return norm
@@ -328,3 +329,16 @@ def switch_virtual_pipeline_parallel_rank(rank):
yield
finally:
gpc.set_virtual_pipeline_parallel_rank(prev_rank)
def disposable(func: Callable) -> Callable:
executed = False
@functools.wraps(func)
def wrapper(*args, **kwargs):
nonlocal executed
if not executed:
executed = True
return func(*args, **kwargs)
return wrapper