[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -26,28 +26,28 @@ from .memory import (
)
__all__ = [
'DataParallelSampler',
'get_dataloader',
'save_checkpoint',
'load_checkpoint',
'colo_device_memory_capacity',
'colo_device_memory_used',
'colo_get_cpu_memory_capacity',
'colo_set_cpu_memory_capacity',
'colo_set_process_memory_fraction',
'report_memory_usage',
'clip_grad_norm_fp32',
'copy_tensor_parallel_attributes',
'count_zeros_fp32',
'is_dp_rank_0',
'is_model_parallel_parameter',
'is_no_pp_or_last_stage',
'is_tp_rank_0',
'is_using_ddp',
'is_using_pp',
'is_using_sequence',
'param_is_not_tensor_parallel_duplicate',
'print_rank_0',
'switch_virtual_pipeline_parallel_rank',
'sync_model_param',
"DataParallelSampler",
"get_dataloader",
"save_checkpoint",
"load_checkpoint",
"colo_device_memory_capacity",
"colo_device_memory_used",
"colo_get_cpu_memory_capacity",
"colo_set_cpu_memory_capacity",
"colo_set_process_memory_fraction",
"report_memory_usage",
"clip_grad_norm_fp32",
"copy_tensor_parallel_attributes",
"count_zeros_fp32",
"is_dp_rank_0",
"is_model_parallel_parameter",
"is_no_pp_or_last_stage",
"is_tp_rank_0",
"is_using_ddp",
"is_using_pp",
"is_using_sequence",
"param_is_not_tensor_parallel_duplicate",
"print_rank_0",
"switch_virtual_pipeline_parallel_rank",
"sync_model_param",
]

View File

@@ -28,7 +28,6 @@ def copy_to_device(obj, device):
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, activation_offload=False, *args):
check_backward_validity(args)
@@ -42,7 +41,7 @@ class CheckpointFunction(torch.autograd.Function):
ctx.fwd_seed_states = get_states(copy=True)
ctx.fwd_current_mode = get_current_mode()
if hasattr(torch, 'is_autocast_enabled'):
if hasattr(torch, "is_autocast_enabled"):
ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
else:
ctx.had_autocast_in_fwd = False
@@ -62,7 +61,7 @@ class CheckpointFunction(torch.autograd.Function):
for i, arg in enumerate(args):
if torch.is_tensor(arg):
if activation_offload:
tensor_inputs.append(copy_to_device(arg, 'cpu'))
tensor_inputs.append(copy_to_device(arg, "cpu"))
else:
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
@@ -79,8 +78,10 @@ class CheckpointFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad() or when an `inputs` parameter is "
"passed to .backward(). Please use .backward() and do not pass its `inputs` argument.")
raise RuntimeError(
"Checkpointing is not compatible with .grad() or when an `inputs` parameter is "
"passed to .backward(). Please use .backward() and do not pass its `inputs` argument."
)
# Copy the list to avoid modifying original list.
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
@@ -131,8 +132,7 @@ class CheckpointFunction(torch.autograd.Function):
outputs_with_grad.append(outputs[i])
args_with_grad.append(args[i])
if len(outputs_with_grad) == 0:
raise RuntimeError("none of output has requires_grad=True,"
" this checkpoint() is not necessary")
raise RuntimeError("none of output has requires_grad=True," " this checkpoint() is not necessary")
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs)
return (None, None) + grads
@@ -169,7 +169,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
fwd_current_mode = get_current_mode()
# check if use autocast
if hasattr(torch, 'is_autocast_enabled'):
if hasattr(torch, "is_autocast_enabled"):
has_autocast_in_fwd = torch.is_autocast_enabled()
else:
has_autocast_in_fwd = False
@@ -179,7 +179,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
weak_holder_list = []
# class for weakref.ref
class Holder():
class Holder:
pass
# return a Holder object for later unpack process
@@ -226,19 +226,20 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
# rerun forward, the inner_pack will store all the activations in storage
if has_autocast_in_fwd:
with torch.enable_grad(), \
torch.cuda.amp.autocast(), \
torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
with torch.enable_grad(), torch.cuda.amp.autocast(), torch.autograd.graph.saved_tensors_hooks(
inner_pack, inner_unpack
):
_unused = function(*args)
else:
with torch.enable_grad(), \
torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
with torch.enable_grad(), torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
_unused = function(*args)
if x not in storage:
raise RuntimeError("Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
" recomputation being triggered in between, this is not currently supported. Please"
" open an issue with details on your use case so that we can prioritize adding this.")
raise RuntimeError(
"Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
" recomputation being triggered in between, this is not currently supported. Please"
" open an issue with details on your use case so that we can prioritize adding this."
)
return storage[x]

View File

@@ -1,3 +1,3 @@
from .module_checkpoint import load_checkpoint, save_checkpoint
__all__ = ['save_checkpoint', 'load_checkpoint']
__all__ = ["save_checkpoint", "load_checkpoint"]

View File

