[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,5 +1,5 @@
from .data_parallel import ColoDDP
__all__ = [
'ColoDDP',
"ColoDDP",
]

View File

@@ -49,11 +49,13 @@ class ColoDDP(torch.nn.Module):
If it's None, the default data parallel group will be used. Defaults to None.
"""
def __init__(self,
module: torch.nn.Module,
process_group: ColoProcessGroup,
bucket_cap_mb: int = 25,
rebuild_bucket: bool = True) -> None:
def __init__(
self,
module: torch.nn.Module,
process_group: ColoProcessGroup,
bucket_cap_mb: int = 25,
rebuild_bucket: bool = True,
) -> None:
assert not isinstance(module, ColoDDP)
super().__init__()
self.module = module
@@ -74,19 +76,18 @@ class ColoDDP(torch.nn.Module):
def parameters(self, recurse: bool = True):
return self.module.parameters(recurse)
def named_parameters(self, prefix: str = '', recurse: bool = True):
def named_parameters(self, prefix: str = "", recurse: bool = True):
return self.module.named_parameters(prefix, recurse)
def named_buffers(self, prefix: str = '', recurse: bool = True):
def named_buffers(self, prefix: str = "", recurse: bool = True):
return self.module.named_buffers(prefix, recurse)
def named_children(self):
return self.module.named_children()
def named_modules(self,
memo: Optional[Set[torch.nn.Module]] = None,
prefix: str = '',
remove_duplicate: bool = True):
def named_modules(
self, memo: Optional[Set[torch.nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
):
return self.module.named_modules(memo, prefix, remove_duplicate)
def forward(self, *args, **kwargs):
@@ -114,9 +115,9 @@ class ColoDDP(torch.nn.Module):
grad = grad / self.dp_world_size
self.comm_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.comm_stream):
self.reducer.all_reduce_async(grad,
group=self.process_group.dp_process_group(),
callback_fn=partial(self._save_grad, p))
self.reducer.all_reduce_async(
grad, group=self.process_group.dp_process_group(), callback_fn=partial(self._save_grad, p)
)
grad.record_stream(self.comm_stream)
else:
ColoDDP._save_grad(p, grad)
@@ -130,7 +131,7 @@ class ColoDDP(torch.nn.Module):
@staticmethod
def _save_grad(p, grad):
if hasattr(p, '_saved_grad'):
if hasattr(p, "_saved_grad"):
p._saved_grad.add_(grad)
else:
p._saved_grad = grad
@@ -138,7 +139,7 @@ class ColoDDP(torch.nn.Module):
def zero_grad(self, set_to_none: bool = False) -> None:
self.module.zero_grad(set_to_none=True)
for p in self.module.parameters():
if getattr(p, '_saved_grad', None) is not None:
if getattr(p, "_saved_grad", None) is not None:
if set_to_none:
p._saved_grad = None
else:
@@ -167,8 +168,8 @@ class ColoDDP(torch.nn.Module):
for p in params_to_ignore:
p._ddp_to_ignore = True
def state_dict(self, destination=None, prefix='', keep_vars=False):
def state_dict(self, destination=None, prefix="", keep_vars=False):
return self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: bool = True):
return self.module.load_state_dict(state_dict, strict)

View File

@@ -14,8 +14,20 @@ from .linear import ColoLinear
from .module_utils import check_colo_module, get_colo_module, init_colo_module, is_colo_module, register_colo_module
__all__ = [
'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module',
'ColoLinear', 'ColoEmbedding', 'CachedEmbeddingBag', 'ParallelCachedEmbeddingBag', 'CachedParamMgr',
'LimitBuffIndexCopyer', 'EvictionStrategy', 'ParallelCachedEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig',
'ParallelCachedEmbeddingBagTablewiseSpiltCache'
"ColoModule",
"register_colo_module",
"is_colo_module",
"get_colo_module",
"init_colo_module",
"check_colo_module",
"ColoLinear",
"ColoEmbedding",
"CachedEmbeddingBag",
"ParallelCachedEmbeddingBag",
"CachedParamMgr",
"LimitBuffIndexCopyer",
"EvictionStrategy",
"ParallelCachedEmbeddingBagTablewise",
"TablewiseEmbeddingBagConfig",
"ParallelCachedEmbeddingBagTablewiseSpiltCache",
]

View File

@@ -7,7 +7,12 @@ from .parallel_cached_embedding_tablewise import ParallelCachedEmbeddingBagTable
from .parallel_cached_embedding_tablewise_split_cache import ParallelCachedEmbeddingBagTablewiseSpiltCache
__all__ = [
'CachedParamMgr', 'LimitBuffIndexCopyer', 'CachedEmbeddingBag', 'ParallelCachedEmbeddingBag', 'EvictionStrategy',
'ParallelCachedEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig',
'ParallelCachedEmbeddingBagTablewiseSpiltCache'
"CachedParamMgr",
"LimitBuffIndexCopyer",
"CachedEmbeddingBag",
"ParallelCachedEmbeddingBag",
"EvictionStrategy",
"ParallelCachedEmbeddingBagTablewise",
"TablewiseEmbeddingBagConfig",
"ParallelCachedEmbeddingBagTablewiseSpiltCache",
]

View File

@@ -4,17 +4,16 @@ import torch.nn as nn
class BaseEmbeddingBag(abc.ABC, nn.Module):
def __init__(
self,
num_embeddings,
embedding_dim,
padding_idx=None,
max_norm=None,
norm_type=2.,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
mode='mean',
mode="mean",
include_last_offset=False,
):
super(BaseEmbeddingBag, self).__init__()
@@ -22,9 +21,9 @@ class BaseEmbeddingBag(abc.ABC, nn.Module):
self.embedding_dim = embedding_dim
if padding_idx is not None:
if padding_idx > 0:
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
assert padding_idx < self.num_embeddings, "Padding_idx must be within num_embeddings"
elif padding_idx < 0:
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
assert padding_idx >= -self.num_embeddings, "Padding_idx must be within num_embeddings"
padding_idx = self.num_embeddings + padding_idx
self.padding_idx = padding_idx
self.max_norm = max_norm

View File

@@ -83,15 +83,16 @@ class CachedParamMgr(torch.nn.Module):
if self._async_copy:
self._memcpy_stream = torch.cuda.Stream()
print('use async copy')
print("use async copy")
if self._evict_strategy == EvictionStrategy.LFU:
# cache_row_idx -> frequency, freq of the cache rows.
# classic lfu cache. evict the minimal freq value row in cuda cache.
self.register_buffer("freq_cnter",
torch.empty(self.cuda_row_num, device=torch.cuda.current_device(),
dtype=torch.long).fill_(sys.maxsize),
persistent=False)
self.register_buffer(
"freq_cnter",
torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), dtype=torch.long).fill_(sys.maxsize),
persistent=False,
)
self._elapsed_dict = {}
self._show_cache_miss = True
self._reset_comm_stats()
@@ -142,10 +143,10 @@ class CachedParamMgr(torch.nn.Module):
if self.cuda_row_num > 0:
# Enable cache with introducing auxiliary data structures
self.cuda_cached_weight = torch.nn.Parameter(
torch.zeros(self.cuda_row_num,
self.embedding_dim,
device=torch.cuda.current_device(),
dtype=weight.dtype))
torch.zeros(
self.cuda_row_num, self.embedding_dim, device=torch.cuda.current_device(), dtype=weight.dtype
)
)
# pin memory cpu for higher CPU-GPU copy bandwidth
self.weight = weight.pin_memory() if self.pin_weight else weight
@@ -158,17 +159,19 @@ class CachedParamMgr(torch.nn.Module):
)
# cached_idx_map: gpu_row_idx -> cpu_row_idx
self.register_buffer("cached_idx_map",
torch.empty(self.cuda_row_num, device=torch.cuda.current_device(),
dtype=torch.long).fill_(-1),
persistent=False)
self.register_buffer(
"cached_idx_map",
torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), dtype=torch.long).fill_(-1),
persistent=False,
)
# cpu_row_id -> gpu_row_idx.
# gpu_row_idx as -1 means cpu_row_id not in CUDA.
self.register_buffer("inverted_cached_idx",
torch.zeros(self.num_embeddings, device=torch.cuda.current_device(),
dtype=torch.long).fill_(-1),
persistent=False)
self.register_buffer(
"inverted_cached_idx",
torch.zeros(self.num_embeddings, device=torch.cuda.current_device(), dtype=torch.long).fill_(-1),
persistent=False,
)
self.evict_backlist = torch.tensor([], device=torch.cuda.current_device())
@@ -191,9 +194,11 @@ class CachedParamMgr(torch.nn.Module):
torch.Tensor: a piece of memory in CPU weight corresponding to row id's payload. The tensor is 1-D.
"""
return self.weight.data.view(-1).narrow(0,
int(row_idx) * self.embedding_dim,
self.embedding_dim).view(1, self.embedding_dim)
return (
self.weight.data.view(-1)
.narrow(0, int(row_idx) * self.embedding_dim, self.embedding_dim)
.view(1, self.embedding_dim)
)
@property
def cuda_available_row_num(self):
@@ -238,15 +243,18 @@ class CachedParamMgr(torch.nn.Module):
preload_cpu_ids = torch.arange(preload_row_num)
preload_cuda_row_idxs = preload_cpu_ids.cuda()
if self.buffer_size > 0:
self.limit_buff_index_copyer.index_copy(0,
src_index=preload_cpu_ids,
tgt_index=preload_cuda_row_idxs,
src=self.weight.view(self.num_embeddings, -1),
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
self.limit_buff_index_copyer.index_copy(
0,
src_index=preload_cpu_ids,
tgt_index=preload_cuda_row_idxs,
src=self.weight.view(self.num_embeddings, -1),
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1),
)
else:
preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_cpu_ids).cuda()
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs,
preload_rows)
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(
0, preload_cuda_row_idxs, preload_rows
)
# update auxiliary info
self.cached_idx_map[preload_cuda_row_idxs] = preload_cpu_ids.cuda()
@@ -260,7 +268,7 @@ class CachedParamMgr(torch.nn.Module):
else:
self.freq_cnter[preload_cuda_row_idxs] = freq_value.cuda()
print(f'Cache warmup finished cost {timer.elapsed} sec.')
print(f"Cache warmup finished cost {timer.elapsed} sec.")
def flush(self):
"""flush all CUDA rows to CPU.
@@ -290,18 +298,18 @@ class CachedParamMgr(torch.nn.Module):
print(
f"CUDA->CPU BWD {self._cuda_to_cpu_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cuda_to_cpu_numel / 1e6} M elem"
)
print(f'cuda_to_cpu_elapse {elapsed} sec')
print(f"cuda_to_cpu_elapse {elapsed} sec")
if self._cpu_to_cuda_numel > 0 and "5_evict_in" in self._elapsed_dict:
elapsed = self._elapsed_dict["5_evict_in"]
print(
f"CPU->CUDA BWD {self._cpu_to_cuda_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cpu_to_cuda_numel / 1e6} M elem"
)
print(f'cpu_to_cuda_elapse {elapsed} sec')
print(f"cpu_to_cuda_elapse {elapsed} sec")
for k, v in self._elapsed_dict.items():
print(f'{k}: {v}')
print(f"{k}: {v}")
print(f'cache miss ratio {self._cache_miss / self._total_cache}')
print(f"cache miss ratio {self._cache_miss / self._total_cache}")
@torch.no_grad()
def _id_to_cached_cuda_id(self, ids: torch.Tensor) -> torch.Tensor:
@@ -336,10 +344,11 @@ class CachedParamMgr(torch.nn.Module):
else:
cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True)
assert len(cpu_row_idxs) <= self.cuda_row_num, \
f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " \
f"It is larger than the capacity of the cache, which at most contains {self.cuda_row_num} rows, " \
assert len(cpu_row_idxs) <= self.cuda_row_num, (
f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. "
f"It is larger than the capacity of the cache, which at most contains {self.cuda_row_num} rows, "
f"Please increase cuda_row_num or decrease the training batch size."
)
self.evict_backlist = cpu_row_idxs
tmp = torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)
comm_cpu_row_idxs = cpu_row_idxs[tmp]
@@ -386,8 +395,9 @@ class CachedParamMgr(torch.nn.Module):
# move evict in rows to gpu
if self._async_copy:
if self.buffer_size == 0:
evict_in_rows_gpu = self.weight.view(self.num_embeddings,
-1).index_select(0, cpu_row_idxs_copy).pin_memory()
evict_in_rows_gpu = (
self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy).pin_memory()
)
with torch.cuda.stream(self._memcpy_stream):
evict_in_rows_gpu = evict_in_rows_gpu.to(torch.cuda.current_device(), non_blocking=True)
else:
@@ -409,9 +419,10 @@ class CachedParamMgr(torch.nn.Module):
# move evict out rows to cpu
if self._async_copy:
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num,
-1).index_select(0, evict_gpu_row_idxs)
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True)
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(
0, evict_gpu_row_idxs
)
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device="cpu", pin_memory=True)
with torch.cuda.stream(None):
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
@@ -425,9 +436,10 @@ class CachedParamMgr(torch.nn.Module):
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
if self._async_copy:
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num,
-1).index_select(0, evict_gpu_row_idxs)
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True)
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(
0, evict_gpu_row_idxs
)
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device="cpu", pin_memory=True)
with torch.cuda.stream(None):
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
@@ -438,11 +450,13 @@ class CachedParamMgr(torch.nn.Module):
with self.timer("3_evict_out") as timer:
if self.buffer_size > 0:
self.limit_buff_index_copyer.index_copy(0,
src_index=evict_gpu_row_idxs,
tgt_index=evict_info.cpu(),
src=self.cuda_cached_weight.view(self.cuda_row_num, -1),
tgt=self.weight.view(self.num_embeddings, -1))
self.limit_buff_index_copyer.index_copy(
0,
src_index=evict_gpu_row_idxs,
tgt_index=evict_info.cpu(),
src=self.cuda_cached_weight.view(self.cuda_row_num, -1),
tgt=self.weight.view(self.num_embeddings, -1),
)
else:
# allocate tmp memory on CPU and copy rows on CUDA to CPU.
# TODO async gpu -> cpu
@@ -450,8 +464,9 @@ class CachedParamMgr(torch.nn.Module):
_wait_for_data(evict_out_rows_cpu, None)
else:
with self.timer("3_1_evict_out_index_select") as timer:
evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num,
-1).index_select(0, evict_gpu_row_idxs)
evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(
0, evict_gpu_row_idxs
)
with self.timer("3_2_evict_out_gpu_to_cpu_copy") as timer:
evict_out_rows_cpu = evict_out_rows_cpu.cpu()
@@ -469,17 +484,19 @@ class CachedParamMgr(torch.nn.Module):
# slots of cuda weight to evict in
with self.timer("4_identify_cuda_slot") as timer:
slots = torch.nonzero(self.cached_idx_map == -1).squeeze(1)[:cpu_row_idxs.numel()]
slots = torch.nonzero(self.cached_idx_map == -1).squeeze(1)[: cpu_row_idxs.numel()]
# TODO wait for optimize
with self.timer("5_evict_in") as timer:
# Here also allocate extra memory on CUDA. #cpu_row_idxs
if self.buffer_size > 0:
self.limit_buff_index_copyer.index_copy(0,
src_index=cpu_row_idxs_copy,
tgt_index=slots,
src=self.weight.view(self.num_embeddings, -1),
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
self.limit_buff_index_copyer.index_copy(
0,
src_index=cpu_row_idxs_copy,
tgt_index=slots,
src=self.weight.view(self.num_embeddings, -1),
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1),
)
else:
if self._async_copy:
_wait_for_data(evict_in_rows_gpu, self._memcpy_stream)
@@ -488,8 +505,9 @@ class CachedParamMgr(torch.nn.Module):
# narrow index select to a subset of self.weight
# tmp = torch.narrow(self.weight.view(self.num_embeddings, -1), 0, min(cpu_row_idxs).cpu(), max(cpu_row_idxs) - min(cpu_row_idxs) + 1)
# evict_in_rows_gpu = tmp.index_select(0, cpu_row_idxs_copy - min(cpu_row_idxs).cpu())
evict_in_rows_gpu = self.weight.view(self.num_embeddings,
-1).index_select(0, cpu_row_idxs_copy).pin_memory()
evict_in_rows_gpu = (
self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy).pin_memory()
)
with self.timer("5_2_evict_in_gpu_to_cpu_copy") as timer:
evict_in_rows_gpu = evict_in_rows_gpu.cuda()
@@ -537,8 +555,9 @@ class CachedParamMgr(torch.nn.Module):
self.cached_idx_map.index_copy_(0, idx, buf)
with Timer() as timer:
cuda_tensor = torch.narrow(self.cuda_cached_weight.view(-1), 0, max_offset * self.embedding_dim,
self.embedding_dim).view(1, self.embedding_dim)
cuda_tensor = torch.narrow(
self.cuda_cached_weight.view(-1), 0, max_offset * self.embedding_dim, self.embedding_dim
).view(1, self.embedding_dim)
self.cpu_weight_data(max_gpu_row_idx).data.copy_(cuda_tensor)
# update inverted_cached_idx, min_slot_id is evicted from cuda
@@ -570,8 +589,9 @@ class CachedParamMgr(torch.nn.Module):
slot_offset = slot_id
# copy payload from cpu to cuda
with Timer() as timer:
cuda_tensor = torch.narrow(self.cuda_cached_weight.view(-1), 0, slot_offset * self.embedding_dim,
self.embedding_dim).view(1, self.embedding_dim)
cuda_tensor = torch.narrow(
self.cuda_cached_weight.view(-1), 0, slot_offset * self.embedding_dim, self.embedding_dim
).view(1, self.embedding_dim)
cuda_tensor.data.copy_(self.cpu_weight_data(row_id))
# update the inverted_cached_idx

View File

@@ -36,27 +36,38 @@ class CachedEmbeddingBag(BaseEmbeddingBag):
evict_strategy (EvictionStrategy, optional): evict strategy of the software cache. Defaults to EvictionStrategy.DATASET.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
max_norm: float = None,
norm_type: float = 2.,
scale_grad_by_freq: bool = False,
sparse: bool = False,
_weight: Optional[torch.Tensor] = None,
mode: str = 'mean',
include_last_offset: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
cache_ratio: float = 0.01,
ids_freq_mapping: Optional[Union[List, torch.Tensor]] = None,
warmup_ratio: float = 0.7,
buffer_size: int = 0,
pin_weight: bool = False,
evict_strategy: EvictionStrategy = EvictionStrategy.LFU):
super(CachedEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type,
scale_grad_by_freq, sparse, mode, include_last_offset)
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
max_norm: float = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
_weight: Optional[torch.Tensor] = None,
mode: str = "mean",
include_last_offset: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
cache_ratio: float = 0.01,
ids_freq_mapping: Optional[Union[List, torch.Tensor]] = None,
warmup_ratio: float = 0.7,
buffer_size: int = 0,
pin_weight: bool = False,
evict_strategy: EvictionStrategy = EvictionStrategy.LFU,
):
super(CachedEmbeddingBag, self).__init__(
num_embeddings,
embedding_dim,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse,
mode,
include_last_offset,
)
assert cache_ratio <= 1.0, f"cache ratio {cache_ratio} must less than 1.0"
self.evict_strategy = evict_strategy
@@ -78,13 +89,15 @@ class CachedEmbeddingBag(BaseEmbeddingBag):
weight[self.padding_idx].fill_(0)
return weight
def _preprocess(self,
weight,
cuda_row_num: int,
ids_freq_mapping: Optional[List[int]] = None,
warmup_ratio=0.7,
buffer_size=50_000,
pin_weight=False):
def _preprocess(
self,
weight,
cuda_row_num: int,
ids_freq_mapping: Optional[List[int]] = None,
warmup_ratio=0.7,
buffer_size=50_000,
pin_weight=False,
):
"""
Called after initialized.
Reorder the weight rows according to the ids_freq_mapping.
@@ -95,11 +108,9 @@ class CachedEmbeddingBag(BaseEmbeddingBag):
ids_freq_mapping (List[int]): a list, idx is id number, value is freq
warmup_ratio (float): the amount of rows preloaded in cuda cache
"""
self.cache_weight_mgr = CachedParamMgr(weight,
cuda_row_num,
buffer_size,
pin_weight,
evict_strategy=self.evict_strategy)
self.cache_weight_mgr = CachedParamMgr(
weight, cuda_row_num, buffer_size, pin_weight, evict_strategy=self.evict_strategy
)
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None):
@@ -107,9 +118,19 @@ class CachedEmbeddingBag(BaseEmbeddingBag):
with torch.no_grad():
input = self.cache_weight_mgr.prepare_ids(input)
embeddings = F.embedding_bag(input.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
per_sample_weights, self.include_last_offset, self.padding_idx)
embeddings = F.embedding_bag(
input.cuda(),
self.cache_weight_mgr.cuda_cached_weight,
offsets,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.mode,
self.sparse,
per_sample_weights,
self.include_last_offset,
self.padding_idx,
)
if shape_hook is not None:
embeddings = shape_hook(embeddings)
return embeddings
@@ -118,8 +139,8 @@ class CachedEmbeddingBag(BaseEmbeddingBag):
def weight(self):
return self.cache_weight_mgr.weight
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
yield 'weight', self.cache_weight_mgr.cuda_cached_weight
def named_parameters(self, prefix: str = "", recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
yield "weight", self.cache_weight_mgr.cuda_cached_weight
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
yield self.cache_weight_mgr.cuda_cached_weight
@@ -127,8 +148,7 @@ class CachedEmbeddingBag(BaseEmbeddingBag):
def set_cache_op(self, cache_op: bool = True):
self.cache_op = cache_op
############################# Perf Log ###################################
############################# Perf Log ###################################
@property
def num_hits_history(self):
@@ -145,14 +165,22 @@ class CachedEmbeddingBag(BaseEmbeddingBag):
@property
def swap_in_bandwidth(self):
if self.cache_weight_mgr._cpu_to_cuda_numel > 0:
return self.cache_weight_mgr._cpu_to_cuda_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \
self.cache_weight_mgr._cpu_to_cuda_elapse
return (
self.cache_weight_mgr._cpu_to_cuda_numel
* self.cache_weight_mgr.elem_size_in_byte
/ 1e6
/ self.cache_weight_mgr._cpu_to_cuda_elapse
)
else:
return 0
@property
def swap_out_bandwidth(self):
if self.cache_weight_mgr._cuda_to_cpu_numel > 0:
return self.cache_weight_mgr._cuda_to_cpu_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \
self.cache_weight_mgr._cuda_to_cpu_elapse
return (
self.cache_weight_mgr._cuda_to_cpu_numel
* self.cache_weight_mgr.elem_size_in_byte
/ 1e6
/ self.cache_weight_mgr._cuda_to_cpu_elapse
)
return 0

View File

@@ -39,7 +39,7 @@ class LimitBuffIndexCopyer(object):
for begin_pos in range(0, dim_size, self._buff_size):
cur_len = min(self._buff_size, dim_size - begin_pos)
src_idx_piece = src_index.narrow(0, begin_pos, cur_len)
if src_device.type == 'cpu' and tgt_device.type == 'cuda':
if src_device.type == "cpu" and tgt_device.type == "cuda":
cpu_tmp_buffer = src.index_select(dim, src_idx_piece).pin_memory()
tmp_buffer = torch.empty_like(cpu_tmp_buffer, device=tgt_device)
tmp_buffer.copy_(cpu_tmp_buffer)

View File

@@ -2,22 +2,24 @@ import torch
class TablewiseEmbeddingBagConfig:
'''
"""
example:
def prepare_tablewise_config(args, cache_ratio, ...):
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = []
...
return embedding_bag_config_list
'''
"""
def __init__(self,
num_embeddings: int,
cuda_row_num: int,
assigned_rank: int = 0,
buffer_size=50_000,
ids_freq_mapping=None,
initial_weight: torch.tensor = None,
name: str = ""):
def __init__(
self,
num_embeddings: int,
cuda_row_num: int,
assigned_rank: int = 0,
buffer_size=50_000,
ids_freq_mapping=None,
initial_weight: torch.tensor = None,
name: str = "",
):
self.num_embeddings = num_embeddings
self.cuda_row_num = cuda_row_num
self.assigned_rank = assigned_rank

View File

@@ -1,13 +1,13 @@
from typing import Iterator, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
import torch.nn.functional as F
from colossalai.legacy.nn._ops._utils import dual_all_to_all
from colossalai.legacy.tensor import ColoTensorSpec, ComputePattern, ProcessGroup, ShardSpec
from colossalai.tensor import ColoParameter, ColoTensor
from colossalai.tensor import ColoTensor
from .cache_mgr import CachedParamMgr, EvictionStrategy
from .cache_mgr import EvictionStrategy
from .cached_embedding import CachedEmbeddingBag
@@ -15,9 +15,9 @@ def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
if world_size == 1:
return 0, embedding_dim, True
assert embedding_dim >= world_size, \
f"Embedding dimension {embedding_dim} must be larger than the world size " \
f"{world_size} of the process group"
assert embedding_dim >= world_size, (
f"Embedding dimension {embedding_dim} must be larger than the world size " f"{world_size} of the process group"
)
chunk_size = embedding_dim // world_size
threshold = embedding_dim % world_size
# if embedding dim is divisible by world size
@@ -31,37 +31,55 @@ def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
class ParallelCachedEmbeddingBag(CachedEmbeddingBag):
def __init__(self,
num_embeddings,
embedding_dim,
padding_idx=None,
max_norm=None,
norm_type=2.,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
mode='mean',
include_last_offset=False,
dtype=None,
device=None,
cache_ratio=0.01,
ids_freq_mapping=None,
warmup_ratio=0.7,
buffer_size=50_000,
pin_weight=False,
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET):
def __init__(
self,
num_embeddings,
embedding_dim,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
mode="mean",
include_last_offset=False,
dtype=None,
device=None,
cache_ratio=0.01,
ids_freq_mapping=None,
warmup_ratio=0.7,
buffer_size=50_000,
pin_weight=False,
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET,
):
self.rank = torch.distributed.get_rank()
self.world_size = torch.distributed.get_world_size()
self.partition_start_index, self.partition_end_index, divisible = get_partition(
embedding_dim, self.rank, self.world_size)
embedding_dim, self.rank, self.world_size
)
self.embedding_dim_per_partition = self.partition_end_index - self.partition_start_index
super(ParallelCachedEmbeddingBag,
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping,
warmup_ratio, buffer_size, pin_weight, evict_strategy)
super(ParallelCachedEmbeddingBag, self).__init__(
num_embeddings,
embedding_dim,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse,
_weight,
mode,
include_last_offset,
dtype,
device,
cache_ratio,
ids_freq_mapping,
warmup_ratio,
buffer_size,
pin_weight,
evict_strategy,
)
self.cache_op = True
def _weight_alloc(self, dtype, device):
@@ -70,9 +88,11 @@ class ParallelCachedEmbeddingBag(CachedEmbeddingBag):
weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings)
if self.padding_idx is not None:
weight[self.padding_idx].fill_(0)
colo_tensor_spec = ColoTensorSpec(pg=ProcessGroup(tp_degree=self.world_size),
dist_attr=ShardSpec(dims=[-1], num_partitions=[self.world_size]),
compute_attr=ComputePattern.TP1D)
colo_tensor_spec = ColoTensorSpec(
pg=ProcessGroup(tp_degree=self.world_size),
dist_attr=ShardSpec(dims=[-1], num_partitions=[self.world_size]),
compute_attr=ComputePattern.TP1D,
)
return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec)
def forward(
@@ -87,15 +107,24 @@ class ParallelCachedEmbeddingBag(CachedEmbeddingBag):
if self.cache_op:
with torch.no_grad():
indices = self.cache_weight_mgr.prepare_ids(indices)
output_shard = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
per_sample_weights, self.include_last_offset, self.padding_idx)
output_shard = F.embedding_bag(
indices.cuda(),
self.cache_weight_mgr.cuda_cached_weight,
offsets,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.mode,
self.sparse,
per_sample_weights,
self.include_last_offset,
self.padding_idx,
)
if shape_hook is not None:
output_shard = shape_hook(output_shard)
output_full = dual_all_to_all(output_shard,
self.weight.get_process_group(),
scatter_dim=scatter_dim,
gather_dim=gather_dim)
output_full = dual_all_to_all(
output_shard, self.weight.get_process_group(), scatter_dim=scatter_dim, gather_dim=gather_dim
)
return output_full
def set_cache_op(self, cache_op: bool = True):
@@ -108,31 +137,33 @@ class ParallelCachedEmbeddingBag(CachedEmbeddingBag):
freeze: bool = True,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
mode: str = 'mean',
mode: str = "mean",
include_last_offset: bool = False,
cuda_row_num: int = 100_000,
ids_freq_mapping: Optional[List[int]] = None,
warmup_ratio: float = 0.7,
buffer_size: int = 0,
) -> 'ParallelCachedEmbeddingBag':
) -> "ParallelCachedEmbeddingBag":
rows, cols = embedding.shape
embedding_bag = cls(rows,
cols,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse,
embedding,
mode,
include_last_offset,
cuda_row_num=cuda_row_num,
ids_freq_mapping=ids_freq_mapping,
warmup_ratio=warmup_ratio,
buffer_size=buffer_size)
embedding_bag = cls(
rows,
cols,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse,
embedding,
mode,
include_last_offset,
cuda_row_num=cuda_row_num,
ids_freq_mapping=ids_freq_mapping,
warmup_ratio=warmup_ratio,
buffer_size=buffer_size,
)
embedding_bag.cache_weight_mgr.cuda_cached_weight.requires_grad_ = not freeze
return embedding_bag

View File

@@ -1,4 +1,3 @@
import time
from typing import List
import torch
@@ -19,24 +18,26 @@ class ParallelCachedEmbeddingBagTablewise(CachedEmbeddingBag):
Those parameters in TablewiseEmbeddingBagConfig are ignored: cuda_row_num, buffer_size, initial_weight.
"""
def __init__(self,
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig],
embedding_dim: int,
padding_idx=None,
max_norm=None,
norm_type=2.,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
mode='mean',
include_last_offset=False,
dtype=None,
device=None,
cache_ratio=0.01,
warmup_ratio=0.7,
buffer_size=50_000,
pin_weight=False,
evict_strategy: EvictionStrategy = EvictionStrategy.LFU):
def __init__(
self,
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig],
embedding_dim: int,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
mode="mean",
include_last_offset=False,
dtype=None,
device=None,
cache_ratio=0.01,
warmup_ratio=0.7,
buffer_size=50_000,
pin_weight=False,
evict_strategy: EvictionStrategy = EvictionStrategy.LFU,
):
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list]
@@ -62,11 +63,27 @@ class ParallelCachedEmbeddingBagTablewise(CachedEmbeddingBag):
break
self.cache_ratio = cache_ratio
# table-associate cache
cuda_row_num = int(cache_ratio * self.num_embeddings)
super(ParallelCachedEmbeddingBagTablewise,
self).__init__(self.num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping,
warmup_ratio, buffer_size, pin_weight, evict_strategy)
int(cache_ratio * self.num_embeddings)
super(ParallelCachedEmbeddingBagTablewise, self).__init__(
self.num_embeddings,
embedding_dim,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse,
_weight,
mode,
include_last_offset,
dtype,
device,
cache_ratio,
ids_freq_mapping,
warmup_ratio,
buffer_size,
pin_weight,
evict_strategy,
)
# for assigned tables reconnection:
self.idx_offset_list = []
@@ -96,7 +113,8 @@ class ParallelCachedEmbeddingBagTablewise(CachedEmbeddingBag):
# not recommanded. it takes time.
batch_size = (offsets.shape[0]) // self.global_tables_num
local_indices, local_offsets, local_per_sample_weights = self.split_along_rank(
batch_size, indices, offsets, per_sample_weights)
batch_size, indices, offsets, per_sample_weights
)
else:
# recommanded.
batch_size = (offsets.shape[0]) // len(self.assigned_table_list)
@@ -104,9 +122,19 @@ class ParallelCachedEmbeddingBagTablewise(CachedEmbeddingBag):
if self.cache_op:
with torch.no_grad():
indices = self.cache_weight_mgr.prepare_ids(local_indices)
local_output = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
local_per_sample_weights, self.include_last_offset, self.padding_idx)
local_output = F.embedding_bag(
indices.cuda(),
self.cache_weight_mgr.cuda_cached_weight,
local_offsets,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.mode,
self.sparse,
local_per_sample_weights,
self.include_last_offset,
self.padding_idx,
)
local_output = torch.cat(local_output.split(batch_size), 1)
remains = batch_size % self.world_size
scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)]
@@ -115,21 +143,19 @@ class ParallelCachedEmbeddingBagTablewise(CachedEmbeddingBag):
output_full = shape_hook(output_full)
return output_full
def split_along_rank(self,
batch_size,
indices: torch.Tensor,
offsets: torch.Tensor = None,
per_sample_weights=None):
'''
def split_along_rank(
self, batch_size, indices: torch.Tensor, offsets: torch.Tensor = None, per_sample_weights=None
):
"""
if input indices and offsets haven't been splitted along assigned rank, this function will do it.
it takes time. please consider splitting data during batch loading.
'''
"""
local_indices_list: List(torch.Tensor) = []
local_offsets_list: List(torch.Tensor) = []
if per_sample_weights != None:
local_per_sample_weights_list: List(torch.Tensor) = []
offset_pre_end = 0 # local_offsets trick
offset_pre_end = 0 # local_offsets trick
for i, handle_table in enumerate(self.assigned_table_list):
indices_start_position = offsets[batch_size * handle_table]
if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]):
@@ -138,7 +164,7 @@ class ParallelCachedEmbeddingBagTablewise(CachedEmbeddingBag):
else:
indices_end_position = offsets[batch_size * (handle_table + 1)]
# alternative approach: reduce malloc
'''
"""
# 1. local_indices_list:
local_indices = indices.narrow(0, indices_start_position, indices_end_position - indices_start_position)
torch.sub(local_indices, self.idx_offset_list[i], out=local_indices)
@@ -158,25 +184,29 @@ class ParallelCachedEmbeddingBagTablewise(CachedEmbeddingBag):
torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets)
offset_pre_end = offsets[batch_size * (handle_table + 1)] + offset_pre_end - temp_holder
local_offsets_list.append(local_offsets)
'''
"""
# 1. local_indices_list:
local_indices_list.append(
indices.narrow(0, indices_start_position,
indices_end_position - indices_start_position).sub(self.idx_offset_list[i]))
indices.narrow(0, indices_start_position, indices_end_position - indices_start_position).sub(
self.idx_offset_list[i]
)
)
# 2. local_offsets_list:
if i + 1 == len(self.assigned_table_list):
# till-the-end special case
if not self.include_last_offset:
local_offsets = offsets.narrow(0, batch_size * handle_table,
batch_size).add(offset_pre_end - offsets[batch_size *
(handle_table)])
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size).add(
offset_pre_end - offsets[batch_size * (handle_table)]
)
else:
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size +
1).add(offset_pre_end - offsets[batch_size * (handle_table)])
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + 1).add(
offset_pre_end - offsets[batch_size * (handle_table)]
)
local_offsets_list.append(local_offsets)
else:
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size +
1).add(offset_pre_end - offsets[batch_size * (handle_table)])
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + 1).add(
offset_pre_end - offsets[batch_size * (handle_table)]
)
offset_pre_end = local_offsets[-1]
local_offsets_list.append(local_offsets[:-1])
# 3. local_per_sample_weights_list:

View File

@@ -19,21 +19,23 @@ class ParallelCachedEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module):
every table assigned to this class instance is managed by a CachedEmbeddingBag.
"""
def __init__(self,
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig],
embedding_dim: int,
padding_idx=None,
max_norm=None,
norm_type=2.,
scale_grad_by_freq=False,
sparse=False,
mode='mean',
include_last_offset=False,
dtype=None,
device=None,
warmup_ratio=0.7,
pin_weight=False,
evict_strategy: EvictionStrategy = EvictionStrategy.LFU):
def __init__(
self,
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig],
embedding_dim: int,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
mode="mean",
include_last_offset=False,
dtype=None,
device=None,
warmup_ratio=0.7,
pin_weight=False,
evict_strategy: EvictionStrategy = EvictionStrategy.LFU,
):
super(ParallelCachedEmbeddingBagTablewiseSpiltCache, self).__init__()
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
@@ -56,24 +58,27 @@ class ParallelCachedEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module):
if config.assigned_rank != self.rank:
continue
self.cached_embedding_bag_list.append(
CachedEmbeddingBag(num_embeddings=config.num_embeddings,
embedding_dim=embedding_dim,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
_weight=config.initial_weight,
mode=mode,
include_last_offset=include_last_offset,
dtype=dtype,
device=device,
cuda_row_num=config.cuda_row_num,
ids_freq_mapping=config.ids_freq_mapping,
warmup_ratio=warmup_ratio,
buffer_size=config.buffer_size,
pin_weight=pin_weight,
evict_strategy=evict_strategy))
CachedEmbeddingBag(
num_embeddings=config.num_embeddings,
embedding_dim=embedding_dim,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
_weight=config.initial_weight,
mode=mode,
include_last_offset=include_last_offset,
dtype=dtype,
device=device,
cuda_row_num=config.cuda_row_num,
ids_freq_mapping=config.ids_freq_mapping,
warmup_ratio=warmup_ratio,
buffer_size=config.buffer_size,
pin_weight=pin_weight,
evict_strategy=evict_strategy,
)
)
# prepare list shape for all_to_all output
self.embedding_dim_per_rank = [0 for i in range(self.world_size)]
@@ -95,22 +100,26 @@ class ParallelCachedEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module):
indices_end_position = offsets[batch_size * (handle_table + 1)]
with record_function("part 2"):
# local_indices = indices[indices_start_position:indices_end_position] - self.global_tables_offsets[handle_table]
local_indices = indices.narrow(0, indices_start_position, indices_end_position -
indices_start_position).sub(self.global_tables_offsets[handle_table])
local_indices = indices.narrow(
0, indices_start_position, indices_end_position - indices_start_position
).sub(self.global_tables_offsets[handle_table])
if self.include_last_offset:
# local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1) + 1] - offsets[batch_size * (handle_table)]
local_offsets = offsets.narrow(0, batch_size * handle_table,
batch_size + 1).sub(offsets[batch_size * (handle_table)])
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + 1).sub(
offsets[batch_size * (handle_table)]
)
else:
# local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1)] - offsets[batch_size * (handle_table)]
local_offsets = offsets.narrow(0, batch_size * handle_table,
batch_size).sub(offsets[batch_size * (handle_table)])
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size).sub(
offsets[batch_size * (handle_table)]
)
local_per_sample_weights = None
if per_sample_weights != None:
local_per_sample_weights = per_sample_weights[indices_start_position:indices_end_position]
with record_function("(tablewise) tablewise forward"):
local_output_list.append(self.cached_embedding_bag_list[i](local_indices, local_offsets,
local_per_sample_weights))
local_output_list.append(
self.cached_embedding_bag_list[i](local_indices, local_offsets, local_per_sample_weights)
)
# get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim))
local_output = torch.cat(local_output_list, 1)

