[legacy] clean up legacy code (#4743)

* [legacy] remove outdated codes of pipeline (#4692)

* [legacy] remove cli of benchmark and update optim (#4690)

* [legacy] remove cli of benchmark and update optim

* [doc] fix cli doc test

* [legacy] fix engine clip grad norm

* [legacy] remove outdated colo tensor (#4694)

* [legacy] remove outdated colo tensor

* [test] fix test import

* [legacy] move outdated zero to legacy (#4696)

* [legacy] clean up utils (#4700)

* [legacy] clean up utils

* [example] update examples

* [legacy] clean up amp

* [legacy] fix amp module

* [legacy] clean up gpc (#4742)

* [legacy] clean up context

* [legacy] clean core, constants and global vars

* [legacy] refactor initialize

* [example] fix examples ci

* [example] fix examples ci

* [legacy] fix tests

* [example] fix gpt example

* [example] fix examples ci

* [devops] fix ci installation

* [example] fix examples ci
This commit is contained in:
Hongxin Liu
2023-09-18 16:31:06 +08:00
committed by GitHub
parent 32e7f99416
commit b5f9e37c70
342 changed files with 2919 additions and 4182 deletions

View File

@@ -0,0 +1,53 @@
from .checkpointing import load_checkpoint, save_checkpoint
from .common import (
clip_grad_norm_fp32,
copy_tensor_parallel_attributes,
count_zeros_fp32,
is_dp_rank_0,
is_model_parallel_parameter,
is_no_pp_or_last_stage,
is_tp_rank_0,
is_using_ddp,
is_using_pp,
is_using_sequence,
param_is_not_tensor_parallel_duplicate,
print_rank_0,
switch_virtual_pipeline_parallel_rank,
sync_model_param,
)
from .data_sampler import DataParallelSampler, get_dataloader
from .memory import (
colo_device_memory_capacity,
colo_device_memory_used,
colo_get_cpu_memory_capacity,
colo_set_cpu_memory_capacity,
colo_set_process_memory_fraction,
report_memory_usage,
)
__all__ = [
'DataParallelSampler',
'get_dataloader',
'save_checkpoint',
'load_checkpoint',
'colo_device_memory_capacity',
'colo_device_memory_used',
'colo_get_cpu_memory_capacity',
'colo_set_cpu_memory_capacity',
'colo_set_process_memory_fraction',
'report_memory_usage',
'clip_grad_norm_fp32',
'copy_tensor_parallel_attributes',
'count_zeros_fp32',
'is_dp_rank_0',
'is_model_parallel_parameter',
'is_no_pp_or_last_stage',
'is_tp_rank_0',
'is_using_ddp',
'is_using_pp',
'is_using_sequence',
'param_is_not_tensor_parallel_duplicate',
'print_rank_0',
'switch_virtual_pipeline_parallel_rank',
'sync_model_param',
]

View File

@@ -0,0 +1,259 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import weakref
import torch
from torch.utils.checkpoint import check_backward_validity, detach_variable
from colossalai.legacy.context.random import get_current_mode, get_states, set_mode, set_seed_states, sync_states
from colossalai.utils import get_current_device
def copy_to_device(obj, device):
if torch.is_tensor(obj):
# Notice:
# When in no_grad context, requires_gard is False after movement
ret = obj.to(device).detach()
ret.requires_grad = obj.requires_grad
return ret
elif isinstance(obj, list):
return [copy_to_device(i, device) for i in obj]
elif isinstance(obj, tuple):
return tuple([copy_to_device(v, device) for v in obj])
elif isinstance(obj, dict):
return {k: copy_to_device(v, device) for k, v in obj.items()}
else:
return obj
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, activation_offload=False, *args):
check_backward_validity(args)
ctx.run_function = run_function
ctx.activation_offload = activation_offload
ctx.device = get_current_device()
# preserve rng states
ctx.fwd_cpu_rng_state = torch.get_rng_state()
sync_states()
ctx.fwd_seed_states = get_states(copy=True)
ctx.fwd_current_mode = get_current_mode()
if hasattr(torch, 'is_autocast_enabled'):
ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
else:
ctx.had_autocast_in_fwd = False
if activation_offload:
inputs_cuda = copy_to_device(args, ctx.device)
else:
inputs_cuda = args
with torch.no_grad():
outputs = run_function(*inputs_cuda)
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
# to be filled out during the backward.
ctx.inputs = []
ctx.tensor_indices = []
tensor_inputs = []
for i, arg in enumerate(args):
if torch.is_tensor(arg):
if activation_offload:
tensor_inputs.append(copy_to_device(arg, 'cpu'))
else:
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)
if activation_offload:
ctx.tensor_inputs = tensor_inputs
else:
ctx.save_for_backward(*tensor_inputs)
return outputs
@staticmethod
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad() or when an `inputs` parameter is "
"passed to .backward(). Please use .backward() and do not pass its `inputs` argument.")
# Copy the list to avoid modifying original list.
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
if ctx.activation_offload:
tensors = ctx.tensor_inputs
else:
tensors = ctx.saved_tensors
# store the current states
bwd_cpu_rng_state = torch.get_rng_state()
sync_states()
bwd_seed_states = get_states(copy=True)
bwd_current_mode = get_current_mode()
# set the states to what it used to be
torch.set_rng_state(ctx.fwd_cpu_rng_state)
for parallel_mode, state in ctx.fwd_seed_states.items():
set_seed_states(parallel_mode, state)
set_mode(ctx.fwd_current_mode)
if ctx.activation_offload:
tensors = copy_to_device(tensors, ctx.device)
# Fill in inputs with appropriate saved tensors.
for i, idx in enumerate(tensor_indices):
inputs[idx] = tensors[i]
detached_inputs = detach_variable(tuple(inputs))
if ctx.had_autocast_in_fwd:
with torch.enable_grad(), torch.cuda.amp.autocast():
outputs = ctx.run_function(*detached_inputs)
else:
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
# recover the rng states
torch.set_rng_state(bwd_cpu_rng_state)
for parallel_mode, state in bwd_seed_states.items():
set_seed_states(parallel_mode, state)
set_mode(bwd_current_mode)
# run backward() with only tensor that requires grad
outputs_with_grad = []
args_with_grad = []
for i in range(len(outputs)):
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
outputs_with_grad.append(outputs[i])
args_with_grad.append(args[i])
if len(outputs_with_grad) == 0:
raise RuntimeError("none of output has requires_grad=True,"
" this checkpoint() is not necessary")
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs)
return (None, None) + grads
def checkpoint(function, activation_offload, *args, use_reentrant: bool = True):
"""Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint.
Args:
function: Describe the forward pass function. It should know how to handle the input tuples.
activation_offload: The variable to check whether we should offload activation to cpu
args (list): Tuple containing the parameters of the function
use_reentrant: Bool type to check if we need to use_reentrant, if use_reentrant=False, there
might be more flexibility for user to define there checkpoint function
Returns:
Output of running function with provided args.
"""
if use_reentrant:
return CheckpointFunction.apply(function, activation_offload, *args)
else:
return _checkpoint_without_reentrant(
function,
activation_offload,
*args,
)
def _checkpoint_without_reentrant(function, activation_offload=False, *args):
# store rng_state
fwd_cpu_state = torch.get_rng_state()
sync_states()
fwd_seed_states = get_states(copy=True)
fwd_current_mode = get_current_mode()
# check if use autocast
if hasattr(torch, 'is_autocast_enabled'):
has_autocast_in_fwd = torch.is_autocast_enabled()
else:
has_autocast_in_fwd = False
# using WeakKeyDictionary to store all the activation the first time we call unpack
storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
weak_holder_list = []
# class for weakref.ref
class Holder():
pass
# return a Holder object for later unpack process
def pack(x):
res = Holder()
weak_holder_list.append(weakref.ref(res))
return res
# unpack hook
def unpack(x):
unpack_counter = 0
# re-compute all the activation inside the function when we first call unpack
if len(storage) == 0:
def inner_pack(inner):
nonlocal unpack_counter
unpack_counter += 1
# If the holder went out of scope, the SavedVariable is dead and so
# the value will never be read from the storage. Skip filling it.
if weak_holder_list[unpack_counter - 1]() is None:
return
# Use detach here to ensure we don't keep the temporary autograd
# graph created during the second forward
storage[weak_holder_list[unpack_counter - 1]()] = inner.detach()
return
def inner_unpack(packed):
raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.")
# restore rng state
torch.set_rng_state(fwd_cpu_state)
for parallel_mode, state in fwd_seed_states.items():
set_seed_states(parallel_mode, state)
set_mode(fwd_current_mode)
# reload arg into device if needed
if activation_offload:
for arg in args:
if torch.is_tensor(arg):
arg = arg.to(device=device)
# rerun forward, the inner_pack will store all the activations in storage
if has_autocast_in_fwd:
with torch.enable_grad(), \
torch.cuda.amp.autocast(), \
torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
_unused = function(*args)
else:
with torch.enable_grad(), \
torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
_unused = function(*args)
if x not in storage:
raise RuntimeError("Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
" recomputation being triggered in between, this is not currently supported. Please"
" open an issue with details on your use case so that we can prioritize adding this.")
return storage[x]
# get device if we need to offload the activation
if activation_offload:
device = get_current_device()
# run function with pack and unpack as saved_tensors_hooks
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
output = function(*args)
# offload activation if needed
if activation_offload:
for arg in args:
if torch.is_tensor(arg):
arg = arg.to(device="cpu")
return output

View File

@@ -0,0 +1,3 @@
from .module_checkpoint import load_checkpoint, save_checkpoint
__all__ = ['save_checkpoint', 'load_checkpoint']

View File

@@ -0,0 +1,140 @@
from typing import Dict, Optional
import torch
import torch.distributed as dist
from colossalai.interface import OptimizerWrapper
from colossalai.tensor import ColoTensor
from .utils import gather_tensor, scatter_tensor
def save_checkpoint(path: str,
epoch: int,
model: torch.nn.Module,
optimizer: Optional[OptimizerWrapper] = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
*args,
**kwargs):
"""save_checkpoint
save a model, whose parameters are `ColoTensor`s.
Args:
path (str): directory to save the checkpoint files.
epoch (int): the number of epoch
model (torch.nn.Module): a torch module initialized by ColoInitContext
optimizer (OptimizerWrapper, optional): optimizers. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
"""
rank = dist.get_rank()
model_state = model.state_dict()
# save the dist context about the tensors in a new dict, while still maintain the original dict.
for k, v in model_state.items():
if isinstance(v, ColoTensor):
gather_tensor(v) # gather shared tensors to rank0
# don't recover tensors in rank0, since the dict is only a copy of model
if rank == 0:
# sanity check
for k, v in model_state.items():
if isinstance(v, ColoTensor):
assert v.save_ready
assert v.is_replicate()
delattr(v, 'save_ready')
# model saving
save_state = {'epoch': epoch, 'model': model_state}
torch.save(save_state, path + '/epoch_{}_model.pth'.format(epoch), *args, **kwargs)
# delete old dicts
del model_state
# synchronize all the processes
dist.barrier()
if optimizer is not None:
mapping = dict()
optim_state = optimizer.state_dict()
for k, v in optim_state['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
mapping[(k, n)] = t.dist_spec
gather_tensor(t)
if rank == 0:
save_state = {'epoch': epoch, 'optim': optim_state}
torch.save(save_state, path + '/epoch_{}_optim.pth'.format(epoch), *args, **kwargs)
# recover colo tensors in rank0
for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
assert hasattr(t, 'save_ready')
t.set_dist_spec(mapping[(k, n)])
delattr(t, 'save_ready')
del optim_state
del mapping
dist.barrier()
def load_checkpoint(path: str,
epoch: int,
model: torch.nn.Module,
optimizer: Optional[OptimizerWrapper] = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
torch_load_kwargs: Optional[Dict] = None,
load_state_dict_kwargs: Optional[Dict] = None):
"""load_checkpoint
load a model, whose parameters are `ColoTensor`s.
Args:
path (str): directory to save the checkpoint files.
epoch (int): the number of epoch
model (torch.nn.Module): a torch module initialized by ColoInitContext
optimizer (OptimizerWrapper, optional): optimizers. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
torch_load_kwargs: (dict, optional): The kwargs of torch.load inside the function
load_state_dict_kwargs (dict, optional): The kwargs of load_state_dict inside the function
"""
# initialize the default parameters
if not torch_load_kwargs:
torch_load_kwargs = dict()
if not load_state_dict_kwargs:
load_state_dict_kwargs = dict()
rank = dist.get_rank()
mapping = dict()
for n, p in model.named_parameters():
if isinstance(p, ColoTensor):
mapping[n] = p.dist_spec
gather_tensor(p)
if rank == 0:
load_state = torch.load(path + '/epoch_{}_model.pth'.format(epoch), **torch_load_kwargs)
model.load_state_dict(load_state['model'], **load_state_dict_kwargs)
dist.barrier()
# scatter loaded parameters
for n, p in model.named_parameters():
if isinstance(p, ColoTensor):
scatter_tensor(p, mapping[n])
if rank == 0:
assert hasattr(p, 'save_ready')
delattr(p, 'save_ready')
del mapping
if optimizer is not None:
mapping = dict()
for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
mapping[(k, n)] = t.dist_spec
gather_tensor(t)
if rank == 0:
colo_checkpoint = torch.load(path + '/epoch_{}_optim.pth'.format(epoch), **torch_load_kwargs)
optimizer.load_state_dict(colo_checkpoint['optim'], **load_state_dict_kwargs)
dist.barrier()
for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
scatter_tensor(t, mapping[(k, n)])
del mapping

View File

@@ -0,0 +1,65 @@
import torch
import torch.distributed as dist
from colossalai.legacy.tensor import ColoTensorSpec
from colossalai.legacy.tensor.distspec import DistPlacementPattern, _DistSpec
from colossalai.tensor import ColoTensor
def robust_broadcast(tensor):
with torch.no_grad():
is_cpu_ten = tensor.device.type == 'cpu'
if is_cpu_ten:
b_data = tensor.cuda()
else:
b_data = tensor
dist.broadcast(b_data, 0)
if is_cpu_ten:
tensor.copy_(b_data)
def gather_tensor(colo_tensor: ColoTensor) -> None:
"""Make colo_tensor replicated when the rank is 0
"""
if not colo_tensor.is_replicate():
pg = colo_tensor.get_process_group()
# for the group which contains rank 0
if pg.dp_local_rank() == 0:
old_dist_spec = colo_tensor.dist_spec
colo_tensor.to_replicate_()
if dist.get_rank() != 0:
colo_tensor.set_dist_spec(old_dist_spec)
# synchronize all processes for unexpected problems
dist.barrier()
if dist.get_rank() == 0:
setattr(colo_tensor, 'save_ready', True) # set saving signature
def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
"""Reversal operation of `gather_tensor`.
"""
if dist_spec.placement == DistPlacementPattern.REPLICATE:
robust_broadcast(colo_tensor.data)
else:
global_size = colo_tensor.size_global()
if dist.get_rank() == 0:
entire_data = colo_tensor.data
else:
entire_data = torch.empty(global_size, device=colo_tensor.device)
robust_broadcast(entire_data)
if dist.get_rank() == 0:
colo_tensor.set_dist_spec(dist_spec)
else:
rep_tensor = ColoTensor(
entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec))
rep_tensor.set_dist_spec(dist_spec)
with torch.no_grad():
colo_tensor.data.copy_(rep_tensor.data)
# synchronize all processes for unexpected problems
dist.barrier()

View File

@@ -0,0 +1,268 @@
from collections import OrderedDict
from itertools import chain
import torch
import torch.distributed as dist
from colossalai.legacy.constants import IS_TENSOR_PARALLEL
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
from .common import is_using_pp
__all__ = ["save_checkpoint", "load_checkpoint"]
def broadcast_state_dict(state_dict, parallel_mode):
state_dict = [state_dict.copy() if isinstance(state_dict, dict) else state_dict]
src_rank = gpc.get_ranks_in_group(parallel_mode)[0]
dist.broadcast_object_list(state_dict, src=src_rank, group=gpc.get_cpu_group(parallel_mode))
return state_dict[0]
def partition_tensor_parallel_state_dict(state_dict: OrderedDict,
parallel_mode: ParallelMode,
dims: dict = dict(),
partition_states: dict = dict()):
src_rank = gpc.get_ranks_in_group(parallel_mode)[0]
depth = gpc.get_world_size(parallel_mode)
group = gpc.get_cpu_group(parallel_mode)
is_rank0 = gpc.get_local_rank(parallel_mode) == 0
partition_info = [None]
if is_rank0:
partition_info_dict = OrderedDict()
for key, param in state_dict.items():
dim = dims[key]
is_partitioned = partition_states[key]
shape = list(param.shape)
if is_partitioned:
shape[dim] = shape[dim] // depth
partition_info_dict[key] = (is_partitioned, param.dtype, shape, dim)
partition_info[0] = partition_info_dict
dist.broadcast_object_list(partition_info, src_rank, group=group)
partitioned_state = OrderedDict()
for key, (is_partitioned, dtype, shape, dim) in partition_info[0].items():
if is_partitioned:
output = torch.empty(shape, dtype=dtype)
if is_rank0:
scatter_list = [t.contiguous() for t in state_dict[key].chunk(depth, dim)]
else:
scatter_list = None
dist.scatter(output, scatter_list, src_rank, group=group)
else:
if is_rank0:
output = state_dict[key]
else:
output = torch.empty(shape, dtype=dtype)
dist.broadcast(output, src_rank, group=group)
partitioned_state[key] = output
return partitioned_state
def gather_tensor_parallel_state_dict(
state_dict: OrderedDict,
parallel_mode: ParallelMode,
dims: dict = dict(),
partition_states: dict = dict(),
keep_vars: bool = False,
):
dst_rank = gpc.get_ranks_in_group(parallel_mode)[0]
depth = gpc.get_world_size(parallel_mode)
for key in list(state_dict.keys()):
param = state_dict.pop(key)
param = param if keep_vars else param.detach()
dim = dims.get(key, 0)
do_partition = partition_states.get(key, True)
if do_partition:
temp = param.transpose(0, dim).contiguous()
gather_list = None
if gpc.get_local_rank(parallel_mode) == 0:
shape = list(param.shape)
shape[0], shape[dim] = shape[dim], shape[0]
shape[0] *= depth
param = torch.empty(shape, dtype=param.dtype, device=param.device)
gather_list = list(torch.chunk(param, depth, dim=0))
dist.gather(temp, gather_list, dst=dst_rank, group=gpc.get_cpu_group(parallel_mode))
param = torch.transpose(param, 0, dim)
# update params in state_dict only on local rank 0
if gpc.get_local_rank(parallel_mode) == 0:
state_dict[key] = param
return state_dict
def _send_state_dict(state_dict, dst, parallel_mode):
state_tensor, state_size = dist.distributed_c10d._object_to_tensor(state_dict)
dist.send(state_size, dst, group=gpc.get_cpu_group(parallel_mode))
dist.send(state_tensor, dst, group=gpc.get_cpu_group(parallel_mode))
def _recv_state_dict(src, parallel_mode):
state_size = torch.tensor([0], dtype=torch.long)
dist.recv(state_size, src, group=gpc.get_cpu_group(parallel_mode))
state_tensor = torch.empty(state_size.item(), dtype=torch.uint8)
dist.recv(state_tensor, src, group=gpc.get_cpu_group(parallel_mode))
state_dict = dist.distributed_c10d._tensor_to_object(state_tensor, state_size)
return state_dict
def partition_pipeline_parallel_state_dict(model, state_dict):
pipeline_state = OrderedDict()
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# receive all states from prev stage
if not gpc.is_first_rank(ParallelMode.PIPELINE):
state_dict = _recv_state_dict(gpc.get_prev_global_rank(ParallelMode.PIPELINE), ParallelMode.PIPELINE)
# move states to output
for name, _ in model.named_parameters(recurse=True):
if name in state_dict:
pipeline_state[name] = state_dict.pop(name)
for name, _ in model.named_buffers(recurse=True):
if name in state_dict:
pipeline_state[name] = state_dict.pop(name)
for name, _ in model.named_modules():
extra_state_key = name + "." + _EXTRA_STATE_KEY_SUFFIX
if extra_state_key in state_dict:
pipeline_state[extra_state_key] = state_dict.pop(extra_state_key)
# send rest states to next stage
if not gpc.is_last_rank(ParallelMode.PIPELINE):
_send_state_dict(state_dict, gpc.get_next_global_rank(ParallelMode.PIPELINE), ParallelMode.PIPELINE)
return pipeline_state
def gather_pipeline_parallel_state_dict(state_dict):
gathered_states = ([None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))]
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else None)
dist.gather_object(
state_dict,
gathered_states,
dst=gpc.get_ranks_in_group(ParallelMode.PIPELINE)[0],
group=gpc.get_cpu_group(ParallelMode.PIPELINE),
)
state_dict = (OrderedDict(chain.from_iterable(state.items() for state in gathered_states))
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else OrderedDict())
return state_dict
def save_checkpoint(file,
epoch: int,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
**kwargs):
"""Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer,
lr_scheduler etc. into a checkpoint dictionary.
Args:
file: a file-like object (has to implement write and flush) or a string or os.PathLike object containing a
file name.
epoch (int): Epoch number (indicates how many epochs have you trained this model).
model (:class:`torch.nn.Module`): Model to be saved.
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be saved.
lr_scheduler (Union[:class:`torch.optim.lr_scheduler`, :class:`colossalai.nn.lr_scheduler`], optional):
lr_scheduler to be saved, defaults to None.
pickle_module: module used for pickling metadata and objects
pickle_protocol: can be specified to override the default protocol
"""
# ckpt container
checkpoint = {"epoch": epoch}
model_state = model.state_dict()
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
model_state = gather_pipeline_parallel_state_dict(model_state)
if gpc.get_global_rank() == 0:
checkpoint["model"] = model_state
# if optimizer is not None:
# checkpoint['optimizer'] = optimizer.state_dict()
# if lr_scheduler is not None:
# checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
torch.save(checkpoint, file, **kwargs)
def broadcast_model(model: torch.nn.Module):
src_rank = gpc.get_ranks_in_group(ParallelMode.TENSOR)[0]
for p in model.parameters():
if not getattr(p, IS_TENSOR_PARALLEL, False) and p.storage().size() > 0:
group = gpc.get_group(ParallelMode.TENSOR) if p.device.type == 'cuda' else gpc.get_cpu_group(
ParallelMode.TENSOR)
dist.broadcast(p, src_rank, group=group)
def load_checkpoint(
file,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
strict: bool = True,
):
"""Loads training states from a checkpoint file.
Args:
file: a file-like object (has to implement read(), readline(), tell(), and seek()), or a string or os.PathLike
object containing a file name.
model (:class:`torch.nn.Module`): Model to load saved weights and buffers.
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to recuperate.
lr_scheduler (:class:`torch.optim.lr_scheduler._LRScheduler`, optional):
lr_scheduler to recuperate, defaults to None.
strict (bool, optional): Whether to strictly enforce that the keys in :attr:`state_dict`
of the checkpoint match the names of parameters and buffers in model, defaults to True.
Returns:
int: The saved epoch number.
Raises:
RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated
"""
state_dict = (torch.load(file, map_location=torch.device("cpu"))
if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None)
# model states
model_state = state_dict.pop("model") if state_dict is not None else dict()
# pipeline
if is_using_pp():
model_state = partition_pipeline_parallel_state_dict(model, model_state)
try:
model.load_state_dict(model_state, strict=strict)
broadcast_model(model)
except RuntimeError as e:
error_msgs = str(e)
if error_msgs.startswith("Error(s) in loading state_dict for "):
error_msgs = error_msgs.split("\n\t")[1:]
dst_rank = gpc.get_ranks_in_group(ParallelMode.MODEL)[0]
all_error_msgs = [None for _ in range(gpc.get_world_size(ParallelMode.MODEL))]
dist.gather_object(error_msgs, all_error_msgs, dst=dst_rank, group=gpc.get_cpu_group(ParallelMode.MODEL))
if gpc.get_global_rank() == 0:
all_error_msgs = list(chain.from_iterable(all_error_msgs))
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(
model.__class__.__name__, "\n\t".join(all_error_msgs)))
else:
raise e
# broadcast the rest states
state_dict = broadcast_state_dict(state_dict, ParallelMode.MODEL)
# # optimizer states
# if optimizer is not None and 'optimizer' in state_dict:
# optimizer.load_state_dict(state_dict['optimizer'])
# # lr scheduler states
# if lr_scheduler is not None and 'lr_scheduler' in state_dict:
# lr_scheduler.load_state_dict(state_dict['lr_scheduler'])
# last epoch
last_epoch = state_dict.pop("epoch", -1)
return last_epoch