@@ -9,13 +9,15 @@ from colossalai.tensor import ColoTensor
from .utils import gather_tensor, scatter_tensor
def save_checkpoint(path: str,
epoch: int,
model: torch.nn.Module,
optimizer: Optional[OptimizerWrapper] = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
*args,
**kwargs):
def save_checkpoint(
path: str,
epoch: int,
model: torch.nn.Module,
optimizer: Optional[OptimizerWrapper] = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
*args,
**kwargs,
):
"""save_checkpoint
save a model, whose parameters are `ColoTensor`s.
Args:
@@ -30,7 +32,7 @@ def save_checkpoint(path: str,
# save the dist context about the tensors in a new dict, while still maintain the original dict.
for k, v in model_state.items():
if isinstance(v, ColoTensor):
gather_tensor(v) # gather shared tensors to rank0
gather_tensor(v) # gather shared tensors to rank0
# don't recover tensors in rank0, since the dict is only a copy of model
if rank == 0:
@@ -39,10 +41,10 @@ def save_checkpoint(path: str,
if isinstance(v, ColoTensor):
assert v.save_ready
assert v.is_replicate()
delattr(v, 'save_ready')
delattr(v, "save_ready")
# model saving
save_state = {'epoch': epoch, 'model': model_state}
torch.save(save_state, path + '/epoch_{}_model.pth'.format(epoch), *args, **kwargs)
save_state = {"epoch": epoch, "model": model_state}
torch.save(save_state, path + "/epoch_{}_model.pth".format(epoch), *args, **kwargs)
# delete old dicts
del model_state
@@ -52,35 +54,37 @@ def save_checkpoint(path: str,
if optimizer is not None:
mapping = dict()
optim_state = optimizer.state_dict()
for k, v in optim_state['state'].items():
for k, v in optim_state["state"].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
mapping[(k, n)] = t.dist_spec
gather_tensor(t)
if rank == 0:
save_state = {'epoch': epoch, 'optim': optim_state}
torch.save(save_state, path + '/epoch_{}_optim.pth'.format(epoch), *args, **kwargs)
save_state = {"epoch": epoch, "optim": optim_state}
torch.save(save_state, path + "/epoch_{}_optim.pth".format(epoch), *args, **kwargs)
# recover colo tensors in rank0
for k, v in optimizer.state_dict()['state'].items():
for k, v in optimizer.state_dict()["state"].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
assert hasattr(t, 'save_ready')
assert hasattr(t, "save_ready")
t.set_dist_spec(mapping[(k, n)])
delattr(t, 'save_ready')
delattr(t, "save_ready")
del optim_state
del mapping
dist.barrier()
def load_checkpoint(path: str,
epoch: int,
model: torch.nn.Module,
optimizer: Optional[OptimizerWrapper] = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
torch_load_kwargs: Optional[Dict] = None,
load_state_dict_kwargs: Optional[Dict] = None):
def load_checkpoint(
path: str,
epoch: int,
model: torch.nn.Module,
optimizer: Optional[OptimizerWrapper] = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
torch_load_kwargs: Optional[Dict] = None,
load_state_dict_kwargs: Optional[Dict] = None,
):
"""load_checkpoint
load a model, whose parameters are `ColoTensor`s.
Args:
@@ -106,8 +110,8 @@ def load_checkpoint(path: str,
gather_tensor(p)
if rank == 0:
load_state = torch.load(path + '/epoch_{}_model.pth'.format(epoch), **torch_load_kwargs)
model.load_state_dict(load_state['model'], **load_state_dict_kwargs)
load_state = torch.load(path + "/epoch_{}_model.pth".format(epoch), **torch_load_kwargs)
model.load_state_dict(load_state["model"], **load_state_dict_kwargs)
dist.barrier()
# scatter loaded parameters
@@ -115,24 +119,24 @@ def load_checkpoint(path: str,
if isinstance(p, ColoTensor):
scatter_tensor(p, mapping[n])
if rank == 0:
assert hasattr(p, 'save_ready')
delattr(p, 'save_ready')
assert hasattr(p, "save_ready")
delattr(p, "save_ready")
del mapping
if optimizer is not None:
mapping = dict()
for k, v in optimizer.state_dict()['state'].items():
for k, v in optimizer.state_dict()["state"].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
mapping[(k, n)] = t.dist_spec
gather_tensor(t)
if rank == 0:
colo_checkpoint = torch.load(path + '/epoch_{}_optim.pth'.format(epoch), **torch_load_kwargs)
optimizer.load_state_dict(colo_checkpoint['optim'], **load_state_dict_kwargs)
colo_checkpoint = torch.load(path + "/epoch_{}_optim.pth".format(epoch), **torch_load_kwargs)
optimizer.load_state_dict(colo_checkpoint["optim"], **load_state_dict_kwargs)
dist.barrier()
for k, v in optimizer.state_dict()['state'].items():
for k, v in optimizer.state_dict()["state"].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
scatter_tensor(t, mapping[(k, n)])

View File

@@ -8,7 +8,7 @@ from colossalai.tensor import ColoTensor
def robust_broadcast(tensor):
with torch.no_grad():
is_cpu_ten = tensor.device.type == 'cpu'
is_cpu_ten = tensor.device.type == "cpu"
if is_cpu_ten:
b_data = tensor.cuda()
else:
@@ -21,8 +21,7 @@ def robust_broadcast(tensor):
def gather_tensor(colo_tensor: ColoTensor) -> None:
"""Make colo_tensor replicated when the rank is 0
"""
"""Make colo_tensor replicated when the rank is 0"""
if not colo_tensor.is_replicate():
pg = colo_tensor.get_process_group()
# for the group which contains rank 0
@@ -36,12 +35,11 @@ def gather_tensor(colo_tensor: ColoTensor) -> None:
dist.barrier()
if dist.get_rank() == 0:
setattr(colo_tensor, 'save_ready', True) # set saving signature
setattr(colo_tensor, "save_ready", True) # set saving signature
def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
"""Reversal operation of `gather_tensor`.
"""
"""Reversal operation of `gather_tensor`."""
if dist_spec.placement == DistPlacementPattern.REPLICATE:
robust_broadcast(colo_tensor.data)
else:
@@ -57,7 +55,8 @@ def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
colo_tensor.set_dist_spec(dist_spec)
else:
rep_tensor = ColoTensor(
entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec))
entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec)
)
rep_tensor.set_dist_spec(dist_spec)
with torch.no_grad():
colo_tensor.data.copy_(rep_tensor.data)

View File

