mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[legacy] clean up legacy code (#4743)
* [legacy] remove outdated codes of pipeline (#4692) * [legacy] remove cli of benchmark and update optim (#4690) * [legacy] remove cli of benchmark and update optim * [doc] fix cli doc test * [legacy] fix engine clip grad norm * [legacy] remove outdated colo tensor (#4694) * [legacy] remove outdated colo tensor * [test] fix test import * [legacy] move outdated zero to legacy (#4696) * [legacy] clean up utils (#4700) * [legacy] clean up utils * [example] update examples * [legacy] clean up amp * [legacy] fix amp module * [legacy] clean up gpc (#4742) * [legacy] clean up context * [legacy] clean core, constants and global vars * [legacy] refactor initialize * [example] fix examples ci * [example] fix examples ci * [legacy] fix tests * [example] fix gpt example * [example] fix examples ci * [devops] fix ci installation * [example] fix examples ci
This commit is contained in:
53
colossalai/legacy/utils/__init__.py
Normal file
53
colossalai/legacy/utils/__init__.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from .checkpointing import load_checkpoint, save_checkpoint
|
||||
from .common import (
|
||||
clip_grad_norm_fp32,
|
||||
copy_tensor_parallel_attributes,
|
||||
count_zeros_fp32,
|
||||
is_dp_rank_0,
|
||||
is_model_parallel_parameter,
|
||||
is_no_pp_or_last_stage,
|
||||
is_tp_rank_0,
|
||||
is_using_ddp,
|
||||
is_using_pp,
|
||||
is_using_sequence,
|
||||
param_is_not_tensor_parallel_duplicate,
|
||||
print_rank_0,
|
||||
switch_virtual_pipeline_parallel_rank,
|
||||
sync_model_param,
|
||||
)
|
||||
from .data_sampler import DataParallelSampler, get_dataloader
|
||||
from .memory import (
|
||||
colo_device_memory_capacity,
|
||||
colo_device_memory_used,
|
||||
colo_get_cpu_memory_capacity,
|
||||
colo_set_cpu_memory_capacity,
|
||||
colo_set_process_memory_fraction,
|
||||
report_memory_usage,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'DataParallelSampler',
|
||||
'get_dataloader',
|
||||
'save_checkpoint',
|
||||
'load_checkpoint',
|
||||
'colo_device_memory_capacity',
|
||||
'colo_device_memory_used',
|
||||
'colo_get_cpu_memory_capacity',
|
||||
'colo_set_cpu_memory_capacity',
|
||||
'colo_set_process_memory_fraction',
|
||||
'report_memory_usage',
|
||||
'clip_grad_norm_fp32',
|
||||
'copy_tensor_parallel_attributes',
|
||||
'count_zeros_fp32',
|
||||
'is_dp_rank_0',
|
||||
'is_model_parallel_parameter',
|
||||
'is_no_pp_or_last_stage',
|
||||
'is_tp_rank_0',
|
||||
'is_using_ddp',
|
||||
'is_using_pp',
|
||||
'is_using_sequence',
|
||||
'param_is_not_tensor_parallel_duplicate',
|
||||
'print_rank_0',
|
||||
'switch_virtual_pipeline_parallel_rank',
|
||||
'sync_model_param',
|
||||
]
|
259
colossalai/legacy/utils/activation_checkpoint.py
Normal file
259
colossalai/legacy/utils/activation_checkpoint.py
Normal file
@@ -0,0 +1,259 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import weakref
|
||||
|
||||
import torch
|
||||
from torch.utils.checkpoint import check_backward_validity, detach_variable
|
||||
|
||||
from colossalai.legacy.context.random import get_current_mode, get_states, set_mode, set_seed_states, sync_states
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def copy_to_device(obj, device):
|
||||
if torch.is_tensor(obj):
|
||||
# Notice:
|
||||
# When in no_grad context, requires_gard is False after movement
|
||||
ret = obj.to(device).detach()
|
||||
ret.requires_grad = obj.requires_grad
|
||||
return ret
|
||||
elif isinstance(obj, list):
|
||||
return [copy_to_device(i, device) for i in obj]
|
||||
elif isinstance(obj, tuple):
|
||||
return tuple([copy_to_device(v, device) for v in obj])
|
||||
elif isinstance(obj, dict):
|
||||
return {k: copy_to_device(v, device) for k, v in obj.items()}
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, activation_offload=False, *args):
|
||||
check_backward_validity(args)
|
||||
ctx.run_function = run_function
|
||||
ctx.activation_offload = activation_offload
|
||||
ctx.device = get_current_device()
|
||||
|
||||
# preserve rng states
|
||||
ctx.fwd_cpu_rng_state = torch.get_rng_state()
|
||||
sync_states()
|
||||
ctx.fwd_seed_states = get_states(copy=True)
|
||||
ctx.fwd_current_mode = get_current_mode()
|
||||
|
||||
if hasattr(torch, 'is_autocast_enabled'):
|
||||
ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
|
||||
else:
|
||||
ctx.had_autocast_in_fwd = False
|
||||
|
||||
if activation_offload:
|
||||
inputs_cuda = copy_to_device(args, ctx.device)
|
||||
else:
|
||||
inputs_cuda = args
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = run_function(*inputs_cuda)
|
||||
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
|
||||
# to be filled out during the backward.
|
||||
ctx.inputs = []
|
||||
ctx.tensor_indices = []
|
||||
tensor_inputs = []
|
||||
for i, arg in enumerate(args):
|
||||
if torch.is_tensor(arg):
|
||||
if activation_offload:
|
||||
tensor_inputs.append(copy_to_device(arg, 'cpu'))
|
||||
else:
|
||||
tensor_inputs.append(arg)
|
||||
ctx.tensor_indices.append(i)
|
||||
ctx.inputs.append(None)
|
||||
else:
|
||||
ctx.inputs.append(arg)
|
||||
|
||||
if activation_offload:
|
||||
ctx.tensor_inputs = tensor_inputs
|
||||
else:
|
||||
ctx.save_for_backward(*tensor_inputs)
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
if not torch.autograd._is_checkpoint_valid():
|
||||
raise RuntimeError("Checkpointing is not compatible with .grad() or when an `inputs` parameter is "
|
||||
"passed to .backward(). Please use .backward() and do not pass its `inputs` argument.")
|
||||
# Copy the list to avoid modifying original list.
|
||||
inputs = list(ctx.inputs)
|
||||
tensor_indices = ctx.tensor_indices
|
||||
|
||||
if ctx.activation_offload:
|
||||
tensors = ctx.tensor_inputs
|
||||
else:
|
||||
tensors = ctx.saved_tensors
|
||||
|
||||
# store the current states
|
||||
bwd_cpu_rng_state = torch.get_rng_state()
|
||||
sync_states()
|
||||
bwd_seed_states = get_states(copy=True)
|
||||
bwd_current_mode = get_current_mode()
|
||||
|
||||
# set the states to what it used to be
|
||||
torch.set_rng_state(ctx.fwd_cpu_rng_state)
|
||||
for parallel_mode, state in ctx.fwd_seed_states.items():
|
||||
set_seed_states(parallel_mode, state)
|
||||
set_mode(ctx.fwd_current_mode)
|
||||
if ctx.activation_offload:
|
||||
tensors = copy_to_device(tensors, ctx.device)
|
||||
|
||||
# Fill in inputs with appropriate saved tensors.
|
||||
for i, idx in enumerate(tensor_indices):
|
||||
inputs[idx] = tensors[i]
|
||||
detached_inputs = detach_variable(tuple(inputs))
|
||||
if ctx.had_autocast_in_fwd:
|
||||
with torch.enable_grad(), torch.cuda.amp.autocast():
|
||||
outputs = ctx.run_function(*detached_inputs)
|
||||
else:
|
||||
with torch.enable_grad():
|
||||
outputs = ctx.run_function(*detached_inputs)
|
||||
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
outputs = (outputs,)
|
||||
# recover the rng states
|
||||
torch.set_rng_state(bwd_cpu_rng_state)
|
||||
for parallel_mode, state in bwd_seed_states.items():
|
||||
set_seed_states(parallel_mode, state)
|
||||
set_mode(bwd_current_mode)
|
||||
|
||||
# run backward() with only tensor that requires grad
|
||||
outputs_with_grad = []
|
||||
args_with_grad = []
|
||||
for i in range(len(outputs)):
|
||||
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
|
||||
outputs_with_grad.append(outputs[i])
|
||||
args_with_grad.append(args[i])
|
||||
if len(outputs_with_grad) == 0:
|
||||
raise RuntimeError("none of output has requires_grad=True,"
|
||||
" this checkpoint() is not necessary")
|
||||
torch.autograd.backward(outputs_with_grad, args_with_grad)
|
||||
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs)
|
||||
return (None, None) + grads
|
||||
|
||||
|
||||
def checkpoint(function, activation_offload, *args, use_reentrant: bool = True):
|
||||
"""Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint.
|
||||
|
||||
Args:
|
||||
function: Describe the forward pass function. It should know how to handle the input tuples.
|
||||
activation_offload: The variable to check whether we should offload activation to cpu
|
||||
args (list): Tuple containing the parameters of the function
|
||||
use_reentrant: Bool type to check if we need to use_reentrant, if use_reentrant=False, there
|
||||
might be more flexibility for user to define there checkpoint function
|
||||
|
||||
Returns:
|
||||
Output of running function with provided args.
|
||||
"""
|
||||
if use_reentrant:
|
||||
return CheckpointFunction.apply(function, activation_offload, *args)
|
||||
else:
|
||||
return _checkpoint_without_reentrant(
|
||||
function,
|
||||
activation_offload,
|
||||
*args,
|
||||
)
|
||||
|
||||
|
||||
def _checkpoint_without_reentrant(function, activation_offload=False, *args):
|
||||
# store rng_state
|
||||
fwd_cpu_state = torch.get_rng_state()
|
||||
sync_states()
|
||||
fwd_seed_states = get_states(copy=True)
|
||||
fwd_current_mode = get_current_mode()
|
||||
|
||||
# check if use autocast
|
||||
if hasattr(torch, 'is_autocast_enabled'):
|
||||
has_autocast_in_fwd = torch.is_autocast_enabled()
|
||||
else:
|
||||
has_autocast_in_fwd = False
|
||||
|
||||
# using WeakKeyDictionary to store all the activation the first time we call unpack
|
||||
storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
|
||||
weak_holder_list = []
|
||||
|
||||
# class for weakref.ref
|
||||
class Holder():
|
||||
pass
|
||||
|
||||
# return a Holder object for later unpack process
|
||||
def pack(x):
|
||||
res = Holder()
|
||||
weak_holder_list.append(weakref.ref(res))
|
||||
return res
|
||||
|
||||
# unpack hook
|
||||
def unpack(x):
|
||||
unpack_counter = 0
|
||||
|
||||
# re-compute all the activation inside the function when we first call unpack
|
||||
if len(storage) == 0:
|
||||
|
||||
def inner_pack(inner):
|
||||
nonlocal unpack_counter
|
||||
unpack_counter += 1
|
||||
|
||||
# If the holder went out of scope, the SavedVariable is dead and so
|
||||
# the value will never be read from the storage. Skip filling it.
|
||||
if weak_holder_list[unpack_counter - 1]() is None:
|
||||
return
|
||||
|
||||
# Use detach here to ensure we don't keep the temporary autograd
|
||||
# graph created during the second forward
|
||||
storage[weak_holder_list[unpack_counter - 1]()] = inner.detach()
|
||||
return
|
||||
|
||||
def inner_unpack(packed):
|
||||
raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.")
|
||||
|
||||
# restore rng state
|
||||
torch.set_rng_state(fwd_cpu_state)
|
||||
for parallel_mode, state in fwd_seed_states.items():
|
||||
set_seed_states(parallel_mode, state)
|
||||
set_mode(fwd_current_mode)
|
||||
|
||||
# reload arg into device if needed
|
||||
if activation_offload:
|
||||
for arg in args:
|
||||
if torch.is_tensor(arg):
|
||||
arg = arg.to(device=device)
|
||||
|
||||
# rerun forward, the inner_pack will store all the activations in storage
|
||||
if has_autocast_in_fwd:
|
||||
with torch.enable_grad(), \
|
||||
torch.cuda.amp.autocast(), \
|
||||
torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
|
||||
_unused = function(*args)
|
||||
else:
|
||||
with torch.enable_grad(), \
|
||||
torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
|
||||
_unused = function(*args)
|
||||
|
||||
if x not in storage:
|
||||
raise RuntimeError("Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
|
||||
" recomputation being triggered in between, this is not currently supported. Please"
|
||||
" open an issue with details on your use case so that we can prioritize adding this.")
|
||||
|
||||
return storage[x]
|
||||
|
||||
# get device if we need to offload the activation
|
||||
if activation_offload:
|
||||
device = get_current_device()
|
||||
|
||||
# run function with pack and unpack as saved_tensors_hooks
|
||||
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
|
||||
output = function(*args)
|
||||
|
||||
# offload activation if needed
|
||||
if activation_offload:
|
||||
for arg in args:
|
||||
if torch.is_tensor(arg):
|
||||
arg = arg.to(device="cpu")
|
||||
|
||||
return output
|
3
colossalai/legacy/utils/checkpoint/__init__.py
Normal file
3
colossalai/legacy/utils/checkpoint/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .module_checkpoint import load_checkpoint, save_checkpoint
|
||||
|
||||
__all__ = ['save_checkpoint', 'load_checkpoint']
|
140
colossalai/legacy/utils/checkpoint/module_checkpoint.py
Normal file
140
colossalai/legacy/utils/checkpoint/module_checkpoint.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.tensor import ColoTensor
|
||||
|
||||
from .utils import gather_tensor, scatter_tensor
|
||||
|
||||
|
||||
def save_checkpoint(path: str,
|
||||
epoch: int,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""save_checkpoint
|
||||
save a model, whose parameters are `ColoTensor`s.
|
||||
Args:
|
||||
path (str): directory to save the checkpoint files.
|
||||
epoch (int): the number of epoch
|
||||
model (torch.nn.Module): a torch module initialized by ColoInitContext
|
||||
optimizer (OptimizerWrapper, optional): optimizers. Defaults to None.
|
||||
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
|
||||
"""
|
||||
rank = dist.get_rank()
|
||||
model_state = model.state_dict()
|
||||
# save the dist context about the tensors in a new dict, while still maintain the original dict.
|
||||
for k, v in model_state.items():
|
||||
if isinstance(v, ColoTensor):
|
||||
gather_tensor(v) # gather shared tensors to rank0
|
||||
# don't recover tensors in rank0, since the dict is only a copy of model
|
||||
|
||||
if rank == 0:
|
||||
# sanity check
|
||||
for k, v in model_state.items():
|
||||
if isinstance(v, ColoTensor):
|
||||
assert v.save_ready
|
||||
assert v.is_replicate()
|
||||
delattr(v, 'save_ready')
|
||||
# model saving
|
||||
save_state = {'epoch': epoch, 'model': model_state}
|
||||
torch.save(save_state, path + '/epoch_{}_model.pth'.format(epoch), *args, **kwargs)
|
||||
|
||||
# delete old dicts
|
||||
del model_state
|
||||
# synchronize all the processes
|
||||
dist.barrier()
|
||||
|
||||
if optimizer is not None:
|
||||
mapping = dict()
|
||||
optim_state = optimizer.state_dict()
|
||||
for k, v in optim_state['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
mapping[(k, n)] = t.dist_spec
|
||||
gather_tensor(t)
|
||||
|
||||
if rank == 0:
|
||||
save_state = {'epoch': epoch, 'optim': optim_state}
|
||||
torch.save(save_state, path + '/epoch_{}_optim.pth'.format(epoch), *args, **kwargs)
|
||||
# recover colo tensors in rank0
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
assert hasattr(t, 'save_ready')
|
||||
t.set_dist_spec(mapping[(k, n)])
|
||||
delattr(t, 'save_ready')
|
||||
|
||||
del optim_state
|
||||
del mapping
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def load_checkpoint(path: str,
|
||||
epoch: int,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
torch_load_kwargs: Optional[Dict] = None,
|
||||
load_state_dict_kwargs: Optional[Dict] = None):
|
||||
"""load_checkpoint
|
||||
load a model, whose parameters are `ColoTensor`s.
|
||||
Args:
|
||||
path (str): directory to save the checkpoint files.
|
||||
epoch (int): the number of epoch
|
||||
model (torch.nn.Module): a torch module initialized by ColoInitContext
|
||||
optimizer (OptimizerWrapper, optional): optimizers. Defaults to None.
|
||||
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
|
||||
torch_load_kwargs: (dict, optional): The kwargs of torch.load inside the function
|
||||
load_state_dict_kwargs (dict, optional): The kwargs of load_state_dict inside the function
|
||||
"""
|
||||
# initialize the default parameters
|
||||
if not torch_load_kwargs:
|
||||
torch_load_kwargs = dict()
|
||||
if not load_state_dict_kwargs:
|
||||
load_state_dict_kwargs = dict()
|
||||
|
||||
rank = dist.get_rank()
|
||||
mapping = dict()
|
||||
for n, p in model.named_parameters():
|
||||
if isinstance(p, ColoTensor):
|
||||
mapping[n] = p.dist_spec
|
||||
gather_tensor(p)
|
||||
|
||||
if rank == 0:
|
||||
load_state = torch.load(path + '/epoch_{}_model.pth'.format(epoch), **torch_load_kwargs)
|
||||
model.load_state_dict(load_state['model'], **load_state_dict_kwargs)
|
||||
dist.barrier()
|
||||
|
||||
# scatter loaded parameters
|
||||
for n, p in model.named_parameters():
|
||||
if isinstance(p, ColoTensor):
|
||||
scatter_tensor(p, mapping[n])
|
||||
if rank == 0:
|
||||
assert hasattr(p, 'save_ready')
|
||||
delattr(p, 'save_ready')
|
||||
del mapping
|
||||
|
||||
if optimizer is not None:
|
||||
mapping = dict()
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
mapping[(k, n)] = t.dist_spec
|
||||
gather_tensor(t)
|
||||
|
||||
if rank == 0:
|
||||
colo_checkpoint = torch.load(path + '/epoch_{}_optim.pth'.format(epoch), **torch_load_kwargs)
|
||||
optimizer.load_state_dict(colo_checkpoint['optim'], **load_state_dict_kwargs)
|
||||
dist.barrier()
|
||||
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
scatter_tensor(t, mapping[(k, n)])
|
||||
|
||||
del mapping
|
65
colossalai/legacy/utils/checkpoint/utils.py
Normal file
65
colossalai/legacy/utils/checkpoint/utils.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.legacy.tensor import ColoTensorSpec
|
||||
from colossalai.legacy.tensor.distspec import DistPlacementPattern, _DistSpec
|
||||
from colossalai.tensor import ColoTensor
|
||||
|
||||
|
||||
def robust_broadcast(tensor):
|
||||
with torch.no_grad():
|
||||
is_cpu_ten = tensor.device.type == 'cpu'
|
||||
if is_cpu_ten:
|
||||
b_data = tensor.cuda()
|
||||
else:
|
||||
b_data = tensor
|
||||
|
||||
dist.broadcast(b_data, 0)
|
||||
|
||||
if is_cpu_ten:
|
||||
tensor.copy_(b_data)
|
||||
|
||||
|
||||
def gather_tensor(colo_tensor: ColoTensor) -> None:
|
||||
"""Make colo_tensor replicated when the rank is 0
|
||||
"""
|
||||
if not colo_tensor.is_replicate():
|
||||
pg = colo_tensor.get_process_group()
|
||||
# for the group which contains rank 0
|
||||
if pg.dp_local_rank() == 0:
|
||||
old_dist_spec = colo_tensor.dist_spec
|
||||
colo_tensor.to_replicate_()
|
||||
if dist.get_rank() != 0:
|
||||
colo_tensor.set_dist_spec(old_dist_spec)
|
||||
|
||||
# synchronize all processes for unexpected problems
|
||||
dist.barrier()
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
setattr(colo_tensor, 'save_ready', True) # set saving signature
|
||||
|
||||
|
||||
def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
|
||||
"""Reversal operation of `gather_tensor`.
|
||||
"""
|
||||
if dist_spec.placement == DistPlacementPattern.REPLICATE:
|
||||
robust_broadcast(colo_tensor.data)
|
||||
else:
|
||||
global_size = colo_tensor.size_global()
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
entire_data = colo_tensor.data
|
||||
else:
|
||||
entire_data = torch.empty(global_size, device=colo_tensor.device)
|
||||
robust_broadcast(entire_data)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
colo_tensor.set_dist_spec(dist_spec)
|
||||
else:
|
||||
rep_tensor = ColoTensor(
|
||||
entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec))
|
||||
rep_tensor.set_dist_spec(dist_spec)
|
||||
with torch.no_grad():
|
||||
colo_tensor.data.copy_(rep_tensor.data)
|
||||
# synchronize all processes for unexpected problems
|
||||
dist.barrier()
|
268
colossalai/legacy/utils/checkpointing.py
Normal file
268
colossalai/legacy/utils/checkpointing.py
Normal file
@@ -0,0 +1,268 @@
|
||||
from collections import OrderedDict
|
||||
from itertools import chain
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.legacy.constants import IS_TENSOR_PARALLEL
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
||||
except ImportError:
|
||||
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
||||
|
||||
from .common import is_using_pp
|
||||
|
||||
__all__ = ["save_checkpoint", "load_checkpoint"]
|
||||
|
||||
|
||||
def broadcast_state_dict(state_dict, parallel_mode):
|
||||
state_dict = [state_dict.copy() if isinstance(state_dict, dict) else state_dict]
|
||||
src_rank = gpc.get_ranks_in_group(parallel_mode)[0]
|
||||
dist.broadcast_object_list(state_dict, src=src_rank, group=gpc.get_cpu_group(parallel_mode))
|
||||
return state_dict[0]
|
||||
|
||||
|
||||
def partition_tensor_parallel_state_dict(state_dict: OrderedDict,
|
||||
parallel_mode: ParallelMode,
|
||||
dims: dict = dict(),
|
||||
partition_states: dict = dict()):
|
||||
src_rank = gpc.get_ranks_in_group(parallel_mode)[0]
|
||||
depth = gpc.get_world_size(parallel_mode)
|
||||
group = gpc.get_cpu_group(parallel_mode)
|
||||
is_rank0 = gpc.get_local_rank(parallel_mode) == 0
|
||||
partition_info = [None]
|
||||
if is_rank0:
|
||||
partition_info_dict = OrderedDict()
|
||||
for key, param in state_dict.items():
|
||||
dim = dims[key]
|
||||
is_partitioned = partition_states[key]
|
||||
shape = list(param.shape)
|
||||
if is_partitioned:
|
||||
shape[dim] = shape[dim] // depth
|
||||
partition_info_dict[key] = (is_partitioned, param.dtype, shape, dim)
|
||||
partition_info[0] = partition_info_dict
|
||||
dist.broadcast_object_list(partition_info, src_rank, group=group)
|
||||
partitioned_state = OrderedDict()
|
||||
for key, (is_partitioned, dtype, shape, dim) in partition_info[0].items():
|
||||
if is_partitioned:
|
||||
output = torch.empty(shape, dtype=dtype)
|
||||
if is_rank0:
|
||||
scatter_list = [t.contiguous() for t in state_dict[key].chunk(depth, dim)]
|
||||
else:
|
||||
scatter_list = None
|
||||
dist.scatter(output, scatter_list, src_rank, group=group)
|
||||
else:
|
||||
if is_rank0:
|
||||
output = state_dict[key]
|
||||
else:
|
||||
output = torch.empty(shape, dtype=dtype)
|
||||
dist.broadcast(output, src_rank, group=group)
|
||||
partitioned_state[key] = output
|
||||
return partitioned_state
|
||||
|
||||
|
||||
def gather_tensor_parallel_state_dict(
|
||||
state_dict: OrderedDict,
|
||||
parallel_mode: ParallelMode,
|
||||
dims: dict = dict(),
|
||||
partition_states: dict = dict(),
|
||||
keep_vars: bool = False,
|
||||
):
|
||||
dst_rank = gpc.get_ranks_in_group(parallel_mode)[0]
|
||||
depth = gpc.get_world_size(parallel_mode)
|
||||
|
||||
for key in list(state_dict.keys()):
|
||||
param = state_dict.pop(key)
|
||||
param = param if keep_vars else param.detach()
|
||||
dim = dims.get(key, 0)
|
||||
do_partition = partition_states.get(key, True)
|
||||
if do_partition:
|
||||
temp = param.transpose(0, dim).contiguous()
|
||||
gather_list = None
|
||||
if gpc.get_local_rank(parallel_mode) == 0:
|
||||
shape = list(param.shape)
|
||||
shape[0], shape[dim] = shape[dim], shape[0]
|
||||
shape[0] *= depth
|
||||
param = torch.empty(shape, dtype=param.dtype, device=param.device)
|
||||
gather_list = list(torch.chunk(param, depth, dim=0))
|
||||
dist.gather(temp, gather_list, dst=dst_rank, group=gpc.get_cpu_group(parallel_mode))
|
||||
param = torch.transpose(param, 0, dim)
|
||||
# update params in state_dict only on local rank 0
|
||||
if gpc.get_local_rank(parallel_mode) == 0:
|
||||
state_dict[key] = param
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def _send_state_dict(state_dict, dst, parallel_mode):
|
||||
state_tensor, state_size = dist.distributed_c10d._object_to_tensor(state_dict)
|
||||
dist.send(state_size, dst, group=gpc.get_cpu_group(parallel_mode))
|
||||
dist.send(state_tensor, dst, group=gpc.get_cpu_group(parallel_mode))
|
||||
|
||||
|
||||
def _recv_state_dict(src, parallel_mode):
|
||||
state_size = torch.tensor([0], dtype=torch.long)
|
||||
dist.recv(state_size, src, group=gpc.get_cpu_group(parallel_mode))
|
||||
state_tensor = torch.empty(state_size.item(), dtype=torch.uint8)
|
||||
dist.recv(state_tensor, src, group=gpc.get_cpu_group(parallel_mode))
|
||||
state_dict = dist.distributed_c10d._tensor_to_object(state_tensor, state_size)
|
||||
return state_dict
|
||||
|
||||
|
||||
def partition_pipeline_parallel_state_dict(model, state_dict):
|
||||
pipeline_state = OrderedDict()
|
||||
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# receive all states from prev stage
|
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
state_dict = _recv_state_dict(gpc.get_prev_global_rank(ParallelMode.PIPELINE), ParallelMode.PIPELINE)
|
||||
# move states to output
|
||||
for name, _ in model.named_parameters(recurse=True):
|
||||
if name in state_dict:
|
||||
pipeline_state[name] = state_dict.pop(name)
|
||||
for name, _ in model.named_buffers(recurse=True):
|
||||
if name in state_dict:
|
||||
pipeline_state[name] = state_dict.pop(name)
|
||||
for name, _ in model.named_modules():
|
||||
extra_state_key = name + "." + _EXTRA_STATE_KEY_SUFFIX
|
||||
if extra_state_key in state_dict:
|
||||
pipeline_state[extra_state_key] = state_dict.pop(extra_state_key)
|
||||
# send rest states to next stage
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
_send_state_dict(state_dict, gpc.get_next_global_rank(ParallelMode.PIPELINE), ParallelMode.PIPELINE)
|
||||
|
||||
return pipeline_state
|
||||
|
||||
|
||||
def gather_pipeline_parallel_state_dict(state_dict):
|
||||
gathered_states = ([None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))]
|
||||
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else None)
|
||||
dist.gather_object(
|
||||
state_dict,
|
||||
gathered_states,
|
||||
dst=gpc.get_ranks_in_group(ParallelMode.PIPELINE)[0],
|
||||
group=gpc.get_cpu_group(ParallelMode.PIPELINE),
|
||||
)
|
||||
|
||||
state_dict = (OrderedDict(chain.from_iterable(state.items() for state in gathered_states))
|
||||
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else OrderedDict())
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_checkpoint(file,
|
||||
epoch: int,
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer = None,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
**kwargs):
|
||||
"""Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer,
|
||||
lr_scheduler etc. into a checkpoint dictionary.
|
||||
|
||||
Args:
|
||||
file: a file-like object (has to implement write and flush) or a string or os.PathLike object containing a
|
||||
file name.
|
||||
epoch (int): Epoch number (indicates how many epochs have you trained this model).
|
||||
model (:class:`torch.nn.Module`): Model to be saved.
|
||||
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be saved.
|
||||
lr_scheduler (Union[:class:`torch.optim.lr_scheduler`, :class:`colossalai.nn.lr_scheduler`], optional):
|
||||
lr_scheduler to be saved, defaults to None.
|
||||
pickle_module: module used for pickling metadata and objects
|
||||
pickle_protocol: can be specified to override the default protocol
|
||||
"""
|
||||
# ckpt container
|
||||
checkpoint = {"epoch": epoch}
|
||||
|
||||
model_state = model.state_dict()
|
||||
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
model_state = gather_pipeline_parallel_state_dict(model_state)
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
checkpoint["model"] = model_state
|
||||
|
||||
# if optimizer is not None:
|
||||
# checkpoint['optimizer'] = optimizer.state_dict()
|
||||
|
||||
# if lr_scheduler is not None:
|
||||
# checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
|
||||
|
||||
torch.save(checkpoint, file, **kwargs)
|
||||
|
||||
|
||||
def broadcast_model(model: torch.nn.Module):
|
||||
src_rank = gpc.get_ranks_in_group(ParallelMode.TENSOR)[0]
|
||||
for p in model.parameters():
|
||||
if not getattr(p, IS_TENSOR_PARALLEL, False) and p.storage().size() > 0:
|
||||
group = gpc.get_group(ParallelMode.TENSOR) if p.device.type == 'cuda' else gpc.get_cpu_group(
|
||||
ParallelMode.TENSOR)
|
||||
dist.broadcast(p, src_rank, group=group)
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
file,
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer = None,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
strict: bool = True,
|
||||
):
|
||||
"""Loads training states from a checkpoint file.
|
||||
|
||||
Args:
|
||||
file: a file-like object (has to implement read(), readline(), tell(), and seek()), or a string or os.PathLike
|
||||
object containing a file name.
|
||||
model (:class:`torch.nn.Module`): Model to load saved weights and buffers.
|
||||
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to recuperate.
|
||||
lr_scheduler (:class:`torch.optim.lr_scheduler._LRScheduler`, optional):
|
||||
lr_scheduler to recuperate, defaults to None.
|
||||
strict (bool, optional): Whether to strictly enforce that the keys in :attr:`state_dict`
|
||||
of the checkpoint match the names of parameters and buffers in model, defaults to True.
|
||||
|
||||
Returns:
|
||||
int: The saved epoch number.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated
|
||||
"""
|
||||
state_dict = (torch.load(file, map_location=torch.device("cpu"))
|
||||
if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None)
|
||||
|
||||
# model states
|
||||
model_state = state_dict.pop("model") if state_dict is not None else dict()
|
||||
# pipeline
|
||||
if is_using_pp():
|
||||
model_state = partition_pipeline_parallel_state_dict(model, model_state)
|
||||
try:
|
||||
model.load_state_dict(model_state, strict=strict)
|
||||
broadcast_model(model)
|
||||
except RuntimeError as e:
|
||||
error_msgs = str(e)
|
||||
if error_msgs.startswith("Error(s) in loading state_dict for "):
|
||||
error_msgs = error_msgs.split("\n\t")[1:]
|
||||
dst_rank = gpc.get_ranks_in_group(ParallelMode.MODEL)[0]
|
||||
all_error_msgs = [None for _ in range(gpc.get_world_size(ParallelMode.MODEL))]
|
||||
dist.gather_object(error_msgs, all_error_msgs, dst=dst_rank, group=gpc.get_cpu_group(ParallelMode.MODEL))
|
||||
if gpc.get_global_rank() == 0:
|
||||
all_error_msgs = list(chain.from_iterable(all_error_msgs))
|
||||
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(
|
||||
model.__class__.__name__, "\n\t".join(all_error_msgs)))
|
||||
else:
|
||||
raise e
|
||||
|
||||
# broadcast the rest states
|
||||
state_dict = broadcast_state_dict(state_dict, ParallelMode.MODEL)
|
||||
|
||||
# # optimizer states
|
||||
# if optimizer is not None and 'optimizer' in state_dict:
|
||||
# optimizer.load_state_dict(state_dict['optimizer'])
|
||||
|
||||
# # lr scheduler states
|
||||
# if lr_scheduler is not None and 'lr_scheduler' in state_dict:
|
||||
# lr_scheduler.load_state_dict(state_dict['lr_scheduler'])
|
||||
|
||||
# last epoch
|
||||
last_epoch = state_dict.pop("epoch", -1)
|
||||
|
||||
return last_epoch
|
434
colossalai/legacy/utils/common.py
Normal file
434
colossalai/legacy/utils/common.py
Normal file
@@ -0,0 +1,434 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import inf
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from colossalai.legacy.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.global_variables import tensor_parallel_env as env
|
||||
from colossalai.legacy.tensor import ProcessGroup
|
||||
from colossalai.tensor import ColoParameter
|
||||
from colossalai.utils.multi_tensor_apply import multi_tensor_applier
|
||||
|
||||
try:
|
||||
from colossalai._C import fused_optim
|
||||
except:
|
||||
fused_optim = None
|
||||
|
||||
|
||||
def print_rank_0(msg: str, logger=None):
|
||||
"""Print messages and save logs(optional). This is executed only if you are the rank-0 gpu.
|
||||
|
||||
Args:
|
||||
msg (str): A string message to output.
|
||||
logger (:class:`colossalai.logging.DistributedLogger`, optional):
|
||||
The logger to record the message, defaults to None.
|
||||
"""
|
||||
if gpc.get_global_rank() == 0:
|
||||
if logger is None:
|
||||
print(msg, flush=True)
|
||||
else:
|
||||
logger.info(msg)
|
||||
|
||||
|
||||
def sync_model_param(model, parallel_mode):
|
||||
r"""Make sure data parameters are consistent during Data Parallel Mode.
|
||||
|
||||
Args:
|
||||
model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
|
||||
parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel mode to be checked.
|
||||
|
||||
Note:
|
||||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
||||
"""
|
||||
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
|
||||
for param in model.parameters():
|
||||
ranks = gpc.get_ranks_in_group(parallel_mode)
|
||||
dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode))
|
||||
|
||||
|
||||
def is_dp_rank_0():
|
||||
return not gpc.is_initialized(ParallelMode.DATA) or gpc.is_first_rank(ParallelMode.DATA)
|
||||
|
||||
|
||||
def is_tp_rank_0():
|
||||
return not gpc.is_initialized(ParallelMode.TENSOR) or gpc.is_first_rank(ParallelMode.TENSOR)
|
||||
|
||||
|
||||
def is_no_pp_or_last_stage():
|
||||
return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE)
|
||||
|
||||
|
||||
def is_using_ddp():
|
||||
return gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1
|
||||
|
||||
|
||||
def is_using_pp():
|
||||
return gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1
|
||||
|
||||
|
||||
def is_using_sequence():
|
||||
return gpc.is_initialized(ParallelMode.SEQUENCE) and gpc.get_world_size(ParallelMode.SEQUENCE) > 1
|
||||
|
||||
|
||||
class model_branch_context(object):
|
||||
|
||||
def __enter__(self):
|
||||
self.env_status = env.save()
|
||||
|
||||
def __exit__(self, *exc_info):
|
||||
env.load(**self.env_status)
|
||||
|
||||
|
||||
def is_model_parallel_parameter(p):
|
||||
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
|
||||
|
||||
|
||||
def _calc_l2_norm(grads):
|
||||
# we should not
|
||||
global fused_optim
|
||||
|
||||
if fused_optim is None:
|
||||
from colossalai.kernel.op_builder import FusedOptimBuilder
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
|
||||
norm = 0.0
|
||||
if len(grads) > 0:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
norm, _ = multi_tensor_applier(
|
||||
fused_optim.multi_tensor_l2norm,
|
||||
dummy_overflow_buf,
|
||||
[grads],
|
||||
False # no per-parameter norm
|
||||
)
|
||||
return norm
|
||||
|
||||
|
||||
def _calc_lp(grads, norm_type):
|
||||
norm = 0.0
|
||||
for grad in grads:
|
||||
grad_norm = torch.norm(grad, norm_type)
|
||||
norm += grad_norm**norm_type
|
||||
return norm
|
||||
|
||||
|
||||
def _move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
|
||||
if torch.is_tensor(norm) and norm.device.type != 'cuda':
|
||||
norm = norm.to(torch.cuda.current_device())
|
||||
return norm
|
||||
|
||||
|
||||
def _get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Tensor:
|
||||
if isinstance(norm, float):
|
||||
norm = torch.Tensor([norm])
|
||||
if move_to_cuda:
|
||||
norm = norm.to(torch.cuda.current_device())
|
||||
return norm
|
||||
|
||||
|
||||
# ======== Gradient Clipping =========
|
||||
|
||||
|
||||
def _compute_local_lp(params: List[ColoParameter], norm_type: float) -> float:
|
||||
if len(params) == 0:
|
||||
return 0.0
|
||||
grads = [p.grad for p in params]
|
||||
use_cuda_kernel = grads[0].device.type == 'cuda'
|
||||
if norm_type == inf:
|
||||
local_lp = max([g.abs().max() for g in grads])
|
||||
elif norm_type == 2.0 and use_cuda_kernel:
|
||||
local_lp = _calc_l2_norm(grads)**norm_type
|
||||
else:
|
||||
local_lp = _calc_lp(grads, norm_type)
|
||||
if isinstance(local_lp, torch.Tensor):
|
||||
return local_lp.item()
|
||||
return local_lp
|
||||
|
||||
|
||||
def _compute_buckets_lp(params: List[ColoParameter], norm_type: float) -> float:
|
||||
if len(params) == 0:
|
||||
return 0.0
|
||||
buckets: Dict[Optional[ProcessGroup], List[ColoParameter]] = defaultdict(list)
|
||||
for p in params:
|
||||
if p.is_replicate():
|
||||
buckets[None].append(p)
|
||||
else:
|
||||
buckets[p.get_process_group().tp_process_group()].append(p)
|
||||
total_lp = 0.0
|
||||
for group, bucket in buckets.items():
|
||||
local_lp = _compute_local_lp(bucket, norm_type)
|
||||
if group is not None:
|
||||
local_lp_tensor = torch.tensor([local_lp], device=torch.cuda.current_device())
|
||||
if norm_type == inf:
|
||||
dist.all_reduce(local_lp_tensor, op=dist.ReduceOp.MAX, group=group)
|
||||
else:
|
||||
dist.all_reduce(local_lp_tensor, group=group)
|
||||
local_lp = local_lp_tensor.item()
|
||||
if norm_type == inf:
|
||||
total_lp = max(total_lp, local_lp)
|
||||
else:
|
||||
total_lp += local_lp
|
||||
return total_lp
|
||||
|
||||
|
||||
def _compute_pp_grad_lp(total_lp: float, norm_type: float) -> float:
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
|
||||
total_lp_tensor = torch.tensor([total_lp], device=torch.cuda.current_device())
|
||||
if norm_type == inf:
|
||||
dist.all_reduce(total_lp_tensor, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
else:
|
||||
dist.all_reduce(total_lp_tensor, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
total_lp = total_lp_tensor.item()
|
||||
return total_lp
|
||||
|
||||
|
||||
def _compute_grad_lp(parameters, norm_type: float = 2.0) -> float:
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
grad_dtype = None
|
||||
cpu_grad_params: List[ColoParameter] = []
|
||||
cuda_grad_params: List[ColoParameter] = []
|
||||
for p in parameters:
|
||||
if p.grad is None:
|
||||
continue
|
||||
assert isinstance(p, ColoParameter)
|
||||
if grad_dtype is None:
|
||||
grad_dtype = p.grad.dtype
|
||||
assert p.grad.dtype == grad_dtype, f'Expected all grads are {grad_dtype}, got {p.grad.dtype}'
|
||||
if p.grad.device.type == 'cuda':
|
||||
cuda_grad_params.append(p)
|
||||
else:
|
||||
cpu_grad_params.append(p)
|
||||
norm_type = float(norm_type)
|
||||
cpu_lp = _compute_buckets_lp(cpu_grad_params, norm_type)
|
||||
cuda_lp = _compute_buckets_lp(cuda_grad_params, norm_type)
|
||||
if norm_type == inf:
|
||||
total_lp = max(cpu_lp, cuda_lp)
|
||||
else:
|
||||
total_lp = cpu_lp + cuda_lp
|
||||
return _compute_pp_grad_lp(total_lp, norm_type)
|
||||
|
||||
|
||||
def compute_grad_norm(parameters, norm_type: float = 2.0) -> float:
|
||||
norm_type = float(norm_type)
|
||||
total_norm = _compute_grad_lp(parameters, norm_type)
|
||||
if norm_type != inf:
|
||||
total_norm = total_norm**(1 / norm_type)
|
||||
return total_norm
|
||||
|
||||
|
||||
def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None:
|
||||
clip_coef = max_norm / (total_norm + 1e-6)
|
||||
if clip_coef < 1.0:
|
||||
cuda_grads: List[torch.Tensor] = []
|
||||
cpu_grads: List[torch.Tensor] = []
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
for p in parameters:
|
||||
if p.grad is None:
|
||||
continue
|
||||
if p.grad.device.type == 'cuda':
|
||||
cuda_grads.append(p.grad.detach())
|
||||
else:
|
||||
cpu_grads.append(p.grad.detach())
|
||||
if len(cuda_grads) > 0:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads],
|
||||
clip_coef)
|
||||
for g in cpu_grads:
|
||||
g.mul_(clip_coef)
|
||||
|
||||
|
||||
def clip_grad_norm(parameters, max_norm: float, norm_type: float = 2.0) -> float:
|
||||
total_norm = compute_grad_norm(parameters, norm_type)
|
||||
_clip_grad_norm(parameters, max_norm, total_norm)
|
||||
return total_norm
|
||||
|
||||
|
||||
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||
"""Clips gradient norm of an iterable of parameters whose gradients are in fp32.
|
||||
|
||||
This is adapted from :func:`torch.nn.utils.clip_grad.clip_grad_norm_` and
|
||||
added functionality to handle model parallel parameters.
|
||||
|
||||
Note:
|
||||
the gradients are modified in place.
|
||||
|
||||
Args:
|
||||
parameters (Iterable[:class:`torch.tensor`] or :class:`torch.tensor`):
|
||||
An iterable of Tensors or a single Tensor that will have gradients normalized.
|
||||
max_norm (Union[float, int]): Max norm of the gradients.
|
||||
norm_type (Union[float, int, 'inf']): Type of the used p-norm. Can be ``'inf'`` for infinity norm.
|
||||
|
||||
Returns:
|
||||
float: Total norm of the parameters.
|
||||
"""
|
||||
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
|
||||
# Filter parameters based on:
|
||||
# - grad should not be none
|
||||
# - parameter should not be shared
|
||||
# - should not be a replica due to tensor model parallelism
|
||||
params: List[Parameter] = []
|
||||
has_zero_shared_param: bool = False
|
||||
for param in parameters:
|
||||
if param.grad is not None:
|
||||
# Make sure the grads are in fp32
|
||||
assert param.grad.dtype == torch.float, \
|
||||
f'expected gradient to be dtype torch.float, but got {param.grad.type()}'
|
||||
if hasattr(param, 'colo_attr') and param.colo_attr.sharded_data_tensor.is_sharded:
|
||||
has_zero_shared_param = True
|
||||
params.append(param)
|
||||
|
||||
if len(params) == 0:
|
||||
enable_cuda_kernels = False
|
||||
else:
|
||||
enable_cuda_kernels = params[0].grad.device.type == 'cuda'
|
||||
# Norm parameters.
|
||||
max_norm = float(max_norm)
|
||||
norm_type = float(norm_type)
|
||||
|
||||
# Parameters can be on CPU or CUDA
|
||||
# If parameters are on CPU, disable CUDA kernels
|
||||
|
||||
# Calculate norm.
|
||||
if norm_type == inf:
|
||||
total_norm = max(p.grad.data.abs().max() for p in params)
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
# Take max across all model-parallel GPUs.
|
||||
if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
|
||||
dist.all_reduce(total_norm_cuda,
|
||||
op=dist.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.MODEL),
|
||||
async_op=False)
|
||||
if has_zero_shared_param:
|
||||
dist.all_reduce(total_norm_cuda,
|
||||
op=dist.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.DATA),
|
||||
async_op=False)
|
||||
total_norm = total_norm_cuda[0].item()
|
||||
else:
|
||||
tensor_parallel_grads = []
|
||||
no_tensor_parallel_grads = []
|
||||
zero_sharded_grads = []
|
||||
for p in params:
|
||||
if is_model_parallel_parameter(p):
|
||||
reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type)
|
||||
tensor_parallel_grads.append(p.grad.data / reductor)
|
||||
elif hasattr(p, 'colo_attr') and p.colo_attr.sharded_data_tensor.is_sharded:
|
||||
zero_sharded_grads.append(p.grad.data)
|
||||
else:
|
||||
no_tensor_parallel_grads.append(p.grad.data)
|
||||
|
||||
if norm_type == 2.0 and enable_cuda_kernels:
|
||||
tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type
|
||||
no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type
|
||||
zero_sharded_norm = _calc_l2_norm(zero_sharded_grads)**norm_type
|
||||
else:
|
||||
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
|
||||
no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type)
|
||||
zero_sharded_norm = _calc_lp(zero_sharded_grads, norm_type)
|
||||
# If norm is type of float, then we convert them into torch.Tensor.
|
||||
tensor_parallel_norm = _get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels)
|
||||
no_tensor_parallel_norm = _get_tensor_norm(no_tensor_parallel_norm, enable_cuda_kernels)
|
||||
zero_sharded_norm = _get_tensor_norm(zero_sharded_norm, enable_cuda_kernels)
|
||||
# If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors
|
||||
if not enable_cuda_kernels:
|
||||
tensor_parallel_norm = _move_norm_to_cuda(tensor_parallel_norm)
|
||||
no_tensor_parallel_norm = _move_norm_to_cuda(no_tensor_parallel_norm)
|
||||
zero_sharded_norm = _move_norm_to_cuda(zero_sharded_norm)
|
||||
|
||||
# Sum across all model-parallel GPUs.
|
||||
if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
|
||||
dist.all_reduce(tensor_parallel_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR))
|
||||
# Sum across all zero sharded GPUs
|
||||
if len(zero_sharded_grads) > 0:
|
||||
dist.all_reduce(zero_sharded_norm, group=gpc.get_group(ParallelMode.DATA))
|
||||
no_tensor_parallel_norm += zero_sharded_norm
|
||||
total_norm = tensor_parallel_norm + no_tensor_parallel_norm
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
|
||||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
total_norm = total_norm**(1.0 / norm_type)
|
||||
if torch.is_tensor(total_norm):
|
||||
total_norm = total_norm.item()
|
||||
|
||||
# Scale.
|
||||
clip_coeff = max_norm / (total_norm + 1.0e-6)
|
||||
if clip_coeff < 1.0:
|
||||
if enable_cuda_kernels:
|
||||
grads = [p.grad.detach() for p in params]
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff)
|
||||
else:
|
||||
for p in params:
|
||||
p.grad.detach().mul_(clip_coeff)
|
||||
return total_norm
|
||||
|
||||
|
||||
def count_zeros_fp32(parameters):
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
|
||||
# Filter parameters based on:
|
||||
# - grad should not be none
|
||||
# - parameter should not be shared
|
||||
# - should not be a replica due to tensor model parallelism
|
||||
total_num_zeros = 0.0
|
||||
for param in parameters:
|
||||
grad_not_none = param.grad is not None
|
||||
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
|
||||
if grad_not_none and is_not_tp_duplicate:
|
||||
grad = param.grad.detach()
|
||||
num_zeros = grad.numel() - torch.count_nonzero(grad)
|
||||
total_num_zeros = num_zeros + total_num_zeros
|
||||
|
||||
total_num_zeros = torch.IntTensor([int(total_num_zeros)]).cuda()
|
||||
|
||||
# Sum across all model-parallel GPUs.
|
||||
ops = []
|
||||
ops.append(
|
||||
dist.all_reduce(total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR), async_op=True))
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
ops.append(
|
||||
dist.all_reduce(total_num_zeros,
|
||||
op=dist.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.PIPELINE),
|
||||
async_op=True))
|
||||
|
||||
for req in ops:
|
||||
req.wait()
|
||||
total_num_zeros = total_num_zeros.item()
|
||||
|
||||
return total_num_zeros
|
||||
|
||||
|
||||
def copy_tensor_parallel_attributes(src_tensor, dst_tensor):
|
||||
for attr in TENSOR_PARALLEL_ATTRIBUTES:
|
||||
if hasattr(src_tensor, attr):
|
||||
val = getattr(src_tensor, attr)
|
||||
setattr(dst_tensor, attr, val)
|
||||
|
||||
|
||||
def param_is_not_tensor_parallel_duplicate(param):
|
||||
return (hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL)) or (gpc.get_local_rank(
|
||||
ParallelMode.TENSOR) == 0)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def switch_virtual_pipeline_parallel_rank(rank):
|
||||
prev_rank = gpc.virtual_pipeline_parallel_rank
|
||||
try:
|
||||
gpc.set_virtual_pipeline_parallel_rank(rank)
|
||||
yield
|
||||
finally:
|
||||
gpc.set_virtual_pipeline_parallel_rank(prev_rank)
|
4
colossalai/legacy/utils/data_sampler/__init__.py
Normal file
4
colossalai/legacy/utils/data_sampler/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .base_sampler import BaseSampler
|
||||
from .data_parallel_sampler import DataParallelSampler, get_dataloader
|
||||
|
||||
__all__ = ['BaseSampler', 'DataParallelSampler', 'get_dataloader']
|
19
colossalai/legacy/utils/data_sampler/base_sampler.py
Normal file
19
colossalai/legacy/utils/data_sampler/base_sampler.py
Normal file
@@ -0,0 +1,19 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseSampler(ABC):
|
||||
|
||||
def __init__(self, dataset, batch_size):
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
|
||||
@abstractmethod
|
||||
def __len__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __iter__(self):
|
||||
pass
|
161
colossalai/legacy/utils/data_sampler/data_parallel_sampler.py
Normal file
161
colossalai/legacy/utils/data_sampler/data_parallel_sampler.py
Normal file
@@ -0,0 +1,161 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
# adapted from torch.utils.data.DistributedSampler
|
||||
|
||||
import math
|
||||
import random
|
||||
from typing import Iterator, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset, Sampler
|
||||
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
|
||||
|
||||
class DataParallelSampler(Sampler):
|
||||
"""A data sampler for distributed data parallelism.
|
||||
|
||||
Args:
|
||||
dataset (:class:`torch.utils.data.Dataset`): The Dataset for sampling.
|
||||
shuffle (bool, optional): Whether to shuffle data, defaults to False.
|
||||
seed (int, optional): The random seed used for sampling, defaults to 0.
|
||||
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
|
||||
is not divisible by the batch size. If False and the size of dataset is not divisible by
|
||||
the batch size, then the last batch will be smaller, defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset: Dataset, shuffle: bool = False, seed: int = 0, drop_last: bool = False) -> None:
|
||||
self.dataset = dataset
|
||||
self.num_replicas = gpc.get_world_size(ParallelMode.DATA)
|
||||
self.rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||
self.epoch = 0
|
||||
self.drop_last = drop_last
|
||||
# If the dataset length is evenly divisible by # of replicas, then there
|
||||
# is no need to drop any data, since the dataset will be split equally.
|
||||
# type: ignore[arg-type]
|
||||
if self.drop_last and len(self.dataset) % self.num_replicas != 0:
|
||||
# Split to nearest available length that is evenly divisible.
|
||||
# This is to ensure each rank receives the same amount of data when
|
||||
# using this Sampler.
|
||||
self.num_samples = math.ceil(
|
||||
# `type:ignore` is required because Dataset cannot provide a default __len__
|
||||
# see NOTE in pytorch/torch/utils/data/sampler.py
|
||||
(len(self.dataset) - self.num_replicas) / \
|
||||
self.num_replicas # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
|
||||
def __iter__(self) -> Iterator[T_co]:
|
||||
if self.shuffle:
|
||||
# deterministically shuffle based on epoch and seed
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.seed + self.epoch)
|
||||
# type: ignore[arg-type]
|
||||
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
||||
|
||||
# update for next epoch so that there is no need to call
|
||||
# set_epoch manually
|
||||
self.epoch += 1
|
||||
else:
|
||||
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
|
||||
|
||||
if not self.drop_last:
|
||||
# add extra samples to make it evenly divisible
|
||||
padding_size = self.total_size - len(indices)
|
||||
if padding_size <= len(indices):
|
||||
indices += indices[:padding_size]
|
||||
else:
|
||||
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
|
||||
else:
|
||||
# remove tail of data to make it evenly divisible.
|
||||
indices = indices[:self.total_size]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_samples
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
r"""Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
|
||||
use a different random ordering for each epoch. Otherwise, the next iteration of this
|
||||
sampler will yield the same ordering.
|
||||
|
||||
Args:
|
||||
epoch (int): Epoch number.
|
||||
"""
|
||||
self.epoch = epoch
|
||||
|
||||
|
||||
def get_dataloader(dataset,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
add_sampler=True,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
**kwargs):
|
||||
r"""Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
|
||||
|
||||
Note:
|
||||
When pipeline parallel is enabled, shuffle cannot be True as it will result in mismatch between input data
|
||||
on the 1st stage and label on the last stage.
|
||||
|
||||
Args:
|
||||
dataset (:class:`torch.utils.data.Dataset`): The dataset to be loaded.
|
||||
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
|
||||
seed (int, optional): Random worker seed for sampling, defaults to 1024.
|
||||
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
|
||||
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
|
||||
is not divisible by the batch size. If False and the size of dataset is not divisible by
|
||||
the batch size, then the last batch will be smaller, defaults to False.
|
||||
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
|
||||
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
|
||||
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
|
||||
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
|
||||
|
||||
Returns:
|
||||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
|
||||
if add_sampler and gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
|
||||
sampler = DataParallelSampler(dataset, shuffle=shuffle)
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
if sampler is None:
|
||||
return DataLoader(dataset,
|
||||
worker_init_fn=seed_worker,
|
||||
shuffle=shuffle,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs)
|
||||
else:
|
||||
return DataLoader(dataset,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs)
|
178
colossalai/legacy/utils/memory.py
Normal file
178
colossalai/legacy/utils/memory.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import gc
|
||||
from collections import namedtuple
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from packaging import version
|
||||
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
_GLOBAL_CUDA_MEM_FRACTION = 1.0
|
||||
_GLOBAL_CPU_MEM_CAPACITY = -1
|
||||
|
||||
|
||||
def _bytes_to_MB(val, decimal=2):
|
||||
"""A byte-to-Megabyte converter, default using binary notation.
|
||||
|
||||
:param val: X bytes to convert
|
||||
:return: X' MB
|
||||
"""
|
||||
return round(val / (1024 * 1024), decimal)
|
||||
|
||||
|
||||
# copy from PatrickStar
|
||||
def _get_cpu_memory_info():
|
||||
ps_mem_info = namedtuple("ps_mem_info", ["total", "free", "cached", "buffers", "used"])
|
||||
try:
|
||||
# psutil reads the memory info from /proc/memory_info,
|
||||
# which results in returning the host memory instead of
|
||||
# that of container.
|
||||
# Here we try to read the container memory with method in:
|
||||
# https://stackoverflow.com/a/46213331/5163915
|
||||
mems = {}
|
||||
with open("/sys/fs/cgroup/memory/memory.meminfo", "rb") as f:
|
||||
for line in f:
|
||||
fields = line.split()
|
||||
mems[fields[0]] = int(fields[1]) * 1024
|
||||
total = mems[b"MemTotal:"]
|
||||
free = mems[b"MemFree:"]
|
||||
cached = mems[b"Cached:"]
|
||||
buffers = mems[b"Buffers:"]
|
||||
used = total - free - cached - buffers
|
||||
if used < 0:
|
||||
used = total - free
|
||||
mem_info = ps_mem_info(total=total, free=free, cached=cached, buffers=buffers, used=used)
|
||||
except FileNotFoundError:
|
||||
mems = psutil.virtual_memory()
|
||||
mem_info = ps_mem_info(
|
||||
total=mems.total,
|
||||
free=mems.free,
|
||||
cached=mems.cached,
|
||||
buffers=mems.buffers,
|
||||
used=mems.used,
|
||||
)
|
||||
return mem_info
|
||||
|
||||
|
||||
def report_memory_usage(message, logger=None, report_cpu=False):
|
||||
"""Calculate and print RAM usage (in GB)
|
||||
|
||||
Args:
|
||||
message (str): A prefix message to add in the log.
|
||||
logger (:class:`colossalai.logging.DistributedLogger`): The logger used to record memory information.
|
||||
report_cpu (bool, optional): Whether to report CPU memory.
|
||||
|
||||
Raises:
|
||||
EnvironmentError: Raise error if no distributed environment has been initialized.
|
||||
"""
|
||||
if not dist.is_initialized():
|
||||
raise EnvironmentError("No distributed environment is initialized")
|
||||
|
||||
gpu_allocated = _bytes_to_MB(torch.cuda.memory_allocated())
|
||||
gpu_max_allocated = _bytes_to_MB(torch.cuda.max_memory_allocated())
|
||||
gpu_cached = _bytes_to_MB(torch.cuda.memory_reserved())
|
||||
gpu_max_cached = _bytes_to_MB(torch.cuda.max_memory_reserved())
|
||||
|
||||
full_log = f"{message}: GPU: allocated {gpu_allocated} MB, max allocated {gpu_max_allocated} MB, " \
|
||||
+ f"cached: {gpu_cached} MB, max cached: {gpu_max_cached} MB"
|
||||
|
||||
if report_cpu:
|
||||
# python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports
|
||||
gc.collect()
|
||||
vm_stats = psutil.virtual_memory()
|
||||
vm_used = _bytes_to_MB(vm_stats.total - vm_stats.available)
|
||||
full_log += f", CPU Virtual Memory: used = {vm_used} MB, percent = {vm_stats.percent}%"
|
||||
|
||||
if logger is None:
|
||||
logger = get_dist_logger()
|
||||
logger.info(full_log)
|
||||
|
||||
# get the peak memory to report correct data, so reset the counter for the next call
|
||||
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
|
||||
def colo_device_memory_capacity(device: torch.device) -> int:
|
||||
"""
|
||||
Get the capacity of the memory of the device
|
||||
|
||||
Args:
|
||||
device (torch.device): a device
|
||||
|
||||
Returns:
|
||||
int: size in byte
|
||||
"""
|
||||
assert isinstance(device, torch.device)
|
||||
if device.type == 'cpu':
|
||||
# In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory.
|
||||
return colo_get_cpu_memory_capacity() / gpc.num_processes_on_current_node
|
||||
if device.type == 'cuda':
|
||||
return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION
|
||||
|
||||
|
||||
def colo_device_memory_used(device: torch.device) -> int:
|
||||
"""
|
||||
Get the device memory on device belonging to the current process.
|
||||
|
||||
Args:
|
||||
device (torch.device): a device
|
||||
|
||||
Returns:
|
||||
int: memory size in bytes
|
||||
"""
|
||||
if device.type == 'cpu':
|
||||
mem_info = _get_cpu_memory_info()
|
||||
# In the context of 1-CPU-N-GPU, the memory usage of the current process is 1/N CPU memory used.
|
||||
# Each process consumes the same amount of memory.
|
||||
ret = mem_info.used / gpc.num_processes_on_current_node
|
||||
return ret
|
||||
elif device.type == 'cuda':
|
||||
ret: int = torch.cuda.memory_allocated(device)
|
||||
# get the peak memory to report correct data, so reset the counter for the next call
|
||||
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
return ret
|
||||
|
||||
|
||||
def colo_set_process_memory_fraction(ratio: float) -> None:
|
||||
"""colo_set_process_memory_fraction
|
||||
|
||||
set how much cuda memory used on the gpu belonging to the current process.
|
||||
|
||||
Args:
|
||||
ratio (float): a ratio between 0. ~ 1.
|
||||
"""
|
||||
if version.parse(torch.__version__) < version.parse('1.8'):
|
||||
logger = get_dist_logger('colo_set_process_memory_fraction')
|
||||
logger.warning('colo_set_process_memory_fraction failed because torch version is less than 1.8')
|
||||
return
|
||||
global _GLOBAL_CUDA_MEM_FRACTION
|
||||
_GLOBAL_CUDA_MEM_FRACTION = ratio
|
||||
torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device())
|
||||
|
||||
|
||||
def colo_set_cpu_memory_capacity(size: int) -> None:
|
||||
global _GLOBAL_CPU_MEM_CAPACITY
|
||||
mem_info = _get_cpu_memory_info()
|
||||
total_size = mem_info.total
|
||||
if size <= total_size:
|
||||
_GLOBAL_CPU_MEM_CAPACITY = size
|
||||
else:
|
||||
_GLOBAL_CPU_MEM_CAPACITY = total_size
|
||||
|
||||
|
||||
def colo_get_cpu_memory_capacity() -> int:
|
||||
"""
|
||||
Get the cpu memory capacity. We may not use all of it.
|
||||
Returns:
|
||||
int: _description_
|
||||
"""
|
||||
global _GLOBAL_CPU_MEM_CAPACITY
|
||||
if _GLOBAL_CPU_MEM_CAPACITY == -1:
|
||||
mem_info = _get_cpu_memory_info()
|
||||
return mem_info.total
|
||||
else:
|
||||
return _GLOBAL_CPU_MEM_CAPACITY
|
2
colossalai/legacy/utils/profiler/__init__.py
Normal file
2
colossalai/legacy/utils/profiler/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .legacy import *
|
||||
from .profiler import profile
|
20
colossalai/legacy/utils/profiler/extention.py
Normal file
20
colossalai/legacy/utils/profiler/extention.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ProfilerExtension(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def prepare_trace(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_trace(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop_trace(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def extend_chrome_trace(self, trace: dict) -> dict:
|
||||
pass
|
6
colossalai/legacy/utils/profiler/legacy/__init__.py
Normal file
6
colossalai/legacy/utils/profiler/legacy/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .comm_profiler import CommProfiler
|
||||
from .mem_profiler import MemProfiler
|
||||
from .pcie_profiler import PcieProfiler
|
||||
from .prof_utils import BaseProfiler, ProfilerContext
|
||||
|
||||
__all__ = ['BaseProfiler', 'CommProfiler', 'PcieProfiler', 'MemProfiler', 'ProfilerContext']
|
311
colossalai/legacy/utils/profiler/legacy/comm_profiler.py
Normal file
311
colossalai/legacy/utils/profiler/legacy/comm_profiler.py
Normal file
@@ -0,0 +1,311 @@
|
||||
import inspect
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.autograd.profiler import profile
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .prof_utils import BaseProfiler, _format_bandwidth, _format_memory, _format_time
|
||||
|
||||
|
||||
def _get_code_location(depth: int):
|
||||
ret = []
|
||||
length = min(len(inspect.stack()), depth + 1)
|
||||
for i in range(3, length):
|
||||
upper_frame = inspect.stack()[i]
|
||||
function_name = inspect.stack()[i - 1].function
|
||||
ret.append(upper_frame.filename)
|
||||
ret.append('(')
|
||||
ret.append(str(upper_frame.lineno))
|
||||
ret.append('): ')
|
||||
ret.append(function_name)
|
||||
if i != length - 1:
|
||||
ret.append('\n')
|
||||
|
||||
return ''.join(ret)
|
||||
|
||||
|
||||
torch_all_reduce = dist.all_reduce
|
||||
torch_all_gather = dist.all_gather
|
||||
torch_reduce_scatter = dist.reduce_scatter
|
||||
torch_broadcast = dist.broadcast
|
||||
torch_reduce = dist.reduce
|
||||
|
||||
|
||||
class CommEvent(object):
|
||||
"""Communication Event. Used for communication time and communication
|
||||
volume recording.
|
||||
"""
|
||||
|
||||
def __init__(self, count: int = 0, comm_vol: float = 0., cuda_time: int = 0):
|
||||
self.self_count = count
|
||||
self.self_comm_vol = comm_vol
|
||||
self.self_cuda_time = cuda_time
|
||||
|
||||
def add(self, rhs):
|
||||
self.self_count += rhs.self_count
|
||||
self.self_comm_vol += rhs.self_comm_vol
|
||||
self.self_cuda_time += rhs.self_cuda_time
|
||||
|
||||
|
||||
class CommProfiler(BaseProfiler):
|
||||
"""Communication profiler. Records all communication events.
|
||||
"""
|
||||
|
||||
def __init__(self, depth: int = 0, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0):
|
||||
super().__init__(profiler_name="Collective_Communication", priority=0)
|
||||
self.depth = 3 + depth
|
||||
self.total_count = total_count
|
||||
self.total_comm_vol = total_comm_vol
|
||||
self.total_cuda_time = total_cuda_time
|
||||
|
||||
self.ops_record = dict()
|
||||
self.profiler = None
|
||||
self.pending_op = None
|
||||
self.pending_metadata = None
|
||||
self.warn_flag = False
|
||||
|
||||
def reset(self):
|
||||
self.total_count = 0
|
||||
self.total_comm_vol = 0
|
||||
self.total_cuda_time = 0
|
||||
|
||||
self.ops_record = dict()
|
||||
self.profiler = None
|
||||
self.pending_op = None
|
||||
self.pending_metadata = None
|
||||
self.warn_flag = False
|
||||
|
||||
def enable(self):
|
||||
dist.all_reduce = partial(all_reduce, profiler=self)
|
||||
dist.all_gather = partial(all_gather, profiler=self)
|
||||
dist.reduce_scatter = partial(reduce_scatter, profiler=self)
|
||||
dist.broadcast = partial(broadcast, profiler=self)
|
||||
dist.reduce = partial(reduce, profiler=self)
|
||||
|
||||
def disable(self):
|
||||
dist.all_reduce = torch_all_reduce
|
||||
dist.all_gather = torch_all_gather
|
||||
dist.reduce_scatter = torch_reduce_scatter
|
||||
dist.broadcast = torch_broadcast
|
||||
dist.reduce = torch_reduce
|
||||
|
||||
def to_tensorboard(self, writer):
|
||||
writer.add_text(tag="Collective Communication", text_string=self.result_str("\n\n"))
|
||||
|
||||
def to_file(self, filename: Path):
|
||||
with open(filename, "w") as f:
|
||||
f.write(self.result_str())
|
||||
|
||||
def show(self):
|
||||
print(self.result_str())
|
||||
|
||||
def result_str(self, sep: str = "\n"):
|
||||
res = []
|
||||
|
||||
def append(s: str = None):
|
||||
if s is not None:
|
||||
res.append(s)
|
||||
res.append(sep)
|
||||
|
||||
if self.warn_flag:
|
||||
append("Warning: there exists multiple communication operations in the same time. As a result, "
|
||||
"the profiling result is not accurate.")
|
||||
|
||||
if self.total_cuda_time == 0:
|
||||
return "No collective communication has been called yet!"
|
||||
|
||||
append("Collective communication profiling result:")
|
||||
append("total cuda time: {}".format(_format_time(self.total_cuda_time)))
|
||||
append("average bandwidth: {}".format(_format_bandwidth(self.total_comm_vol, self.total_cuda_time)))
|
||||
append("total number of calls: {}".format(self.total_count))
|
||||
append("All events:")
|
||||
|
||||
separation = '-' * 74
|
||||
row_format = '{:^10}' + '{:^12}' * 2 + '{:^16}' + '{:^12}' * 2
|
||||
|
||||
append(separation)
|
||||
append(row_format.format('Location', 'GPU time', 'Percentage', 'Comm volume', 'Bandwidth', 'Num of calls'))
|
||||
append(separation)
|
||||
|
||||
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time)
|
||||
for location, event in show_list:
|
||||
append(location)
|
||||
append(
|
||||
row_format.format('', _format_time(event.self_cuda_time),
|
||||
'{:.1f}%'.format(event.self_cuda_time / self.total_cuda_time * 100.0),
|
||||
_format_memory(event.self_comm_vol),
|
||||
_format_bandwidth(event.self_comm_vol, event.self_cuda_time), event.self_count))
|
||||
append()
|
||||
|
||||
return ''.join(res)
|
||||
|
||||
@property
|
||||
def has_aync_op(self):
|
||||
return self.pending_op is not None
|
||||
|
||||
def activate_profiler(self, kn: str, vol: float):
|
||||
self.pending_metadata = (kn, _get_code_location(self.depth), vol)
|
||||
self.profiler = profile(enabled=True, use_cuda=True, use_cpu=True, use_kineto=True)
|
||||
self.profiler.__enter__()
|
||||
|
||||
def close_profiler(self, group=None):
|
||||
assert self.profiler is not None, "There is no running dist op"
|
||||
kernel_name, code_location, vol = self.pending_metadata
|
||||
self.profiler.__exit__(None, None, None)
|
||||
|
||||
if self.profiler.enabled and dist.get_world_size(group) > 1:
|
||||
assert_flag = 0
|
||||
current_comm_event = None
|
||||
events = self.profiler.function_events
|
||||
for event in events:
|
||||
if kernel_name in event.name:
|
||||
assert assert_flag == 0, "Multiple dist ops has been called "
|
||||
current_comm_event = CommEvent(1, vol, event.self_cuda_time_total)
|
||||
assert_flag += 1
|
||||
|
||||
assert current_comm_event is not None, "dist op has not been found"
|
||||
|
||||
buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_current_device())
|
||||
torch_all_reduce(buffer, op=ReduceOp.MIN, group=group)
|
||||
current_comm_event.self_cuda_time = buffer.item()
|
||||
|
||||
self.total_count += current_comm_event.self_count
|
||||
self.total_comm_vol += current_comm_event.self_comm_vol
|
||||
self.total_cuda_time += current_comm_event.self_cuda_time
|
||||
if code_location in self.ops_record:
|
||||
self.ops_record[code_location].add(current_comm_event)
|
||||
else:
|
||||
self.ops_record[code_location] = current_comm_event
|
||||
|
||||
self.profiler = None
|
||||
self.pending_op = None
|
||||
self.pending_metadata = None
|
||||
|
||||
def wait_async_op(self):
|
||||
if self.pending_op is not None:
|
||||
op = self.pending_op
|
||||
op.wait()
|
||||
self.close_profiler()
|
||||
|
||||
|
||||
class CommHandler(object):
|
||||
"""Communication handler. A dummy handler to wait aync operations.
|
||||
"""
|
||||
|
||||
def __init__(self, profiler: CommProfiler):
|
||||
super().__init__()
|
||||
self.prof = profiler
|
||||
|
||||
def wait(self):
|
||||
self.prof.wait_async_op()
|
||||
|
||||
|
||||
def async_check(profiler: CommProfiler):
|
||||
if profiler.pending_op is not None:
|
||||
profiler.warn_flag = True
|
||||
profiler.wait_async_op()
|
||||
|
||||
|
||||
def all_reduce(tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
profiler: CommProfiler = None) -> Optional[CommHandler]:
|
||||
async_check(profiler)
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
correction = 2 * (comm_size - 1) / comm_size
|
||||
comm_vol = correction * tensor.element_size() * tensor.numel()
|
||||
profiler.activate_profiler("ncclKernel_AllReduce_", comm_vol)
|
||||
profiler.pending_op = torch_all_reduce(tensor, op, group, async_op)
|
||||
|
||||
if async_op:
|
||||
return CommHandler(profiler)
|
||||
|
||||
profiler.close_profiler(group)
|
||||
|
||||
|
||||
def reduce_scatter(output: torch.Tensor,
|
||||
input_list: List[torch.Tensor],
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
profiler: CommProfiler = None) -> Optional[CommHandler]:
|
||||
async_check(profiler)
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
correction = (comm_size - 1) / comm_size
|
||||
comm_vol = 0
|
||||
for tensor in input_list:
|
||||
comm_vol += tensor.element_size() * tensor.numel()
|
||||
comm_vol *= correction
|
||||
profiler.activate_profiler("ncclKernel_ReduceScatter_", comm_vol)
|
||||
profiler.pending_op = torch_reduce_scatter(output, input_list, op, group, async_op)
|
||||
|
||||
if async_op:
|
||||
return CommHandler(profiler)
|
||||
|
||||
profiler.close_profiler(group)
|
||||
|
||||
|
||||
def all_gather(tensor_list: List[torch.Tensor],
|
||||
tensor: torch.Tensor,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
profiler: CommProfiler = None) -> Optional[CommHandler]:
|
||||
async_check(profiler)
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
correction = (comm_size - 1) / comm_size
|
||||
comm_vol = 0
|
||||
for ten in tensor_list:
|
||||
comm_vol += ten.element_size() * ten.numel()
|
||||
comm_vol *= correction
|
||||
profiler.activate_profiler("ncclKernel_AllGather_", comm_vol)
|
||||
profiler.pending_op = torch_all_gather(tensor_list, tensor, group, async_op)
|
||||
|
||||
if async_op:
|
||||
return CommHandler(profiler)
|
||||
|
||||
profiler.close_profiler(group)
|
||||
|
||||
|
||||
def broadcast(tensor: torch.Tensor,
|
||||
src: int,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
profiler: CommProfiler = None) -> Optional[CommHandler]:
|
||||
async_check(profiler)
|
||||
|
||||
comm_vol = 1.0 * tensor.element_size() * tensor.numel()
|
||||
profiler.activate_profiler("ncclKernel_Broadcast_", comm_vol)
|
||||
profiler.pending_op = torch_broadcast(tensor, src, group, async_op)
|
||||
|
||||
if async_op:
|
||||
return CommHandler(profiler)
|
||||
|
||||
profiler.close_profiler(group)
|
||||
|
||||
|
||||
def reduce(tensor: torch.Tensor,
|
||||
dst: int,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
profiler: CommProfiler = None) -> Optional[CommHandler]:
|
||||
async_check(profiler)
|
||||
|
||||
comm_vol = 1.0 * tensor.element_size() * tensor.numel()
|
||||
profiler.activate_profiler("ncclKernel_Reduce_", comm_vol)
|
||||
profiler.pending_op = torch_reduce(tensor, dst, op, group, async_op)
|
||||
|
||||
if async_op:
|
||||
return CommHandler(profiler)
|
||||
|
||||
profiler.close_profiler(group)
|
150
colossalai/legacy/utils/profiler/legacy/pcie_profiler.py
Normal file
150
colossalai/legacy/utils/profiler/legacy/pcie_profiler.py
Normal file
@@ -0,0 +1,150 @@
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from torch.autograd.profiler import profile
|
||||
|
||||
from .prof_utils import BaseProfiler, _format_bandwidth, _format_memory, _format_time
|
||||
|
||||
|
||||
def _get_size(dtype: str):
|
||||
if dtype == "fp16":
|
||||
return 2
|
||||
elif dtype == "fp32":
|
||||
return 4
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _get_numel(my_list: List[int]) -> int:
|
||||
from functools import reduce
|
||||
from operator import mul
|
||||
return reduce(mul, my_list)
|
||||
|
||||
|
||||
def _reduce_location(locations: List[str]) -> str:
|
||||
ret = []
|
||||
for lo in locations:
|
||||
ret.append(lo)
|
||||
ret.append("\n")
|
||||
ret = ret[:-1]
|
||||
return ''.join(ret)
|
||||
|
||||
|
||||
class PcieEvent(object):
|
||||
"""Pcie Event.
|
||||
"""
|
||||
|
||||
def __init__(self, count: int = 0, pcie_vol: int = 0, cuda_time: int = 0):
|
||||
self.count = count
|
||||
self.pcie_vol = pcie_vol
|
||||
self.cuda_time = cuda_time
|
||||
|
||||
def add(self, rhs):
|
||||
self.count += rhs.count
|
||||
self.pcie_vol += rhs.pcie_vol
|
||||
self.cuda_time += rhs.cuda_time
|
||||
|
||||
|
||||
class PcieProfiler(BaseProfiler):
|
||||
"""Pcie profiler. Records all data transmission between CPU and GPU.
|
||||
|
||||
TODO: Merge pcie profiler into communication profiler
|
||||
"""
|
||||
|
||||
def __init__(self, dtype: str = "fp32", depth: int = 1):
|
||||
super().__init__(profiler_name="Pcie", priority=10)
|
||||
self.depth = depth
|
||||
self.data_size = _get_size(dtype)
|
||||
self.h2d_count = 0
|
||||
self.h2d_time = 0
|
||||
self.d2h_count = 0
|
||||
self.d2h_time = 0
|
||||
|
||||
self.ops_record = dict()
|
||||
self.profiler = None
|
||||
|
||||
def reset(self):
|
||||
self.h2d_count = 0
|
||||
self.h2d_time = 0
|
||||
self.d2h_count = 0
|
||||
self.d2h_time = 0
|
||||
|
||||
self.ops_record = dict()
|
||||
self.profiler = None
|
||||
|
||||
def enable(self):
|
||||
self.profiler = profile(enabled=True,
|
||||
use_cuda=True,
|
||||
use_cpu=True,
|
||||
use_kineto=True,
|
||||
record_shapes=True,
|
||||
with_stack=True)
|
||||
self.profiler.__enter__()
|
||||
|
||||
def disable(self):
|
||||
self.profiler.__exit__(None, None, None)
|
||||
|
||||
if self.profiler.enabled:
|
||||
events = self.profiler.function_events
|
||||
for event in events:
|
||||
if event.name == "aten::copy_":
|
||||
t_shape = event.input_shapes[0]
|
||||
if len(t_shape) == 0 or event.cuda_time_total == 0 or len(event.stack) == 0:
|
||||
continue
|
||||
current_comm_event = PcieEvent(1, self.data_size * _get_numel(t_shape), event.cuda_time_total)
|
||||
code_location = _reduce_location(event.stack[:self.depth])
|
||||
if code_location in self.ops_record:
|
||||
self.ops_record[code_location].add(current_comm_event)
|
||||
else:
|
||||
self.ops_record[code_location] = current_comm_event
|
||||
elif 'Memcpy HtoD' in event.name:
|
||||
self.h2d_count += 1
|
||||
self.h2d_time += event.cuda_time_total
|
||||
elif 'Memcpy DtoH' in event.name:
|
||||
self.d2h_count += 1
|
||||
self.d2h_time += event.cuda_time_total
|
||||
|
||||
self.profiler = None
|
||||
|
||||
def to_tensorboard(self, writer):
|
||||
writer.add_text(tag="Data Transmission", text_string=self.result_str("\n\n"))
|
||||
|
||||
def to_file(self, filename: Path):
|
||||
with open(filename, "w") as f:
|
||||
f.write(self.result_str())
|
||||
|
||||
def show(self):
|
||||
print(self.result_str())
|
||||
|
||||
def result_str(self, sep: str = "\n"):
|
||||
res = []
|
||||
|
||||
def append(s: str = None):
|
||||
if s is not None:
|
||||
res.append(s)
|
||||
res.append(sep)
|
||||
|
||||
append("Pcie profiling result:")
|
||||
append("time of data transmission (CPU -> GPU): {}".format(_format_time(self.h2d_time)))
|
||||
append("number of transmission (CPU -> GPU): {}".format(self.h2d_count))
|
||||
append("time of data transmission (GPU -> CPU): {}".format(_format_time(self.d2h_time)))
|
||||
append("number of transmission (GPU -> CPU): {}".format(self.d2h_count))
|
||||
|
||||
append("Possible data transmission events in PCIE:")
|
||||
|
||||
separation = '-' * 62
|
||||
row_format = '{:^10}' + '{:^12}' + '{:^16}' + '{:^12}' * 2
|
||||
|
||||
append(separation)
|
||||
append(row_format.format('Location', 'GPU time', 'Trans volume', 'Bandwidth', 'Num of calls'))
|
||||
append(separation)
|
||||
|
||||
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time)
|
||||
for location, event in show_list:
|
||||
append(location)
|
||||
append(
|
||||
row_format.format('', _format_time(event.cuda_time), _format_memory(event.pcie_vol),
|
||||
_format_bandwidth(event.pcie_vol, event.cuda_time), event.count))
|
||||
append()
|
||||
|
||||
return ''.join(res)
|
132
colossalai/legacy/utils/profiler/legacy/prof_utils.py
Normal file
132
colossalai/legacy/utils/profiler/legacy/prof_utils.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
|
||||
|
||||
# copied from high version pytorch to support low version
|
||||
def _format_time(time_us):
|
||||
"""Defines how to format time in FunctionEvent"""
|
||||
US_IN_SECOND = 1000.0 * 1000.0
|
||||
US_IN_MS = 1000.0
|
||||
if time_us >= US_IN_SECOND:
|
||||
return '{:.3f}s'.format(time_us / US_IN_SECOND)
|
||||
if time_us >= US_IN_MS:
|
||||
return '{:.3f}ms'.format(time_us / US_IN_MS)
|
||||
return '{:.3f}us'.format(time_us)
|
||||
|
||||
|
||||
# copied from high version pytorch to support low version
|
||||
def _format_memory(nbytes):
|
||||
"""Returns a formatted memory size string"""
|
||||
KB = 1024
|
||||
MB = 1024 * KB
|
||||
GB = 1024 * MB
|
||||
if (abs(nbytes) >= GB):
|
||||
return '{:.2f} GB'.format(nbytes * 1.0 / GB)
|
||||
elif (abs(nbytes) >= MB):
|
||||
return '{:.2f} MB'.format(nbytes * 1.0 / MB)
|
||||
elif (abs(nbytes) >= KB):
|
||||
return '{:.2f} KB'.format(nbytes * 1.0 / KB)
|
||||
else:
|
||||
return str(nbytes) + ' B'
|
||||
|
||||
|
||||
def _format_bandwidth(volume: float or int, time_us: int):
|
||||
sec_div_mb = (1000.0 / 1024.0)**2
|
||||
mb_per_sec = volume / time_us * sec_div_mb
|
||||
|
||||
if mb_per_sec >= 1024.0:
|
||||
return '{:.3f} GB/s'.format(mb_per_sec / 1024.0)
|
||||
else:
|
||||
return '{:.3f} MB/s'.format(mb_per_sec)
|
||||
|
||||
|
||||
class BaseProfiler(ABC):
|
||||
|
||||
def __init__(self, profiler_name: str, priority: int):
|
||||
self.name = profiler_name
|
||||
self.priority = priority
|
||||
|
||||
@abstractmethod
|
||||
def enable(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def disable(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def to_tensorboard(self, writer):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def to_file(self, filename: Path):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def show(self):
|
||||
pass
|
||||
|
||||
|
||||
class ProfilerContext(object):
|
||||
"""Profiler context manager
|
||||
|
||||
Usage::
|
||||
|
||||
world_size = 4
|
||||
inputs = torch.randn(10, 10, dtype=torch.float32, device=get_current_device())
|
||||
outputs = torch.empty(world_size, 10, 10, dtype=torch.float32, device=get_current_device())
|
||||
outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0))
|
||||
|
||||
cc_prof = CommProfiler()
|
||||
|
||||
with ProfilerContext([cc_prof]) as prof:
|
||||
op = dist.all_reduce(inputs, async_op=True)
|
||||
dist.all_gather(outputs_list, inputs)
|
||||
op.wait()
|
||||
dist.reduce_scatter(inputs, outputs_list)
|
||||
dist.broadcast(inputs, 0)
|
||||
dist.reduce(inputs, 0)
|
||||
|
||||
prof.show()
|
||||
"""
|
||||
|
||||
def __init__(self, profilers: List[BaseProfiler] = None, enable: bool = True):
|
||||
self.enable = enable
|
||||
self.profilers = sorted(profilers, key=lambda prof: prof.priority)
|
||||
|
||||
def __enter__(self):
|
||||
if self.enable:
|
||||
for prof in self.profilers:
|
||||
prof.enable()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.enable:
|
||||
for prof in self.profilers:
|
||||
prof.disable()
|
||||
|
||||
def to_tensorboard(self, writer):
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
assert isinstance(writer, SummaryWriter), \
|
||||
f'torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}.'
|
||||
|
||||
for prof in self.profilers:
|
||||
prof.to_tensorboard(writer)
|
||||
|
||||
def to_file(self, log_dir: Union[str, Path]):
|
||||
if isinstance(log_dir, str):
|
||||
log_dir = Path(log_dir)
|
||||
|
||||
if not log_dir.exists():
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
for prof in self.profilers:
|
||||
log_file = log_dir.joinpath(f'{prof.name}_rank_{gpc.get_global_rank()}.log')
|
||||
prof.to_file(log_file)
|
||||
|
||||
def show(self):
|
||||
for prof in self.profilers:
|
||||
prof.show()
|
201
colossalai/legacy/utils/profiler/profiler.py
Normal file
201
colossalai/legacy/utils/profiler/profiler.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import gzip
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Callable, Iterable, List, Optional
|
||||
|
||||
from torch.autograd import ProfilerActivity
|
||||
from torch.profiler import profile as torch_profile
|
||||
from torch.profiler.profiler import ProfilerAction
|
||||
|
||||
from colossalai.legacy.engine import Engine
|
||||
from colossalai.legacy.utils.profiler.extention import ProfilerExtension
|
||||
from colossalai.legacy.utils.profiler.stateful_tensor_mem_extention import StatefulTensorMemoryProfilerExtention
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
class profile(torch_profile):
|
||||
"""Profiler context manager.
|
||||
|
||||
Args:
|
||||
activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values:
|
||||
``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``.
|
||||
Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA.
|
||||
schedule (callable): callable that takes step (int) as a single parameter and returns
|
||||
``ProfilerAction`` value that specifies the profiler action to perform at each step.
|
||||
on_trace_ready (callable): callable that is called at each step when ``schedule``
|
||||
returns ``ProfilerAction.RECORD_AND_SAVE`` during the profiling.
|
||||
engine (Optional[Engine], optional): An ``Engine`` instance. Defaults to None.
|
||||
record_shapes (bool): save information about operator's input shapes.
|
||||
profile_memory (bool): track tensor memory allocation/deallocation.
|
||||
with_stack (bool): record source information (file and line number) for the ops.
|
||||
with_flops (bool): use formula to estimate the FLOPs (floating point operations) of specific operators
|
||||
(matrix multiplication and 2D convolution).
|
||||
with_modules (bool): record module hierarchy (including function names)
|
||||
corresponding to the callstack of the op. e.g. If module A's forward call's
|
||||
module B's forward which contains an aten::add op,
|
||||
then aten::add's module hierarchy is A.B
|
||||
Note that this support exist, at the moment, only for TorchScript models
|
||||
and not eager mode models.
|
||||
profile_stateful_tensor_memory (bool): track stateful tensor memory usage. ``engine`` must not be None if you enable this.
|
||||
|
||||
.. note::
|
||||
Use :func:`~torch.profiler.schedule` to generate the callable schedule.
|
||||
Non-default schedules are useful when profiling long training jobs
|
||||
and allow the user to obtain multiple traces at the different iterations
|
||||
of the training process.
|
||||
The default schedule simply records all the events continuously for the
|
||||
duration of the context manager.
|
||||
|
||||
.. note::
|
||||
Use :func:`~torch.profiler.tensorboard_trace_handler` to generate result files for TensorBoard:
|
||||
|
||||
``on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name)``
|
||||
|
||||
After profiling, result files can be found in the specified directory. Use the command:
|
||||
|
||||
``tensorboard --logdir dir_name``
|
||||
|
||||
to see the results in TensorBoard.
|
||||
For more information, see
|
||||
`PyTorch Profiler TensorBoard Plugin <https://github.com/pytorch/kineto/tree/master/tb_plugin>`__
|
||||
|
||||
.. note::
|
||||
Enabling shape and stack tracing results in additional overhead.
|
||||
When record_shapes=True is specified, profiler will temporarily hold references to the tensors;
|
||||
that may further prevent certain optimizations that depend on the reference count and introduce
|
||||
extra tensor copies.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
with torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
]
|
||||
) as p:
|
||||
code_to_profile()
|
||||
print(p.key_averages().table(
|
||||
sort_by="self_cuda_time_total", row_limit=-1))
|
||||
|
||||
Using the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Non-default profiler schedule allows user to turn profiler on and off
|
||||
# on different iterations of the training loop;
|
||||
# trace_handler is called every time a new trace becomes available
|
||||
def trace_handler(prof):
|
||||
print(prof.key_averages().table(
|
||||
sort_by="self_cuda_time_total", row_limit=-1))
|
||||
# prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json")
|
||||
|
||||
with torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
|
||||
# In this example with wait=1, warmup=1, active=2,
|
||||
# profiler will skip the first step/iteration,
|
||||
# start warming up on the second, record
|
||||
# the third and the forth iterations,
|
||||
# after which the trace will become available
|
||||
# and on_trace_ready (when set) is called;
|
||||
# the cycle repeats starting with the next step
|
||||
|
||||
schedule=torch.profiler.schedule(
|
||||
wait=1,
|
||||
warmup=1,
|
||||
active=2),
|
||||
on_trace_ready=trace_handler
|
||||
# on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')
|
||||
# used when outputting for tensorboard
|
||||
) as p:
|
||||
for iter in range(N):
|
||||
code_iteration_to_profile(iter)
|
||||
# send a signal to the profiler that the next iteration has started
|
||||
p.step()
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
activities: Optional[Iterable[ProfilerActivity]] = None,
|
||||
schedule: Optional[Callable[[int], ProfilerAction]] = None,
|
||||
on_trace_ready: Optional[Callable[..., Any]] = None,
|
||||
engine: Optional[Engine] = None,
|
||||
record_shapes: bool = False,
|
||||
profile_memory: bool = False,
|
||||
with_stack: bool = False,
|
||||
with_flops: bool = False,
|
||||
with_modules: bool = False,
|
||||
profile_stateful_tensor_memory: bool = False) -> None:
|
||||
super().__init__(activities=activities,
|
||||
schedule=schedule,
|
||||
on_trace_ready=on_trace_ready,
|
||||
record_shapes=record_shapes,
|
||||
profile_memory=profile_memory,
|
||||
with_stack=with_stack,
|
||||
with_flops=with_flops,
|
||||
with_modules=with_modules)
|
||||
self._logger = get_dist_logger()
|
||||
self.extentions: List[ProfilerExtension] = []
|
||||
if profile_stateful_tensor_memory:
|
||||
if engine is None:
|
||||
self._logger.warning('Ignore "profile_model_data" since engine is None', ranks=[0])
|
||||
else:
|
||||
self.extentions.append(StatefulTensorMemoryProfilerExtention(engine))
|
||||
|
||||
def prepare_trace(self) -> None:
|
||||
if hasattr(super(), 'prepare_trace'):
|
||||
super().prepare_trace()
|
||||
elif hasattr(super(), '_start_warmup'):
|
||||
super()._start_warmup()
|
||||
for ext in self.extentions:
|
||||
ext.prepare_trace()
|
||||
|
||||
def _start_warmup(self):
|
||||
self.prepare_trace()
|
||||
|
||||
def start_trace(self):
|
||||
if hasattr(super(), '_start_trace'):
|
||||
super()._start_trace()
|
||||
elif hasattr(super(), 'start_trace'):
|
||||
super().start_trace()
|
||||
for ext in self.extentions:
|
||||
ext.start_trace()
|
||||
|
||||
def _start_trace(self):
|
||||
self.start_trace()
|
||||
|
||||
def stop_trace(self):
|
||||
if hasattr(super(), '_stop_trace'):
|
||||
super()._stop_trace()
|
||||
elif hasattr(super(), 'stop_trace'):
|
||||
super().stop_trace()
|
||||
for ext in self.extentions:
|
||||
ext.stop_trace()
|
||||
|
||||
def _stop_trace(self):
|
||||
self.stop_trace()
|
||||
|
||||
def export_chrome_trace(self, path: str):
|
||||
"""
|
||||
Exports the collected trace in Chrome JSON format.
|
||||
"""
|
||||
assert self.profiler
|
||||
fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False)
|
||||
fp.close()
|
||||
retvalue = self.profiler.export_chrome_trace(fp.name)
|
||||
with open(fp.name) as fin:
|
||||
trace = json.load(fin)
|
||||
for ext in self.extentions:
|
||||
trace = ext.extend_chrome_trace(trace)
|
||||
open_func = gzip.open if path.endswith('.gz') else open
|
||||
with open_func(path, 'wt') as fout:
|
||||
json.dump(trace, fout)
|
||||
|
||||
os.remove(fp.name)
|
||||
return retvalue
|
@@ -0,0 +1,135 @@
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.gemini.ophooks import BaseOpHook
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||
from colossalai.legacy.engine import Engine
|
||||
from colossalai.legacy.utils.profiler.extention import ProfilerExtension
|
||||
|
||||
|
||||
class DeviceType(Enum):
|
||||
CPU = 0
|
||||
CUDA = 1
|
||||
|
||||
|
||||
def get_timestamp_us():
|
||||
return int(time.time() * 1e6)
|
||||
|
||||
|
||||
def generic_instant_event(name, pid, tid, timestamp, args):
|
||||
return {'ph': 'i', 's': 't', 'name': name, 'pid': pid, 'tid': tid, 'ts': timestamp, 'args': args}
|
||||
|
||||
|
||||
class StatefulTensorMemoryEvent:
|
||||
EVENT_NAME = '[statefulTensorMemory]'
|
||||
|
||||
def __init__(self, timestamp: int, device_type: DeviceType, bytes_: int) -> None:
|
||||
self.pid = os.getpid()
|
||||
self.tid = threading.get_ident()
|
||||
self.timestamp = timestamp
|
||||
self.device_type = device_type
|
||||
self.device_id = torch.cuda.current_device() if device_type == DeviceType.CUDA else -1
|
||||
self.bytes = bytes_
|
||||
|
||||
def state_dict(self):
|
||||
return generic_instant_event(StatefulTensorMemoryEvent.EVENT_NAME, self.pid, self.tid, self.timestamp, {
|
||||
'Device Type': self.device_type.value,
|
||||
'Device Id': self.device_id,
|
||||
'Bytes': self.bytes
|
||||
})
|
||||
|
||||
|
||||
class StatefulTensorMemoryTracer:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.events: List[StatefulTensorMemoryEvent] = []
|
||||
self._tracing = False
|
||||
|
||||
def sample(self):
|
||||
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
|
||||
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
|
||||
timestamp = get_timestamp_us()
|
||||
if self._tracing:
|
||||
self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CUDA, cuda_mem))
|
||||
self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CPU, cpu_mem))
|
||||
|
||||
def start_trace(self):
|
||||
self.events.clear()
|
||||
self._tracing = True
|
||||
|
||||
def stop_trace(self):
|
||||
self._tracing = False
|
||||
|
||||
def state_dict(self):
|
||||
return [event.state_dict() for event in self.events]
|
||||
|
||||
|
||||
class StatefulTensorMemoryTracerHook(BaseOpHook):
|
||||
|
||||
def __init__(self, tracer: StatefulTensorMemoryTracer):
|
||||
super().__init__()
|
||||
self.tracer = tracer
|
||||
self._enable = False
|
||||
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
if self._enable:
|
||||
self.tracer.sample()
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
if self._enable:
|
||||
self.tracer.sample()
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input_, output):
|
||||
if self._enable:
|
||||
self.tracer.sample()
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input_):
|
||||
if self._enable:
|
||||
self.tracer.sample()
|
||||
|
||||
def post_iter(self):
|
||||
if self._enable:
|
||||
self.tracer.sample()
|
||||
|
||||
def enable(self):
|
||||
self._enable = True
|
||||
|
||||
def disable(self):
|
||||
self._enable = False
|
||||
|
||||
|
||||
class StatefulTensorMemoryProfilerExtention(ProfilerExtension):
|
||||
|
||||
def __init__(self, engine: Engine) -> None:
|
||||
self.engine = engine
|
||||
self.tracer = StatefulTensorMemoryTracer()
|
||||
self.hook = StatefulTensorMemoryTracerHook(self.tracer)
|
||||
self.hook_registered = False
|
||||
|
||||
def prepare_trace(self):
|
||||
self.hook.enable()
|
||||
if not self.hook_registered:
|
||||
self.engine.add_hook(self.hook)
|
||||
self.hook_registered = True
|
||||
|
||||
def start_trace(self):
|
||||
self.prepare_trace()
|
||||
self.tracer.start_trace()
|
||||
|
||||
def stop_trace(self):
|
||||
self.tracer.stop_trace()
|
||||
self.hook.disable()
|
||||
if self.hook_registered:
|
||||
self.engine.remove_hook(self.hook)
|
||||
# remove_hook is not implemented now
|
||||
# FIXME(ver217): uncomment below line when remove_hook is implemented
|
||||
# self.hook_registered = False
|
||||
|
||||
def extend_chrome_trace(self, trace: dict) -> dict:
|
||||
trace['traceEvents'].extend(self.tracer.state_dict())
|
||||
return trace
|
Reference in New Issue
Block a user