View File

@@ -0,0 +1,434 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from collections import defaultdict
from contextlib import contextmanager
from typing import Dict, List, Optional, Union
import torch
import torch.distributed as dist
from torch import inf
from torch.nn.parameter import Parameter
from colossalai.legacy.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.global_variables import tensor_parallel_env as env
from colossalai.legacy.tensor import ProcessGroup
from colossalai.tensor import ColoParameter
from colossalai.utils.multi_tensor_apply import multi_tensor_applier
try:
from colossalai._C import fused_optim
except:
fused_optim = None
def print_rank_0(msg: str, logger=None):
"""Print messages and save logs(optional). This is executed only if you are the rank-0 gpu.
Args:
msg (str): A string message to output.
logger (:class:`colossalai.logging.DistributedLogger`, optional):
The logger to record the message, defaults to None.
"""
if gpc.get_global_rank() == 0:
if logger is None:
print(msg, flush=True)
else:
logger.info(msg)
def sync_model_param(model, parallel_mode):
r"""Make sure data parameters are consistent during Data Parallel Mode.
Args:
model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel mode to be checked.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
for param in model.parameters():
ranks = gpc.get_ranks_in_group(parallel_mode)
dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode))
def is_dp_rank_0():
return not gpc.is_initialized(ParallelMode.DATA) or gpc.is_first_rank(ParallelMode.DATA)
def is_tp_rank_0():
return not gpc.is_initialized(ParallelMode.TENSOR) or gpc.is_first_rank(ParallelMode.TENSOR)
def is_no_pp_or_last_stage():
return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE)
def is_using_ddp():
return gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1
def is_using_pp():
return gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1
def is_using_sequence():
return gpc.is_initialized(ParallelMode.SEQUENCE) and gpc.get_world_size(ParallelMode.SEQUENCE) > 1
class model_branch_context(object):
def __enter__(self):
self.env_status = env.save()
def __exit__(self, *exc_info):
env.load(**self.env_status)
def is_model_parallel_parameter(p):
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
def _calc_l2_norm(grads):
# we should not
global fused_optim
if fused_optim is None:
from colossalai.kernel.op_builder import FusedOptimBuilder
fused_optim = FusedOptimBuilder().load()
norm = 0.0
if len(grads) > 0:
dummy_overflow_buf = torch.cuda.IntTensor([0])
norm, _ = multi_tensor_applier(
fused_optim.multi_tensor_l2norm,
dummy_overflow_buf,
[grads],
False # no per-parameter norm
)
return norm
def _calc_lp(grads, norm_type):
norm = 0.0
for grad in grads:
grad_norm = torch.norm(grad, norm_type)
norm += grad_norm**norm_type
return norm
def _move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
if torch.is_tensor(norm) and norm.device.type != 'cuda':
norm = norm.to(torch.cuda.current_device())
return norm
def _get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Tensor:
if isinstance(norm, float):
norm = torch.Tensor([norm])
if move_to_cuda:
norm = norm.to(torch.cuda.current_device())
return norm
# ======== Gradient Clipping =========
def _compute_local_lp(params: List[ColoParameter], norm_type: float) -> float:
if len(params) == 0:
return 0.0
grads = [p.grad for p in params]
use_cuda_kernel = grads[0].device.type == 'cuda'
if norm_type == inf:
local_lp = max([g.abs().max() for g in grads])
elif norm_type == 2.0 and use_cuda_kernel:
local_lp = _calc_l2_norm(grads)**norm_type
else:
local_lp = _calc_lp(grads, norm_type)
if isinstance(local_lp, torch.Tensor):
return local_lp.item()
return local_lp
def _compute_buckets_lp(params: List[ColoParameter], norm_type: float) -> float:
if len(params) == 0:
return 0.0
buckets: Dict[Optional[ProcessGroup], List[ColoParameter]] = defaultdict(list)
for p in params:
if p.is_replicate():
buckets[None].append(p)
else:
buckets[p.get_process_group().tp_process_group()].append(p)
total_lp = 0.0
for group, bucket in buckets.items():
local_lp = _compute_local_lp(bucket, norm_type)
if group is not None:
local_lp_tensor = torch.tensor([local_lp], device=torch.cuda.current_device())
if norm_type == inf:
dist.all_reduce(local_lp_tensor, op=dist.ReduceOp.MAX, group=group)
else:
dist.all_reduce(local_lp_tensor, group=group)
local_lp = local_lp_tensor.item()
if norm_type == inf:
total_lp = max(total_lp, local_lp)
else:
total_lp += local_lp
return total_lp
def _compute_pp_grad_lp(total_lp: float, norm_type: float) -> float:
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
total_lp_tensor = torch.tensor([total_lp], device=torch.cuda.current_device())
if norm_type == inf:
dist.all_reduce(total_lp_tensor, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PIPELINE))
else:
dist.all_reduce(total_lp_tensor, group=gpc.get_group(ParallelMode.PIPELINE))
total_lp = total_lp_tensor.item()
return total_lp
def _compute_grad_lp(parameters, norm_type: float = 2.0) -> float:
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
grad_dtype = None
cpu_grad_params: List[ColoParameter] = []
cuda_grad_params: List[ColoParameter] = []
for p in parameters:
if p.grad is None:
continue
assert isinstance(p, ColoParameter)
if grad_dtype is None:
grad_dtype = p.grad.dtype
assert p.grad.dtype == grad_dtype, f'Expected all grads are {grad_dtype}, got {p.grad.dtype}'
if p.grad.device.type == 'cuda':
cuda_grad_params.append(p)
else:
cpu_grad_params.append(p)
norm_type = float(norm_type)
cpu_lp = _compute_buckets_lp(cpu_grad_params, norm_type)
cuda_lp = _compute_buckets_lp(cuda_grad_params, norm_type)
if norm_type == inf:
total_lp = max(cpu_lp, cuda_lp)
else:
total_lp = cpu_lp + cuda_lp
return _compute_pp_grad_lp(total_lp, norm_type)
def compute_grad_norm(parameters, norm_type: float = 2.0) -> float:
norm_type = float(norm_type)
total_norm = _compute_grad_lp(parameters, norm_type)
if norm_type != inf:
total_norm = total_norm**(1 / norm_type)
return total_norm
def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None:
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1.0:
cuda_grads: List[torch.Tensor] = []
cpu_grads: List[torch.Tensor] = []
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
for p in parameters:
if p.grad is None:
continue
if p.grad.device.type == 'cuda':
cuda_grads.append(p.grad.detach())
else:
cpu_grads.append(p.grad.detach())
if len(cuda_grads) > 0:
dummy_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads],
clip_coef)
for g in cpu_grads:
g.mul_(clip_coef)
def clip_grad_norm(parameters, max_norm: float, norm_type: float = 2.0) -> float:
total_norm = compute_grad_norm(parameters, norm_type)
_clip_grad_norm(parameters, max_norm, total_norm)
return total_norm
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
"""Clips gradient norm of an iterable of parameters whose gradients are in fp32.
This is adapted from :func:`torch.nn.utils.clip_grad.clip_grad_norm_` and
added functionality to handle model parallel parameters.
Note:
the gradients are modified in place.
Args:
parameters (Iterable[:class:`torch.tensor`] or :class:`torch.tensor`):
An iterable of Tensors or a single Tensor that will have gradients normalized.
max_norm (Union[float, int]): Max norm of the gradients.
norm_type (Union[float, int, 'inf']): Type of the used p-norm. Can be ``'inf'`` for infinity norm.
Returns:
float: Total norm of the parameters.
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
params: List[Parameter] = []
has_zero_shared_param: bool = False
for param in parameters:
if param.grad is not None:
# Make sure the grads are in fp32
assert param.grad.dtype == torch.float, \
f'expected gradient to be dtype torch.float, but got {param.grad.type()}'
if hasattr(param, 'colo_attr') and param.colo_attr.sharded_data_tensor.is_sharded:
has_zero_shared_param = True
params.append(param)
if len(params) == 0:
enable_cuda_kernels = False
else:
enable_cuda_kernels = params[0].grad.device.type == 'cuda'
# Norm parameters.
max_norm = float(max_norm)
norm_type = float(norm_type)
# Parameters can be on CPU or CUDA
# If parameters are on CPU, disable CUDA kernels
# Calculate norm.
if norm_type == inf:
total_norm = max(p.grad.data.abs().max() for p in params)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all model-parallel GPUs.
if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
dist.all_reduce(total_norm_cuda,
op=dist.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.MODEL),
async_op=False)
if has_zero_shared_param:
dist.all_reduce(total_norm_cuda,
op=dist.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.DATA),
async_op=False)
total_norm = total_norm_cuda[0].item()
else:
tensor_parallel_grads = []
no_tensor_parallel_grads = []
zero_sharded_grads = []
for p in params:
if is_model_parallel_parameter(p):
reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type)
tensor_parallel_grads.append(p.grad.data / reductor)
elif hasattr(p, 'colo_attr') and p.colo_attr.sharded_data_tensor.is_sharded:
zero_sharded_grads.append(p.grad.data)
else:
no_tensor_parallel_grads.append(p.grad.data)
if norm_type == 2.0 and enable_cuda_kernels:
tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type
no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type
zero_sharded_norm = _calc_l2_norm(zero_sharded_grads)**norm_type
else:
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type)
zero_sharded_norm = _calc_lp(zero_sharded_grads, norm_type)
# If norm is type of float, then we convert them into torch.Tensor.
tensor_parallel_norm = _get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels)
no_tensor_parallel_norm = _get_tensor_norm(no_tensor_parallel_norm, enable_cuda_kernels)
zero_sharded_norm = _get_tensor_norm(zero_sharded_norm, enable_cuda_kernels)
# If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors
if not enable_cuda_kernels:
tensor_parallel_norm = _move_norm_to_cuda(tensor_parallel_norm)
no_tensor_parallel_norm = _move_norm_to_cuda(no_tensor_parallel_norm)
zero_sharded_norm = _move_norm_to_cuda(zero_sharded_norm)
# Sum across all model-parallel GPUs.
if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
dist.all_reduce(tensor_parallel_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR))
# Sum across all zero sharded GPUs
if len(zero_sharded_grads) > 0:
dist.all_reduce(zero_sharded_norm, group=gpc.get_group(ParallelMode.DATA))
no_tensor_parallel_norm += zero_sharded_norm
total_norm = tensor_parallel_norm + no_tensor_parallel_norm
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE))
total_norm = total_norm**(1.0 / norm_type)
if torch.is_tensor(total_norm):
total_norm = total_norm.item()
# Scale.
clip_coeff = max_norm / (total_norm + 1.0e-6)
if clip_coeff < 1.0:
if enable_cuda_kernels:
grads = [p.grad.detach() for p in params]
dummy_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff)
else:
for p in params:
p.grad.detach().mul_(clip_coeff)
return total_norm
def count_zeros_fp32(parameters):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
total_num_zeros = 0.0
for param in parameters:
grad_not_none = param.grad is not None
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if grad_not_none and is_not_tp_duplicate:
grad = param.grad.detach()
num_zeros = grad.numel() - torch.count_nonzero(grad)
total_num_zeros = num_zeros + total_num_zeros
total_num_zeros = torch.IntTensor([int(total_num_zeros)]).cuda()
# Sum across all model-parallel GPUs.
ops = []
ops.append(
dist.all_reduce(total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR), async_op=True))
if gpc.is_initialized(ParallelMode.PIPELINE):
ops.append(
dist.all_reduce(total_num_zeros,
op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PIPELINE),
async_op=True))
for req in ops:
req.wait()
total_num_zeros = total_num_zeros.item()
return total_num_zeros
def copy_tensor_parallel_attributes(src_tensor, dst_tensor):
for attr in TENSOR_PARALLEL_ATTRIBUTES:
if hasattr(src_tensor, attr):
val = getattr(src_tensor, attr)
setattr(dst_tensor, attr, val)
def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL)) or (gpc.get_local_rank(
ParallelMode.TENSOR) == 0)
@contextmanager
def switch_virtual_pipeline_parallel_rank(rank):
prev_rank = gpc.virtual_pipeline_parallel_rank
try:
gpc.set_virtual_pipeline_parallel_rank(rank)
yield
finally:
gpc.set_virtual_pipeline_parallel_rank(prev_rank)