@@ -11,7 +11,7 @@ from colossalai.legacy.core import global_context as gpc
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
from .common import is_using_pp
@@ -25,10 +25,9 @@ def broadcast_state_dict(state_dict, parallel_mode):
return state_dict[0]
def partition_tensor_parallel_state_dict(state_dict: OrderedDict,
parallel_mode: ParallelMode,
dims: dict = dict(),
partition_states: dict = dict()):
def partition_tensor_parallel_state_dict(
state_dict: OrderedDict, parallel_mode: ParallelMode, dims: dict = dict(), partition_states: dict = dict()
):
src_rank = gpc.get_ranks_in_group(parallel_mode)[0]
depth = gpc.get_world_size(parallel_mode)
group = gpc.get_cpu_group(parallel_mode)
@@ -65,11 +64,11 @@ def partition_tensor_parallel_state_dict(state_dict: OrderedDict,
def gather_tensor_parallel_state_dict(
state_dict: OrderedDict,
parallel_mode: ParallelMode,
dims: dict = dict(),
partition_states: dict = dict(),
keep_vars: bool = False,
state_dict: OrderedDict,
parallel_mode: ParallelMode,
dims: dict = dict(),
partition_states: dict = dict(),
keep_vars: bool = False,
):
dst_rank = gpc.get_ranks_in_group(parallel_mode)[0]
depth = gpc.get_world_size(parallel_mode)
@@ -138,8 +137,11 @@ def partition_pipeline_parallel_state_dict(model, state_dict):
def gather_pipeline_parallel_state_dict(state_dict):
gathered_states = ([None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))]
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else None)
gathered_states = (
[None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))]
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0
else None
)
dist.gather_object(
state_dict,
gathered_states,
@@ -147,18 +149,23 @@ def gather_pipeline_parallel_state_dict(state_dict):
group=gpc.get_cpu_group(ParallelMode.PIPELINE),
)
state_dict = (OrderedDict(chain.from_iterable(state.items() for state in gathered_states))
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else OrderedDict())
state_dict = (
OrderedDict(chain.from_iterable(state.items() for state in gathered_states))
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0
else OrderedDict()
)
return state_dict
def save_checkpoint(file,
epoch: int,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
**kwargs):
def save_checkpoint(
file,
epoch: int,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
**kwargs,
):
"""Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer,
lr_scheduler etc. into a checkpoint dictionary.
@@ -196,8 +203,11 @@ def broadcast_model(model: torch.nn.Module):
src_rank = gpc.get_ranks_in_group(ParallelMode.TENSOR)[0]
for p in model.parameters():
if not getattr(p, IS_TENSOR_PARALLEL, False) and p.storage().size() > 0:
group = gpc.get_group(ParallelMode.TENSOR) if p.device.type == 'cuda' else gpc.get_cpu_group(
ParallelMode.TENSOR)
group = (
gpc.get_group(ParallelMode.TENSOR)
if p.device.type == "cuda"
else gpc.get_cpu_group(ParallelMode.TENSOR)
)
dist.broadcast(p, src_rank, group=group)
@@ -226,8 +236,9 @@ def load_checkpoint(
Raises:
RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated
"""
state_dict = (torch.load(file, map_location=torch.device("cpu"))
if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None)
state_dict = (
torch.load(file, map_location=torch.device("cpu")) if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None
)
# model states
model_state = state_dict.pop("model") if state_dict is not None else dict()
@@ -246,8 +257,11 @@ def load_checkpoint(
dist.gather_object(error_msgs, all_error_msgs, dst=dst_rank, group=gpc.get_cpu_group(ParallelMode.MODEL))
if gpc.get_global_rank() == 0:
all_error_msgs = list(chain.from_iterable(all_error_msgs))
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(
model.__class__.__name__, "\n\t".join(all_error_msgs)))
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
model.__class__.__name__, "\n\t".join(all_error_msgs)
)
)
else:
raise e

View File