View File

@@ -5,7 +5,6 @@ from colossalai.legacy.tensor.distspec import _DistSpec
class ColoModule(object):
def __init__(self):
self._shard_params: List[str] = []
self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {}
@@ -13,18 +12,18 @@ class ColoModule(object):
def _register_shard_params(self, params: List[str]):
self._shard_params = params
def _register_allowed_patterns(self,
compute_pattern: ComputePattern,
dist_specs: Dict[str, _DistSpec],
mode='default'):
assert list(
dist_specs.keys()).sort() == self._shard_params.sort(), 'Every registered param should have dist_spec.'
def _register_allowed_patterns(
self, compute_pattern: ComputePattern, dist_specs: Dict[str, _DistSpec], mode="default"
):
assert (
list(dist_specs.keys()).sort() == self._shard_params.sort()
), "Every registered param should have dist_spec."
if not compute_pattern in self._allowed_patterns:
self._allowed_patterns[compute_pattern] = {}
self._allowed_patterns[compute_pattern][mode] = dist_specs
def _set_default(self, compute_pattern: ComputePattern, target_mode):
self._allowed_patterns[compute_pattern]['default'] = self._allowed_patterns[compute_pattern][target_mode]
self._allowed_patterns[compute_pattern]["default"] = self._allowed_patterns[compute_pattern][target_mode]
def has_compute_pattern(self, compute_pattern: ComputePattern):
return compute_pattern in self._allowed_patterns
@@ -33,10 +32,10 @@ class ColoModule(object):
assert self.has_compute_pattern(compute_pattern)
return self._allowed_patterns[compute_pattern]
def has_compute_pattern_with_mode(self, compute_pattern: ComputePattern, mode='default'):
def has_compute_pattern_with_mode(self, compute_pattern: ComputePattern, mode="default"):
return compute_pattern in self._allowed_patterns and mode in self._allowed_patterns[compute_pattern]
def get_dist_specs_with_mode(self, compute_pattern: ComputePattern, mode='default'):
def get_dist_specs_with_mode(self, compute_pattern: ComputePattern, mode="default"):
assert self.has_compute_pattern_with_mode(compute_pattern, mode)
return self._allowed_patterns[compute_pattern][mode]

