mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
Migrated project
This commit is contained in:
22
colossalai/utils/__init__.py
Normal file
22
colossalai/utils/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from .activation_checkpoint import checkpoint
|
||||
from .common import print_rank_0, sync_model_param_in_dp, is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
|
||||
from .cuda import get_current_device, synchronize, empty_cache, set_to_cuda
|
||||
from .memory import report_memory_usage
|
||||
from .timer import MultiTimer, Timer
|
||||
|
||||
_GLOBAL_MULTI_TIMER = MultiTimer(on=False)
|
||||
|
||||
|
||||
def get_global_multitimer():
|
||||
return _GLOBAL_MULTI_TIMER
|
||||
|
||||
|
||||
def set_global_multitimer_status(mode: bool):
|
||||
_GLOBAL_MULTI_TIMER.set_status(mode)
|
||||
|
||||
|
||||
__all__ = ['checkpoint', 'print_rank_0', 'sync_model_param_in_dp', 'get_current_device',
|
||||
'synchronize', 'empty_cache', 'set_to_cuda', 'report_memory_usage', 'Timer', 'MultiTimer',
|
||||
'get_global_multitimer', 'set_global_multitimer_status',
|
||||
'is_dp_rank_0', 'is_tp_rank_0', 'is_no_pp_or_last_stage'
|
||||
]
|
117
colossalai/utils/activation_checkpoint.py
Normal file
117
colossalai/utils/activation_checkpoint.py
Normal file
@@ -0,0 +1,117 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from torch.utils.checkpoint import check_backward_validity, detach_variable
|
||||
|
||||
from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, *args):
|
||||
check_backward_validity(args)
|
||||
ctx.run_function = run_function
|
||||
|
||||
# preserve rng states
|
||||
ctx.fwd_cpu_rng_state = torch.get_rng_state()
|
||||
sync_states()
|
||||
ctx.fwd_seed_states = get_states(copy=True)
|
||||
ctx.fwd_current_mode = get_current_mode()
|
||||
|
||||
if hasattr(torch, 'is_autocast_enabled'):
|
||||
ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
|
||||
else:
|
||||
ctx.had_autocast_in_fwd = False
|
||||
|
||||
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
|
||||
# to be filled out during the backward.
|
||||
ctx.inputs = []
|
||||
ctx.tensor_indices = []
|
||||
tensor_inputs = []
|
||||
for i, arg in enumerate(args):
|
||||
if torch.is_tensor(arg):
|
||||
tensor_inputs.append(arg)
|
||||
ctx.tensor_indices.append(i)
|
||||
ctx.inputs.append(None)
|
||||
else:
|
||||
ctx.inputs.append(arg)
|
||||
|
||||
ctx.save_for_backward(*tensor_inputs)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = run_function(*args)
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
if not torch.autograd._is_checkpoint_valid():
|
||||
raise RuntimeError(
|
||||
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
|
||||
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
|
||||
" argument.")
|
||||
# Copy the list to avoid modifying original list.
|
||||
inputs = list(ctx.inputs)
|
||||
tensor_indices = ctx.tensor_indices
|
||||
tensors = ctx.saved_tensors
|
||||
|
||||
# store the current states
|
||||
bwd_cpu_rng_state = torch.get_rng_state()
|
||||
sync_states()
|
||||
bwd_seed_states = get_states(copy=True)
|
||||
bwd_current_mode = get_current_mode()
|
||||
|
||||
# set the states to what it used to be
|
||||
torch.set_rng_state(ctx.fwd_cpu_rng_state)
|
||||
for parallel_mode, state in ctx.fwd_seed_states.items():
|
||||
set_seed_states(parallel_mode, state)
|
||||
set_mode(ctx.fwd_current_mode)
|
||||
|
||||
# Fill in inputs with appropriate saved tensors.
|
||||
for i, idx in enumerate(tensor_indices):
|
||||
inputs[idx] = tensors[i]
|
||||
|
||||
detached_inputs = detach_variable(tuple(inputs))
|
||||
if ctx.had_autocast_in_fwd:
|
||||
with torch.enable_grad(), torch.cuda.amp.autocast():
|
||||
outputs = ctx.run_function(*detached_inputs)
|
||||
else:
|
||||
with torch.enable_grad():
|
||||
outputs = ctx.run_function(*detached_inputs)
|
||||
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
outputs = (outputs,)
|
||||
|
||||
# recover the rng states
|
||||
torch.set_rng_state(bwd_cpu_rng_state)
|
||||
for parallel_mode, state in bwd_seed_states.items():
|
||||
set_seed_states(parallel_mode, state)
|
||||
set_mode(bwd_current_mode)
|
||||
|
||||
# run backward() with only tensor that requires grad
|
||||
outputs_with_grad = []
|
||||
args_with_grad = []
|
||||
for i in range(len(outputs)):
|
||||
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
|
||||
outputs_with_grad.append(outputs[i])
|
||||
args_with_grad.append(args[i])
|
||||
if len(outputs_with_grad) == 0:
|
||||
raise RuntimeError(
|
||||
"none of output has requires_grad=True,"
|
||||
" this checkpoint() is not necessary")
|
||||
torch.autograd.backward(outputs_with_grad, args_with_grad)
|
||||
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
|
||||
for inp in detached_inputs)
|
||||
|
||||
return (None,) + grads
|
||||
|
||||
|
||||
def checkpoint(function, *args):
|
||||
'''Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint
|
||||
|
||||
:param function: describe the forward pass function. It should know how to handle the input tuples.
|
||||
:param args: tuple containing inputs to the function
|
||||
:return: Output of running function on \*args
|
||||
'''
|
||||
return CheckpointFunction.apply(function, *args)
|
42
colossalai/utils/common.py
Normal file
42
colossalai/utils/common.py
Normal file
@@ -0,0 +1,42 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
|
||||
def print_rank_0(msg: str, logger=None):
|
||||
'''Print messages and save logs(optional). This is executed only if you are the rank-0 gpu.
|
||||
|
||||
:param msg: A str message to output
|
||||
:param logger: python logger object, defaults to None
|
||||
'''
|
||||
if gpc.get_global_rank() == 0:
|
||||
if logger is None:
|
||||
print(msg, flush=True)
|
||||
else:
|
||||
logger.info(msg)
|
||||
# print(msg, flush=True)
|
||||
|
||||
|
||||
def sync_model_param_in_dp(model):
|
||||
'''Make sure data parameters are consistent during Data Parallel Mode
|
||||
|
||||
:param model: A pyTorch nn.model on whose parameters you check the consistency
|
||||
'''
|
||||
|
||||
if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 2:
|
||||
for param in model.parameters():
|
||||
ranks = gpc.get_ranks_in_group(ParallelMode.DATA)
|
||||
dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA))
|
||||
|
||||
def is_dp_rank_0():
|
||||
return not gpc.is_initialized(ParallelMode.DATA) or gpc.is_first_rank(ParallelMode.DATA)
|
||||
|
||||
def is_tp_rank_0():
|
||||
return not gpc.is_initialized(ParallelMode.TENSOR) or gpc.is_first_rank(ParallelMode.TENSOR)
|
||||
|
||||
def is_no_pp_or_last_stage():
|
||||
return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE)
|
48
colossalai/utils/cuda.py
Normal file
48
colossalai/utils/cuda.py
Normal file
@@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def set_to_cuda(models):
|
||||
'''Send model to gpu.
|
||||
|
||||
:param models: nn.module or a list of module
|
||||
'''
|
||||
if isinstance(models, list) and len(models) > 1:
|
||||
ret = []
|
||||
for model in models:
|
||||
ret.append(model.to(get_current_device()))
|
||||
return ret
|
||||
elif isinstance(models, list):
|
||||
return models[0].to(get_current_device())
|
||||
else:
|
||||
return models.to(get_current_device())
|
||||
|
||||
|
||||
def get_current_device():
|
||||
'''
|
||||
Returns the index of a currently selected device (gpu/cpu).
|
||||
'''
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.current_device()
|
||||
else:
|
||||
return 'cpu'
|
||||
|
||||
|
||||
def synchronize():
|
||||
'''
|
||||
Similar to cuda.synchronize().
|
||||
Waits for all kernels in all streams on a CUDA device to complete.
|
||||
'''
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def empty_cache():
|
||||
'''
|
||||
Similar to cuda.empty_cache()
|
||||
Releases all unoccupied cached memory currently held by the caching allocator.
|
||||
'''
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
49
colossalai/utils/memory.py
Normal file
49
colossalai/utils/memory.py
Normal file
@@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import gc
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_global_dist_logger
|
||||
|
||||
|
||||
def bytes_to_GB(val, decimal=2):
|
||||
'''A byte-to-Gigabyte converter, defaultly using binary notation.
|
||||
|
||||
:param val: X bytes to convert
|
||||
:return: X' Gb
|
||||
'''
|
||||
return round(val / (1024 * 1024 * 1024), decimal)
|
||||
|
||||
|
||||
def report_memory_usage(message):
|
||||
'''Calculate and print RAM usage (in GB)
|
||||
|
||||
:param message: a prefix message to add in the log
|
||||
:type message: str
|
||||
:raises EnvironmentError: raise error if no distributed environment has been initialized
|
||||
'''
|
||||
if not gpc.is_initialized(ParallelMode.GLOBAL):
|
||||
raise EnvironmentError("No distributed environment is initialized")
|
||||
|
||||
# python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports
|
||||
gc.collect()
|
||||
vm_stats = psutil.virtual_memory()
|
||||
vm_used = bytes_to_GB(vm_stats.total - vm_stats.available)
|
||||
|
||||
gpu_allocated = bytes_to_GB(torch.cuda.memory_allocated())
|
||||
gpu_max_allocated = bytes_to_GB(torch.cuda.max_memory_allocated())
|
||||
gpu_cached = bytes_to_GB(torch.cuda.memory_cached())
|
||||
gpu_max_cached = bytes_to_GB(torch.cuda.max_memory_cached())
|
||||
|
||||
get_global_dist_logger().info(
|
||||
f"{message} - GPU: allocated {gpu_allocated}GB, max allocated {gpu_max_allocated}GB, cached: {gpu_cached} GB, "
|
||||
f"max cached: {gpu_max_cached}GB, CPU Virtual Memory: used = {vm_used}GB, percent = {vm_stats.percent}%")
|
||||
|
||||
# get the peak memory to report correct data, so reset the counter for the next call
|
||||
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
|
||||
torch.cuda.reset_peak_memory_stats()
|
143
colossalai/utils/timer.py
Normal file
143
colossalai/utils/timer.py
Normal file
@@ -0,0 +1,143 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import time
|
||||
|
||||
from .cuda import synchronize
|
||||
|
||||
|
||||
class Timer:
|
||||
'''
|
||||
A timer object which helps to log the execution times, and provides different tools to assess the times.
|
||||
'''
|
||||
|
||||
def __init__(self):
|
||||
self._started = False
|
||||
self._start_time = time.time()
|
||||
self._elapsed = 0
|
||||
self._history = []
|
||||
|
||||
@property
|
||||
def has_history(self):
|
||||
return len(self._history) != 0
|
||||
|
||||
def start(self):
|
||||
'''Fisrtly synchronize cuda, reset the clock and then start the timer.
|
||||
'''
|
||||
self._elapsed = 0
|
||||
synchronize()
|
||||
self._start_time = time.time()
|
||||
self._started = True
|
||||
|
||||
def stop(self, keep_in_history: bool = False):
|
||||
'''Stop the timer and record the start-stop time interval.
|
||||
|
||||
:param keep_in_history: whether does it record into history each start-stop interval, defaults to False
|
||||
:type keep_in_history: bool, optional
|
||||
:return: start-stop interval
|
||||
:rtype: int
|
||||
'''
|
||||
synchronize()
|
||||
end_time = time.time()
|
||||
elapsed = end_time - self._start_time
|
||||
if keep_in_history:
|
||||
self._history.append(elapsed)
|
||||
self._elapsed = elapsed
|
||||
self._started = False
|
||||
return elapsed
|
||||
|
||||
def get_history_mean(self):
|
||||
'''mean of all history start-stop time intervals.
|
||||
|
||||
:return: mean of time intervals
|
||||
:rtype: int
|
||||
'''
|
||||
return sum(self._history) / len(self._history)
|
||||
|
||||
def get_history_sum(self):
|
||||
'''add up all the start-stop time intervals.
|
||||
|
||||
:return: sum of time intervals
|
||||
:rtype: int
|
||||
'''
|
||||
return sum(self._history)
|
||||
|
||||
def get_elapsed_time(self):
|
||||
'''return the last start-stop time interval. *use it only when timer is not in progress*
|
||||
|
||||
:return: the last time interval
|
||||
:rtype: int
|
||||
'''
|
||||
assert not self._started, 'Timer is still in progress'
|
||||
return self._elapsed
|
||||
|
||||
def reset(self):
|
||||
'''clear up the timer and its history
|
||||
'''
|
||||
self._history = []
|
||||
self._started = False
|
||||
self._elapsed = 0
|
||||
|
||||
|
||||
class MultiTimer:
|
||||
'''An object contains multiple timers
|
||||
'''
|
||||
|
||||
def __init__(self, on: bool = True):
|
||||
self._on = on
|
||||
self._timers = dict()
|
||||
|
||||
def start(self, name: str):
|
||||
'''Start namely one of the timers
|
||||
|
||||
:param name: timer's key
|
||||
:type name: str
|
||||
'''
|
||||
if self._on:
|
||||
if name not in self._timers:
|
||||
self._timers[name] = Timer()
|
||||
return self._timers[name].start()
|
||||
|
||||
def stop(self, name: str, keep_in_history: bool):
|
||||
'''Stop namely one of the timers.
|
||||
|
||||
:param name: timer's key
|
||||
:param keep_in_history: whether does it record into history each start-stop interval
|
||||
:type keep_in_history: bool
|
||||
'''
|
||||
if self._on:
|
||||
return self._timers[name].stop(keep_in_history)
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_timer(self, name):
|
||||
'''Get timer by its name (from multitimer)
|
||||
|
||||
:param name: timer's key
|
||||
:return: timer with the name you give correctly
|
||||
:rtype: Timer
|
||||
'''
|
||||
return self._timers[name]
|
||||
|
||||
def reset(self, name=None):
|
||||
'''Reset timers.
|
||||
|
||||
:param name: if name is designated, the named timer will be reset and others will not, defaults to None
|
||||
'''
|
||||
if self._on:
|
||||
if name is not None:
|
||||
self._timers[name].reset()
|
||||
else:
|
||||
for timer in self._timers:
|
||||
timer.reset()
|
||||
|
||||
def is_on(self):
|
||||
|
||||
return self._on
|
||||
|
||||
def set_status(self, mode: bool):
|
||||
self._on = mode
|
||||
|
||||
def __iter__(self):
|
||||
for name, timer in self._timers.items():
|
||||
yield name, timer
|
Reference in New Issue
Block a user