@@ -80,7 +80,6 @@ def is_using_sequence():
class model_branch_context(object):
def __enter__(self):
self.env_status = env.save()
@@ -98,16 +97,14 @@ def _calc_l2_norm(grads):
if fused_optim is None:
from colossalai.kernel.op_builder import FusedOptimBuilder
fused_optim = FusedOptimBuilder().load()
norm = 0.0
if len(grads) > 0:
dummy_overflow_buf = torch.cuda.IntTensor([0])
norm, _ = multi_tensor_applier(
fused_optim.multi_tensor_l2norm,
dummy_overflow_buf,
[grads],
False # no per-parameter norm
fused_optim.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm
)
return norm
@@ -121,7 +118,7 @@ def _calc_lp(grads, norm_type):
def _move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
if torch.is_tensor(norm) and norm.device.type != 'cuda':
if torch.is_tensor(norm) and norm.device.type != "cuda":
norm = norm.to(torch.cuda.current_device())
return norm
@@ -141,11 +138,11 @@ def _compute_local_lp(params: List[ColoParameter], norm_type: float) -> float:
if len(params) == 0:
return 0.0
grads = [p.grad for p in params]
use_cuda_kernel = grads[0].device.type == 'cuda'
use_cuda_kernel = grads[0].device.type == "cuda"
if norm_type == inf:
local_lp = max([g.abs().max() for g in grads])
elif norm_type == 2.0 and use_cuda_kernel:
local_lp = _calc_l2_norm(grads)**norm_type
local_lp = _calc_l2_norm(grads) ** norm_type
else:
local_lp = _calc_lp(grads, norm_type)
if isinstance(local_lp, torch.Tensor):
@@ -202,8 +199,8 @@ def _compute_grad_lp(parameters, norm_type: float = 2.0) -> float:
assert isinstance(p, ColoParameter)
if grad_dtype is None:
grad_dtype = p.grad.dtype
assert p.grad.dtype == grad_dtype, f'Expected all grads are {grad_dtype}, got {p.grad.dtype}'
if p.grad.device.type == 'cuda':
assert p.grad.dtype == grad_dtype, f"Expected all grads are {grad_dtype}, got {p.grad.dtype}"
if p.grad.device.type == "cuda":
cuda_grad_params.append(p)
else:
cpu_grad_params.append(p)
@@ -221,7 +218,7 @@ def compute_grad_norm(parameters, norm_type: float = 2.0) -> float:
norm_type = float(norm_type)
total_norm = _compute_grad_lp(parameters, norm_type)
if norm_type != inf:
total_norm = total_norm**(1 / norm_type)
total_norm = total_norm ** (1 / norm_type)
return total_norm
@@ -235,14 +232,15 @@ def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None:
for p in parameters:
if p.grad is None:
continue
if p.grad.device.type == 'cuda':
if p.grad.device.type == "cuda":
cuda_grads.append(p.grad.detach())
else:
cpu_grads.append(p.grad.detach())
if len(cuda_grads) > 0:
dummy_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads],
clip_coef)
multi_tensor_applier(
fused_optim.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads], clip_coef
)
for g in cpu_grads:
g.mul_(clip_coef)
@@ -284,16 +282,17 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
for param in parameters:
if param.grad is not None:
# Make sure the grads are in fp32
assert param.grad.dtype == torch.float, \
f'expected gradient to be dtype torch.float, but got {param.grad.type()}'
if hasattr(param, 'colo_attr') and param.colo_attr.sharded_data_tensor.is_sharded:
assert (
param.grad.dtype == torch.float
), f"expected gradient to be dtype torch.float, but got {param.grad.type()}"
if hasattr(param, "colo_attr") and param.colo_attr.sharded_data_tensor.is_sharded:
has_zero_shared_param = True
params.append(param)
if len(params) == 0:
enable_cuda_kernels = False
else:
enable_cuda_kernels = params[0].grad.device.type == 'cuda'
enable_cuda_kernels = params[0].grad.device.type == "cuda"
# Norm parameters.
max_norm = float(max_norm)
norm_type = float(norm_type)
@@ -307,15 +306,13 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all model-parallel GPUs.
if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
dist.all_reduce(total_norm_cuda,
op=dist.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.MODEL),
async_op=False)
dist.all_reduce(
total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL), async_op=False
)
if has_zero_shared_param:
dist.all_reduce(total_norm_cuda,
op=dist.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.DATA),
async_op=False)
dist.all_reduce(
total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.DATA), async_op=False
)
total_norm = total_norm_cuda[0].item()
else:
tensor_parallel_grads = []
@@ -323,17 +320,17 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
zero_sharded_grads = []
for p in params:
if is_model_parallel_parameter(p):
reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type)
reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS)) ** (1 / norm_type)
tensor_parallel_grads.append(p.grad.data / reductor)
elif hasattr(p, 'colo_attr') and p.colo_attr.sharded_data_tensor.is_sharded:
elif hasattr(p, "colo_attr") and p.colo_attr.sharded_data_tensor.is_sharded:
zero_sharded_grads.append(p.grad.data)
else:
no_tensor_parallel_grads.append(p.grad.data)
if norm_type == 2.0 and enable_cuda_kernels:
tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type
no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type
zero_sharded_norm = _calc_l2_norm(zero_sharded_grads)**norm_type
tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads) ** norm_type
no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads) ** norm_type
zero_sharded_norm = _calc_l2_norm(zero_sharded_grads) ** norm_type
else:
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type)
@@ -358,7 +355,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
total_norm = tensor_parallel_norm + no_tensor_parallel_norm
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE))
total_norm = total_norm**(1.0 / norm_type)
total_norm = total_norm ** (1.0 / norm_type)
if torch.is_tensor(total_norm):
total_norm = total_norm.item()
@@ -397,13 +394,14 @@ def count_zeros_fp32(parameters):
# Sum across all model-parallel GPUs.
ops = []
ops.append(
dist.all_reduce(total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR), async_op=True))
dist.all_reduce(total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR), async_op=True)
)
if gpc.is_initialized(ParallelMode.PIPELINE):
ops.append(
dist.all_reduce(total_num_zeros,
op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PIPELINE),
async_op=True))
dist.all_reduce(
total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE), async_op=True
)
)
for req in ops:
req.wait()
@@ -420,8 +418,9 @@ def copy_tensor_parallel_attributes(src_tensor, dst_tensor):
def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL)) or (gpc.get_local_rank(
ParallelMode.TENSOR) == 0)
return (hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL)) or (
gpc.get_local_rank(ParallelMode.TENSOR) == 0
)
@contextmanager

View File

@@ -1,4 +1,4 @@
from .base_sampler import BaseSampler
from .data_parallel_sampler import DataParallelSampler, get_dataloader
__all__ = ['BaseSampler', 'DataParallelSampler', 'get_dataloader']
__all__ = ["BaseSampler", "DataParallelSampler", "get_dataloader"]

View File

