mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 21:22:49 +00:00
[gemini]remove registered gradients hooks (#5696)
* fix gemini fix gemini * fix fix
This commit is contained in:
@@ -20,7 +20,12 @@ class ChunkManager:
|
||||
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
|
||||
"""
|
||||
|
||||
def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
chunk_configuration,
|
||||
init_device: Optional[torch.device] = None,
|
||||
reuse_fp16_chunk: bool = True,
|
||||
) -> None:
|
||||
self.device = init_device or get_accelerator().get_current_device()
|
||||
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
|
||||
self.kwargs_config = chunk_configuration
|
||||
@@ -33,6 +38,10 @@ class ChunkManager:
|
||||
self.accessed_chunks: Set[Chunk] = set()
|
||||
self.accessed_mem: int = 0
|
||||
self.total_mem: Dict[str, int] = {"cpu": 0, "cuda": 0}
|
||||
self.reuse_fp16_chunk = reuse_fp16_chunk
|
||||
# Whether model is accumulating gradients,
|
||||
self.accumulating_grads = False
|
||||
self.overflow_counter = 0
|
||||
|
||||
def register_tensor(
|
||||
self,
|
||||
|
Reference in New Issue
Block a user