mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 21:22:49 +00:00
[gemini] improve compatibility and add static placement policy (#4479)
* [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
This commit is contained in:
@@ -2,8 +2,9 @@ from collections import deque
|
||||
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.tensor import ColoTensor
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .chunk import Chunk, ChunkFullError, TensorState
|
||||
@@ -27,16 +28,17 @@ class ChunkManager:
|
||||
self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size')
|
||||
v['init_device'] = self.device
|
||||
|
||||
self.chunk_groups: Dict[str, Deque] = dict()
|
||||
self.chunk_groups: Dict[str, Deque[Chunk]] = dict()
|
||||
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict()
|
||||
self.accessed_chunks: Set[Chunk] = set()
|
||||
self.accessed_mem: int = 0
|
||||
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
|
||||
|
||||
def register_tensor(self,
|
||||
tensor: ColoTensor,
|
||||
tensor: torch.Tensor,
|
||||
group_type: str,
|
||||
config_key: int,
|
||||
process_group: ProcessGroup,
|
||||
cpu_offload: bool = False,
|
||||
pin_memory: bool = False) -> None:
|
||||
"""
|
||||
@@ -51,7 +53,7 @@ class ChunkManager:
|
||||
pin_memory: whether the chunk is pinned in the cpu memory
|
||||
"""
|
||||
assert tensor not in self.tensor_chunk_map
|
||||
assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager"
|
||||
assert isinstance(tensor, torch.Tensor), "Please feed Tensor to this ChunkManager"
|
||||
assert config_key in self.dp_degree_chunk_size_dict
|
||||
|
||||
chunk_size = self.dp_degree_chunk_size_dict[config_key]
|
||||
@@ -73,12 +75,12 @@ class ChunkManager:
|
||||
|
||||
if tensor.numel() > chunk_size:
|
||||
chunk_size = tensor.numel()
|
||||
dp_size = tensor.get_dp_world_size()
|
||||
dp_size = dist.get_world_size(process_group)
|
||||
chunk_size = chunk_size + (-chunk_size % dp_size)
|
||||
|
||||
chunk = Chunk(
|
||||
chunk_size=chunk_size,
|
||||
process_group=tensor.process_group,
|
||||
process_group=process_group,
|
||||
dtype=tensor.dtype,
|
||||
cpu_shard_init=cpu_offload,
|
||||
pin_memory=pin_memory,
|
||||
@@ -220,7 +222,7 @@ class ChunkManager:
|
||||
msg.append(f'[{i}] {chunk}\n')
|
||||
return ''.join(msg)
|
||||
|
||||
def __get_chunk_group(self, group_name: str) -> Deque:
|
||||
def __get_chunk_group(self, group_name: str) -> Deque[Chunk]:
|
||||
"""Register a chunk group.
|
||||
"""
|
||||
if group_name not in self.chunk_groups:
|
||||
|
Reference in New Issue
Block a user