[zero] add colo move inline (#521)

This commit is contained in:
Jiarui Fang
2022-03-25 14:02:55 +08:00
committed by GitHub
parent 7be397ca9c
commit 920c5889a7
4 changed files with 39 additions and 8 deletions

View File

@@ -65,6 +65,34 @@ def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t
src_t.data = torch.tensor([], device=src_dev, dtype=src_t_payload.dtype)
def colo_model_data_tensor_move_inline(t: Union[ShardedTensor, torch.Tensor], target_device: torch.device) -> None:
"""
move a tensor to the target_device
Args:
t (Union[ShardedTensor, torch.Tensor]): the tensor be moved
"""
if isinstance(t, ShardedTensor):
t_payload = t.payload
elif isinstance(t, torch.Tensor):
t_payload = t
else:
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
assert isinstance(target_device, torch.device)
# deal with torch.device('cpu') and torch.device('cpu:0)
if t_payload.device.type == target_device.type:
return
if target_device.type == 'cuda':
GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload)
elif target_device.type == 'cpu':
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload)
t_payload.data = t_payload.data.to(target_device)
def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
"""colo_model_data_move_to_cpu