Migrated project

This commit is contained in:
zbian
2021-10-28 18:21:23 +02:00
parent 2ebaefc542
commit 404ecbdcc6
409 changed files with 35853 additions and 0 deletions

View File

@@ -0,0 +1,22 @@
from .activation_checkpoint import checkpoint
from .common import print_rank_0, sync_model_param_in_dp, is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
from .cuda import get_current_device, synchronize, empty_cache, set_to_cuda
from .memory import report_memory_usage
from .timer import MultiTimer, Timer
_GLOBAL_MULTI_TIMER = MultiTimer(on=False)
def get_global_multitimer():
return _GLOBAL_MULTI_TIMER
def set_global_multitimer_status(mode: bool):
_GLOBAL_MULTI_TIMER.set_status(mode)
__all__ = ['checkpoint', 'print_rank_0', 'sync_model_param_in_dp', 'get_current_device',
'synchronize', 'empty_cache', 'set_to_cuda', 'report_memory_usage', 'Timer', 'MultiTimer',
'get_global_multitimer', 'set_global_multitimer_status',
'is_dp_rank_0', 'is_tp_rank_0', 'is_no_pp_or_last_stage'
]

View File

@@ -0,0 +1,117 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from torch.utils.checkpoint import check_backward_validity, detach_variable
from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, *args):
check_backward_validity(args)
ctx.run_function = run_function
# preserve rng states
ctx.fwd_cpu_rng_state = torch.get_rng_state()
sync_states()
ctx.fwd_seed_states = get_states(copy=True)
ctx.fwd_current_mode = get_current_mode()
if hasattr(torch, 'is_autocast_enabled'):
ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
else:
ctx.had_autocast_in_fwd = False
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
# to be filled out during the backward.
ctx.inputs = []
ctx.tensor_indices = []
tensor_inputs = []
for i, arg in enumerate(args):
if torch.is_tensor(arg):
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)
ctx.save_for_backward(*tensor_inputs)
with torch.no_grad():
outputs = run_function(*args)
return outputs
@staticmethod
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
" argument.")
# Copy the list to avoid modifying original list.
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensors = ctx.saved_tensors
# store the current states
bwd_cpu_rng_state = torch.get_rng_state()
sync_states()
bwd_seed_states = get_states(copy=True)
bwd_current_mode = get_current_mode()
# set the states to what it used to be
torch.set_rng_state(ctx.fwd_cpu_rng_state)
for parallel_mode, state in ctx.fwd_seed_states.items():
set_seed_states(parallel_mode, state)
set_mode(ctx.fwd_current_mode)
# Fill in inputs with appropriate saved tensors.
for i, idx in enumerate(tensor_indices):
inputs[idx] = tensors[i]
detached_inputs = detach_variable(tuple(inputs))
if ctx.had_autocast_in_fwd:
with torch.enable_grad(), torch.cuda.amp.autocast():
outputs = ctx.run_function(*detached_inputs)
else:
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
# recover the rng states
torch.set_rng_state(bwd_cpu_rng_state)
for parallel_mode, state in bwd_seed_states.items():
set_seed_states(parallel_mode, state)
set_mode(bwd_current_mode)
# run backward() with only tensor that requires grad
outputs_with_grad = []
args_with_grad = []
for i in range(len(outputs)):
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
outputs_with_grad.append(outputs[i])
args_with_grad.append(args[i])
if len(outputs_with_grad) == 0:
raise RuntimeError(
"none of output has requires_grad=True,"
" this checkpoint() is not necessary")
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs)
return (None,) + grads
def checkpoint(function, *args):
'''Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint
:param function: describe the forward pass function. It should know how to handle the input tuples.
:param args: tuple containing inputs to the function
:return: Output of running function on \*args
'''
return CheckpointFunction.apply(function, *args)

View File

@@ -0,0 +1,42 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
def print_rank_0(msg: str, logger=None):
'''Print messages and save logs(optional). This is executed only if you are the rank-0 gpu.
:param msg: A str message to output
:param logger: python logger object, defaults to None
'''
if gpc.get_global_rank() == 0:
if logger is None:
print(msg, flush=True)
else:
logger.info(msg)
# print(msg, flush=True)
def sync_model_param_in_dp(model):
'''Make sure data parameters are consistent during Data Parallel Mode
:param model: A pyTorch nn.model on whose parameters you check the consistency
'''
if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 2:
for param in model.parameters():
ranks = gpc.get_ranks_in_group(ParallelMode.DATA)
dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA))
def is_dp_rank_0():
return not gpc.is_initialized(ParallelMode.DATA) or gpc.is_first_rank(ParallelMode.DATA)
def is_tp_rank_0():
return not gpc.is_initialized(ParallelMode.TENSOR) or gpc.is_first_rank(ParallelMode.TENSOR)
def is_no_pp_or_last_stage():
return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE)

48
colossalai/utils/cuda.py Normal file
View File