View File

@@ -1,13 +1,12 @@
from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec
from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec
from .colo_module import ColoModule
class ColoEmbedding(ColoModule):
def __init__(self):
super(ColoEmbedding, self).__init__()
self._register_shard_params(['weight'])
self._register_shard_params(["weight"])
def register(self, compute_pattern, pg: ProcessGroup):
if not compute_pattern in self._allowed_patterns:
@@ -20,18 +19,18 @@ class ColoEmbedding(ColoModule):
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': ShardSpec([0], [pg.tp_world_size()]),
"weight": ShardSpec([0], [pg.tp_world_size()]),
},
mode='row',
mode="row",
)
# TP1D Col Linear
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': ShardSpec([-1], [pg.tp_world_size()]),
"weight": ShardSpec([-1], [pg.tp_world_size()]),
},
mode='col',
mode="col",
)
self._set_default(compute_pattern=_compute_pattern, target_mode='row')
self._set_default(compute_pattern=_compute_pattern, target_mode="row")

View File

@@ -1,13 +1,12 @@
from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec
from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec
from .colo_module import ColoModule
class ColoLinear(ColoModule):
def __init__(self):
super(ColoLinear, self).__init__()
self._register_shard_params(['weight', 'bias'])
self._register_shard_params(["weight", "bias"])
def register(self, compute_pattern, pg: ProcessGroup):
if not compute_pattern in self._allowed_patterns:
@@ -19,21 +18,15 @@ class ColoLinear(ColoModule):
_compute_pattern = ComputePattern.TP1D
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': ShardSpec([-1], [pg.tp_world_size()]),
'bias': None
},
mode='row',
dist_specs={"weight": ShardSpec([-1], [pg.tp_world_size()]), "bias": None},
mode="row",
)
# TP1D Col Linear
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': ShardSpec([0], [pg.tp_world_size()]),
'bias': ShardSpec([0], [pg.tp_world_size()])
},
mode='col',
dist_specs={"weight": ShardSpec([0], [pg.tp_world_size()]), "bias": ShardSpec([0], [pg.tp_world_size()])},
mode="col",
)
self._set_default(compute_pattern=_compute_pattern, target_mode='row')
self._set_default(compute_pattern=_compute_pattern, target_mode="row")