View File

@@ -0,0 +1,4 @@
from .base_sampler import BaseSampler
from .data_parallel_sampler import DataParallelSampler, get_dataloader
__all__ = ['BaseSampler', 'DataParallelSampler', 'get_dataloader']

View File

@@ -0,0 +1,19 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
class BaseSampler(ABC):
def __init__(self, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size
@abstractmethod
def __len__(self):
pass
@abstractmethod
def __iter__(self):
pass

View File

@@ -0,0 +1,161 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# adapted from torch.utils.data.DistributedSampler
import math
import random
from typing import Iterator, TypeVar
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, Sampler
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
T_co = TypeVar('T_co', covariant=True)
class DataParallelSampler(Sampler):
"""A data sampler for distributed data parallelism.
Args:
dataset (:class:`torch.utils.data.Dataset`): The Dataset for sampling.
shuffle (bool, optional): Whether to shuffle data, defaults to False.
seed (int, optional): The random seed used for sampling, defaults to 0.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
"""
def __init__(self, dataset: Dataset, shuffle: bool = False, seed: int = 0, drop_last: bool = False) -> None:
self.dataset = dataset
self.num_replicas = gpc.get_world_size(ParallelMode.DATA)
self.rank = gpc.get_local_rank(ParallelMode.DATA)
self.epoch = 0
self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
# type: ignore[arg-type]
if self.drop_last and len(self.dataset) % self.num_replicas != 0:
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
# `type:ignore` is required because Dataset cannot provide a default __len__
# see NOTE in pytorch/torch/utils/data/sampler.py
(len(self.dataset) - self.num_replicas) / \
self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.seed = seed
def __iter__(self) -> Iterator[T_co]:
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
# type: ignore[arg-type]
indices = torch.randperm(len(self.dataset), generator=g).tolist()
# update for next epoch so that there is no need to call
# set_epoch manually
self.epoch += 1
else:
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self) -> int:
return self.num_samples
def set_epoch(self, epoch: int) -> None:
r"""Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
use a different random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
def get_dataloader(dataset,
shuffle=False,
seed=1024,
add_sampler=True,
drop_last=False,
pin_memory=False,
num_workers=0,
**kwargs):
r"""Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
Note:
When pipeline parallel is enabled, shuffle cannot be True as it will result in mismatch between input data
on the 1st stage and label on the last stage.
Args:
dataset (:class:`torch.utils.data.Dataset`): The dataset to be loaded.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
seed (int, optional): Random worker seed for sampling, defaults to 1024.
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
Returns:
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
if add_sampler and gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
sampler = DataParallelSampler(dataset, shuffle=shuffle)
else:
sampler = None
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
if sampler is None:
return DataLoader(dataset,
worker_init_fn=seed_worker,
shuffle=shuffle,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs)
else:
return DataLoader(dataset,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs)

