mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 11:08:50 +00:00
[embedding] polish async copy (#1657)
This commit is contained in:
@@ -15,6 +15,23 @@ class EvictionStrategy(Enum):
|
||||
DATASET = 2
|
||||
|
||||
|
||||
def _wait_for_data(t, stream: Optional[torch.cuda.streams.Stream]) -> None:
|
||||
if stream is None:
|
||||
return
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
# As mentioned in https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html,
|
||||
# PyTorch uses the "caching allocator" for memroy allocation for tensors. When a tensor is
|
||||
# freed, its memory is likely to be reused by newly constructed tenosrs. By default,
|
||||
# this allocator traces whether a tensor is still in use by only the CUDA stream where it
|
||||
# was created. When a tensor is used by additional CUDA streams, we need to call record_stream
|
||||
# to tell the allocator about all these streams. Otherwise, the allocator might free the
|
||||
# underlying memory of the tensor once it is no longer used by the creator stream. This is
|
||||
# a notable programming trick when we write programs using multi CUDA streams.
|
||||
cur_stream = torch.cuda.current_stream()
|
||||
assert isinstance(t, torch.Tensor)
|
||||
t.record_stream(cur_stream)
|
||||
|
||||
|
||||
class CachedParamMgr(torch.nn.Module):
|
||||
"""
|
||||
Manage Embedding Weights on CPU and CUDA memory uses a software cache.
|
||||
@@ -37,7 +54,7 @@ class CachedParamMgr(torch.nn.Module):
|
||||
weight: torch.Tensor,
|
||||
cuda_row_num: int = 0,
|
||||
buffer_size: int = 0,
|
||||
pin_weight: bool = False,
|
||||
pin_weight: bool = True,
|
||||
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET,
|
||||
async_copy: bool = False,
|
||||
) -> None:
|
||||
@@ -62,6 +79,8 @@ class CachedParamMgr(torch.nn.Module):
|
||||
self._async_copy = async_copy
|
||||
|
||||
if self._async_copy:
|
||||
self._memcpy_stream = torch.cuda.Stream()
|
||||
|
||||
print('use async copy')
|
||||
|
||||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
@@ -350,11 +369,10 @@ class CachedParamMgr(torch.nn.Module):
|
||||
# move evict in rows to gpu
|
||||
if self._async_copy:
|
||||
if self.buffer_size == 0:
|
||||
idxslt_stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(idxslt_stream):
|
||||
rows_cpu = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy).pin_memory()
|
||||
# evict_in_rows_gpu = torch.empty_like(rows_cpu, device=torch.cuda.current_device())
|
||||
# evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True)
|
||||
evict_in_rows_gpu = self.weight.view(self.num_embeddings,
|
||||
-1).index_select(0, cpu_row_idxs_copy).pin_memory()
|
||||
with torch.cuda.stream(self._memcpy_stream):
|
||||
evict_in_rows_gpu = evict_in_rows_gpu.to(torch.cuda.current_device(), non_blocking=True)
|
||||
else:
|
||||
raise NotImplemented
|
||||
|
||||
@@ -378,7 +396,8 @@ class CachedParamMgr(torch.nn.Module):
|
||||
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num,
|
||||
-1).index_select(0, evict_gpu_row_idxs)
|
||||
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True)
|
||||
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
|
||||
with torch.cuda.stream(None):
|
||||
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
|
||||
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
||||
|
||||
elif self._evict_strategy == EvictionStrategy.LFU:
|
||||
@@ -393,7 +412,8 @@ class CachedParamMgr(torch.nn.Module):
|
||||
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num,
|
||||
-1).index_select(0, evict_gpu_row_idxs)
|
||||
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True)
|
||||
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
|
||||
with torch.cuda.stream(None):
|
||||
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
|
||||
with self.timer("3_1_2_find_evict_index_copy") as timer:
|
||||
self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs)
|
||||
|
||||
@@ -410,7 +430,7 @@ class CachedParamMgr(torch.nn.Module):
|
||||
# allocate tmp memory on CPU and copy rows on CUDA to CPU.
|
||||
# TODO async gpu -> cpu
|
||||
if self._async_copy:
|
||||
pass
|
||||
_wait_for_data(evict_out_rows_cpu, None)
|
||||
else:
|
||||
with self.timer("3_2_1_evict_out_index_select") as timer:
|
||||
evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num,
|
||||
@@ -445,10 +465,7 @@ class CachedParamMgr(torch.nn.Module):
|
||||
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
||||
else:
|
||||
if self._async_copy:
|
||||
torch.cuda.current_stream().wait_stream(idxslt_stream)
|
||||
evict_in_rows_gpu = torch.empty_like(rows_cpu, device=torch.cuda.current_device())
|
||||
evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True)
|
||||
pass
|
||||
_wait_for_data(evict_in_rows_gpu, self._memcpy_stream)
|
||||
else:
|
||||
with self.timer("3_4_1_evict_in_index_select") as timer:
|
||||
# narrow index select to a subset of self.weight
|
||||
|
Reference in New Issue
Block a user