@@ -5,7 +5,6 @@ from abc import ABC, abstractmethod
class BaseSampler(ABC):
def __init__(self, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size

View File

@@ -13,7 +13,7 @@ from torch.utils.data import DataLoader, Dataset, Sampler
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
T_co = TypeVar('T_co', covariant=True)
T_co = TypeVar("T_co", covariant=True)
class DataParallelSampler(Sampler):
@@ -44,11 +44,11 @@ class DataParallelSampler(Sampler):
self.num_samples = math.ceil(
# `type:ignore` is required because Dataset cannot provide a default __len__
# see NOTE in pytorch/torch/utils/data/sampler.py
(len(self.dataset) - self.num_replicas) / \
self.num_replicas # type: ignore[arg-type]
(len(self.dataset) - self.num_replicas)
/ self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.seed = seed
@@ -65,7 +65,7 @@ class DataParallelSampler(Sampler):
# set_epoch manually
self.epoch += 1
else:
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
if not self.drop_last:
# add extra samples to make it evenly divisible
@@ -76,11 +76,11 @@ class DataParallelSampler(Sampler):
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
indices = indices[: self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
@@ -99,14 +99,9 @@ class DataParallelSampler(Sampler):
self.epoch = epoch
def get_dataloader(dataset,
shuffle=False,
seed=1024,
add_sampler=True,
drop_last=False,
pin_memory=False,
num_workers=0,
**kwargs):
def get_dataloader(
dataset, shuffle=False, seed=1024, add_sampler=True, drop_last=False, pin_memory=False, num_workers=0, **kwargs
):
r"""Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
Note:
@@ -144,18 +139,22 @@ def get_dataloader(dataset,
random.seed(worker_seed)
if sampler is None:
return DataLoader(dataset,
worker_init_fn=seed_worker,
shuffle=shuffle,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs)
return DataLoader(
dataset,
worker_init_fn=seed_worker,
shuffle=shuffle,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)
else:
return DataLoader(dataset,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs)
return DataLoader(
dataset,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)

View File

@@ -76,8 +76,10 @@ def report_memory_usage(message, logger=None, report_cpu=False):
gpu_cached = _bytes_to_MB(torch.cuda.memory_reserved())
gpu_max_cached = _bytes_to_MB(torch.cuda.max_memory_reserved())
full_log = f"{message}: GPU: allocated {gpu_allocated} MB, max allocated {gpu_max_allocated} MB, " \
full_log = (
f"{message}: GPU: allocated {gpu_allocated} MB, max allocated {gpu_max_allocated} MB, "
+ f"cached: {gpu_cached} MB, max cached: {gpu_max_cached} MB"
)
if report_cpu:
# python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports
@@ -91,7 +93,7 @@ def report_memory_usage(message, logger=None, report_cpu=False):
logger.info(full_log)
# get the peak memory to report correct data, so reset the counter for the next call
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
torch.cuda.reset_peak_memory_stats()
@@ -106,10 +108,10 @@ def colo_device_memory_capacity(device: torch.device) -> int:
int: size in byte
"""
assert isinstance(device, torch.device)
if device.type == 'cpu':
if device.type == "cpu":
# In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory.
return colo_get_cpu_memory_capacity() / gpc.num_processes_on_current_node
if device.type == 'cuda':
if device.type == "cuda":
return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION
@@ -123,16 +125,16 @@ def colo_device_memory_used(device: torch.device) -> int:
Returns:
int: memory size in bytes
"""
if device.type == 'cpu':
if device.type == "cpu":
mem_info = _get_cpu_memory_info()
# In the context of 1-CPU-N-GPU, the memory usage of the current process is 1/N CPU memory used.
# Each process consumes the same amount of memory.
ret = mem_info.used / gpc.num_processes_on_current_node
return ret
elif device.type == 'cuda':
elif device.type == "cuda":
ret: int = torch.cuda.memory_allocated(device)
# get the peak memory to report correct data, so reset the counter for the next call
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
torch.cuda.reset_peak_memory_stats(device)
return ret
@@ -145,9 +147,9 @@ def colo_set_process_memory_fraction(ratio: float) -> None:
Args:
ratio (float): a ratio between 0. ~ 1.
"""
if version.parse(torch.__version__) < version.parse('1.8'):
logger = get_dist_logger('colo_set_process_memory_fraction')
logger.warning('colo_set_process_memory_fraction failed because torch version is less than 1.8')
if version.parse(torch.__version__) < version.parse("1.8"):
logger = get_dist_logger("colo_set_process_memory_fraction")
logger.warning("colo_set_process_memory_fraction failed because torch version is less than 1.8")
return
global _GLOBAL_CUDA_MEM_FRACTION
_GLOBAL_CUDA_MEM_FRACTION = ratio

View File

@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
class ProfilerExtension(ABC):
@abstractmethod
def prepare_trace(self):
pass

View File

@@ -3,4 +3,4 @@ from .mem_profiler import MemProfiler
from .pcie_profiler import PcieProfiler
from .prof_utils import BaseProfiler, ProfilerContext
__all__ = ['BaseProfiler', 'CommProfiler', 'PcieProfiler', 'MemProfiler', 'ProfilerContext']
__all__ = ["BaseProfiler", "CommProfiler", "PcieProfiler", "MemProfiler", "ProfilerContext"]

View File

@@ -20,14 +20,14 @@ def _get_code_location(depth: int):
upper_frame = inspect.stack()[i]
function_name = inspect.stack()[i - 1].function
ret.append(upper_frame.filename)
ret.append('(')
ret.append("(")
ret.append(str(upper_frame.lineno))
ret.append('): ')
ret.append("): ")
ret.append(function_name)
if i != length - 1:
ret.append('\n')
ret.append("\n")
return ''.join(ret)
return "".join(ret)
torch_all_reduce = dist.all_reduce
@@ -42,7 +42,7 @@ class CommEvent(object):
volume recording.
"""
def __init__(self, count: int = 0, comm_vol: float = 0., cuda_time: int = 0):
def __init__(self, count: int = 0, comm_vol: float = 0.0, cuda_time: int = 0):
self.self_count = count
self.self_comm_vol = comm_vol
self.self_cuda_time = cuda_time
@@ -54,8 +54,7 @@ class CommEvent(object):
class CommProfiler(BaseProfiler):
"""Communication profiler. Records all communication events.
"""
"""Communication profiler. Records all communication events."""
def __init__(self, depth: int = 0, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0):
super().__init__(profiler_name="Collective_Communication", priority=0)
@@ -114,8 +113,10 @@ class CommProfiler(BaseProfiler):
res.append(sep)
if self.warn_flag:
append("Warning: there exists multiple communication operations in the same time. As a result, "
"the profiling result is not accurate.")
append(
"Warning: there exists multiple communication operations in the same time. As a result, "
"the profiling result is not accurate."
)
if self.total_cuda_time == 0:
return "No collective communication has been called yet!"
@@ -126,24 +127,29 @@ class CommProfiler(BaseProfiler):
append("total number of calls: {}".format(self.total_count))
append("All events:")
separation = '-' * 74
row_format = '{:^10}' + '{:^12}' * 2 + '{:^16}' + '{:^12}' * 2
separation = "-" * 74
row_format = "{:^10}" + "{:^12}" * 2 + "{:^16}" + "{:^12}" * 2
append(separation)
append(row_format.format('Location', 'GPU time', 'Percentage', 'Comm volume', 'Bandwidth', 'Num of calls'))
append(row_format.format("Location", "GPU time", "Percentage", "Comm volume", "Bandwidth", "Num of calls"))
append(separation)
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time)
for location, event in show_list:
append(location)
append(
row_format.format('', _format_time(event.self_cuda_time),
'{:.1f}%'.format(event.self_cuda_time / self.total_cuda_time * 100.0),
_format_memory(event.self_comm_vol),
_format_bandwidth(event.self_comm_vol, event.self_cuda_time), event.self_count))
row_format.format(
"",
_format_time(event.self_cuda_time),
"{:.1f}%".format(event.self_cuda_time / self.total_cuda_time * 100.0),
_format_memory(event.self_comm_vol),
_format_bandwidth(event.self_comm_vol, event.self_cuda_time),
event.self_count,
)
)
append()
return ''.join(res)
return "".join(res)
@property
def has_aync_op(self):
@@ -195,8 +201,7 @@ class CommProfiler(BaseProfiler):
class CommHandler(object):
"""Communication handler. A dummy handler to wait aync operations.
"""
"""Communication handler. A dummy handler to wait aync operations."""
def __init__(self, profiler: CommProfiler):
super().__init__()
@@ -212,11 +217,9 @@ def async_check(profiler: CommProfiler):
profiler.wait_async_op()
def all_reduce(tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
group=None,
async_op: bool = False,
profiler: CommProfiler = None) -> Optional[CommHandler]:
def all_reduce(
tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, group=None, async_op: bool = False, profiler: CommProfiler = None
) -> Optional[CommHandler]:
async_check(profiler)
comm_size = dist.get_world_size(group)
@@ -231,12 +234,14 @@ def all_reduce(tensor: torch.Tensor,
profiler.close_profiler(group)
def reduce_scatter(output: torch.Tensor,
input_list: List[torch.Tensor],
op: ReduceOp = ReduceOp.SUM,
group=None,
async_op: bool = False,
profiler: CommProfiler = None) -> Optional[CommHandler]:
def reduce_scatter(
output: torch.Tensor,
input_list: List[torch.Tensor],
op: ReduceOp = ReduceOp.SUM,
group=None,
async_op: bool = False,
profiler: CommProfiler = None,
) -> Optional[CommHandler]:
async_check(profiler)
comm_size = dist.get_world_size(group)
@@ -254,11 +259,13 @@ def reduce_scatter(output: torch.Tensor,
profiler.close_profiler(group)
def all_gather(tensor_list: List[torch.Tensor],
tensor: torch.Tensor,
group=None,
async_op: bool = False,
profiler: CommProfiler = None) -> Optional[CommHandler]:
def all_gather(
tensor_list: List[torch.Tensor],
tensor: torch.Tensor,
group=None,
async_op: bool = False,
profiler: CommProfiler = None,
) -> Optional[CommHandler]:
async_check(profiler)
comm_size = dist.get_world_size(group)
@@ -276,11 +283,9 @@ def all_gather(tensor_list: List[torch.Tensor],
profiler.close_profiler(group)
def broadcast(tensor: torch.Tensor,
src: int,
group=None,
async_op: bool = False,
profiler: CommProfiler = None) -> Optional[CommHandler]:
def broadcast(
tensor: torch.Tensor, src: int, group=None, async_op: bool = False, profiler: CommProfiler = None
) -> Optional[CommHandler]:
async_check(profiler)
comm_vol = 1.0 * tensor.element_size() * tensor.numel()
@@ -293,12 +298,14 @@ def broadcast(tensor: torch.Tensor,
profiler.close_profiler(group)
def reduce(tensor: torch.Tensor,
dst: int,
op: ReduceOp = ReduceOp.SUM,
group=None,
async_op: bool = False,
profiler: CommProfiler = None) -> Optional[CommHandler]:
def reduce(
tensor: torch.Tensor,
dst: int,
op: ReduceOp = ReduceOp.SUM,
group=None,
async_op: bool = False,
profiler: CommProfiler = None,
) -> Optional[CommHandler]:
async_check(profiler)
comm_vol = 1.0 * tensor.element_size() * tensor.numel()

View File

@@ -18,6 +18,7 @@ def _get_size(dtype: str):
def _get_numel(my_list: List[int]) -> int:
from functools import reduce
from operator import mul
return reduce(mul, my_list)
@@ -27,12 +28,11 @@ def _reduce_location(locations: List[str]) -> str:
ret.append(lo)
ret.append("\n")
ret = ret[:-1]
return ''.join(ret)
return "".join(ret)
class PcieEvent(object):
"""Pcie Event.
"""
"""Pcie Event."""
def __init__(self, count: int = 0, pcie_vol: int = 0, cuda_time: int = 0):
self.count = count
@@ -73,12 +73,9 @@ class PcieProfiler(BaseProfiler):
self.profiler = None
def enable(self):
self.profiler = profile(enabled=True,
use_cuda=True,
use_cpu=True,
use_kineto=True,
record_shapes=True,
with_stack=True)
self.profiler = profile(
enabled=True, use_cuda=True, use_cpu=True, use_kineto=True, record_shapes=True, with_stack=True
)
self.profiler.__enter__()
def disable(self):
@@ -92,15 +89,15 @@ class PcieProfiler(BaseProfiler):
if len(t_shape) == 0 or event.cuda_time_total == 0 or len(event.stack) == 0:
continue
current_comm_event = PcieEvent(1, self.data_size * _get_numel(t_shape), event.cuda_time_total)
code_location = _reduce_location(event.stack[:self.depth])
code_location = _reduce_location(event.stack[: self.depth])
if code_location in self.ops_record:
self.ops_record[code_location].add(current_comm_event)
else:
self.ops_record[code_location] = current_comm_event
elif 'Memcpy HtoD' in event.name:
elif "Memcpy HtoD" in event.name:
self.h2d_count += 1
self.h2d_time += event.cuda_time_total
elif 'Memcpy DtoH' in event.name:
elif "Memcpy DtoH" in event.name:
self.d2h_count += 1
self.d2h_time += event.cuda_time_total
@@ -132,19 +129,25 @@ class PcieProfiler(BaseProfiler):
append("Possible data transmission events in PCIE:")
separation = '-' * 62
row_format = '{:^10}' + '{:^12}' + '{:^16}' + '{:^12}' * 2
separation = "-" * 62
row_format = "{:^10}" + "{:^12}" + "{:^16}" + "{:^12}" * 2
append(separation)
append(row_format.format('Location', 'GPU time', 'Trans volume', 'Bandwidth', 'Num of calls'))
append(row_format.format("Location", "GPU time", "Trans volume", "Bandwidth", "Num of calls"))
append(separation)
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time)
for location, event in show_list:
append(location)
append(
row_format.format('', _format_time(event.cuda_time), _format_memory(event.pcie_vol),
_format_bandwidth(event.pcie_vol, event.cuda_time), event.count))
row_format.format(
"",
_format_time(event.cuda_time),
_format_memory(event.pcie_vol),
_format_bandwidth(event.pcie_vol, event.cuda_time),
event.count,
)
)
append()
return ''.join(res)
return "".join(res)

View File

@@ -11,10 +11,10 @@ def _format_time(time_us):
US_IN_SECOND = 1000.0 * 1000.0
US_IN_MS = 1000.0
if time_us >= US_IN_SECOND:
return '{:.3f}s'.format(time_us / US_IN_SECOND)
return "{:.3f}s".format(time_us / US_IN_SECOND)
if time_us >= US_IN_MS:
return '{:.3f}ms'.format(time_us / US_IN_MS)
return '{:.3f}us'.format(time_us)
return "{:.3f}ms".format(time_us / US_IN_MS)
return "{:.3f}us".format(time_us)
# copied from high version pytorch to support low version
@@ -23,28 +23,27 @@ def _format_memory(nbytes):
KB = 1024
MB = 1024 * KB
GB = 1024 * MB
if (abs(nbytes) >= GB):
return '{:.2f} GB'.format(nbytes * 1.0 / GB)
elif (abs(nbytes) >= MB):
return '{:.2f} MB'.format(nbytes * 1.0 / MB)
elif (abs(nbytes) >= KB):
return '{:.2f} KB'.format(nbytes * 1.0 / KB)
if abs(nbytes) >= GB:
return "{:.2f} GB".format(nbytes * 1.0 / GB)
elif abs(nbytes) >= MB:
return "{:.2f} MB".format(nbytes * 1.0 / MB)
elif abs(nbytes) >= KB:
return "{:.2f} KB".format(nbytes * 1.0 / KB)
else:
return str(nbytes) + ' B'
return str(nbytes) + " B"
def _format_bandwidth(volume: float or int, time_us: int):
sec_div_mb = (1000.0 / 1024.0)**2
sec_div_mb = (1000.0 / 1024.0) ** 2
mb_per_sec = volume / time_us * sec_div_mb
if mb_per_sec >= 1024.0:
return '{:.3f} GB/s'.format(mb_per_sec / 1024.0)
return "{:.3f} GB/s".format(mb_per_sec / 1024.0)
else:
return '{:.3f} MB/s'.format(mb_per_sec)
return "{:.3f} MB/s".format(mb_per_sec)
class BaseProfiler(ABC):
def __init__(self, profiler_name: str, priority: int):
self.name = profiler_name
self.priority = priority
@@ -111,8 +110,9 @@ class ProfilerContext(object):
def to_tensorboard(self, writer):
from torch.utils.tensorboard import SummaryWriter
assert isinstance(writer, SummaryWriter), \
f'torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}.'
assert isinstance(
writer, SummaryWriter
), f"torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}."
for prof in self.profilers:
prof.to_tensorboard(writer)
@@ -124,7 +124,7 @@ class ProfilerContext(object):
if not log_dir.exists():
log_dir.mkdir(parents=True, exist_ok=True)
for prof in self.profilers:
log_file = log_dir.joinpath(f'{prof.name}_rank_{gpc.get_global_rank()}.log')
log_file = log_dir.joinpath(f"{prof.name}_rank_{gpc.get_global_rank()}.log")
prof.to_file(log_file)
def show(self):

View File

@@ -120,26 +120,30 @@ class profile(torch_profile):
p.step()
"""
def __init__(self,
*,
activities: Optional[Iterable[ProfilerActivity]] = None,
schedule: Optional[Callable[[int], ProfilerAction]] = None,
on_trace_ready: Optional[Callable[..., Any]] = None,
engine: Optional[Engine] = None,
record_shapes: bool = False,
profile_memory: bool = False,
with_stack: bool = False,
with_flops: bool = False,
with_modules: bool = False,
profile_stateful_tensor_memory: bool = False) -> None:
super().__init__(activities=activities,
schedule=schedule,
on_trace_ready=on_trace_ready,
record_shapes=record_shapes,
profile_memory=profile_memory,
with_stack=with_stack,
with_flops=with_flops,
with_modules=with_modules)
def __init__(
self,
*,
activities: Optional[Iterable[ProfilerActivity]] = None,
schedule: Optional[Callable[[int], ProfilerAction]] = None,
on_trace_ready: Optional[Callable[..., Any]] = None,
engine: Optional[Engine] = None,
record_shapes: bool = False,
profile_memory: bool = False,
with_stack: bool = False,
with_flops: bool = False,
with_modules: bool = False,
profile_stateful_tensor_memory: bool = False,
) -> None:
super().__init__(
activities=activities,
schedule=schedule,
on_trace_ready=on_trace_ready,
record_shapes=record_shapes,
profile_memory=profile_memory,
with_stack=with_stack,
with_flops=with_flops,
with_modules=with_modules,
)
self._logger = get_dist_logger()
self.extentions: List[ProfilerExtension] = []
if profile_stateful_tensor_memory:
@@ -149,9 +153,9 @@ class profile(torch_profile):
self.extentions.append(StatefulTensorMemoryProfilerExtention(engine))
def prepare_trace(self) -> None:
if hasattr(super(), 'prepare_trace'):
if hasattr(super(), "prepare_trace"):
super().prepare_trace()
elif hasattr(super(), '_start_warmup'):
elif hasattr(super(), "_start_warmup"):
super()._start_warmup()
for ext in self.extentions:
ext.prepare_trace()
@@ -160,9 +164,9 @@ class profile(torch_profile):
self.prepare_trace()
def start_trace(self):
if hasattr(super(), '_start_trace'):
if hasattr(super(), "_start_trace"):
super()._start_trace()
elif hasattr(super(), 'start_trace'):
elif hasattr(super(), "start_trace"):
super().start_trace()
for ext in self.extentions:
ext.start_trace()
@@ -171,9 +175,9 @@ class profile(torch_profile):
self.start_trace()
def stop_trace(self):
if hasattr(super(), '_stop_trace'):
if hasattr(super(), "_stop_trace"):
super()._stop_trace()
elif hasattr(super(), 'stop_trace'):
elif hasattr(super(), "stop_trace"):
super().stop_trace()
for ext in self.extentions:
ext.stop_trace()
@@ -186,15 +190,15 @@ class profile(torch_profile):
Exports the collected trace in Chrome JSON format.
"""
assert self.profiler
fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False)
fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=False)
fp.close()
retvalue = self.profiler.export_chrome_trace(fp.name)
with open(fp.name) as fin:
trace = json.load(fin)
for ext in self.extentions:
trace = ext.extend_chrome_trace(trace)
open_func = gzip.open if path.endswith('.gz') else open
with open_func(path, 'wt') as fout:
open_func = gzip.open if path.endswith(".gz") else open
with open_func(path, "wt") as fout:
json.dump(trace, fout)
os.remove(fp.name)

View File

@@ -22,11 +22,11 @@ def get_timestamp_us():
def generic_instant_event(name, pid, tid, timestamp, args):
return {'ph': 'i', 's': 't', 'name': name, 'pid': pid, 'tid': tid, 'ts': timestamp, 'args': args}
return {"ph": "i", "s": "t", "name": name, "pid": pid, "tid": tid, "ts": timestamp, "args": args}
class StatefulTensorMemoryEvent:
EVENT_NAME = '[statefulTensorMemory]'
EVENT_NAME = "[statefulTensorMemory]"
def __init__(self, timestamp: int, device_type: DeviceType, bytes_: int) -> None:
self.pid = os.getpid()
@@ -37,22 +37,23 @@ class StatefulTensorMemoryEvent:
self.bytes = bytes_
def state_dict(self):
return generic_instant_event(StatefulTensorMemoryEvent.EVENT_NAME, self.pid, self.tid, self.timestamp, {
'Device Type': self.device_type.value,
'Device Id': self.device_id,
'Bytes': self.bytes
})
return generic_instant_event(
StatefulTensorMemoryEvent.EVENT_NAME,
self.pid,
self.tid,
self.timestamp,
{"Device Type": self.device_type.value, "Device Id": self.device_id, "Bytes": self.bytes},
)
class StatefulTensorMemoryTracer:
def __init__(self) -> None:
self.events: List[StatefulTensorMemoryEvent] = []
self._tracing = False
def sample(self):
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
cuda_mem = StatefulTensor.GST_MGR.total_mem["cuda"]
cpu_mem = StatefulTensor.GST_MGR.total_mem["cpu"]
timestamp = get_timestamp_us()
if self._tracing:
self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CUDA, cuda_mem))
@@ -70,7 +71,6 @@ class StatefulTensorMemoryTracer:
class StatefulTensorMemoryTracerHook(BaseOpHook):
def __init__(self, tracer: StatefulTensorMemoryTracer):
super().__init__()
self.tracer = tracer
@@ -104,7 +104,6 @@ class StatefulTensorMemoryTracerHook(BaseOpHook):
class StatefulTensorMemoryProfilerExtention(ProfilerExtension):
def __init__(self, engine: Engine) -> None:
self.engine = engine
self.tracer = StatefulTensorMemoryTracer()
@@ -131,5 +130,5 @@ class StatefulTensorMemoryProfilerExtention(ProfilerExtension):
# self.hook_registered = False
def extend_chrome_trace(self, trace: dict) -> dict:
trace['traceEvents'].extend(self.tracer.state_dict())
trace["traceEvents"].extend(self.tracer.state_dict())
return trace