mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
[zero] refactor memstats collector (#706)
* refactor memstats collector * fix disposable * polish code
This commit is contained in:
@@ -4,8 +4,8 @@ import os
|
||||
import random
|
||||
import socket
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
from typing import Callable, List, Union
|
||||
import functools
|
||||
import torch
|
||||
from torch._six import inf
|
||||
from torch.nn.parameter import Parameter
|
||||
@@ -112,6 +112,7 @@ def conditional_context(context_manager, enable=True):
|
||||
|
||||
|
||||
class model_branch_context(object):
|
||||
|
||||
def __enter__(self):
|
||||
self.env_status = env.save()
|
||||
|
||||
@@ -131,7 +132,7 @@ def _calc_l2_norm(grads):
|
||||
colossal_C.multi_tensor_l2norm,
|
||||
dummy_overflow_buf,
|
||||
[grads],
|
||||
False # no per-parameter norm
|
||||
False # no per-parameter norm
|
||||
)
|
||||
return norm
|
||||
|
||||
@@ -328,3 +329,16 @@ def switch_virtual_pipeline_parallel_rank(rank):
|
||||
yield
|
||||
finally:
|
||||
gpc.set_virtual_pipeline_parallel_rank(prev_rank)
|
||||
|
||||
|
||||
def disposable(func: Callable) -> Callable:
|
||||
executed = False
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
nonlocal executed
|
||||
if not executed:
|
||||
executed = True
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
Reference in New Issue
Block a user