mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 13:11:05 +00:00
[zero] add has_inf_or_nan in AgChunk; enhance the unit test of AgChunk (#1426)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from typing import Optional, Dict
|
||||
from typing import Optional, Dict, List
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
@@ -45,10 +45,11 @@ class AgChunk:
|
||||
self.shard_size = chunk_size // self.pg_size
|
||||
self.shard_begin = self.shard_size * self.pg_rank
|
||||
self.shard_end = self.shard_begin + self.shard_size
|
||||
self.valid_end = self.shard_size
|
||||
|
||||
self.dtype = dtype
|
||||
device = init_device or get_current_device()
|
||||
self.chunk_temp = torch.empty(chunk_size, dtype=dtype, device=device)
|
||||
self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
|
||||
self.chunk_total = None # we force chunk_total located in CUDA
|
||||
self.cuda_shard = None # using two attributes for the better interpretation
|
||||
self.cpu_shard = None
|
||||
@@ -114,7 +115,7 @@ class AgChunk:
|
||||
if self.chunk_temp is not None:
|
||||
return self.chunk_temp.device.type
|
||||
else:
|
||||
if self.chunk_total is not None:
|
||||
if self.is_gathered:
|
||||
return 'cuda'
|
||||
elif self.cuda_shard is not None:
|
||||
return 'cuda'
|
||||
@@ -153,6 +154,12 @@ class AgChunk:
|
||||
# sanity check
|
||||
assert self.chunk_temp is not None
|
||||
|
||||
# calculate the valid end for each shard
|
||||
if self.utilized_size <= self.shard_begin:
|
||||
self.valid_end = 0
|
||||
elif self.utilized_size < self.shard_end:
|
||||
self.valid_end = self.utilized_size - self.shard_begin
|
||||
|
||||
if self.chunk_temp.device.type == 'cpu':
|
||||
self.chunk_total = self.chunk_temp.to(get_current_device())
|
||||
else:
|
||||
@@ -257,7 +264,7 @@ class AgChunk:
|
||||
self.shard_size, dtype=self.dtype, device=get_current_device())
|
||||
|
||||
input_list = list(torch.chunk(self.chunk_total, chunks=self.pg_size, dim=0))
|
||||
dist.reduce_scatter(self.cuda_shard, input_list, self.torch_pg)
|
||||
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
|
||||
|
||||
free_storage(self.chunk_total)
|
||||
self.is_gathered = False
|
||||
@@ -298,17 +305,38 @@ class AgChunk:
|
||||
assert self.is_gathered
|
||||
|
||||
tensor_info = self.tensors_info[tensor]
|
||||
self.chunk_total[tensor_info.offset:tensor_info.end].copy_(data_slice.flatten())
|
||||
self.chunk_total[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten())
|
||||
tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||
|
||||
@property
|
||||
def can_move(self) -> bool:
|
||||
return not self.is_gathered
|
||||
|
||||
@property
|
||||
def can_release(self) -> bool:
|
||||
return self.tensors_state_monitor[TensorState.HOLD] == self.num_tensors
|
||||
if self.keep_gathered:
|
||||
return False
|
||||
else:
|
||||
return self.tensors_state_monitor[TensorState.HOLD] + \
|
||||
self.tensors_state_monitor[TensorState.HOLD_AFTER_BWD] == self.num_tensors
|
||||
|
||||
@property
|
||||
def can_reduce(self):
|
||||
return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors
|
||||
|
||||
@property
|
||||
def has_inf_or_nan(self) -> bool:
|
||||
"""
|
||||
Check if the chunk has inf or nan values in CUDA.
|
||||
"""
|
||||
if self.is_gathered:
|
||||
valid_tensor = self.chunk_total[: self.utilized_size]
|
||||
else:
|
||||
assert self.cuda_shard is not None # only check in CUDA
|
||||
valid_tensor = self.cuda_shard[: self.valid_end]
|
||||
|
||||
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
|
||||
|
||||
def __gather(self):
|
||||
if not self.is_gathered:
|
||||
# sanity check
|
||||
@@ -375,6 +403,12 @@ class AgChunk:
|
||||
if prev_state is None or tensor_info.state == prev_state:
|
||||
self.__update_one_tensor_info(tensor_info, next_state)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(id(self))
|
||||
|
||||
def __eq__(self, __o: object) -> bool:
|
||||
return self is __o
|
||||
|
||||
def __repr__(self, detailed: bool = False):
|
||||
output = [
|
||||
"AgChunk Information:\n",
|
||||
@@ -413,3 +447,6 @@ class AgChunk:
|
||||
output.append("\t\t# of {}: {}\n".format(st, self.tensors_state_monitor[st]))
|
||||
|
||||
return ''.join(output)
|
||||
|
||||
def get_tensors(self) -> List[torch.Tensor]:
|
||||
return list(self.tensors_info.keys())
|
||||
|
Reference in New Issue
Block a user