View File

@@ -0,0 +1,178 @@
import gc
from collections import namedtuple
import psutil
import torch
import torch.distributed as dist
from packaging import version
from colossalai.legacy.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
_GLOBAL_CUDA_MEM_FRACTION = 1.0
_GLOBAL_CPU_MEM_CAPACITY = -1
def _bytes_to_MB(val, decimal=2):
"""A byte-to-Megabyte converter, default using binary notation.
:param val: X bytes to convert
:return: X' MB
"""
return round(val / (1024 * 1024), decimal)
# copy from PatrickStar
def _get_cpu_memory_info():
ps_mem_info = namedtuple("ps_mem_info", ["total", "free", "cached", "buffers", "used"])
try:
# psutil reads the memory info from /proc/memory_info,
# which results in returning the host memory instead of
# that of container.
# Here we try to read the container memory with method in:
# https://stackoverflow.com/a/46213331/5163915
mems = {}
with open("/sys/fs/cgroup/memory/memory.meminfo", "rb") as f:
for line in f:
fields = line.split()
mems[fields[0]] = int(fields[1]) * 1024
total = mems[b"MemTotal:"]
free = mems[b"MemFree:"]
cached = mems[b"Cached:"]
buffers = mems[b"Buffers:"]
used = total - free - cached - buffers
if used < 0:
used = total - free
mem_info = ps_mem_info(total=total, free=free, cached=cached, buffers=buffers, used=used)
except FileNotFoundError:
mems = psutil.virtual_memory()
mem_info = ps_mem_info(
total=mems.total,
free=mems.free,
cached=mems.cached,
buffers=mems.buffers,
used=mems.used,
)
return mem_info
def report_memory_usage(message, logger=None, report_cpu=False):
"""Calculate and print RAM usage (in GB)
Args:
message (str): A prefix message to add in the log.
logger (:class:`colossalai.logging.DistributedLogger`): The logger used to record memory information.
report_cpu (bool, optional): Whether to report CPU memory.
Raises:
EnvironmentError: Raise error if no distributed environment has been initialized.
"""
if not dist.is_initialized():
raise EnvironmentError("No distributed environment is initialized")
gpu_allocated = _bytes_to_MB(torch.cuda.memory_allocated())
gpu_max_allocated = _bytes_to_MB(torch.cuda.max_memory_allocated())
gpu_cached = _bytes_to_MB(torch.cuda.memory_reserved())
gpu_max_cached = _bytes_to_MB(torch.cuda.max_memory_reserved())
full_log = f"{message}: GPU: allocated {gpu_allocated} MB, max allocated {gpu_max_allocated} MB, " \
+ f"cached: {gpu_cached} MB, max cached: {gpu_max_cached} MB"
if report_cpu:
# python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports
gc.collect()
vm_stats = psutil.virtual_memory()
vm_used = _bytes_to_MB(vm_stats.total - vm_stats.available)
full_log += f", CPU Virtual Memory: used = {vm_used} MB, percent = {vm_stats.percent}%"
if logger is None:
logger = get_dist_logger()
logger.info(full_log)
# get the peak memory to report correct data, so reset the counter for the next call
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
torch.cuda.reset_peak_memory_stats()
def colo_device_memory_capacity(device: torch.device) -> int:
"""
Get the capacity of the memory of the device
Args:
device (torch.device): a device
Returns:
int: size in byte
"""
assert isinstance(device, torch.device)
if device.type == 'cpu':
# In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory.
return colo_get_cpu_memory_capacity() / gpc.num_processes_on_current_node
if device.type == 'cuda':
return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION
def colo_device_memory_used(device: torch.device) -> int:
"""
Get the device memory on device belonging to the current process.
Args:
device (torch.device): a device
Returns:
int: memory size in bytes
"""
if device.type == 'cpu':
mem_info = _get_cpu_memory_info()
# In the context of 1-CPU-N-GPU, the memory usage of the current process is 1/N CPU memory used.
# Each process consumes the same amount of memory.
ret = mem_info.used / gpc.num_processes_on_current_node
return ret
elif device.type == 'cuda':
ret: int = torch.cuda.memory_allocated(device)
# get the peak memory to report correct data, so reset the counter for the next call
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
torch.cuda.reset_peak_memory_stats(device)
return ret
def colo_set_process_memory_fraction(ratio: float) -> None:
"""colo_set_process_memory_fraction
set how much cuda memory used on the gpu belonging to the current process.
Args:
ratio (float): a ratio between 0. ~ 1.
"""
if version.parse(torch.__version__) < version.parse('1.8'):
logger = get_dist_logger('colo_set_process_memory_fraction')
logger.warning('colo_set_process_memory_fraction failed because torch version is less than 1.8')
return
global _GLOBAL_CUDA_MEM_FRACTION
_GLOBAL_CUDA_MEM_FRACTION = ratio
torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device())
def colo_set_cpu_memory_capacity(size: int) -> None:
global _GLOBAL_CPU_MEM_CAPACITY
mem_info = _get_cpu_memory_info()
total_size = mem_info.total
if size <= total_size:
_GLOBAL_CPU_MEM_CAPACITY = size
else:
_GLOBAL_CPU_MEM_CAPACITY = total_size
def colo_get_cpu_memory_capacity() -> int:
"""
Get the cpu memory capacity. We may not use all of it.
Returns:
int: _description_
"""
global _GLOBAL_CPU_MEM_CAPACITY
if _GLOBAL_CPU_MEM_CAPACITY == -1:
mem_info = _get_cpu_memory_info()
return mem_info.total
else:
return _GLOBAL_CPU_MEM_CAPACITY

