mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-01 13:15:26 +00:00
[NFC] polish comments for Chunk class (#2116)
This commit is contained in:
parent
09d69e1c25
commit
e99edfcb51
@ -71,8 +71,9 @@ class Chunk:
|
|||||||
chunk_size (int): the number of elements in the chunk
|
chunk_size (int): the number of elements in the chunk
|
||||||
process_group (ColoProcessGroup): the process group of this chunk
|
process_group (ColoProcessGroup): the process group of this chunk
|
||||||
dtype (torch.dtype): the data type of the chunk
|
dtype (torch.dtype): the data type of the chunk
|
||||||
init_device (torch.device): optional, the device where the tensor is initialized
|
init_device (torch.device): optional, During the chunk construction process, where the tensor is stored.
|
||||||
The default value is None, which is the current GPU
|
The default value is None, which is the current GPU
|
||||||
|
cpu_shard_init (bool): a flag indicates the local chunk shard is resident on CPU.
|
||||||
keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory
|
keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory
|
||||||
pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory
|
pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory
|
||||||
"""
|
"""
|
||||||
@ -81,13 +82,12 @@ class Chunk:
|
|||||||
|
|
||||||
self.chunk_size = chunk_size
|
self.chunk_size = chunk_size
|
||||||
self.utilized_size = 0
|
self.utilized_size = 0
|
||||||
# Here, we use torch process group,
|
|
||||||
# since ColoProcessGroup might get deprecated soon
|
|
||||||
self.torch_pg = process_group.dp_process_group()
|
self.torch_pg = process_group.dp_process_group()
|
||||||
self.pg_size = dist.get_world_size(self.torch_pg)
|
self.pg_size = dist.get_world_size(self.torch_pg)
|
||||||
self.pg_rank = dist.get_rank(self.torch_pg)
|
self.pg_rank = dist.get_rank(self.torch_pg)
|
||||||
|
|
||||||
# the chunk size should be able to be divied by the size of GPU
|
# the chunk size should be divisible by the dp degree
|
||||||
if not keep_gathered:
|
if not keep_gathered:
|
||||||
assert chunk_size % self.pg_size == 0
|
assert chunk_size % self.pg_size == 0
|
||||||
self.shard_size = chunk_size // self.pg_size
|
self.shard_size = chunk_size // self.pg_size
|
||||||
@ -97,13 +97,21 @@ class Chunk:
|
|||||||
|
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
device = init_device or get_current_device()
|
device = init_device or get_current_device()
|
||||||
|
|
||||||
|
# chunk_temp is a global chunk, which only exists during building the chunks.
|
||||||
self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
|
self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
|
||||||
self.chunk_total = None # we force chunk_total located in CUDA
|
|
||||||
self.cuda_shard = None # using two attributes for the better interpretation
|
self.cuda_global_chunk = None # we force cuda_global_chunk located in CUDA
|
||||||
|
|
||||||
|
# cuda local chunk, which is sharded on GPUs
|
||||||
|
self.cuda_shard = None
|
||||||
|
# cpu local chunk, which is sharded on CPUs
|
||||||
self.cpu_shard = None
|
self.cpu_shard = None
|
||||||
|
# is the chunks gathers, which means chunks are duplicated on each process,
|
||||||
|
# and we should use the cuda_global_chunk.
|
||||||
self.is_gathered = True
|
self.is_gathered = True
|
||||||
|
|
||||||
# configure the init deivce of the shard
|
# configure the init device of the shard
|
||||||
# no-offload default: fp16, fp32 -> CUDA
|
# no-offload default: fp16, fp32 -> CUDA
|
||||||
# offload default: fp16, fp32 -> CPU
|
# offload default: fp16, fp32 -> CPU
|
||||||
self.shard_device = torch.device("cpu") if cpu_shard_init else get_current_device()
|
self.shard_device = torch.device("cpu") if cpu_shard_init else get_current_device()
|
||||||
@ -111,17 +119,19 @@ class Chunk:
|
|||||||
self.chunk_mem = self.chunk_size * self.chunk_temp.element_size()
|
self.chunk_mem = self.chunk_size * self.chunk_temp.element_size()
|
||||||
self.shard_mem = self.chunk_mem // self.pg_size
|
self.shard_mem = self.chunk_mem // self.pg_size
|
||||||
|
|
||||||
# each tensor is associated with a TensorInfo to track meta info
|
# each tensor is associated with a TensorInfo to track its meta info
|
||||||
|
# (state, offset, end)
|
||||||
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
|
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
|
||||||
# the total number of all tensors
|
# the total number of tensors in the chunk
|
||||||
self.num_tensors = 0
|
self.num_tensors = 0
|
||||||
# monitor the states of all tensors
|
|
||||||
self.tensors_state_monitor: Dict[TensorState, int] = dict()
|
|
||||||
for state in TensorState:
|
|
||||||
self.tensors_state_monitor[state] = 0
|
|
||||||
|
|
||||||
# some chunks can keep gathered all the time
|
# Record the number of tensors in different states
|
||||||
# so their computation patterns are the same as that of the parameters in DDP
|
self.tensor_state_cnter: Dict[TensorState, int] = dict()
|
||||||
|
for state in TensorState:
|
||||||
|
self.tensor_state_cnter[state] = 0
|
||||||
|
|
||||||
|
# If a chunk is kept gathered,
|
||||||
|
# they are treated the same as that of the parameters in DDP during training.
|
||||||
self.keep_gathered = keep_gathered
|
self.keep_gathered = keep_gathered
|
||||||
if self.keep_gathered:
|
if self.keep_gathered:
|
||||||
pin_memory = False # since this chunk is gathered, it doesn't need to pin
|
pin_memory = False # since this chunk is gathered, it doesn't need to pin
|
||||||
@ -182,7 +192,7 @@ class Chunk:
|
|||||||
assert self.chunk_temp is None
|
assert self.chunk_temp is None
|
||||||
|
|
||||||
if self.is_gathered:
|
if self.is_gathered:
|
||||||
return self.chunk_total
|
return self.cuda_global_chunk
|
||||||
elif self.cuda_shard is not None:
|
elif self.cuda_shard is not None:
|
||||||
return self.cuda_shard
|
return self.cuda_shard
|
||||||
else:
|
else:
|
||||||
@ -207,19 +217,19 @@ class Chunk:
|
|||||||
if self.keep_gathered:
|
if self.keep_gathered:
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
return self.tensors_state_monitor[TensorState.HOLD] + \
|
return self.tensor_state_cnter[TensorState.HOLD] + \
|
||||||
self.tensors_state_monitor[TensorState.HOLD_AFTER_BWD] == self.num_tensors
|
self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] == self.num_tensors
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def can_reduce(self):
|
def can_reduce(self):
|
||||||
return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors
|
return self.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == self.num_tensors
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_inf_or_nan(self) -> bool:
|
def has_inf_or_nan(self) -> bool:
|
||||||
"""Check if the chunk has inf or nan values on CUDA.
|
"""Check if the chunk has inf or nan values on CUDA.
|
||||||
"""
|
"""
|
||||||
if self.is_gathered:
|
if self.is_gathered:
|
||||||
valid_tensor = self.chunk_total[:self.utilized_size]
|
valid_tensor = self.cuda_global_chunk[:self.utilized_size]
|
||||||
else:
|
else:
|
||||||
assert self.cuda_shard is not None # only check on CUDA
|
assert self.cuda_shard is not None # only check on CUDA
|
||||||
valid_tensor = self.cuda_shard[:self.valid_end]
|
valid_tensor = self.cuda_shard[:self.valid_end]
|
||||||
@ -231,7 +241,7 @@ class Chunk:
|
|||||||
"""
|
"""
|
||||||
assert self.l2_norm is None, "you are calculating the l2 norm twice"
|
assert self.l2_norm is None, "you are calculating the l2 norm twice"
|
||||||
if self.is_gathered:
|
if self.is_gathered:
|
||||||
valid_tensor = self.chunk_total[:self.utilized_size]
|
valid_tensor = self.cuda_global_chunk[:self.utilized_size]
|
||||||
else:
|
else:
|
||||||
assert self.cuda_shard is not None # calculate on CUDA
|
assert self.cuda_shard is not None # calculate on CUDA
|
||||||
valid_tensor = self.cuda_shard[:self.valid_end]
|
valid_tensor = self.cuda_shard[:self.valid_end]
|
||||||
@ -261,7 +271,7 @@ class Chunk:
|
|||||||
self.num_tensors += 1
|
self.num_tensors += 1
|
||||||
tensor_state = TensorState.HOLD
|
tensor_state = TensorState.HOLD
|
||||||
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size)
|
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size)
|
||||||
self.tensors_state_monitor[tensor_state] += 1
|
self.tensor_state_cnter[tensor_state] += 1
|
||||||
self.utilized_size = new_utilized_size
|
self.utilized_size = new_utilized_size
|
||||||
|
|
||||||
def close_chunk(self):
|
def close_chunk(self):
|
||||||
@ -277,10 +287,10 @@ class Chunk:
|
|||||||
self.valid_end = self.utilized_size - self.shard_begin
|
self.valid_end = self.utilized_size - self.shard_begin
|
||||||
|
|
||||||
if self.chunk_temp.device.type == 'cpu':
|
if self.chunk_temp.device.type == 'cpu':
|
||||||
self.chunk_total = self.chunk_temp.to(get_current_device())
|
self.cuda_global_chunk = self.chunk_temp.to(get_current_device())
|
||||||
self.__update_tensors_ptr()
|
self.__update_tensors_ptr()
|
||||||
else:
|
else:
|
||||||
self.chunk_total = self.chunk_temp
|
self.cuda_global_chunk = self.chunk_temp
|
||||||
self.chunk_temp = None
|
self.chunk_temp = None
|
||||||
|
|
||||||
self.__scatter()
|
self.__scatter()
|
||||||
@ -366,19 +376,19 @@ class Chunk:
|
|||||||
|
|
||||||
if self.pg_size == 1:
|
if self.pg_size == 1:
|
||||||
# tricky code here
|
# tricky code here
|
||||||
# just move chunk_total to cuda_shard
|
# just move cuda_global_chunk to cuda_shard
|
||||||
# the communication is not necessary
|
# the communication is not necessary
|
||||||
self.__scatter()
|
self.__scatter()
|
||||||
elif self.keep_gathered:
|
elif self.keep_gathered:
|
||||||
# we use all-reduce here
|
# we use all-reduce here
|
||||||
dist.all_reduce(self.chunk_total, group=self.torch_pg)
|
dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg)
|
||||||
else:
|
else:
|
||||||
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())
|
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())
|
||||||
|
|
||||||
input_list = list(torch.chunk(self.chunk_total, chunks=self.pg_size, dim=0))
|
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
||||||
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
|
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
|
||||||
|
|
||||||
free_storage(self.chunk_total)
|
free_storage(self.cuda_global_chunk)
|
||||||
self.is_gathered = False
|
self.is_gathered = False
|
||||||
self.__update_tensors_state(TensorState.HOLD)
|
self.__update_tensors_state(TensorState.HOLD)
|
||||||
|
|
||||||
@ -413,8 +423,8 @@ class Chunk:
|
|||||||
assert self.is_gathered
|
assert self.is_gathered
|
||||||
|
|
||||||
tensor_info = self.tensors_info[tensor]
|
tensor_info = self.tensors_info[tensor]
|
||||||
self.chunk_total[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten())
|
self.cuda_global_chunk[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten())
|
||||||
tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
tensor.data = self.cuda_global_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||||
|
|
||||||
def get_valid_length(self) -> int:
|
def get_valid_length(self) -> int:
|
||||||
"""Get the valid length of the chunk's payload.
|
"""Get the valid length of the chunk's payload.
|
||||||
@ -443,7 +453,7 @@ class Chunk:
|
|||||||
friend_chunk = self.paired_chunk
|
friend_chunk = self.paired_chunk
|
||||||
if self.is_gathered is True:
|
if self.is_gathered is True:
|
||||||
assert friend_chunk.is_gathered is True
|
assert friend_chunk.is_gathered is True
|
||||||
self.chunk_total.copy_(friend_chunk.chunk_total)
|
self.cuda_global_chunk.copy_(friend_chunk.cuda_global_chunk)
|
||||||
self.optim_sync_flag = True
|
self.optim_sync_flag = True
|
||||||
elif friend_chunk.device_type == 'cuda' and self.device_type == 'cuda':
|
elif friend_chunk.device_type == 'cuda' and self.device_type == 'cuda':
|
||||||
self.cuda_shard.copy_(friend_chunk.cuda_shard)
|
self.cuda_shard.copy_(friend_chunk.cuda_shard)
|
||||||
@ -465,8 +475,8 @@ class Chunk:
|
|||||||
# sanity check
|
# sanity check
|
||||||
assert self.cuda_shard is not None
|
assert self.cuda_shard is not None
|
||||||
|
|
||||||
alloc_storage(self.chunk_total)
|
alloc_storage(self.cuda_global_chunk)
|
||||||
gather_list = list(torch.chunk(input=self.chunk_total, chunks=self.pg_size, dim=0))
|
gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
||||||
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
|
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
|
||||||
|
|
||||||
self.cuda_shard = None
|
self.cuda_shard = None
|
||||||
@ -480,11 +490,11 @@ class Chunk:
|
|||||||
# sanity check
|
# sanity check
|
||||||
assert self.cuda_shard is None
|
assert self.cuda_shard is None
|
||||||
|
|
||||||
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.chunk_total.device)
|
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.cuda_global_chunk.device)
|
||||||
|
|
||||||
self.cuda_shard.copy_(self.chunk_total[self.shard_begin:self.shard_end])
|
self.cuda_shard.copy_(self.cuda_global_chunk[self.shard_begin:self.shard_end])
|
||||||
|
|
||||||
free_storage(self.chunk_total)
|
free_storage(self.cuda_global_chunk)
|
||||||
self.is_gathered = False
|
self.is_gathered = False
|
||||||
|
|
||||||
def __paired_shard_move(self):
|
def __paired_shard_move(self):
|
||||||
@ -505,15 +515,15 @@ class Chunk:
|
|||||||
def __update_tensors_ptr(self) -> None:
|
def __update_tensors_ptr(self) -> None:
|
||||||
# sanity check
|
# sanity check
|
||||||
assert self.is_gathered
|
assert self.is_gathered
|
||||||
assert type(self.chunk_total) == torch.Tensor
|
assert type(self.cuda_global_chunk) == torch.Tensor
|
||||||
|
|
||||||
for tensor, tensor_info in self.tensors_info.items():
|
for tensor, tensor_info in self.tensors_info.items():
|
||||||
tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
tensor.data = self.cuda_global_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||||
|
|
||||||
def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState):
|
def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState):
|
||||||
self.tensors_state_monitor[tensor_info.state] -= 1
|
self.tensor_state_cnter[tensor_info.state] -= 1
|
||||||
tensor_info.state = next_state
|
tensor_info.state = next_state
|
||||||
self.tensors_state_monitor[tensor_info.state] += 1
|
self.tensor_state_cnter[tensor_info.state] += 1
|
||||||
|
|
||||||
def __update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
|
def __update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
|
||||||
for tensor_info in self.tensors_info.values():
|
for tensor_info in self.tensors_info.values():
|
||||||
@ -543,9 +553,9 @@ class Chunk:
|
|||||||
output.append("\tchunk temp:\n")
|
output.append("\tchunk temp:\n")
|
||||||
print_tensor(tensor=self.chunk_temp, prefix='\t\t')
|
print_tensor(tensor=self.chunk_temp, prefix='\t\t')
|
||||||
|
|
||||||
if self.chunk_total is not None and self.chunk_total.storage().size() > 0:
|
if self.cuda_global_chunk is not None and self.cuda_global_chunk.storage().size() > 0:
|
||||||
output.append("\tchunk total:\n")
|
output.append("\tchunk total:\n")
|
||||||
print_tensor(tensor=self.chunk_total, prefix='\t\t')
|
print_tensor(tensor=self.cuda_global_chunk, prefix='\t\t')
|
||||||
|
|
||||||
if self.cuda_shard is not None:
|
if self.cuda_shard is not None:
|
||||||
output.append("\tcuda shard:\n")
|
output.append("\tcuda shard:\n")
|
||||||
@ -561,6 +571,6 @@ class Chunk:
|
|||||||
if detailed:
|
if detailed:
|
||||||
output.append("\ttensor state monitor:\n")
|
output.append("\ttensor state monitor:\n")
|
||||||
for st in TensorState:
|
for st in TensorState:
|
||||||
output.append("\t\t# of {}: {}\n".format(st, self.tensors_state_monitor[st]))
|
output.append("\t\t# of {}: {}\n".format(st, self.tensor_state_cnter[st]))
|
||||||
|
|
||||||
return ''.join(output)
|
return ''.join(output)
|
||||||
|
@ -299,7 +299,7 @@ class ZeroDDP(ColoDDP):
|
|||||||
reduced = self.chunk_manager.reduce_chunk(chunk)
|
reduced = self.chunk_manager.reduce_chunk(chunk)
|
||||||
if reduced:
|
if reduced:
|
||||||
if chunk.is_gathered:
|
if chunk.is_gathered:
|
||||||
chunk.chunk_total.div_(chunk.pg_size)
|
chunk.cuda_global_chunk.div_(chunk.pg_size)
|
||||||
else:
|
else:
|
||||||
chunk.cuda_shard.div_(chunk.pg_size)
|
chunk.cuda_shard.div_(chunk.pg_size)
|
||||||
# check overflow elements
|
# check overflow elements
|
||||||
@ -529,7 +529,7 @@ class ZeroDDP(ColoDDP):
|
|||||||
load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice))
|
load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice))
|
||||||
|
|
||||||
if chunk.is_gathered:
|
if chunk.is_gathered:
|
||||||
chunk.chunk_total.copy_(temp_chunk)
|
chunk.cuda_global_chunk.copy_(temp_chunk)
|
||||||
elif chunk.cuda_shard is not None:
|
elif chunk.cuda_shard is not None:
|
||||||
chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
|
chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
|
||||||
else:
|
else:
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from colossalai.gemini.chunk import Chunk
|
from colossalai.gemini.chunk import Chunk
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
||||||
def get_temp_total_chunk_on_cuda(chunk: Chunk):
|
def get_temp_total_chunk_on_cuda(chunk: Chunk):
|
||||||
if chunk.is_gathered:
|
if chunk.is_gathered:
|
||||||
return chunk.chunk_total
|
return chunk.cuda_global_chunk
|
||||||
|
|
||||||
if chunk.cuda_shard is not None:
|
if chunk.cuda_shard is not None:
|
||||||
shard_temp = chunk.cuda_shard
|
shard_temp = chunk.cuda_shard
|
||||||
|
@ -9,10 +9,11 @@ from colossalai.tensor.tensor_spec import ColoTensorSpec
|
|||||||
|
|
||||||
|
|
||||||
class ColoParamOpHook(ABC):
|
class ColoParamOpHook(ABC):
|
||||||
"""Hook which is triggered by each operation when operands contain ColoParameter.
|
"""
|
||||||
|
Hook which is triggered by each operation when operands contain ColoParameter.
|
||||||
To customize it, you must inherit this abstract class, and implement ``pre_forward``,
|
To customize it, you must inherit this abstract class, and implement ``pre_forward``,
|
||||||
``post_forward``, ``pre_backward`` and ``post_backward``. These four methods take a list
|
``post_forward``, ``pre_backward`` and ``post_backward``.
|
||||||
of ColoParameter.
|
These four methods apply a list of ColoParameter as input args.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -33,7 +34,8 @@ class ColoParamOpHook(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class ColoParamOpHookManager:
|
class ColoParamOpHookManager:
|
||||||
"""Manage your param op hooks. It only has static methods.
|
"""
|
||||||
|
Manage your param op hooks. It only has static methods.
|
||||||
The only static method you should call is ``use_hooks(*hooks)``.
|
The only static method you should call is ``use_hooks(*hooks)``.
|
||||||
"""
|
"""
|
||||||
hooks: Tuple[ColoParamOpHook, ...] = tuple()
|
hooks: Tuple[ColoParamOpHook, ...] = tuple()
|
||||||
|
@ -2,23 +2,22 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from colossalai.gemini.memory_tracer import MemStatsCollector
|
||||||
|
from colossalai.gemini.ophooks import BaseOpHook
|
||||||
|
from colossalai.gemini.stateful_tensor import TensorState
|
||||||
|
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.registry import OPHOOKS
|
from colossalai.registry import OPHOOKS
|
||||||
|
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.gemini.ophooks import BaseOpHook
|
|
||||||
|
|
||||||
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
|
|
||||||
from colossalai.gemini.memory_tracer import MemStatsCollector
|
|
||||||
from colossalai.gemini.stateful_tensor import TensorState
|
|
||||||
|
|
||||||
|
|
||||||
@OPHOOKS.register_module
|
@OPHOOKS.register_module
|
||||||
class ZeroHook(BaseOpHook):
|
class ZeroHook(BaseOpHook):
|
||||||
"""
|
"""
|
||||||
A hook to process sharded param for ZeRO method.
|
A hook to process sharded param for ZeRO method.
|
||||||
|
Warning: this class has been deprecated after version 0.1.12
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -69,7 +69,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
|
|||||||
assert my_chunk.can_move
|
assert my_chunk.can_move
|
||||||
my_chunk.shard_move(get_current_device())
|
my_chunk.shard_move(get_current_device())
|
||||||
else:
|
else:
|
||||||
assert my_chunk.chunk_total.size(0) == 1024
|
assert my_chunk.cuda_global_chunk.size(0) == 1024
|
||||||
assert my_chunk.device_type == 'cuda'
|
assert my_chunk.device_type == 'cuda'
|
||||||
assert not my_chunk.can_move
|
assert not my_chunk.can_move
|
||||||
|
|
||||||
@ -82,27 +82,27 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
|
|||||||
for param, param_cp in zip(param_list, param_cp_list):
|
for param, param_cp in zip(param_list, param_cp_list):
|
||||||
check_euqal(param, param_cp)
|
check_euqal(param, param_cp)
|
||||||
|
|
||||||
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4
|
assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4
|
||||||
my_chunk.tensor_trans_state(param_list[0], TensorState.COMPUTE)
|
my_chunk.tensor_trans_state(param_list[0], TensorState.COMPUTE)
|
||||||
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 3
|
assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 3
|
||||||
assert my_chunk.tensors_state_monitor[TensorState.COMPUTE] == 1
|
assert my_chunk.tensor_state_cnter[TensorState.COMPUTE] == 1
|
||||||
assert not my_chunk.can_release
|
assert not my_chunk.can_release
|
||||||
|
|
||||||
for param in param_list:
|
for param in param_list:
|
||||||
my_chunk.tensor_trans_state(param, TensorState.COMPUTE)
|
my_chunk.tensor_trans_state(param, TensorState.COMPUTE)
|
||||||
my_chunk.tensor_trans_state(param, TensorState.READY_FOR_REDUCE)
|
my_chunk.tensor_trans_state(param, TensorState.READY_FOR_REDUCE)
|
||||||
|
|
||||||
assert my_chunk.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == 4
|
assert my_chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4
|
||||||
assert my_chunk.can_reduce
|
assert my_chunk.can_reduce
|
||||||
my_chunk.reduce()
|
my_chunk.reduce()
|
||||||
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4
|
assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4
|
||||||
|
|
||||||
if keep_gathered is False:
|
if keep_gathered is False:
|
||||||
assert my_chunk.cuda_shard.size(0) == 1024 // world_size
|
assert my_chunk.cuda_shard.size(0) == 1024 // world_size
|
||||||
assert my_chunk.device_type == 'cuda'
|
assert my_chunk.device_type == 'cuda'
|
||||||
assert my_chunk.can_move
|
assert my_chunk.can_move
|
||||||
else:
|
else:
|
||||||
assert my_chunk.chunk_total.size(0) == 1024
|
assert my_chunk.cuda_global_chunk.size(0) == 1024
|
||||||
assert my_chunk.device_type == 'cuda'
|
assert my_chunk.device_type == 'cuda'
|
||||||
assert not my_chunk.can_move
|
assert not my_chunk.can_move
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user