mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[gemini] gemini supports lazy init (#3379)
* [gemini] fix nvme optimizer init * [gemini] gemini supports lazy init * [gemini] add init example * [gemini] add fool model * [zero] update gemini ddp * [zero] update init example * add chunk method * add chunk method * [lazyinit] fix lazy tensor tolist * [gemini] fix buffer materialization * [misc] remove useless file * [booster] update gemini plugin * [test] update gemini plugin test * [test] fix gemini plugin test * [gemini] fix import * [gemini] fix import * [lazyinit] use new metatensor * [lazyinit] use new metatensor * [lazyinit] fix __set__ method
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
import torch
|
||||
import math
|
||||
import os
|
||||
import tempfile
|
||||
import math
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
from typing import Optional, List, Dict, Callable
|
||||
|
||||
|
||||
class NVMeOptimizer(torch.optim.Optimizer):
|
||||
@@ -42,8 +43,9 @@ class NVMeOptimizer(torch.optim.Optimizer):
|
||||
self.offloader = None
|
||||
self.is_on_nvme: Dict[Parameter, bool] = {}
|
||||
self.offloaded_numel: int = 0
|
||||
self.total_numel: int = self._get_numel()
|
||||
self.can_offload_numel = math.floor(self.total_numel * self.nvme_offload_fraction)
|
||||
# As param may be not materialized here, these attributes are initalized when the first step
|
||||
self.total_numel: Optional[int] = None
|
||||
self.can_offload_numel: Optional[int] = None
|
||||
|
||||
self.prefetch_params: List[Parameter] = []
|
||||
self.param_to_prefetch_idx: Dict[Parameter, int] = {}
|
||||
@@ -77,6 +79,9 @@ class NVMeOptimizer(torch.optim.Optimizer):
|
||||
self.prefetch_params.append(p)
|
||||
|
||||
def _pre_step(self, *state_keys: str) -> None:
|
||||
if self.total_numel is None:
|
||||
self.total_numel = self._get_numel()
|
||||
self.can_offload_numel = math.floor(self.total_numel * self.nvme_offload_fraction)
|
||||
self._setup_prefetch_params()
|
||||
if self.offloader is None or len(self.prefetch_params) == 0:
|
||||
return
|
||||
|
Reference in New Issue
Block a user