View File

@@ -0,0 +1,2 @@
from .legacy import *
from .profiler import profile

View File

@@ -0,0 +1,20 @@
from abc import ABC, abstractmethod
class ProfilerExtension(ABC):
@abstractmethod
def prepare_trace(self):
pass
@abstractmethod
def start_trace(self):
pass
@abstractmethod
def stop_trace(self):
pass
@abstractmethod
def extend_chrome_trace(self, trace: dict) -> dict:
pass

View File

@@ -0,0 +1,6 @@
from .comm_profiler import CommProfiler
from .mem_profiler import MemProfiler
from .pcie_profiler import PcieProfiler
from .prof_utils import BaseProfiler, ProfilerContext
__all__ = ['BaseProfiler', 'CommProfiler', 'PcieProfiler', 'MemProfiler', 'ProfilerContext']

View File

@@ -0,0 +1,311 @@
import inspect
from functools import partial
from pathlib import Path
from typing import List, Optional
import torch
import torch.distributed as dist
from torch.autograd.profiler import profile
from torch.distributed import ReduceOp
from colossalai.utils import get_current_device
from .prof_utils import BaseProfiler, _format_bandwidth, _format_memory, _format_time
def _get_code_location(depth: int):
ret = []
length = min(len(inspect.stack()), depth + 1)
for i in range(3, length):
upper_frame = inspect.stack()[i]
function_name = inspect.stack()[i - 1].function
ret.append(upper_frame.filename)
ret.append('(')
ret.append(str(upper_frame.lineno))
ret.append('): ')
ret.append(function_name)
if i != length - 1:
ret.append('\n')
return ''.join(ret)
torch_all_reduce = dist.all_reduce
torch_all_gather = dist.all_gather
torch_reduce_scatter = dist.reduce_scatter
torch_broadcast = dist.broadcast
torch_reduce = dist.reduce
class CommEvent(object):
"""Communication Event. Used for communication time and communication
volume recording.
"""
def __init__(self, count: int = 0, comm_vol: float = 0., cuda_time: int = 0):
self.self_count = count
self.self_comm_vol = comm_vol
self.self_cuda_time = cuda_time
def add(self, rhs):
self.self_count += rhs.self_count
self.self_comm_vol += rhs.self_comm_vol
self.self_cuda_time += rhs.self_cuda_time
class CommProfiler(BaseProfiler):
"""Communication profiler. Records all communication events.
"""
def __init__(self, depth: int = 0, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0):
super().__init__(profiler_name="Collective_Communication", priority=0)
self.depth = 3 + depth
self.total_count = total_count
self.total_comm_vol = total_comm_vol
self.total_cuda_time = total_cuda_time
self.ops_record = dict()
self.profiler = None
self.pending_op = None
self.pending_metadata = None
self.warn_flag = False
def reset(self):
self.total_count = 0
self.total_comm_vol = 0
self.total_cuda_time = 0
self.ops_record = dict()
self.profiler = None
self.pending_op = None
self.pending_metadata = None
self.warn_flag = False
def enable(self):
dist.all_reduce = partial(all_reduce, profiler=self)
dist.all_gather = partial(all_gather, profiler=self)
dist.reduce_scatter = partial(reduce_scatter, profiler=self)
dist.broadcast = partial(broadcast, profiler=self)
dist.reduce = partial(reduce, profiler=self)
def disable(self):
dist.all_reduce = torch_all_reduce
dist.all_gather = torch_all_gather
dist.reduce_scatter = torch_reduce_scatter
dist.broadcast = torch_broadcast
dist.reduce = torch_reduce
def to_tensorboard(self, writer):
writer.add_text(tag="Collective Communication", text_string=self.result_str("\n\n"))
def to_file(self, filename: Path):
with open(filename, "w") as f:
f.write(self.result_str())
def show(self):
print(self.result_str())
def result_str(self, sep: str = "\n"):
res = []
def append(s: str = None):
if s is not None:
res.append(s)
res.append(sep)
if self.warn_flag:
append("Warning: there exists multiple communication operations in the same time. As a result, "
"the profiling result is not accurate.")
if self.total_cuda_time == 0:
return "No collective communication has been called yet!"
append("Collective communication profiling result:")
append("total cuda time: {}".format(_format_time(self.total_cuda_time)))
append("average bandwidth: {}".format(_format_bandwidth(self.total_comm_vol, self.total_cuda_time)))
append("total number of calls: {}".format(self.total_count))
append("All events:")
separation = '-' * 74
row_format = '{:^10}' + '{:^12}' * 2 + '{:^16}' + '{:^12}' * 2
append(separation)
append(row_format.format('Location', 'GPU time', 'Percentage', 'Comm volume', 'Bandwidth', 'Num of calls'))
append(separation)
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time)
for location, event in show_list:
append(location)
append(
row_format.format('', _format_time(event.self_cuda_time),
'{:.1f}%'.format(event.self_cuda_time / self.total_cuda_time * 100.0),
_format_memory(event.self_comm_vol),
_format_bandwidth(event.self_comm_vol, event.self_cuda_time), event.self_count))
append()
return ''.join(res)
@property
def has_aync_op(self):
return self.pending_op is not None
def activate_profiler(self, kn: str, vol: float):
self.pending_metadata = (kn, _get_code_location(self.depth), vol)
self.profiler = profile(enabled=True, use_cuda=True, use_cpu=True, use_kineto=True)
self.profiler.__enter__()
def close_profiler(self, group=None):
assert self.profiler is not None, "There is no running dist op"
kernel_name, code_location, vol = self.pending_metadata
self.profiler.__exit__(None, None, None)
if self.profiler.enabled and dist.get_world_size(group) > 1:
assert_flag = 0
current_comm_event = None
events = self.profiler.function_events
for event in events:
if kernel_name in event.name:
assert assert_flag == 0, "Multiple dist ops has been called "
current_comm_event = CommEvent(1, vol, event.self_cuda_time_total)
assert_flag += 1
assert current_comm_event is not None, "dist op has not been found"
buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_current_device())
torch_all_reduce(buffer, op=ReduceOp.MIN, group=group)
current_comm_event.self_cuda_time = buffer.item()
self.total_count += current_comm_event.self_count
self.total_comm_vol += current_comm_event.self_comm_vol
self.total_cuda_time += current_comm_event.self_cuda_time
if code_location in self.ops_record:
self.ops_record[code_location].add(current_comm_event)
else:
self.ops_record[code_location] = current_comm_event
self.profiler = None
self.pending_op = None
self.pending_metadata = None
def wait_async_op(self):
if self.pending_op is not None:
op = self.pending_op
op.wait()
self.close_profiler()
class CommHandler(object):
"""Communication handler. A dummy handler to wait aync operations.
"""
def __init__(self, profiler: CommProfiler):
super().__init__()
self.prof = profiler
def wait(self):
self.prof.wait_async_op()
def async_check(profiler: CommProfiler):
if profiler.pending_op is not None:
profiler.warn_flag = True
profiler.wait_async_op()
def all_reduce(tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
group=None,
async_op: bool = False,
profiler: CommProfiler = None) -> Optional[CommHandler]:
async_check(profiler)
comm_size = dist.get_world_size(group)
correction = 2 * (comm_size - 1) / comm_size
comm_vol = correction * tensor.element_size() * tensor.numel()
profiler.activate_profiler("ncclKernel_AllReduce_", comm_vol)
profiler.pending_op = torch_all_reduce(tensor, op, group, async_op)
if async_op:
return CommHandler(profiler)
profiler.close_profiler(group)
def reduce_scatter(output: torch.Tensor,
input_list: List[torch.Tensor],
op: ReduceOp = ReduceOp.SUM,
group=None,
async_op: bool = False,
profiler: CommProfiler = None) -> Optional[CommHandler]:
async_check(profiler)
comm_size = dist.get_world_size(group)
correction = (comm_size - 1) / comm_size
comm_vol = 0
for tensor in input_list:
comm_vol += tensor.element_size() * tensor.numel()
comm_vol *= correction
profiler.activate_profiler("ncclKernel_ReduceScatter_", comm_vol)
profiler.pending_op = torch_reduce_scatter(output, input_list, op, group, async_op)
if async_op:
return CommHandler(profiler)
profiler.close_profiler(group)
def all_gather(tensor_list: List[torch.Tensor],
tensor: torch.Tensor,
group=None,
async_op: bool = False,
profiler: CommProfiler = None) -> Optional[CommHandler]:
async_check(profiler)
comm_size = dist.get_world_size(group)
correction = (comm_size - 1) / comm_size
comm_vol = 0
for ten in tensor_list:
comm_vol += ten.element_size() * ten.numel()
comm_vol *= correction
profiler.activate_profiler("ncclKernel_AllGather_", comm_vol)
profiler.pending_op = torch_all_gather(tensor_list, tensor, group, async_op)
if async_op:
return CommHandler(profiler)
profiler.close_profiler(group)
def broadcast(tensor: torch.Tensor,
src: int,
group=None,
async_op: bool = False,
profiler: CommProfiler = None) -> Optional[CommHandler]:
async_check(profiler)
comm_vol = 1.0 * tensor.element_size() * tensor.numel()
profiler.activate_profiler("ncclKernel_Broadcast_", comm_vol)
profiler.pending_op = torch_broadcast(tensor, src, group, async_op)
if async_op:
return CommHandler(profiler)
profiler.close_profiler(group)
def reduce(tensor: torch.Tensor,
dst: int,
op: ReduceOp = ReduceOp.SUM,
group=None,
async_op: bool = False,
profiler: CommProfiler = None) -> Optional[CommHandler]:
async_check(profiler)
comm_vol = 1.0 * tensor.element_size() * tensor.numel()
profiler.activate_profiler("ncclKernel_Reduce_", comm_vol)
profiler.pending_op = torch_reduce(tensor, dst, op, group, async_op)
if async_op:
return CommHandler(profiler)
profiler.close_profiler(group)

