mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[elixir] refactored the chunk module (#3956)
This commit is contained in:
@@ -1,2 +1,7 @@
|
||||
from .core import BlockRequire, Chunk, ChunkGroup, MemoryPool, PrivateBlock, PublicBlock, TensorBlock, TensorState
|
||||
from .core import BlockSpec, Chunk, ChunkGroup, MemoryPool, PrivateBlock, PublicBlock, TensorBlock, TensorState
|
||||
from .fetcher import ChunkFetcher
|
||||
|
||||
__all__ = [
|
||||
'BlockSpec', 'Chunk', 'ChunkGroup', 'MemoryPool', 'PrivateBlock', 'PublicBlock', 'TensorBlock', 'TensorState',
|
||||
'ChunkFetcher'
|
||||
]
|
||||
|
@@ -1,4 +1,8 @@
|
||||
from .chunk import Chunk
|
||||
from .group import ChunkGroup
|
||||
from .memory_pool import BlockRequire, MemoryPool, PrivateBlock, PublicBlock, TensorBlock
|
||||
from .memory_pool import BlockSpec, MemoryPool, PrivateBlock, PublicBlock, TensorBlock
|
||||
from .states import TensorState
|
||||
|
||||
__all__ = [
|
||||
'Chunk', 'ChunkGroup', 'BlockSpec', 'MemoryPool', 'PrivateBlock', 'PublicBlock', 'TensorBlock', 'TensorState'
|
||||
]
|
||||
|
@@ -8,8 +8,8 @@ from torch.distributed import ProcessGroup
|
||||
from colossalai.elixir.cuda import gpu_device
|
||||
from colossalai.elixir.tensor import FakeTensor
|
||||
|
||||
from .memory_pool import MemoryPool, PrivateBlock, PublicBlock, TensorBlock
|
||||
from .states import TensorState, ts_update_sanity_check
|
||||
from .memory_pool import MemoryPool, TensorBlock
|
||||
from .states import TensorState, validate_tensor_state_update
|
||||
|
||||
|
||||
class ChunkFullError(Exception):
|
||||
@@ -383,7 +383,11 @@ class Chunk:
|
||||
prev_state = self.tensors_info[tensor].state
|
||||
if prev_state == tensor_state:
|
||||
return
|
||||
if ts_update_sanity_check(prev_state, tensor_state):
|
||||
|
||||
# validate whether the update is legal
|
||||
# if illegal, raise an exception
|
||||
is_update_valid = validate_tensor_state_update(prev_state, tensor_state, raise_exception=True)
|
||||
if is_update_valid:
|
||||
self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state)
|
||||
|
||||
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
|
||||
|
@@ -129,7 +129,7 @@ class ChunkGroup(object):
|
||||
"""Check whether the rcache has enough blocks to store the gathered chunk."""
|
||||
if chunk.rcache_fused:
|
||||
return True
|
||||
return self.rcache.public_free_cnt > 0
|
||||
return self.rcache.public_free_count > 0
|
||||
|
||||
def access_chunk(self, chunk: Chunk) -> bool:
|
||||
"""Access a chunk into rCache."""
|
||||
@@ -141,7 +141,7 @@ class ChunkGroup(object):
|
||||
if chunk.rcache_fused:
|
||||
block = None
|
||||
else:
|
||||
block = self.rcache.get_public_block()
|
||||
block = self.rcache.pop_public_block()
|
||||
chunk.access_chunk(block)
|
||||
self.__add_to_accset(chunk)
|
||||
return True
|
||||
|
@@ -1,35 +1,52 @@
|
||||
from abc import ABC
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import Iterable, NamedTuple
|
||||
|
||||
import torch
|
||||
from torch.autograd.profiler_util import _format_memory
|
||||
|
||||
|
||||
class BlockRequire(NamedTuple):
|
||||
class BlockSpec(NamedTuple):
|
||||
"""
|
||||
BlockSpec is the specification of a block. It contains the number of elements and the data type of the block.
|
||||
|
||||
Args:
|
||||
numel (int): the number of elements in the block
|
||||
dtype (torch.dtype): the data type of the block
|
||||
"""
|
||||
numel: int
|
||||
dtype: torch.dtype
|
||||
|
||||
|
||||
class BlockType(Enum):
|
||||
"""
|
||||
BlockType is the type of a block. There are two types of blocks: public and private.
|
||||
"""
|
||||
PUBLIC = 0
|
||||
PRIVATE = 1
|
||||
|
||||
|
||||
class TensorBlock(ABC):
|
||||
"""TensorBlock is the memory unit of memory pool.
|
||||
It is a continuous memory block used to store tensors.
|
||||
"""
|
||||
TensorBlock is the memory unit of memory pool. It is a contiguous memory block used to store tensors.
|
||||
Each chunk needs a corresponding TensorBlock to store its data during training.
|
||||
|
||||
args:
|
||||
numel: the number of elements in the block
|
||||
dtype: the data type of the block
|
||||
device_type: the device type of the block
|
||||
size (int): the number of elements in the block
|
||||
dtype (torch.dtype): the data type of the block
|
||||
device_type (str): the device type of the block
|
||||
"""
|
||||
total_count: int = 0
|
||||
|
||||
def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None:
|
||||
def __init__(self, size: int, dtype: torch.dtype, device_type: str, block_type: BlockType) -> None:
|
||||
self.block_id = TensorBlock.total_count
|
||||
TensorBlock.total_count += 1
|
||||
|
||||
self.device_type = device_type
|
||||
self.payload: torch.Tensor = torch.empty((numel,), dtype=dtype, device=device_type)
|
||||
self.memo_occ: int = self.payload.numel() * self.payload.element_size()
|
||||
self.payload: torch.Tensor = torch.empty((size,), dtype=dtype, device=device_type)
|
||||
self.size_in_bytes: int = self.payload.numel() * self.payload.element_size()
|
||||
self.block_type = block_type
|
||||
|
||||
@property
|
||||
def numel(self):
|
||||
@@ -50,122 +67,145 @@ class TensorBlock(ABC):
|
||||
return self.block_id == other.block_id
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'(id={self.block_id}, numel={self.numel}, device={self.device_type}, dtype={self.dtype}, memo={self.memo_occ})'
|
||||
return f'{self.block_type}(\n\tID = {self.block_id}, \n\tsize = {self.numel}, \n\tdevice = {self.device_type}, \n\tdtype = {self.dtype}, \n\tsize in bytes={self.size_in_bytes}\n)'
|
||||
|
||||
|
||||
class PublicBlock(TensorBlock):
|
||||
"""Public blocks have the same length.
|
||||
Chunks of the same length can share the same public block.
|
||||
"""
|
||||
Public blocks have the same length. Chunks of the same length can share the same public block.
|
||||
"""
|
||||
|
||||
def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None:
|
||||
super().__init__(numel, dtype, device_type)
|
||||
self.block_type = 'public'
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'PublicBlock{super().__repr__()}'
|
||||
super().__init__(numel, dtype, device_type, BlockType.PUBLIC)
|
||||
|
||||
|
||||
class PrivateBlock(TensorBlock):
|
||||
"""Private blocks may have different lengths.
|
||||
Each private chunk should use its own private block.
|
||||
"""
|
||||
Private blocks may have different lengths. Each private chunk should use its own private block.
|
||||
"""
|
||||
|
||||
def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None:
|
||||
super().__init__(numel, dtype, device_type)
|
||||
self.block_type = 'private'
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'PrivateBlock{super().__repr__()}'
|
||||
super().__init__(numel, dtype, device_type, BlockType.PRIVATE)
|
||||
|
||||
|
||||
class MemoryPool(object):
|
||||
"""A memory pool consists of public blocks and private blocks.
|
||||
"""
|
||||
A memory pool consists of public blocks and private blocks.
|
||||
rCache uses memory pool to manage memory bolcks.
|
||||
Users should allocate memory blocks before using it.
|
||||
|
||||
args:
|
||||
device_type: the device type of the memory pool
|
||||
device_type (str): the device type of the memory pool
|
||||
"""
|
||||
|
||||
def __init__(self, device_type: str) -> None:
|
||||
assert device_type in [
|
||||
'cuda', 'cpu'
|
||||
], f'Expected device type to be cuda or cpu, but got an invalid device type: {device_type}'
|
||||
self.device_type: str = device_type
|
||||
|
||||
# public space
|
||||
# public space = number of public block x the public block size in bytes
|
||||
# all public blocks have the same block size
|
||||
self.public_space: int = 0
|
||||
self.public_block_size: int = 0
|
||||
self.public_dtype: torch.dtype = None
|
||||
|
||||
self.public_free_blocks: list = None
|
||||
self.public_used_blocks: set = None
|
||||
|
||||
self.public_free_cnt: int = 0
|
||||
self.public_used_cnt: int = 0
|
||||
# create block holder and counter
|
||||
self.public_free_blocks: list = list()
|
||||
self.public_used_blocks: set = set()
|
||||
self.public_free_count: int = 0
|
||||
self.public_used_count: int = 0
|
||||
|
||||
# private space
|
||||
# private look up dict returns an empty list if the block is not found
|
||||
self.private_space: int = 0
|
||||
self.private_blocks: list = None
|
||||
self.private_lookup_dict: dict[BlockRequire, list] = None
|
||||
self.private_blocks: list = list()
|
||||
self.private_lookup_dict: dict[BlockSpec, list] = defaultdict(list)
|
||||
|
||||
self.__allocate_flag = False
|
||||
# flags for block allcation
|
||||
self.__public_allocated_flag = False
|
||||
self.__private_allocated_flag = False
|
||||
|
||||
def allocate(self,
|
||||
public_dtype: torch.dtype = torch.float,
|
||||
public_block_size: int = 1024,
|
||||
public_block_number: int = 0,
|
||||
private_block_list: Iterable[BlockRequire] = ()):
|
||||
assert self.__allocate_flag is False
|
||||
assert public_block_number >= 0
|
||||
def allocate_public_blocks(self, block_num: int, block_spec: BlockSpec = None):
|
||||
"""
|
||||
Allocate public tensor blocks for the memory pool. This method will allocate public_block_number blocks with size equal to public_block_size.
|
||||
"""
|
||||
assert not self.__public_allocated_flag, 'Public blocks have been allocated to this MemoryPool object, it is not allowed to allocate again.'
|
||||
assert block_num >= 0, f'Expected public_block_number >= 0, but got {block_num}'
|
||||
|
||||
self.public_free_blocks = list()
|
||||
self.public_used_blocks = set()
|
||||
for _ in range(public_block_number):
|
||||
block = PublicBlock(public_block_size, public_dtype, self.device_type)
|
||||
if block_spec is None:
|
||||
block_spec = BlockSpec(numel=1024, dtype=torch.float)
|
||||
|
||||
# allocate public blocks
|
||||
for _ in range(block_num):
|
||||
block = PublicBlock(block_spec.numel, block_spec.dtype, self.device_type)
|
||||
self.public_free_blocks.append(block)
|
||||
self.public_space += block.size_in_bytes
|
||||
self.public_free_count += 1
|
||||
|
||||
if public_block_number <= 0:
|
||||
self.public_space = 0
|
||||
else:
|
||||
self.public_space = self.public_free_blocks[0].memo_occ * public_block_number
|
||||
self.public_block_size = public_block_size
|
||||
self.public_dtype = public_dtype
|
||||
# store the block spec info
|
||||
self.public_block_size = block_spec.numel
|
||||
self.public_dtype = block_spec.dtype
|
||||
|
||||
self.public_free_cnt = public_block_number
|
||||
self.public_used_cnt = 0
|
||||
def allocate_private_blocks(self, block_specs: Iterable[BlockSpec]):
|
||||
"""
|
||||
Allocate private blocks for the memory pool. This method will allocate private blocks according to the block_specs.
|
||||
|
||||
self.private_space = 0
|
||||
self.private_blocks = list()
|
||||
self.private_lookup_dict = defaultdict(list)
|
||||
Args:
|
||||
block_specs (Iterable[BlockSpec]): the block specs of the private blocks to be allocated
|
||||
"""
|
||||
# allocate private blocks
|
||||
assert not self.__private_allocated_flag, 'Private blocks have been allocated to this MemoryPool object, it is not allowed to allocate again.'
|
||||
|
||||
for require in private_block_list:
|
||||
block = PrivateBlock(require.numel, require.dtype, self.device_type)
|
||||
self.private_space += block.memo_occ
|
||||
for spec in block_specs:
|
||||
block = PrivateBlock(spec.numel, spec.dtype, self.device_type)
|
||||
self.private_space += block.size_in_bytes
|
||||
self.private_blocks.append(block)
|
||||
self.private_lookup_dict[require].append(block)
|
||||
self.private_lookup_dict[spec].append(block)
|
||||
|
||||
self.__allocate_flag = True
|
||||
self.__private_allocated_flag = True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'MP(public_space={_format_memory(self.public_space)}, private_space={_format_memory(self.private_space)})'
|
||||
return f'Memory Pool(\n\tpublic_space = {_format_memory(self.public_space)}, \n\tprivate_space={_format_memory(self.private_space)}\n)'
|
||||
|
||||
def get_private_block(self, numel: int, dtype: torch.dtype):
|
||||
block_list = self.private_lookup_dict.get(BlockRequire(numel=numel, dtype=dtype))
|
||||
return block_list.pop()
|
||||
def get_private_block(self, numel: int, dtype: torch.dtype) -> PrivateBlock:
|
||||
"""
|
||||
Get a private block with the given numel and dtype.
|
||||
"""
|
||||
block_list = self.private_lookup_dict.get(BlockSpec(numel=numel, dtype=dtype))
|
||||
|
||||
def get_public_block(self):
|
||||
self.public_free_cnt -= 1
|
||||
self.public_used_cnt += 1
|
||||
if len(block_list) == 0:
|
||||
raise ValueError(f'No private block with numel={numel} and dtype={dtype} is found.')
|
||||
else:
|
||||
return block_list.pop()
|
||||
|
||||
def pop_public_block(self) -> PublicBlock:
|
||||
"""
|
||||
Get a public block from the memory pool.
|
||||
"""
|
||||
self.public_free_count -= 1
|
||||
self.public_used_count += 1
|
||||
|
||||
block = self.public_free_blocks.pop()
|
||||
self.public_used_blocks.add(block)
|
||||
|
||||
return block
|
||||
|
||||
def free_public_block(self, block: TensorBlock):
|
||||
def free_public_block(self, block: TensorBlock) -> PublicBlock:
|
||||
"""
|
||||
Free a public block to the memory pool.
|
||||
|
||||
Args:
|
||||
block (TensorBlock): the public block to be freed
|
||||
"""
|
||||
assert isinstance(block, PublicBlock)
|
||||
assert block in self.public_used_blocks
|
||||
assert block in self.public_used_blocks, f'Cound not find the given block in the used public blocks'
|
||||
|
||||
self.public_free_cnt += 1
|
||||
self.public_used_cnt -= 1
|
||||
# update counter
|
||||
self.public_free_count += 1
|
||||
self.public_used_count -= 1
|
||||
|
||||
# update free and used blocks
|
||||
self.public_used_blocks.remove(block)
|
||||
self.public_free_blocks.append(block)
|
||||
|
||||
|
@@ -2,6 +2,10 @@ from enum import Enum
|
||||
|
||||
|
||||
class TensorState(Enum):
|
||||
"""
|
||||
TensorState represents the state of a tensor in Elixir.
|
||||
There are five states of a tensor: free, compute, hold, hold_after_bwd, ready_for_reduce.
|
||||
"""
|
||||
FREE = 0
|
||||
COMPUTE = 1
|
||||
HOLD = 2
|
||||
@@ -9,17 +13,35 @@ class TensorState(Enum):
|
||||
READY_FOR_REDUCE = 4
|
||||
|
||||
|
||||
# expected: free -> hold -> compute -> hold ->
|
||||
# this includes the possible state transition in tensor state:
|
||||
# the item in the list is in the format of (old_state, new_state)
|
||||
# the complete state transtition is:
|
||||
# free -> hold -> compute -> hold ->
|
||||
# -> compute -> hold_after_bwd -> ready_for_reduce
|
||||
legal_ts_update_list = [(TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
|
||||
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
|
||||
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
|
||||
(TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
|
||||
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE),
|
||||
(TensorState.READY_FOR_REDUCE, TensorState.HOLD)]
|
||||
LEGAL_TENSOR_STATE_UPDATE_LIST = [(TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
|
||||
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
|
||||
(TensorState.COMPUTE, TensorState.HOLD),
|
||||
(TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
|
||||
(TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
|
||||
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE),
|
||||
(TensorState.READY_FOR_REDUCE, TensorState.HOLD)]
|
||||
|
||||
|
||||
def ts_update_sanity_check(old_state, new_state) -> bool:
|
||||
if (old_state, new_state) not in legal_ts_update_list:
|
||||
raise RuntimeError(f'illegal tensor state updating: {old_state} -> {new_state}')
|
||||
def validate_tensor_state_update(old_state: TensorState, new_state: TensorState, raise_exception: bool = False) -> bool:
|
||||
"""
|
||||
Validate the tensor state update is legal or not.
|
||||
|
||||
Args:
|
||||
old_state (TensorState): the old state of the tensor
|
||||
new_state (TensorState): the new state of the tensor
|
||||
raise_exception (bool, optional): whether to raise exception when the state update is illegal. Defaults to False.
|
||||
|
||||
Returns:
|
||||
bool: whether the state update is legal or not.
|
||||
"""
|
||||
if (old_state, new_state) not in LEGAL_TENSOR_STATE_UPDATE_LIST:
|
||||
if raise_exception:
|
||||
raise RuntimeError(f'Found illegal tensor state updating: {old_state} -> {new_state}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
3
colossalai/elixir/context/__init__.py
Normal file
3
colossalai/elixir/context/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .meta_context import MetaContext
|
||||
|
||||
__all__ = ['MetaContext']
|
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
|
||||
tensor_creation_methods = dict(tensor=torch.tensor,
|
||||
TESNOR_CREATION_METHODS = dict(tensor=torch.tensor,
|
||||
sparse_coo_tensor=torch.sparse_coo_tensor,
|
||||
asarray=torch.asarray,
|
||||
as_tensor=torch.as_tensor,
|
||||
@@ -29,4 +29,34 @@ tensor_creation_methods = dict(tensor=torch.tensor,
|
||||
polar=torch.polar,
|
||||
heaviside=torch.heaviside)
|
||||
|
||||
from .meta_ctx import MetaContext
|
||||
|
||||
# TODO: unify this with lazy init context
|
||||
class MetaContext(object):
|
||||
"""A context manager that wraps all tensor creation methods in torch.
|
||||
By default, all tensors will be created in meta.
|
||||
|
||||
args:
|
||||
device_type: The device type of the tensors to be created.
|
||||
"""
|
||||
|
||||
def __init__(self, device_type: str = 'meta') -> None:
|
||||
super().__init__()
|
||||
self.device_type = device_type
|
||||
return None
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
def meta_wrap(func):
|
||||
|
||||
def wrapped_func(*args, **kwargs):
|
||||
kwargs['device'] = self.device_type
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapped_func
|
||||
|
||||
for name, method in TESNOR_CREATION_METHODS.items():
|
||||
setattr(torch, name, meta_wrap(method))
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
for name, method in TESNOR_CREATION_METHODS.items():
|
||||
setattr(torch, name, method)
|
@@ -1,34 +0,0 @@
|
||||
import torch
|
||||
|
||||
from colossalai.elixir.ctx import tensor_creation_methods
|
||||
|
||||
|
||||
class MetaContext(object):
|
||||
"""A context manager that wraps all tensor creation methods in torch.
|
||||
By default, all tensors will be created in meta.
|
||||
|
||||
args:
|
||||
device_type: The device type of the tensors to be created.
|
||||
"""
|
||||
|
||||
def __init__(self, device_type: str = 'meta') -> None:
|
||||
super().__init__()
|
||||
self.device_type = device_type
|
||||
return None
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
def meta_wrap(func):
|
||||
|
||||
def wrapped_func(*args, **kwargs):
|
||||
kwargs['device'] = self.device_type
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapped_func
|
||||
|
||||
for name, method in tensor_creation_methods.items():
|
||||
setattr(torch, name, meta_wrap(method))
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
for name, method in tensor_creation_methods.items():
|
||||
setattr(torch, name, method)
|
@@ -1,4 +1,3 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
fused_torch_functions = {F.layer_norm: F.layer_norm}
|
||||
@@ -12,6 +11,3 @@ def register_fused_layer_norm():
|
||||
except:
|
||||
print('Cannot import fused layer norm, please install apex from source.')
|
||||
pass
|
||||
|
||||
|
||||
register_fused_layer_norm()
|
||||
|
@@ -5,7 +5,7 @@ from typing import List, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.elixir.chunk import BlockRequire, ChunkGroup, MemoryPool
|
||||
from colossalai.elixir.chunk import BlockSpec, ChunkGroup, MemoryPool
|
||||
from colossalai.elixir.tracer.param_tracer import generate_tf_order
|
||||
from colossalai.elixir.tracer.utils import meta_copy
|
||||
from colossalai.elixir.utils import print_rank_0
|
||||
@@ -119,7 +119,7 @@ class SearchBase(ABC):
|
||||
for plan in chunk_plans:
|
||||
kwargs = plan.kwargs
|
||||
if kwargs.get('rcache_fused', False):
|
||||
block_require_list.append(BlockRequire(plan.chunk_size, plan.chunk_dtype))
|
||||
block_require_list.append(BlockSpec(plan.chunk_size, plan.chunk_dtype))
|
||||
|
||||
mp = MemoryPool('cuda')
|
||||
mp.allocate(public_dtype=self.unified_dtype,
|
||||
|
Reference in New Issue
Block a user