@@ -0,0 +1,48 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
def set_to_cuda(models):
'''Send model to gpu.
:param models: nn.module or a list of module
'''
if isinstance(models, list) and len(models) > 1:
ret = []
for model in models:
ret.append(model.to(get_current_device()))
return ret
elif isinstance(models, list):
return models[0].to(get_current_device())
else:
return models.to(get_current_device())
def get_current_device():
'''
Returns the index of a currently selected device (gpu/cpu).
'''
if torch.cuda.is_available():
return torch.cuda.current_device()
else:
return 'cpu'
def synchronize():
'''
Similar to cuda.synchronize().
Waits for all kernels in all streams on a CUDA device to complete.
'''
if torch.cuda.is_available():
torch.cuda.synchronize()
def empty_cache():
'''
Similar to cuda.empty_cache()
Releases all unoccupied cached memory currently held by the caching allocator.
'''
if torch.cuda.is_available():
torch.cuda.empty_cache()

View File

@@ -0,0 +1,49 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import gc
import psutil
import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
def bytes_to_GB(val, decimal=2):
'''A byte-to-Gigabyte converter, defaultly using binary notation.
:param val: X bytes to convert
:return: X' Gb
'''
return round(val / (1024 * 1024 * 1024), decimal)
def report_memory_usage(message):
'''Calculate and print RAM usage (in GB)
:param message: a prefix message to add in the log
:type message: str
:raises EnvironmentError: raise error if no distributed environment has been initialized
'''
if not gpc.is_initialized(ParallelMode.GLOBAL):
raise EnvironmentError("No distributed environment is initialized")
# python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports
gc.collect()
vm_stats = psutil.virtual_memory()
vm_used = bytes_to_GB(vm_stats.total - vm_stats.available)
gpu_allocated = bytes_to_GB(torch.cuda.memory_allocated())
gpu_max_allocated = bytes_to_GB(torch.cuda.max_memory_allocated())
gpu_cached = bytes_to_GB(torch.cuda.memory_cached())
gpu_max_cached = bytes_to_GB(torch.cuda.max_memory_cached())
get_global_dist_logger().info(
f"{message} - GPU: allocated {gpu_allocated}GB, max allocated {gpu_max_allocated}GB, cached: {gpu_cached} GB, "
f"max cached: {gpu_max_cached}GB, CPU Virtual Memory: used = {vm_used}GB, percent = {vm_stats.percent}%")
# get the peak memory to report correct data, so reset the counter for the next call
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
torch.cuda.reset_peak_memory_stats()

143
colossalai/utils/timer.py Normal file
View File

@@ -0,0 +1,143 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import time
from .cuda import synchronize
class Timer:
'''
A timer object which helps to log the execution times, and provides different tools to assess the times.
'''
def __init__(self):
self._started = False
self._start_time = time.time()
self._elapsed = 0
self._history = []
@property
def has_history(self):
return len(self._history) != 0
def start(self):
'''Fisrtly synchronize cuda, reset the clock and then start the timer.
'''
self._elapsed = 0
synchronize()
self._start_time = time.time()
self._started = True
def stop(self, keep_in_history: bool = False):
'''Stop the timer and record the start-stop time interval.
:param keep_in_history: whether does it record into history each start-stop interval, defaults to False
:type keep_in_history: bool, optional
:return: start-stop interval
:rtype: int
'''
synchronize()
end_time = time.time()
elapsed = end_time - self._start_time
if keep_in_history:
self._history.append(elapsed)
self._elapsed = elapsed
self._started = False
return elapsed
def get_history_mean(self):
'''mean of all history start-stop time intervals.
:return: mean of time intervals
:rtype: int
'''
return sum(self._history) / len(self._history)
def get_history_sum(self):
'''add up all the start-stop time intervals.
:return: sum of time intervals
:rtype: int
'''
return sum(self._history)
def get_elapsed_time(self):
'''return the last start-stop time interval. *use it only when timer is not in progress*
:return: the last time interval
:rtype: int
'''
assert not self._started, 'Timer is still in progress'
return self._elapsed
def reset(self):
'''clear up the timer and its history
'''
self._history = []
self._started = False
self._elapsed = 0
class MultiTimer:
'''An object contains multiple timers
'''
def __init__(self, on: bool = True):
self._on = on
self._timers = dict()
def start(self, name: str):
'''Start namely one of the timers
:param name: timer's key
:type name: str
'''
if self._on:
if name not in self._timers:
self._timers[name] = Timer()
return self._timers[name].start()
def stop(self, name: str, keep_in_history: bool):
'''Stop namely one of the timers.
:param name: timer's key
:param keep_in_history: whether does it record into history each start-stop interval
:type keep_in_history: bool
'''
if self._on:
return self._timers[name].stop(keep_in_history)
else:
return None
def get_timer(self, name):
'''Get timer by its name (from multitimer)
:param name: timer's key
:return: timer with the name you give correctly
:rtype: Timer
'''
return self._timers[name]
def reset(self, name=None):
'''Reset timers.
:param name: if name is designated, the named timer will be reset and others will not, defaults to None
'''
if self._on:
if name is not None:
self._timers[name].reset()
else:
for timer in self._timers:
timer.reset()
def is_on(self):
return self._on
def set_status(self, mode: bool):
self._on = mode
def __iter__(self):
for name, timer in self._timers.items():
yield name, timer