View File

@@ -0,0 +1,150 @@
from pathlib import Path
from typing import List
from torch.autograd.profiler import profile
from .prof_utils import BaseProfiler, _format_bandwidth, _format_memory, _format_time
def _get_size(dtype: str):
if dtype == "fp16":
return 2
elif dtype == "fp32":
return 4
else:
raise NotImplementedError
def _get_numel(my_list: List[int]) -> int:
from functools import reduce
from operator import mul
return reduce(mul, my_list)
def _reduce_location(locations: List[str]) -> str:
ret = []
for lo in locations:
ret.append(lo)
ret.append("\n")
ret = ret[:-1]
return ''.join(ret)
class PcieEvent(object):
"""Pcie Event.
"""
def __init__(self, count: int = 0, pcie_vol: int = 0, cuda_time: int = 0):
self.count = count
self.pcie_vol = pcie_vol
self.cuda_time = cuda_time
def add(self, rhs):
self.count += rhs.count
self.pcie_vol += rhs.pcie_vol
self.cuda_time += rhs.cuda_time
class PcieProfiler(BaseProfiler):
"""Pcie profiler. Records all data transmission between CPU and GPU.
TODO: Merge pcie profiler into communication profiler
"""
def __init__(self, dtype: str = "fp32", depth: int = 1):
super().__init__(profiler_name="Pcie", priority=10)
self.depth = depth
self.data_size = _get_size(dtype)
self.h2d_count = 0
self.h2d_time = 0
self.d2h_count = 0
self.d2h_time = 0
self.ops_record = dict()
self.profiler = None
def reset(self):
self.h2d_count = 0
self.h2d_time = 0
self.d2h_count = 0
self.d2h_time = 0
self.ops_record = dict()
self.profiler = None
def enable(self):
self.profiler = profile(enabled=True,
use_cuda=True,
use_cpu=True,
use_kineto=True,
record_shapes=True,
with_stack=True)
self.profiler.__enter__()
def disable(self):
self.profiler.__exit__(None, None, None)
if self.profiler.enabled:
events = self.profiler.function_events
for event in events:
if event.name == "aten::copy_":
t_shape = event.input_shapes[0]
if len(t_shape) == 0 or event.cuda_time_total == 0 or len(event.stack) == 0:
continue
current_comm_event = PcieEvent(1, self.data_size * _get_numel(t_shape), event.cuda_time_total)
code_location = _reduce_location(event.stack[:self.depth])
if code_location in self.ops_record:
self.ops_record[code_location].add(current_comm_event)
else:
self.ops_record[code_location] = current_comm_event
elif 'Memcpy HtoD' in event.name:
self.h2d_count += 1
self.h2d_time += event.cuda_time_total
elif 'Memcpy DtoH' in event.name:
self.d2h_count += 1
self.d2h_time += event.cuda_time_total
self.profiler = None
def to_tensorboard(self, writer):
writer.add_text(tag="Data Transmission", text_string=self.result_str("\n\n"))
def to_file(self, filename: Path):
with open(filename, "w") as f:
f.write(self.result_str())
def show(self):
print(self.result_str())
def result_str(self, sep: str = "\n"):
res = []
def append(s: str = None):
if s is not None:
res.append(s)
res.append(sep)
append("Pcie profiling result:")
append("time of data transmission (CPU -> GPU): {}".format(_format_time(self.h2d_time)))
append("number of transmission (CPU -> GPU): {}".format(self.h2d_count))
append("time of data transmission (GPU -> CPU): {}".format(_format_time(self.d2h_time)))
append("number of transmission (GPU -> CPU): {}".format(self.d2h_count))
append("Possible data transmission events in PCIE:")
separation = '-' * 62
row_format = '{:^10}' + '{:^12}' + '{:^16}' + '{:^12}' * 2
append(separation)
append(row_format.format('Location', 'GPU time', 'Trans volume', 'Bandwidth', 'Num of calls'))
append(separation)
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time)
for location, event in show_list:
append(location)
append(
row_format.format('', _format_time(event.cuda_time), _format_memory(event.pcie_vol),
_format_bandwidth(event.pcie_vol, event.cuda_time), event.count))
append()
return ''.join(res)

