From f792507ff34cb58af601b91a70967f4855a7471f Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Fri, 29 Jul 2022 13:27:05 +0800 Subject: [PATCH] [chunk] add PG check for tensor appending (#1383) --- colossalai/gemini/chunk_mgr.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/colossalai/gemini/chunk_mgr.py b/colossalai/gemini/chunk_mgr.py index 42392d76d..2fb8772a0 100644 --- a/colossalai/gemini/chunk_mgr.py +++ b/colossalai/gemini/chunk_mgr.py @@ -3,7 +3,7 @@ from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable from collections import deque from colossalai.utils import get_current_device -from colossalai.tensor import ProcessGroup as ColoProcessGroup +from colossalai.tensor import ProcessGroup as ColoProcessGroup, ColoTensor from .chunk import Chunk, ChunkFullError, TensorState @@ -13,6 +13,7 @@ class ChunkManager: Args: chunk_size (int): the size of a chunk. + process_group (ColoProcessGroup): process group of the chunk. enable_distributed_storage (bool): optional, allow for distributed storage of a chunk. The default is false. init_device (torch.device): optional, the device on which the chunk is initialized. The default is None. """ @@ -57,6 +58,9 @@ class ChunkManager: group_name (str): the name of the chunk group. """ assert tensor not in self.tensor_chunk_map + if isinstance(tensor, ColoTensor): + assert tensor.get_process_group().dp_process_group() == self.process_group.dp_process_group( + ), f"Chunk Manager can only manage ColoTensor with the same DP process group" if self.chunk_size is not None and tensor.numel() > self.chunk_size: raise ValueError( f'Cannot create chunk, got tensor numel ({tensor.numel()}) > chunk size ({self.chunk_size})')