mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 23:18:36 +00:00
[gemini] accelerate adjust_layout() (#878)
* add lru cache * polish code * update unit test * fix sharded optim
This commit is contained in:
@@ -285,7 +285,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
shard_mem = self.master_params[p].payload.numel() * self.master_params[p].payload.element_size()
|
||||
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
|
||||
colo_model_data_tensor_move_inline(self.master_params[p], torch.cuda.current_device())
|
||||
p.grad.data = p.grad.data.to(torch.cuda.current_device())
|
||||
colo_model_data_tensor_move_inline(p.colo_attr.saved_grad, torch.cuda.current_device())
|
||||
p.colo_attr.offload_grad = False
|
||||
fp32_shards_used_cuda_margin_mem += shard_mem
|
||||
|
||||
@@ -297,7 +297,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
p.colo_attr.saved_grad.trans_state(TensorState.COMPUTE)
|
||||
# If reuse_fp16_shard, grad fp16 which wasn't be offloaded may be evicted to CPU
|
||||
if not p.colo_attr.offload_grad:
|
||||
colo_model_data_tensor_move_inline(p.colo_attr.grad_payload, torch.cuda.current_device())
|
||||
colo_model_data_tensor_move_inline(p.colo_attr.saved_grad, torch.cuda.current_device())
|
||||
# FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation
|
||||
# If we change p.grad directly
|
||||
# it may raise error because of different shape/dtype/device of p.data and p.grad
|
||||
|
@@ -114,5 +114,6 @@ class ZeroHook(BaseOpHook):
|
||||
def post_iter(self):
|
||||
if self._stateful_tensor_mgr:
|
||||
self.logger.info(
|
||||
f"CPU-GPU data moving this iteration {self._stateful_tensor_mgr.cpu_gpu_move_volume/1e9} GB", ranks=[0])
|
||||
f"CPU-GPU data moving this iteration {self._stateful_tensor_mgr.cpu_gpu_move_volume/1e9} GB, get layout info time: {self._stateful_tensor_mgr._layout_time}, evict cpu time: {self._stateful_tensor_mgr._evict_time}",
|
||||
ranks=[0])
|
||||
self._stateful_tensor_mgr.finish_iter()
|
||||
|
Reference in New Issue
Block a user