mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -26,28 +26,28 @@ from .memory import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'DataParallelSampler',
|
||||
'get_dataloader',
|
||||
'save_checkpoint',
|
||||
'load_checkpoint',
|
||||
'colo_device_memory_capacity',
|
||||
'colo_device_memory_used',
|
||||
'colo_get_cpu_memory_capacity',
|
||||
'colo_set_cpu_memory_capacity',
|
||||
'colo_set_process_memory_fraction',
|
||||
'report_memory_usage',
|
||||
'clip_grad_norm_fp32',
|
||||
'copy_tensor_parallel_attributes',
|
||||
'count_zeros_fp32',
|
||||
'is_dp_rank_0',
|
||||
'is_model_parallel_parameter',
|
||||
'is_no_pp_or_last_stage',
|
||||
'is_tp_rank_0',
|
||||
'is_using_ddp',
|
||||
'is_using_pp',
|
||||
'is_using_sequence',
|
||||
'param_is_not_tensor_parallel_duplicate',
|
||||
'print_rank_0',
|
||||
'switch_virtual_pipeline_parallel_rank',
|
||||
'sync_model_param',
|
||||
"DataParallelSampler",
|
||||
"get_dataloader",
|
||||
"save_checkpoint",
|
||||
"load_checkpoint",
|
||||
"colo_device_memory_capacity",
|
||||
"colo_device_memory_used",
|
||||
"colo_get_cpu_memory_capacity",
|
||||
"colo_set_cpu_memory_capacity",
|
||||
"colo_set_process_memory_fraction",
|
||||
"report_memory_usage",
|
||||
"clip_grad_norm_fp32",
|
||||
"copy_tensor_parallel_attributes",
|
||||
"count_zeros_fp32",
|
||||
"is_dp_rank_0",
|
||||
"is_model_parallel_parameter",
|
||||
"is_no_pp_or_last_stage",
|
||||
"is_tp_rank_0",
|
||||
"is_using_ddp",
|
||||
"is_using_pp",
|
||||
"is_using_sequence",
|
||||
"param_is_not_tensor_parallel_duplicate",
|
||||
"print_rank_0",
|
||||
"switch_virtual_pipeline_parallel_rank",
|
||||
"sync_model_param",
|
||||
]
|
||||
|
@@ -28,7 +28,6 @@ def copy_to_device(obj, device):
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, activation_offload=False, *args):
|
||||
check_backward_validity(args)
|
||||
@@ -42,7 +41,7 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
ctx.fwd_seed_states = get_states(copy=True)
|
||||
ctx.fwd_current_mode = get_current_mode()
|
||||
|
||||
if hasattr(torch, 'is_autocast_enabled'):
|
||||
if hasattr(torch, "is_autocast_enabled"):
|
||||
ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
|
||||
else:
|
||||
ctx.had_autocast_in_fwd = False
|
||||
@@ -62,7 +61,7 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
for i, arg in enumerate(args):
|
||||
if torch.is_tensor(arg):
|
||||
if activation_offload:
|
||||
tensor_inputs.append(copy_to_device(arg, 'cpu'))
|
||||
tensor_inputs.append(copy_to_device(arg, "cpu"))
|
||||
else:
|
||||
tensor_inputs.append(arg)
|
||||
ctx.tensor_indices.append(i)
|
||||
@@ -79,8 +78,10 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
if not torch.autograd._is_checkpoint_valid():
|
||||
raise RuntimeError("Checkpointing is not compatible with .grad() or when an `inputs` parameter is "
|
||||
"passed to .backward(). Please use .backward() and do not pass its `inputs` argument.")
|
||||
raise RuntimeError(
|
||||
"Checkpointing is not compatible with .grad() or when an `inputs` parameter is "
|
||||
"passed to .backward(). Please use .backward() and do not pass its `inputs` argument."
|
||||
)
|
||||
# Copy the list to avoid modifying original list.
|
||||
inputs = list(ctx.inputs)
|
||||
tensor_indices = ctx.tensor_indices
|
||||
@@ -131,8 +132,7 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
outputs_with_grad.append(outputs[i])
|
||||
args_with_grad.append(args[i])
|
||||
if len(outputs_with_grad) == 0:
|
||||
raise RuntimeError("none of output has requires_grad=True,"
|
||||
" this checkpoint() is not necessary")
|
||||
raise RuntimeError("none of output has requires_grad=True," " this checkpoint() is not necessary")
|
||||
torch.autograd.backward(outputs_with_grad, args_with_grad)
|
||||
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs)
|
||||
return (None, None) + grads
|
||||
@@ -169,7 +169,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
|
||||
fwd_current_mode = get_current_mode()
|
||||
|
||||
# check if use autocast
|
||||
if hasattr(torch, 'is_autocast_enabled'):
|
||||
if hasattr(torch, "is_autocast_enabled"):
|
||||
has_autocast_in_fwd = torch.is_autocast_enabled()
|
||||
else:
|
||||
has_autocast_in_fwd = False
|
||||
@@ -179,7 +179,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
|
||||
weak_holder_list = []
|
||||
|
||||
# class for weakref.ref
|
||||
class Holder():
|
||||
class Holder:
|
||||
pass
|
||||
|
||||
# return a Holder object for later unpack process
|
||||
@@ -226,19 +226,20 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
|
||||
|
||||
# rerun forward, the inner_pack will store all the activations in storage
|
||||
if has_autocast_in_fwd:
|
||||
with torch.enable_grad(), \
|
||||
torch.cuda.amp.autocast(), \
|
||||
torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
|
||||
with torch.enable_grad(), torch.cuda.amp.autocast(), torch.autograd.graph.saved_tensors_hooks(
|
||||
inner_pack, inner_unpack
|
||||
):
|
||||
_unused = function(*args)
|
||||
else:
|
||||
with torch.enable_grad(), \
|
||||
torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
|
||||
with torch.enable_grad(), torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
|
||||
_unused = function(*args)
|
||||
|
||||
if x not in storage:
|
||||
raise RuntimeError("Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
|
||||
" recomputation being triggered in between, this is not currently supported. Please"
|
||||
" open an issue with details on your use case so that we can prioritize adding this.")
|
||||
raise RuntimeError(
|
||||
"Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
|
||||
" recomputation being triggered in between, this is not currently supported. Please"
|
||||
" open an issue with details on your use case so that we can prioritize adding this."
|
||||
)
|
||||
|
||||
return storage[x]
|
||||
|
||||
|
@@ -1,3 +1,3 @@
|
||||
from .module_checkpoint import load_checkpoint, save_checkpoint
|
||||
|
||||
__all__ = ['save_checkpoint', 'load_checkpoint']
|
||||
__all__ = ["save_checkpoint", "load_checkpoint"]
|
||||
|
@@ -9,13 +9,15 @@ from colossalai.tensor import ColoTensor
|
||||
from .utils import gather_tensor, scatter_tensor
|
||||
|
||||
|
||||
def save_checkpoint(path: str,
|
||||
epoch: int,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
*args,
|
||||
**kwargs):
|
||||
def save_checkpoint(
|
||||
path: str,
|
||||
epoch: int,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""save_checkpoint
|
||||
save a model, whose parameters are `ColoTensor`s.
|
||||
Args:
|
||||
@@ -30,7 +32,7 @@ def save_checkpoint(path: str,
|
||||
# save the dist context about the tensors in a new dict, while still maintain the original dict.
|
||||
for k, v in model_state.items():
|
||||
if isinstance(v, ColoTensor):
|
||||
gather_tensor(v) # gather shared tensors to rank0
|
||||
gather_tensor(v) # gather shared tensors to rank0
|
||||
# don't recover tensors in rank0, since the dict is only a copy of model
|
||||
|
||||
if rank == 0:
|
||||
@@ -39,10 +41,10 @@ def save_checkpoint(path: str,
|
||||
if isinstance(v, ColoTensor):
|
||||
assert v.save_ready
|
||||
assert v.is_replicate()
|
||||
delattr(v, 'save_ready')
|
||||
delattr(v, "save_ready")
|
||||
# model saving
|
||||
save_state = {'epoch': epoch, 'model': model_state}
|
||||
torch.save(save_state, path + '/epoch_{}_model.pth'.format(epoch), *args, **kwargs)
|
||||
save_state = {"epoch": epoch, "model": model_state}
|
||||
torch.save(save_state, path + "/epoch_{}_model.pth".format(epoch), *args, **kwargs)
|
||||
|
||||
# delete old dicts
|
||||
del model_state
|
||||
@@ -52,35 +54,37 @@ def save_checkpoint(path: str,
|
||||
if optimizer is not None:
|
||||
mapping = dict()
|
||||
optim_state = optimizer.state_dict()
|
||||
for k, v in optim_state['state'].items():
|
||||
for k, v in optim_state["state"].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
mapping[(k, n)] = t.dist_spec
|
||||
gather_tensor(t)
|
||||
|
||||
if rank == 0:
|
||||
save_state = {'epoch': epoch, 'optim': optim_state}
|
||||
torch.save(save_state, path + '/epoch_{}_optim.pth'.format(epoch), *args, **kwargs)
|
||||
save_state = {"epoch": epoch, "optim": optim_state}
|
||||
torch.save(save_state, path + "/epoch_{}_optim.pth".format(epoch), *args, **kwargs)
|
||||
# recover colo tensors in rank0
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for k, v in optimizer.state_dict()["state"].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
assert hasattr(t, 'save_ready')
|
||||
assert hasattr(t, "save_ready")
|
||||
t.set_dist_spec(mapping[(k, n)])
|
||||
delattr(t, 'save_ready')
|
||||
delattr(t, "save_ready")
|
||||
|
||||
del optim_state
|
||||
del mapping
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def load_checkpoint(path: str,
|
||||
epoch: int,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
torch_load_kwargs: Optional[Dict] = None,
|
||||
load_state_dict_kwargs: Optional[Dict] = None):
|
||||
def load_checkpoint(
|
||||
path: str,
|
||||
epoch: int,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
torch_load_kwargs: Optional[Dict] = None,
|
||||
load_state_dict_kwargs: Optional[Dict] = None,
|
||||
):
|
||||
"""load_checkpoint
|
||||
load a model, whose parameters are `ColoTensor`s.
|
||||
Args:
|
||||
@@ -106,8 +110,8 @@ def load_checkpoint(path: str,
|
||||
gather_tensor(p)
|
||||
|
||||
if rank == 0:
|
||||
load_state = torch.load(path + '/epoch_{}_model.pth'.format(epoch), **torch_load_kwargs)
|
||||
model.load_state_dict(load_state['model'], **load_state_dict_kwargs)
|
||||
load_state = torch.load(path + "/epoch_{}_model.pth".format(epoch), **torch_load_kwargs)
|
||||
model.load_state_dict(load_state["model"], **load_state_dict_kwargs)
|
||||
dist.barrier()
|
||||
|
||||
# scatter loaded parameters
|
||||
@@ -115,24 +119,24 @@ def load_checkpoint(path: str,
|
||||
if isinstance(p, ColoTensor):
|
||||
scatter_tensor(p, mapping[n])
|
||||
if rank == 0:
|
||||
assert hasattr(p, 'save_ready')
|
||||
delattr(p, 'save_ready')
|
||||
assert hasattr(p, "save_ready")
|
||||
delattr(p, "save_ready")
|
||||
del mapping
|
||||
|
||||
if optimizer is not None:
|
||||
mapping = dict()
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for k, v in optimizer.state_dict()["state"].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
mapping[(k, n)] = t.dist_spec
|
||||
gather_tensor(t)
|
||||
|
||||
if rank == 0:
|
||||
colo_checkpoint = torch.load(path + '/epoch_{}_optim.pth'.format(epoch), **torch_load_kwargs)
|
||||
optimizer.load_state_dict(colo_checkpoint['optim'], **load_state_dict_kwargs)
|
||||
colo_checkpoint = torch.load(path + "/epoch_{}_optim.pth".format(epoch), **torch_load_kwargs)
|
||||
optimizer.load_state_dict(colo_checkpoint["optim"], **load_state_dict_kwargs)
|
||||
dist.barrier()
|
||||
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for k, v in optimizer.state_dict()["state"].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
scatter_tensor(t, mapping[(k, n)])
|
||||
|
@@ -8,7 +8,7 @@ from colossalai.tensor import ColoTensor
|
||||
|
||||
def robust_broadcast(tensor):
|
||||
with torch.no_grad():
|
||||
is_cpu_ten = tensor.device.type == 'cpu'
|
||||
is_cpu_ten = tensor.device.type == "cpu"
|
||||
if is_cpu_ten:
|
||||
b_data = tensor.cuda()
|
||||
else:
|
||||
@@ -21,8 +21,7 @@ def robust_broadcast(tensor):
|
||||
|
||||
|
||||
def gather_tensor(colo_tensor: ColoTensor) -> None:
|
||||
"""Make colo_tensor replicated when the rank is 0
|
||||
"""
|
||||
"""Make colo_tensor replicated when the rank is 0"""
|
||||
if not colo_tensor.is_replicate():
|
||||
pg = colo_tensor.get_process_group()
|
||||
# for the group which contains rank 0
|
||||
@@ -36,12 +35,11 @@ def gather_tensor(colo_tensor: ColoTensor) -> None:
|
||||
dist.barrier()
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
setattr(colo_tensor, 'save_ready', True) # set saving signature
|
||||
setattr(colo_tensor, "save_ready", True) # set saving signature
|
||||
|
||||
|
||||
def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
|
||||
"""Reversal operation of `gather_tensor`.
|
||||
"""
|
||||
"""Reversal operation of `gather_tensor`."""
|
||||
if dist_spec.placement == DistPlacementPattern.REPLICATE:
|
||||
robust_broadcast(colo_tensor.data)
|
||||
else:
|
||||
@@ -57,7 +55,8 @@ def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
|
||||
colo_tensor.set_dist_spec(dist_spec)
|
||||
else:
|
||||
rep_tensor = ColoTensor(
|
||||
entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec))
|
||||
entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec)
|
||||
)
|
||||
rep_tensor.set_dist_spec(dist_spec)
|
||||
with torch.no_grad():
|
||||
colo_tensor.data.copy_(rep_tensor.data)
|
||||
|
@@ -11,7 +11,7 @@ from colossalai.legacy.core import global_context as gpc
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
||||
except ImportError:
|
||||
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
||||
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
|
||||
|
||||
from .common import is_using_pp
|
||||
|
||||
@@ -25,10 +25,9 @@ def broadcast_state_dict(state_dict, parallel_mode):
|
||||
return state_dict[0]
|
||||
|
||||
|
||||
def partition_tensor_parallel_state_dict(state_dict: OrderedDict,
|
||||
parallel_mode: ParallelMode,
|
||||
dims: dict = dict(),
|
||||
partition_states: dict = dict()):
|
||||
def partition_tensor_parallel_state_dict(
|
||||
state_dict: OrderedDict, parallel_mode: ParallelMode, dims: dict = dict(), partition_states: dict = dict()
|
||||
):
|
||||
src_rank = gpc.get_ranks_in_group(parallel_mode)[0]
|
||||
depth = gpc.get_world_size(parallel_mode)
|
||||
group = gpc.get_cpu_group(parallel_mode)
|
||||
@@ -65,11 +64,11 @@ def partition_tensor_parallel_state_dict(state_dict: OrderedDict,
|
||||
|
||||
|
||||
def gather_tensor_parallel_state_dict(
|
||||
state_dict: OrderedDict,
|
||||
parallel_mode: ParallelMode,
|
||||
dims: dict = dict(),
|
||||
partition_states: dict = dict(),
|
||||
keep_vars: bool = False,
|
||||
state_dict: OrderedDict,
|
||||
parallel_mode: ParallelMode,
|
||||
dims: dict = dict(),
|
||||
partition_states: dict = dict(),
|
||||
keep_vars: bool = False,
|
||||
):
|
||||
dst_rank = gpc.get_ranks_in_group(parallel_mode)[0]
|
||||
depth = gpc.get_world_size(parallel_mode)
|
||||
@@ -138,8 +137,11 @@ def partition_pipeline_parallel_state_dict(model, state_dict):
|
||||
|
||||
|
||||
def gather_pipeline_parallel_state_dict(state_dict):
|
||||
gathered_states = ([None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))]
|
||||
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else None)
|
||||
gathered_states = (
|
||||
[None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))]
|
||||
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0
|
||||
else None
|
||||
)
|
||||
dist.gather_object(
|
||||
state_dict,
|
||||
gathered_states,
|
||||
@@ -147,18 +149,23 @@ def gather_pipeline_parallel_state_dict(state_dict):
|
||||
group=gpc.get_cpu_group(ParallelMode.PIPELINE),
|
||||
)
|
||||
|
||||
state_dict = (OrderedDict(chain.from_iterable(state.items() for state in gathered_states))
|
||||
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else OrderedDict())
|
||||
state_dict = (
|
||||
OrderedDict(chain.from_iterable(state.items() for state in gathered_states))
|
||||
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0
|
||||
else OrderedDict()
|
||||
)
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_checkpoint(file,
|
||||
epoch: int,
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer = None,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
**kwargs):
|
||||
def save_checkpoint(
|
||||
file,
|
||||
epoch: int,
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer = None,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer,
|
||||
lr_scheduler etc. into a checkpoint dictionary.
|
||||
|
||||
@@ -196,8 +203,11 @@ def broadcast_model(model: torch.nn.Module):
|
||||
src_rank = gpc.get_ranks_in_group(ParallelMode.TENSOR)[0]
|
||||
for p in model.parameters():
|
||||
if not getattr(p, IS_TENSOR_PARALLEL, False) and p.storage().size() > 0:
|
||||
group = gpc.get_group(ParallelMode.TENSOR) if p.device.type == 'cuda' else gpc.get_cpu_group(
|
||||
ParallelMode.TENSOR)
|
||||
group = (
|
||||
gpc.get_group(ParallelMode.TENSOR)
|
||||
if p.device.type == "cuda"
|
||||
else gpc.get_cpu_group(ParallelMode.TENSOR)
|
||||
)
|
||||
dist.broadcast(p, src_rank, group=group)
|
||||
|
||||
|
||||
@@ -226,8 +236,9 @@ def load_checkpoint(
|
||||
Raises:
|
||||
RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated
|
||||
"""
|
||||
state_dict = (torch.load(file, map_location=torch.device("cpu"))
|
||||
if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None)
|
||||
state_dict = (
|
||||
torch.load(file, map_location=torch.device("cpu")) if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None
|
||||
)
|
||||
|
||||
# model states
|
||||
model_state = state_dict.pop("model") if state_dict is not None else dict()
|
||||
@@ -246,8 +257,11 @@ def load_checkpoint(
|
||||
dist.gather_object(error_msgs, all_error_msgs, dst=dst_rank, group=gpc.get_cpu_group(ParallelMode.MODEL))
|
||||
if gpc.get_global_rank() == 0:
|
||||
all_error_msgs = list(chain.from_iterable(all_error_msgs))
|
||||
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(
|
||||
model.__class__.__name__, "\n\t".join(all_error_msgs)))
|
||||
raise RuntimeError(
|
||||
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
||||
model.__class__.__name__, "\n\t".join(all_error_msgs)
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
@@ -80,7 +80,6 @@ def is_using_sequence():
|
||||
|
||||
|
||||
class model_branch_context(object):
|
||||
|
||||
def __enter__(self):
|
||||
self.env_status = env.save()
|
||||
|
||||
@@ -98,16 +97,14 @@ def _calc_l2_norm(grads):
|
||||
|
||||
if fused_optim is None:
|
||||
from colossalai.kernel.op_builder import FusedOptimBuilder
|
||||
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
|
||||
norm = 0.0
|
||||
if len(grads) > 0:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
norm, _ = multi_tensor_applier(
|
||||
fused_optim.multi_tensor_l2norm,
|
||||
dummy_overflow_buf,
|
||||
[grads],
|
||||
False # no per-parameter norm
|
||||
fused_optim.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm
|
||||
)
|
||||
return norm
|
||||
|
||||
@@ -121,7 +118,7 @@ def _calc_lp(grads, norm_type):
|
||||
|
||||
|
||||
def _move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
|
||||
if torch.is_tensor(norm) and norm.device.type != 'cuda':
|
||||
if torch.is_tensor(norm) and norm.device.type != "cuda":
|
||||
norm = norm.to(torch.cuda.current_device())
|
||||
return norm
|
||||
|
||||
@@ -141,11 +138,11 @@ def _compute_local_lp(params: List[ColoParameter], norm_type: float) -> float:
|
||||
if len(params) == 0:
|
||||
return 0.0
|
||||
grads = [p.grad for p in params]
|
||||
use_cuda_kernel = grads[0].device.type == 'cuda'
|
||||
use_cuda_kernel = grads[0].device.type == "cuda"
|
||||
if norm_type == inf:
|
||||
local_lp = max([g.abs().max() for g in grads])
|
||||
elif norm_type == 2.0 and use_cuda_kernel:
|
||||
local_lp = _calc_l2_norm(grads)**norm_type
|
||||
local_lp = _calc_l2_norm(grads) ** norm_type
|
||||
else:
|
||||
local_lp = _calc_lp(grads, norm_type)
|
||||
if isinstance(local_lp, torch.Tensor):
|
||||
@@ -202,8 +199,8 @@ def _compute_grad_lp(parameters, norm_type: float = 2.0) -> float:
|
||||
assert isinstance(p, ColoParameter)
|
||||
if grad_dtype is None:
|
||||
grad_dtype = p.grad.dtype
|
||||
assert p.grad.dtype == grad_dtype, f'Expected all grads are {grad_dtype}, got {p.grad.dtype}'
|
||||
if p.grad.device.type == 'cuda':
|
||||
assert p.grad.dtype == grad_dtype, f"Expected all grads are {grad_dtype}, got {p.grad.dtype}"
|
||||
if p.grad.device.type == "cuda":
|
||||
cuda_grad_params.append(p)
|
||||
else:
|
||||
cpu_grad_params.append(p)
|
||||
@@ -221,7 +218,7 @@ def compute_grad_norm(parameters, norm_type: float = 2.0) -> float:
|
||||
norm_type = float(norm_type)
|
||||
total_norm = _compute_grad_lp(parameters, norm_type)
|
||||
if norm_type != inf:
|
||||
total_norm = total_norm**(1 / norm_type)
|
||||
total_norm = total_norm ** (1 / norm_type)
|
||||
return total_norm
|
||||
|
||||
|
||||
@@ -235,14 +232,15 @@ def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None:
|
||||
for p in parameters:
|
||||
if p.grad is None:
|
||||
continue
|
||||
if p.grad.device.type == 'cuda':
|
||||
if p.grad.device.type == "cuda":
|
||||
cuda_grads.append(p.grad.detach())
|
||||
else:
|
||||
cpu_grads.append(p.grad.detach())
|
||||
if len(cuda_grads) > 0:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads],
|
||||
clip_coef)
|
||||
multi_tensor_applier(
|
||||
fused_optim.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads], clip_coef
|
||||
)
|
||||
for g in cpu_grads:
|
||||
g.mul_(clip_coef)
|
||||
|
||||
@@ -284,16 +282,17 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||
for param in parameters:
|
||||
if param.grad is not None:
|
||||
# Make sure the grads are in fp32
|
||||
assert param.grad.dtype == torch.float, \
|
||||
f'expected gradient to be dtype torch.float, but got {param.grad.type()}'
|
||||
if hasattr(param, 'colo_attr') and param.colo_attr.sharded_data_tensor.is_sharded:
|
||||
assert (
|
||||
param.grad.dtype == torch.float
|
||||
), f"expected gradient to be dtype torch.float, but got {param.grad.type()}"
|
||||
if hasattr(param, "colo_attr") and param.colo_attr.sharded_data_tensor.is_sharded:
|
||||
has_zero_shared_param = True
|
||||
params.append(param)
|
||||
|
||||
if len(params) == 0:
|
||||
enable_cuda_kernels = False
|
||||
else:
|
||||
enable_cuda_kernels = params[0].grad.device.type == 'cuda'
|
||||
enable_cuda_kernels = params[0].grad.device.type == "cuda"
|
||||
# Norm parameters.
|
||||
max_norm = float(max_norm)
|
||||
norm_type = float(norm_type)
|
||||
@@ -307,15 +306,13 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
# Take max across all model-parallel GPUs.
|
||||
if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
|
||||
dist.all_reduce(total_norm_cuda,
|
||||
op=dist.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.MODEL),
|
||||
async_op=False)
|
||||
dist.all_reduce(
|
||||
total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL), async_op=False
|
||||
)
|
||||
if has_zero_shared_param:
|
||||
dist.all_reduce(total_norm_cuda,
|
||||
op=dist.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.DATA),
|
||||
async_op=False)
|
||||
dist.all_reduce(
|
||||
total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.DATA), async_op=False
|
||||
)
|
||||
total_norm = total_norm_cuda[0].item()
|
||||
else:
|
||||
tensor_parallel_grads = []
|
||||
@@ -323,17 +320,17 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||
zero_sharded_grads = []
|
||||
for p in params:
|
||||
if is_model_parallel_parameter(p):
|
||||
reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type)
|
||||
reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS)) ** (1 / norm_type)
|
||||
tensor_parallel_grads.append(p.grad.data / reductor)
|
||||
elif hasattr(p, 'colo_attr') and p.colo_attr.sharded_data_tensor.is_sharded:
|
||||
elif hasattr(p, "colo_attr") and p.colo_attr.sharded_data_tensor.is_sharded:
|
||||
zero_sharded_grads.append(p.grad.data)
|
||||
else:
|
||||
no_tensor_parallel_grads.append(p.grad.data)
|
||||
|
||||
if norm_type == 2.0 and enable_cuda_kernels:
|
||||
tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type
|
||||
no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type
|
||||
zero_sharded_norm = _calc_l2_norm(zero_sharded_grads)**norm_type
|
||||
tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads) ** norm_type
|
||||
no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads) ** norm_type
|
||||
zero_sharded_norm = _calc_l2_norm(zero_sharded_grads) ** norm_type
|
||||
else:
|
||||
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
|
||||
no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type)
|
||||
@@ -358,7 +355,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||
total_norm = tensor_parallel_norm + no_tensor_parallel_norm
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
|
||||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
total_norm = total_norm**(1.0 / norm_type)
|
||||
total_norm = total_norm ** (1.0 / norm_type)
|
||||
if torch.is_tensor(total_norm):
|
||||
total_norm = total_norm.item()
|
||||
|
||||
@@ -397,13 +394,14 @@ def count_zeros_fp32(parameters):
|
||||
# Sum across all model-parallel GPUs.
|
||||
ops = []
|
||||
ops.append(
|
||||
dist.all_reduce(total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR), async_op=True))
|
||||
dist.all_reduce(total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR), async_op=True)
|
||||
)
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
ops.append(
|
||||
dist.all_reduce(total_num_zeros,
|
||||
op=dist.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.PIPELINE),
|
||||
async_op=True))
|
||||
dist.all_reduce(
|
||||
total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE), async_op=True
|
||||
)
|
||||
)
|
||||
|
||||
for req in ops:
|
||||
req.wait()
|
||||
@@ -420,8 +418,9 @@ def copy_tensor_parallel_attributes(src_tensor, dst_tensor):
|
||||
|
||||
|
||||
def param_is_not_tensor_parallel_duplicate(param):
|
||||
return (hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL)) or (gpc.get_local_rank(
|
||||
ParallelMode.TENSOR) == 0)
|
||||
return (hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL)) or (
|
||||
gpc.get_local_rank(ParallelMode.TENSOR) == 0
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from .base_sampler import BaseSampler
|
||||
from .data_parallel_sampler import DataParallelSampler, get_dataloader
|
||||
|
||||
__all__ = ['BaseSampler', 'DataParallelSampler', 'get_dataloader']
|
||||
__all__ = ["BaseSampler", "DataParallelSampler", "get_dataloader"]
|
||||
|
@@ -5,7 +5,6 @@ from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseSampler(ABC):
|
||||
|
||||
def __init__(self, dataset, batch_size):
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
|
@@ -13,7 +13,7 @@ from torch.utils.data import DataLoader, Dataset, Sampler
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
T_co = TypeVar("T_co", covariant=True)
|
||||
|
||||
|
||||
class DataParallelSampler(Sampler):
|
||||
@@ -44,11 +44,11 @@ class DataParallelSampler(Sampler):
|
||||
self.num_samples = math.ceil(
|
||||
# `type:ignore` is required because Dataset cannot provide a default __len__
|
||||
# see NOTE in pytorch/torch/utils/data/sampler.py
|
||||
(len(self.dataset) - self.num_replicas) / \
|
||||
self.num_replicas # type: ignore[arg-type]
|
||||
(len(self.dataset) - self.num_replicas)
|
||||
/ self.num_replicas # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
|
||||
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
@@ -65,7 +65,7 @@ class DataParallelSampler(Sampler):
|
||||
# set_epoch manually
|
||||
self.epoch += 1
|
||||
else:
|
||||
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
|
||||
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
|
||||
|
||||
if not self.drop_last:
|
||||
# add extra samples to make it evenly divisible
|
||||
@@ -76,11 +76,11 @@ class DataParallelSampler(Sampler):
|
||||
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
|
||||
else:
|
||||
# remove tail of data to make it evenly divisible.
|
||||
indices = indices[:self.total_size]
|
||||
indices = indices[: self.total_size]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
indices = indices[self.rank : self.total_size : self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
@@ -99,14 +99,9 @@ class DataParallelSampler(Sampler):
|
||||
self.epoch = epoch
|
||||
|
||||
|
||||
def get_dataloader(dataset,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
add_sampler=True,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
**kwargs):
|
||||
def get_dataloader(
|
||||
dataset, shuffle=False, seed=1024, add_sampler=True, drop_last=False, pin_memory=False, num_workers=0, **kwargs
|
||||
):
|
||||
r"""Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
|
||||
|
||||
Note:
|
||||
@@ -144,18 +139,22 @@ def get_dataloader(dataset,
|
||||
random.seed(worker_seed)
|
||||
|
||||
if sampler is None:
|
||||
return DataLoader(dataset,
|
||||
worker_init_fn=seed_worker,
|
||||
shuffle=shuffle,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
worker_init_fn=seed_worker,
|
||||
shuffle=shuffle,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs,
|
||||
)
|
||||
else:
|
||||
return DataLoader(dataset,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs,
|
||||
)
|
||||
|
@@ -76,8 +76,10 @@ def report_memory_usage(message, logger=None, report_cpu=False):
|
||||
gpu_cached = _bytes_to_MB(torch.cuda.memory_reserved())
|
||||
gpu_max_cached = _bytes_to_MB(torch.cuda.max_memory_reserved())
|
||||
|
||||
full_log = f"{message}: GPU: allocated {gpu_allocated} MB, max allocated {gpu_max_allocated} MB, " \
|
||||
full_log = (
|
||||
f"{message}: GPU: allocated {gpu_allocated} MB, max allocated {gpu_max_allocated} MB, "
|
||||
+ f"cached: {gpu_cached} MB, max cached: {gpu_max_cached} MB"
|
||||
)
|
||||
|
||||
if report_cpu:
|
||||
# python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports
|
||||
@@ -91,7 +93,7 @@ def report_memory_usage(message, logger=None, report_cpu=False):
|
||||
logger.info(full_log)
|
||||
|
||||
# get the peak memory to report correct data, so reset the counter for the next call
|
||||
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
|
||||
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
|
||||
@@ -106,10 +108,10 @@ def colo_device_memory_capacity(device: torch.device) -> int:
|
||||
int: size in byte
|
||||
"""
|
||||
assert isinstance(device, torch.device)
|
||||
if device.type == 'cpu':
|
||||
if device.type == "cpu":
|
||||
# In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory.
|
||||
return colo_get_cpu_memory_capacity() / gpc.num_processes_on_current_node
|
||||
if device.type == 'cuda':
|
||||
if device.type == "cuda":
|
||||
return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION
|
||||
|
||||
|
||||
@@ -123,16 +125,16 @@ def colo_device_memory_used(device: torch.device) -> int:
|
||||
Returns:
|
||||
int: memory size in bytes
|
||||
"""
|
||||
if device.type == 'cpu':
|
||||
if device.type == "cpu":
|
||||
mem_info = _get_cpu_memory_info()
|
||||
# In the context of 1-CPU-N-GPU, the memory usage of the current process is 1/N CPU memory used.
|
||||
# Each process consumes the same amount of memory.
|
||||
ret = mem_info.used / gpc.num_processes_on_current_node
|
||||
return ret
|
||||
elif device.type == 'cuda':
|
||||
elif device.type == "cuda":
|
||||
ret: int = torch.cuda.memory_allocated(device)
|
||||
# get the peak memory to report correct data, so reset the counter for the next call
|
||||
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
|
||||
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
return ret
|
||||
|
||||
@@ -145,9 +147,9 @@ def colo_set_process_memory_fraction(ratio: float) -> None:
|
||||
Args:
|
||||
ratio (float): a ratio between 0. ~ 1.
|
||||
"""
|
||||
if version.parse(torch.__version__) < version.parse('1.8'):
|
||||
logger = get_dist_logger('colo_set_process_memory_fraction')
|
||||
logger.warning('colo_set_process_memory_fraction failed because torch version is less than 1.8')
|
||||
if version.parse(torch.__version__) < version.parse("1.8"):
|
||||
logger = get_dist_logger("colo_set_process_memory_fraction")
|
||||
logger.warning("colo_set_process_memory_fraction failed because torch version is less than 1.8")
|
||||
return
|
||||
global _GLOBAL_CUDA_MEM_FRACTION
|
||||
_GLOBAL_CUDA_MEM_FRACTION = ratio
|
||||
|
@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ProfilerExtension(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def prepare_trace(self):
|
||||
pass
|
||||
|
@@ -3,4 +3,4 @@ from .mem_profiler import MemProfiler
|
||||
from .pcie_profiler import PcieProfiler
|
||||
from .prof_utils import BaseProfiler, ProfilerContext
|
||||
|
||||
__all__ = ['BaseProfiler', 'CommProfiler', 'PcieProfiler', 'MemProfiler', 'ProfilerContext']
|
||||
__all__ = ["BaseProfiler", "CommProfiler", "PcieProfiler", "MemProfiler", "ProfilerContext"]
|
||||
|
@@ -20,14 +20,14 @@ def _get_code_location(depth: int):
|
||||
upper_frame = inspect.stack()[i]
|
||||
function_name = inspect.stack()[i - 1].function
|
||||
ret.append(upper_frame.filename)
|
||||
ret.append('(')
|
||||
ret.append("(")
|
||||
ret.append(str(upper_frame.lineno))
|
||||
ret.append('): ')
|
||||
ret.append("): ")
|
||||
ret.append(function_name)
|
||||
if i != length - 1:
|
||||
ret.append('\n')
|
||||
ret.append("\n")
|
||||
|
||||
return ''.join(ret)
|
||||
return "".join(ret)
|
||||
|
||||
|
||||
torch_all_reduce = dist.all_reduce
|
||||
@@ -42,7 +42,7 @@ class CommEvent(object):
|
||||
volume recording.
|
||||
"""
|
||||
|
||||
def __init__(self, count: int = 0, comm_vol: float = 0., cuda_time: int = 0):
|
||||
def __init__(self, count: int = 0, comm_vol: float = 0.0, cuda_time: int = 0):
|
||||
self.self_count = count
|
||||
self.self_comm_vol = comm_vol
|
||||
self.self_cuda_time = cuda_time
|
||||
@@ -54,8 +54,7 @@ class CommEvent(object):
|
||||
|
||||
|
||||
class CommProfiler(BaseProfiler):
|
||||
"""Communication profiler. Records all communication events.
|
||||
"""
|
||||
"""Communication profiler. Records all communication events."""
|
||||
|
||||
def __init__(self, depth: int = 0, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0):
|
||||
super().__init__(profiler_name="Collective_Communication", priority=0)
|
||||
@@ -114,8 +113,10 @@ class CommProfiler(BaseProfiler):
|
||||
res.append(sep)
|
||||
|
||||
if self.warn_flag:
|
||||
append("Warning: there exists multiple communication operations in the same time. As a result, "
|
||||
"the profiling result is not accurate.")
|
||||
append(
|
||||
"Warning: there exists multiple communication operations in the same time. As a result, "
|
||||
"the profiling result is not accurate."
|
||||
)
|
||||
|
||||
if self.total_cuda_time == 0:
|
||||
return "No collective communication has been called yet!"
|
||||
@@ -126,24 +127,29 @@ class CommProfiler(BaseProfiler):
|
||||
append("total number of calls: {}".format(self.total_count))
|
||||
append("All events:")
|
||||
|
||||
separation = '-' * 74
|
||||
row_format = '{:^10}' + '{:^12}' * 2 + '{:^16}' + '{:^12}' * 2
|
||||
separation = "-" * 74
|
||||
row_format = "{:^10}" + "{:^12}" * 2 + "{:^16}" + "{:^12}" * 2
|
||||
|
||||
append(separation)
|
||||
append(row_format.format('Location', 'GPU time', 'Percentage', 'Comm volume', 'Bandwidth', 'Num of calls'))
|
||||
append(row_format.format("Location", "GPU time", "Percentage", "Comm volume", "Bandwidth", "Num of calls"))
|
||||
append(separation)
|
||||
|
||||
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time)
|
||||
for location, event in show_list:
|
||||
append(location)
|
||||
append(
|
||||
row_format.format('', _format_time(event.self_cuda_time),
|
||||
'{:.1f}%'.format(event.self_cuda_time / self.total_cuda_time * 100.0),
|
||||
_format_memory(event.self_comm_vol),
|
||||
_format_bandwidth(event.self_comm_vol, event.self_cuda_time), event.self_count))
|
||||
row_format.format(
|
||||
"",
|
||||
_format_time(event.self_cuda_time),
|
||||
"{:.1f}%".format(event.self_cuda_time / self.total_cuda_time * 100.0),
|
||||
_format_memory(event.self_comm_vol),
|
||||
_format_bandwidth(event.self_comm_vol, event.self_cuda_time),
|
||||
event.self_count,
|
||||
)
|
||||
)
|
||||
append()
|
||||
|
||||
return ''.join(res)
|
||||
return "".join(res)
|
||||
|
||||
@property
|
||||
def has_aync_op(self):
|
||||
@@ -195,8 +201,7 @@ class CommProfiler(BaseProfiler):
|
||||
|
||||
|
||||
class CommHandler(object):
|
||||
"""Communication handler. A dummy handler to wait aync operations.
|
||||
"""
|
||||
"""Communication handler. A dummy handler to wait aync operations."""
|
||||
|
||||
def __init__(self, profiler: CommProfiler):
|
||||
super().__init__()
|
||||
@@ -212,11 +217,9 @@ def async_check(profiler: CommProfiler):
|
||||
profiler.wait_async_op()
|
||||
|
||||
|
||||
def all_reduce(tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
profiler: CommProfiler = None) -> Optional[CommHandler]:
|
||||
def all_reduce(
|
||||
tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, group=None, async_op: bool = False, profiler: CommProfiler = None
|
||||
) -> Optional[CommHandler]:
|
||||
async_check(profiler)
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
@@ -231,12 +234,14 @@ def all_reduce(tensor: torch.Tensor,
|
||||
profiler.close_profiler(group)
|
||||
|
||||
|
||||
def reduce_scatter(output: torch.Tensor,
|
||||
input_list: List[torch.Tensor],
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
profiler: CommProfiler = None) -> Optional[CommHandler]:
|
||||
def reduce_scatter(
|
||||
output: torch.Tensor,
|
||||
input_list: List[torch.Tensor],
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
profiler: CommProfiler = None,
|
||||
) -> Optional[CommHandler]:
|
||||
async_check(profiler)
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
@@ -254,11 +259,13 @@ def reduce_scatter(output: torch.Tensor,
|
||||
profiler.close_profiler(group)
|
||||
|
||||
|
||||
def all_gather(tensor_list: List[torch.Tensor],
|
||||
tensor: torch.Tensor,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
profiler: CommProfiler = None) -> Optional[CommHandler]:
|
||||
def all_gather(
|
||||
tensor_list: List[torch.Tensor],
|
||||
tensor: torch.Tensor,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
profiler: CommProfiler = None,
|
||||
) -> Optional[CommHandler]:
|
||||
async_check(profiler)
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
@@ -276,11 +283,9 @@ def all_gather(tensor_list: List[torch.Tensor],
|
||||
profiler.close_profiler(group)
|
||||
|
||||
|
||||
def broadcast(tensor: torch.Tensor,
|
||||
src: int,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
profiler: CommProfiler = None) -> Optional[CommHandler]:
|
||||
def broadcast(
|
||||
tensor: torch.Tensor, src: int, group=None, async_op: bool = False, profiler: CommProfiler = None
|
||||
) -> Optional[CommHandler]:
|
||||
async_check(profiler)
|
||||
|
||||
comm_vol = 1.0 * tensor.element_size() * tensor.numel()
|
||||
@@ -293,12 +298,14 @@ def broadcast(tensor: torch.Tensor,
|
||||
profiler.close_profiler(group)
|
||||
|
||||
|
||||
def reduce(tensor: torch.Tensor,
|
||||
dst: int,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
profiler: CommProfiler = None) -> Optional[CommHandler]:
|
||||
def reduce(
|
||||
tensor: torch.Tensor,
|
||||
dst: int,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
profiler: CommProfiler = None,
|
||||
) -> Optional[CommHandler]:
|
||||
async_check(profiler)
|
||||
|
||||
comm_vol = 1.0 * tensor.element_size() * tensor.numel()
|
||||
|
@@ -18,6 +18,7 @@ def _get_size(dtype: str):
|
||||
def _get_numel(my_list: List[int]) -> int:
|
||||
from functools import reduce
|
||||
from operator import mul
|
||||
|
||||
return reduce(mul, my_list)
|
||||
|
||||
|
||||
@@ -27,12 +28,11 @@ def _reduce_location(locations: List[str]) -> str:
|
||||
ret.append(lo)
|
||||
ret.append("\n")
|
||||
ret = ret[:-1]
|
||||
return ''.join(ret)
|
||||
return "".join(ret)
|
||||
|
||||
|
||||
class PcieEvent(object):
|
||||
"""Pcie Event.
|
||||
"""
|
||||
"""Pcie Event."""
|
||||
|
||||
def __init__(self, count: int = 0, pcie_vol: int = 0, cuda_time: int = 0):
|
||||
self.count = count
|
||||
@@ -73,12 +73,9 @@ class PcieProfiler(BaseProfiler):
|
||||
self.profiler = None
|
||||
|
||||
def enable(self):
|
||||
self.profiler = profile(enabled=True,
|
||||
use_cuda=True,
|
||||
use_cpu=True,
|
||||
use_kineto=True,
|
||||
record_shapes=True,
|
||||
with_stack=True)
|
||||
self.profiler = profile(
|
||||
enabled=True, use_cuda=True, use_cpu=True, use_kineto=True, record_shapes=True, with_stack=True
|
||||
)
|
||||
self.profiler.__enter__()
|
||||
|
||||
def disable(self):
|
||||
@@ -92,15 +89,15 @@ class PcieProfiler(BaseProfiler):
|
||||
if len(t_shape) == 0 or event.cuda_time_total == 0 or len(event.stack) == 0:
|
||||
continue
|
||||
current_comm_event = PcieEvent(1, self.data_size * _get_numel(t_shape), event.cuda_time_total)
|
||||
code_location = _reduce_location(event.stack[:self.depth])
|
||||
code_location = _reduce_location(event.stack[: self.depth])
|
||||
if code_location in self.ops_record:
|
||||
self.ops_record[code_location].add(current_comm_event)
|
||||
else:
|
||||
self.ops_record[code_location] = current_comm_event
|
||||
elif 'Memcpy HtoD' in event.name:
|
||||
elif "Memcpy HtoD" in event.name:
|
||||
self.h2d_count += 1
|
||||
self.h2d_time += event.cuda_time_total
|
||||
elif 'Memcpy DtoH' in event.name:
|
||||
elif "Memcpy DtoH" in event.name:
|
||||
self.d2h_count += 1
|
||||
self.d2h_time += event.cuda_time_total
|
||||
|
||||
@@ -132,19 +129,25 @@ class PcieProfiler(BaseProfiler):
|
||||
|
||||
append("Possible data transmission events in PCIE:")
|
||||
|
||||
separation = '-' * 62
|
||||
row_format = '{:^10}' + '{:^12}' + '{:^16}' + '{:^12}' * 2
|
||||
separation = "-" * 62
|
||||
row_format = "{:^10}" + "{:^12}" + "{:^16}" + "{:^12}" * 2
|
||||
|
||||
append(separation)
|
||||
append(row_format.format('Location', 'GPU time', 'Trans volume', 'Bandwidth', 'Num of calls'))
|
||||
append(row_format.format("Location", "GPU time", "Trans volume", "Bandwidth", "Num of calls"))
|
||||
append(separation)
|
||||
|
||||
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time)
|
||||
for location, event in show_list:
|
||||
append(location)
|
||||
append(
|
||||
row_format.format('', _format_time(event.cuda_time), _format_memory(event.pcie_vol),
|
||||
_format_bandwidth(event.pcie_vol, event.cuda_time), event.count))
|
||||
row_format.format(
|
||||
"",
|
||||
_format_time(event.cuda_time),
|
||||
_format_memory(event.pcie_vol),
|
||||
_format_bandwidth(event.pcie_vol, event.cuda_time),
|
||||
event.count,
|
||||
)
|
||||
)
|
||||
append()
|
||||
|
||||
return ''.join(res)
|
||||
return "".join(res)
|
||||
|
@@ -11,10 +11,10 @@ def _format_time(time_us):
|
||||
US_IN_SECOND = 1000.0 * 1000.0
|
||||
US_IN_MS = 1000.0
|
||||
if time_us >= US_IN_SECOND:
|
||||
return '{:.3f}s'.format(time_us / US_IN_SECOND)
|
||||
return "{:.3f}s".format(time_us / US_IN_SECOND)
|
||||
if time_us >= US_IN_MS:
|
||||
return '{:.3f}ms'.format(time_us / US_IN_MS)
|
||||
return '{:.3f}us'.format(time_us)
|
||||
return "{:.3f}ms".format(time_us / US_IN_MS)
|
||||
return "{:.3f}us".format(time_us)
|
||||
|
||||
|
||||
# copied from high version pytorch to support low version
|
||||
@@ -23,28 +23,27 @@ def _format_memory(nbytes):
|
||||
KB = 1024
|
||||
MB = 1024 * KB
|
||||
GB = 1024 * MB
|
||||
if (abs(nbytes) >= GB):
|
||||
return '{:.2f} GB'.format(nbytes * 1.0 / GB)
|
||||
elif (abs(nbytes) >= MB):
|
||||
return '{:.2f} MB'.format(nbytes * 1.0 / MB)
|
||||
elif (abs(nbytes) >= KB):
|
||||
return '{:.2f} KB'.format(nbytes * 1.0 / KB)
|
||||
if abs(nbytes) >= GB:
|
||||
return "{:.2f} GB".format(nbytes * 1.0 / GB)
|
||||
elif abs(nbytes) >= MB:
|
||||
return "{:.2f} MB".format(nbytes * 1.0 / MB)
|
||||
elif abs(nbytes) >= KB:
|
||||
return "{:.2f} KB".format(nbytes * 1.0 / KB)
|
||||
else:
|
||||
return str(nbytes) + ' B'
|
||||
return str(nbytes) + " B"
|
||||
|
||||
|
||||
def _format_bandwidth(volume: float or int, time_us: int):
|
||||
sec_div_mb = (1000.0 / 1024.0)**2
|
||||
sec_div_mb = (1000.0 / 1024.0) ** 2
|
||||
mb_per_sec = volume / time_us * sec_div_mb
|
||||
|
||||
if mb_per_sec >= 1024.0:
|
||||
return '{:.3f} GB/s'.format(mb_per_sec / 1024.0)
|
||||
return "{:.3f} GB/s".format(mb_per_sec / 1024.0)
|
||||
else:
|
||||
return '{:.3f} MB/s'.format(mb_per_sec)
|
||||
return "{:.3f} MB/s".format(mb_per_sec)
|
||||
|
||||
|
||||
class BaseProfiler(ABC):
|
||||
|
||||
def __init__(self, profiler_name: str, priority: int):
|
||||
self.name = profiler_name
|
||||
self.priority = priority
|
||||
@@ -111,8 +110,9 @@ class ProfilerContext(object):
|
||||
def to_tensorboard(self, writer):
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
assert isinstance(writer, SummaryWriter), \
|
||||
f'torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}.'
|
||||
assert isinstance(
|
||||
writer, SummaryWriter
|
||||
), f"torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}."
|
||||
|
||||
for prof in self.profilers:
|
||||
prof.to_tensorboard(writer)
|
||||
@@ -124,7 +124,7 @@ class ProfilerContext(object):
|
||||
if not log_dir.exists():
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
for prof in self.profilers:
|
||||
log_file = log_dir.joinpath(f'{prof.name}_rank_{gpc.get_global_rank()}.log')
|
||||
log_file = log_dir.joinpath(f"{prof.name}_rank_{gpc.get_global_rank()}.log")
|
||||
prof.to_file(log_file)
|
||||
|
||||
def show(self):
|
||||
|
@@ -120,26 +120,30 @@ class profile(torch_profile):
|
||||
p.step()
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
activities: Optional[Iterable[ProfilerActivity]] = None,
|
||||
schedule: Optional[Callable[[int], ProfilerAction]] = None,
|
||||
on_trace_ready: Optional[Callable[..., Any]] = None,
|
||||
engine: Optional[Engine] = None,
|
||||
record_shapes: bool = False,
|
||||
profile_memory: bool = False,
|
||||
with_stack: bool = False,
|
||||
with_flops: bool = False,
|
||||
with_modules: bool = False,
|
||||
profile_stateful_tensor_memory: bool = False) -> None:
|
||||
super().__init__(activities=activities,
|
||||
schedule=schedule,
|
||||
on_trace_ready=on_trace_ready,
|
||||
record_shapes=record_shapes,
|
||||
profile_memory=profile_memory,
|
||||
with_stack=with_stack,
|
||||
with_flops=with_flops,
|
||||
with_modules=with_modules)
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
activities: Optional[Iterable[ProfilerActivity]] = None,
|
||||
schedule: Optional[Callable[[int], ProfilerAction]] = None,
|
||||
on_trace_ready: Optional[Callable[..., Any]] = None,
|
||||
engine: Optional[Engine] = None,
|
||||
record_shapes: bool = False,
|
||||
profile_memory: bool = False,
|
||||
with_stack: bool = False,
|
||||
with_flops: bool = False,
|
||||
with_modules: bool = False,
|
||||
profile_stateful_tensor_memory: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
activities=activities,
|
||||
schedule=schedule,
|
||||
on_trace_ready=on_trace_ready,
|
||||
record_shapes=record_shapes,
|
||||
profile_memory=profile_memory,
|
||||
with_stack=with_stack,
|
||||
with_flops=with_flops,
|
||||
with_modules=with_modules,
|
||||
)
|
||||
self._logger = get_dist_logger()
|
||||
self.extentions: List[ProfilerExtension] = []
|
||||
if profile_stateful_tensor_memory:
|
||||
@@ -149,9 +153,9 @@ class profile(torch_profile):
|
||||
self.extentions.append(StatefulTensorMemoryProfilerExtention(engine))
|
||||
|
||||
def prepare_trace(self) -> None:
|
||||
if hasattr(super(), 'prepare_trace'):
|
||||
if hasattr(super(), "prepare_trace"):
|
||||
super().prepare_trace()
|
||||
elif hasattr(super(), '_start_warmup'):
|
||||
elif hasattr(super(), "_start_warmup"):
|
||||
super()._start_warmup()
|
||||
for ext in self.extentions:
|
||||
ext.prepare_trace()
|
||||
@@ -160,9 +164,9 @@ class profile(torch_profile):
|
||||
self.prepare_trace()
|
||||
|
||||
def start_trace(self):
|
||||
if hasattr(super(), '_start_trace'):
|
||||
if hasattr(super(), "_start_trace"):
|
||||
super()._start_trace()
|
||||
elif hasattr(super(), 'start_trace'):
|
||||
elif hasattr(super(), "start_trace"):
|
||||
super().start_trace()
|
||||
for ext in self.extentions:
|
||||
ext.start_trace()
|
||||
@@ -171,9 +175,9 @@ class profile(torch_profile):
|
||||
self.start_trace()
|
||||
|
||||
def stop_trace(self):
|
||||
if hasattr(super(), '_stop_trace'):
|
||||
if hasattr(super(), "_stop_trace"):
|
||||
super()._stop_trace()
|
||||
elif hasattr(super(), 'stop_trace'):
|
||||
elif hasattr(super(), "stop_trace"):
|
||||
super().stop_trace()
|
||||
for ext in self.extentions:
|
||||
ext.stop_trace()
|
||||
@@ -186,15 +190,15 @@ class profile(torch_profile):
|
||||
Exports the collected trace in Chrome JSON format.
|
||||
"""
|
||||
assert self.profiler
|
||||
fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False)
|
||||
fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=False)
|
||||
fp.close()
|
||||
retvalue = self.profiler.export_chrome_trace(fp.name)
|
||||
with open(fp.name) as fin:
|
||||
trace = json.load(fin)
|
||||
for ext in self.extentions:
|
||||
trace = ext.extend_chrome_trace(trace)
|
||||
open_func = gzip.open if path.endswith('.gz') else open
|
||||
with open_func(path, 'wt') as fout:
|
||||
open_func = gzip.open if path.endswith(".gz") else open
|
||||
with open_func(path, "wt") as fout:
|
||||
json.dump(trace, fout)
|
||||
|
||||
os.remove(fp.name)
|
||||
|
@@ -22,11 +22,11 @@ def get_timestamp_us():
|
||||
|
||||
|
||||
def generic_instant_event(name, pid, tid, timestamp, args):
|
||||
return {'ph': 'i', 's': 't', 'name': name, 'pid': pid, 'tid': tid, 'ts': timestamp, 'args': args}
|
||||
return {"ph": "i", "s": "t", "name": name, "pid": pid, "tid": tid, "ts": timestamp, "args": args}
|
||||
|
||||
|
||||
class StatefulTensorMemoryEvent:
|
||||
EVENT_NAME = '[statefulTensorMemory]'
|
||||
EVENT_NAME = "[statefulTensorMemory]"
|
||||
|
||||
def __init__(self, timestamp: int, device_type: DeviceType, bytes_: int) -> None:
|
||||
self.pid = os.getpid()
|
||||
@@ -37,22 +37,23 @@ class StatefulTensorMemoryEvent:
|
||||
self.bytes = bytes_
|
||||
|
||||
def state_dict(self):
|
||||
return generic_instant_event(StatefulTensorMemoryEvent.EVENT_NAME, self.pid, self.tid, self.timestamp, {
|
||||
'Device Type': self.device_type.value,
|
||||
'Device Id': self.device_id,
|
||||
'Bytes': self.bytes
|
||||
})
|
||||
return generic_instant_event(
|
||||
StatefulTensorMemoryEvent.EVENT_NAME,
|
||||
self.pid,
|
||||
self.tid,
|
||||
self.timestamp,
|
||||
{"Device Type": self.device_type.value, "Device Id": self.device_id, "Bytes": self.bytes},
|
||||
)
|
||||
|
||||
|
||||
class StatefulTensorMemoryTracer:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.events: List[StatefulTensorMemoryEvent] = []
|
||||
self._tracing = False
|
||||
|
||||
def sample(self):
|
||||
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
|
||||
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
|
||||
cuda_mem = StatefulTensor.GST_MGR.total_mem["cuda"]
|
||||
cpu_mem = StatefulTensor.GST_MGR.total_mem["cpu"]
|
||||
timestamp = get_timestamp_us()
|
||||
if self._tracing:
|
||||
self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CUDA, cuda_mem))
|
||||
@@ -70,7 +71,6 @@ class StatefulTensorMemoryTracer:
|
||||
|
||||
|
||||
class StatefulTensorMemoryTracerHook(BaseOpHook):
|
||||
|
||||
def __init__(self, tracer: StatefulTensorMemoryTracer):
|
||||
super().__init__()
|
||||
self.tracer = tracer
|
||||
@@ -104,7 +104,6 @@ class StatefulTensorMemoryTracerHook(BaseOpHook):
|
||||
|
||||
|
||||
class StatefulTensorMemoryProfilerExtention(ProfilerExtension):
|
||||
|
||||
def __init__(self, engine: Engine) -> None:
|
||||
self.engine = engine
|
||||
self.tracer = StatefulTensorMemoryTracer()
|
||||
@@ -131,5 +130,5 @@ class StatefulTensorMemoryProfilerExtention(ProfilerExtension):
|
||||
# self.hook_registered = False
|
||||
|
||||
def extend_chrome_trace(self, trace: dict) -> dict:
|
||||
trace['traceEvents'].extend(self.tracer.state_dict())
|
||||
trace["traceEvents"].extend(self.tracer.state_dict())
|
||||
return trace
|
||||
|
Reference in New Issue
Block a user