View File

@@ -0,0 +1,132 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Union
from colossalai.legacy.core import global_context as gpc
# copied from high version pytorch to support low version
def _format_time(time_us):
"""Defines how to format time in FunctionEvent"""
US_IN_SECOND = 1000.0 * 1000.0
US_IN_MS = 1000.0
if time_us >= US_IN_SECOND:
return '{:.3f}s'.format(time_us / US_IN_SECOND)
if time_us >= US_IN_MS:
return '{:.3f}ms'.format(time_us / US_IN_MS)
return '{:.3f}us'.format(time_us)
# copied from high version pytorch to support low version
def _format_memory(nbytes):
"""Returns a formatted memory size string"""
KB = 1024
MB = 1024 * KB
GB = 1024 * MB
if (abs(nbytes) >= GB):
return '{:.2f} GB'.format(nbytes * 1.0 / GB)
elif (abs(nbytes) >= MB):
return '{:.2f} MB'.format(nbytes * 1.0 / MB)
elif (abs(nbytes) >= KB):
return '{:.2f} KB'.format(nbytes * 1.0 / KB)
else:
return str(nbytes) + ' B'
def _format_bandwidth(volume: float or int, time_us: int):
sec_div_mb = (1000.0 / 1024.0)**2
mb_per_sec = volume / time_us * sec_div_mb
if mb_per_sec >= 1024.0:
return '{:.3f} GB/s'.format(mb_per_sec / 1024.0)
else:
return '{:.3f} MB/s'.format(mb_per_sec)
class BaseProfiler(ABC):
def __init__(self, profiler_name: str, priority: int):
self.name = profiler_name
self.priority = priority
@abstractmethod
def enable(self):
pass
@abstractmethod
def disable(self):
pass
@abstractmethod
def to_tensorboard(self, writer):
pass
@abstractmethod
def to_file(self, filename: Path):
pass
@abstractmethod
def show(self):
pass
class ProfilerContext(object):
"""Profiler context manager
Usage::
world_size = 4
inputs = torch.randn(10, 10, dtype=torch.float32, device=get_current_device())
outputs = torch.empty(world_size, 10, 10, dtype=torch.float32, device=get_current_device())
outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0))
cc_prof = CommProfiler()
with ProfilerContext([cc_prof]) as prof:
op = dist.all_reduce(inputs, async_op=True)
dist.all_gather(outputs_list, inputs)
op.wait()
dist.reduce_scatter(inputs, outputs_list)
dist.broadcast(inputs, 0)
dist.reduce(inputs, 0)
prof.show()
"""
def __init__(self, profilers: List[BaseProfiler] = None, enable: bool = True):
self.enable = enable
self.profilers = sorted(profilers, key=lambda prof: prof.priority)
def __enter__(self):
if self.enable:
for prof in self.profilers:
prof.enable()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.enable:
for prof in self.profilers:
prof.disable()
def to_tensorboard(self, writer):
from torch.utils.tensorboard import SummaryWriter
assert isinstance(writer, SummaryWriter), \
f'torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}.'
for prof in self.profilers:
prof.to_tensorboard(writer)
def to_file(self, log_dir: Union[str, Path]):
if isinstance(log_dir, str):
log_dir = Path(log_dir)
if not log_dir.exists():
log_dir.mkdir(parents=True, exist_ok=True)
for prof in self.profilers:
log_file = log_dir.joinpath(f'{prof.name}_rank_{gpc.get_global_rank()}.log')
prof.to_file(log_file)
def show(self):
for prof in self.profilers:
prof.show()

View File