View File

@@ -2,7 +2,7 @@ from typing import Dict
import torch
from colossalai.legacy.tensor import ComputeSpec, ProcessGroup, distspec
from colossalai.legacy.tensor import ComputeSpec, ProcessGroup
from colossalai.tensor import ColoParameter
from . import ColoModule
@@ -41,7 +41,7 @@ def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True)
for param_name in param_names:
param = module.get_parameter(param_name)
if not isinstance(param, ColoParameter):
raise Exception(f'Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.')
raise Exception(f"Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.")
if param.has_compute_spec():
cur_compute_pattern = param.compute_spec.compute_pattern
if compute_pattern is None:
@@ -49,7 +49,8 @@ def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True)
else:
if cur_compute_pattern != compute_pattern:
raise Exception(
f'Invalid ColoParameter spec: Params in {module} have different compute_pattern.')
f"Invalid ColoParameter spec: Params in {module} have different compute_pattern."
)
else:
continue
@@ -57,7 +58,8 @@ def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True)
colo_module.register(compute_pattern, pg)
if not colo_module.has_compute_pattern(compute_pattern):
raise Exception(
f'Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed.')
f"Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed."
)
match_specs = False
allowed_specs = colo_module.get_dist_specs(compute_pattern)
@@ -77,17 +79,15 @@ def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True)
match_specs = True
break
if match_specs == False:
raise Exception(f'Invalid ColoParameter spec: Params in {module} are incorrectly sharded.')
raise Exception(f"Invalid ColoParameter spec: Params in {module} are incorrectly sharded.")
if recursive == True:
for submodule in module.children():
check_colo_module(submodule, pg=pg, recursive=True)
def init_colo_module(module: torch.nn.Module,
compute_spec: ComputeSpec,
pg: ProcessGroup,
recursive=True,
mode='default'):
def init_colo_module(
module: torch.nn.Module, compute_spec: ComputeSpec, pg: ProcessGroup, recursive=True, mode="default"
):
compute_pattern = compute_spec.compute_pattern
if is_colo_module(module):
# for each param

