[colotensor] use cpu memory to store state_dict (#1367)

This commit is contained in:
HELSON
2022-07-26 14:13:38 +08:00
committed by GitHub
parent 943a96323e
commit 87775a0682
4 changed files with 26 additions and 5 deletions

View File

@@ -318,7 +318,8 @@ class ZeroDDP(ColoDDP):
self.chunk_manager.access_chunk(chunk)
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
if p is not None:
destination[prefix + name] = fp32_p.clone() if keep_vars else fp32_p.clone().detach()
rec_p = fp32_p.clone() if fp32_p.device.type == 'cpu' else fp32_p.cpu()
destination[prefix + name] = rec_p if keep_vars else rec_p.detach()
for chunk in chunks:
self.chunk_manager.release_chunk(chunk)
for name, buf in self.named_buffers():

View File

@@ -4,6 +4,20 @@ from colossalai.tensor import ColoTensor, ColoTensorSpec
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
def robust_broadcast(tensor):
with torch.no_grad():
is_cpu_ten = tensor.device.type == 'cpu'
if is_cpu_ten:
b_data = tensor.cuda()
else:
b_data = tensor
dist.broadcast(b_data, 0)
if is_cpu_ten:
tensor.copy_(b_data)
def gather_tensor(colo_tensor: ColoTensor) -> None:
"""Make colo_tensor replicated when the rank is 0
"""
@@ -27,7 +41,7 @@ def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
"""Reversal operation of `gather_tensor`.
"""
if dist_spec.placement == DistPlacementPattern.REPLICATE:
dist.broadcast(colo_tensor.data, 0)
robust_broadcast(colo_tensor.data)
else:
global_size = colo_tensor.size_global()
@@ -35,7 +49,7 @@ def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
entire_data = colo_tensor.data
else:
entire_data = torch.empty(global_size, device=colo_tensor.device)
dist.broadcast(entire_data, 0)
robust_broadcast(entire_data)
if dist.get_rank() == 0:
colo_tensor.set_dist_spec(dist_spec)