@@ -0,0 +1,201 @@
import gzip
import json
import os
import tempfile
from typing import Any, Callable, Iterable, List, Optional
from torch.autograd import ProfilerActivity
from torch.profiler import profile as torch_profile
from torch.profiler.profiler import ProfilerAction
from colossalai.legacy.engine import Engine
from colossalai.legacy.utils.profiler.extention import ProfilerExtension
from colossalai.legacy.utils.profiler.stateful_tensor_mem_extention import StatefulTensorMemoryProfilerExtention
from colossalai.logging import get_dist_logger
class profile(torch_profile):
"""Profiler context manager.
Args:
activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values:
``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``.
Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA.
schedule (callable): callable that takes step (int) as a single parameter and returns
``ProfilerAction`` value that specifies the profiler action to perform at each step.
on_trace_ready (callable): callable that is called at each step when ``schedule``
returns ``ProfilerAction.RECORD_AND_SAVE`` during the profiling.
engine (Optional[Engine], optional): An ``Engine`` instance. Defaults to None.
record_shapes (bool): save information about operator's input shapes.
profile_memory (bool): track tensor memory allocation/deallocation.
with_stack (bool): record source information (file and line number) for the ops.
with_flops (bool): use formula to estimate the FLOPs (floating point operations) of specific operators
(matrix multiplication and 2D convolution).
with_modules (bool): record module hierarchy (including function names)
corresponding to the callstack of the op. e.g. If module A's forward call's
module B's forward which contains an aten::add op,
then aten::add's module hierarchy is A.B
Note that this support exist, at the moment, only for TorchScript models
and not eager mode models.
profile_stateful_tensor_memory (bool): track stateful tensor memory usage. ``engine`` must not be None if you enable this.
.. note::
Use :func:`~torch.profiler.schedule` to generate the callable schedule.
Non-default schedules are useful when profiling long training jobs
and allow the user to obtain multiple traces at the different iterations
of the training process.
The default schedule simply records all the events continuously for the
duration of the context manager.
.. note::
Use :func:`~torch.profiler.tensorboard_trace_handler` to generate result files for TensorBoard:
``on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name)``
After profiling, result files can be found in the specified directory. Use the command:
``tensorboard --logdir dir_name``
to see the results in TensorBoard.
For more information, see
`PyTorch Profiler TensorBoard Plugin <https://github.com/pytorch/kineto/tree/master/tb_plugin>`__
.. note::
Enabling shape and stack tracing results in additional overhead.
When record_shapes=True is specified, profiler will temporarily hold references to the tensors;
that may further prevent certain optimizations that depend on the reference count and introduce
extra tensor copies.
Examples:
.. code-block:: python
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]
) as p:
code_to_profile()
print(p.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))
Using the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions:
.. code-block:: python
# Non-default profiler schedule allows user to turn profiler on and off
# on different iterations of the training loop;
# trace_handler is called every time a new trace becomes available
def trace_handler(prof):
print(prof.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))
# prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json")
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
# In this example with wait=1, warmup=1, active=2,
# profiler will skip the first step/iteration,
# start warming up on the second, record
# the third and the forth iterations,
# after which the trace will become available
# and on_trace_ready (when set) is called;
# the cycle repeats starting with the next step
schedule=torch.profiler.schedule(
wait=1,
warmup=1,
active=2),
on_trace_ready=trace_handler
# on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')
# used when outputting for tensorboard
) as p:
for iter in range(N):
code_iteration_to_profile(iter)
# send a signal to the profiler that the next iteration has started
p.step()
"""
def __init__(self,
*,
activities: Optional[Iterable[ProfilerActivity]] = None,
schedule: Optional[Callable[[int], ProfilerAction]] = None,
on_trace_ready: Optional[Callable[..., Any]] = None,
engine: Optional[Engine] = None,
record_shapes: bool = False,
profile_memory: bool = False,
with_stack: bool = False,
with_flops: bool = False,
with_modules: bool = False,
profile_stateful_tensor_memory: bool = False) -> None:
super().__init__(activities=activities,
schedule=schedule,
on_trace_ready=on_trace_ready,
record_shapes=record_shapes,
profile_memory=profile_memory,
with_stack=with_stack,
with_flops=with_flops,
with_modules=with_modules)
self._logger = get_dist_logger()
self.extentions: List[ProfilerExtension] = []
if profile_stateful_tensor_memory:
if engine is None:
self._logger.warning('Ignore "profile_model_data" since engine is None', ranks=[0])
else:
self.extentions.append(StatefulTensorMemoryProfilerExtention(engine))
def prepare_trace(self) -> None:
if hasattr(super(), 'prepare_trace'):
super().prepare_trace()
elif hasattr(super(), '_start_warmup'):
super()._start_warmup()
for ext in self.extentions:
ext.prepare_trace()
def _start_warmup(self):
self.prepare_trace()
def start_trace(self):
if hasattr(super(), '_start_trace'):
super()._start_trace()
elif hasattr(super(), 'start_trace'):
super().start_trace()
for ext in self.extentions:
ext.start_trace()
def _start_trace(self):
self.start_trace()
def stop_trace(self):
if hasattr(super(), '_stop_trace'):
super()._stop_trace()
elif hasattr(super(), 'stop_trace'):
super().stop_trace()
for ext in self.extentions:
ext.stop_trace()
def _stop_trace(self):
self.stop_trace()
def export_chrome_trace(self, path: str):
"""
Exports the collected trace in Chrome JSON format.
"""
assert self.profiler
fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False)
fp.close()
retvalue = self.profiler.export_chrome_trace(fp.name)
with open(fp.name) as fin:
trace = json.load(fin)
for ext in self.extentions:
trace = ext.extend_chrome_trace(trace)
open_func = gzip.open if path.endswith('.gz') else open
with open_func(path, 'wt') as fout:
json.dump(trace, fout)
os.remove(fp.name)
return retvalue

View File

@@ -0,0 +1,135 @@
import os
import threading
import time
from enum import Enum
from typing import List
import torch
from colossalai.gemini.ophooks import BaseOpHook
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.legacy.engine import Engine
from colossalai.legacy.utils.profiler.extention import ProfilerExtension
class DeviceType(Enum):
CPU = 0
CUDA = 1
def get_timestamp_us():
return int(time.time() * 1e6)
def generic_instant_event(name, pid, tid, timestamp, args):
return {'ph': 'i', 's': 't', 'name': name, 'pid': pid, 'tid': tid, 'ts': timestamp, 'args': args}
class StatefulTensorMemoryEvent:
EVENT_NAME = '[statefulTensorMemory]'
def __init__(self, timestamp: int, device_type: DeviceType, bytes_: int) -> None:
self.pid = os.getpid()
self.tid = threading.get_ident()
self.timestamp = timestamp
self.device_type = device_type
self.device_id = torch.cuda.current_device() if device_type == DeviceType.CUDA else -1
self.bytes = bytes_
def state_dict(self):
return generic_instant_event(StatefulTensorMemoryEvent.EVENT_NAME, self.pid, self.tid, self.timestamp, {
'Device Type': self.device_type.value,
'Device Id': self.device_id,
'Bytes': self.bytes
})
class StatefulTensorMemoryTracer:
def __init__(self) -> None:
self.events: List[StatefulTensorMemoryEvent] = []
self._tracing = False
def sample(self):
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
timestamp = get_timestamp_us()
if self._tracing:
self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CUDA, cuda_mem))
self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CPU, cpu_mem))
def start_trace(self):
self.events.clear()
self._tracing = True
def stop_trace(self):
self._tracing = False
def state_dict(self):
return [event.state_dict() for event in self.events]
class StatefulTensorMemoryTracerHook(BaseOpHook):
def __init__(self, tracer: StatefulTensorMemoryTracer):
super().__init__()
self.tracer = tracer
self._enable = False
def pre_fwd_exec(self, module: torch.nn.Module, *args):
if self._enable:
self.tracer.sample()
def post_fwd_exec(self, module: torch.nn.Module, *args):
if self._enable:
self.tracer.sample()
def pre_bwd_exec(self, module: torch.nn.Module, input_, output):
if self._enable:
self.tracer.sample()
def post_bwd_exec(self, module: torch.nn.Module, input_):
if self._enable:
self.tracer.sample()
def post_iter(self):
if self._enable:
self.tracer.sample()
def enable(self):
self._enable = True
def disable(self):
self._enable = False
class StatefulTensorMemoryProfilerExtention(ProfilerExtension):
def __init__(self, engine: Engine) -> None:
self.engine = engine
self.tracer = StatefulTensorMemoryTracer()
self.hook = StatefulTensorMemoryTracerHook(self.tracer)
self.hook_registered = False
def prepare_trace(self):
self.hook.enable()
if not self.hook_registered:
self.engine.add_hook(self.hook)
self.hook_registered = True
def start_trace(self):
self.prepare_trace()
self.tracer.start_trace()
def stop_trace(self):
self.tracer.stop_trace()
self.hook.disable()
if self.hook_registered:
self.engine.remove_hook(self.hook)
# remove_hook is not implemented now
# FIXME(ver217): uncomment below line when remove_hook is implemented
# self.hook_registered = False
def extend_chrome_trace(self, trace: dict) -> dict:
trace['traceEvents'].extend(self.tracer.state_dict())
return trace