View File

@@ -13,7 +13,6 @@ from torch.distributed import ProcessGroup
class Bucket:
def __init__(self, size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup):
self.buffer = torch.zeros(size, dtype=dtype, device=device)
self.group = group
@@ -26,7 +25,7 @@ class Bucket:
assert len(self.callbacks) == 0
return
# reduce-scatter bucket
dist.all_reduce(self.buffer[:self.offset], group=self.group)
dist.all_reduce(self.buffer[: self.offset], group=self.group)
# execute post-reduction callbacks
for callback_fn in self.callbacks:
@@ -37,24 +36,22 @@ class Bucket:
self.buffer = torch.zeros_like(self.buffer)
def alloc(self) -> None:
if self.buffer.storage().size() == 0:
self.buffer.storage().resize_(self.buffer.numel())
def free(self) -> None:
assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown"
self.buffer.storage().resize_(0)
def append(self, tensor: Tensor, callback_fn: Callable):
tensor_size = tensor.numel()
offset = self.offset
self.buffer[offset:offset + tensor_size].copy_(tensor.flatten())
self.buffer[offset : offset + tensor_size].copy_(tensor.flatten())
self.offset += tensor_size
# callback will be given the reduced result
if callback_fn is not None:
result_view = self.buffer[offset:offset + tensor_size].view(tensor.shape)
result_view = self.buffer[offset : offset + tensor_size].view(tensor.shape)
self.callbacks.append(functools.partial(callback_fn, result_view))
@property
@@ -63,7 +60,6 @@ class Bucket:
class Reducer:
def __init__(self, bucket_size_mb: int = 25):
self.bucket_size_mb = bucket_size_mb
self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {}
@@ -101,7 +97,7 @@ class Reducer:
@functools.lru_cache()
def _get_bucket_size(self, element_size: int) -> int:
if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing.
if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing.
return 0
MB = 1024 * 1024
bucket_size = self.bucket_size_mb * MB / element_size