[gemini] gemini supports lazy init (#3379)

* [gemini] fix nvme optimizer init

* [gemini] gemini supports lazy init

* [gemini] add init example

* [gemini] add fool model

* [zero] update gemini ddp

* [zero] update init example

* add chunk method

* add chunk method

* [lazyinit] fix lazy tensor tolist

* [gemini] fix buffer materialization

* [misc] remove useless file

* [booster] update gemini plugin

* [test] update gemini plugin test

* [test] fix gemini plugin test

* [gemini] fix import

* [gemini] fix import

* [lazyinit] use new metatensor

* [lazyinit] use new metatensor

* [lazyinit] fix __set__ method
This commit is contained in:
Hongxin Liu
2023-04-12 16:03:25 +08:00
committed by GitHub
parent 366a035552
commit 152239bbfa
7 changed files with 80 additions and 72 deletions

View File

@@ -1,9 +1,10 @@
import torch
import math
import os
import tempfile
import math
from typing import Callable, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from typing import Optional, List, Dict, Callable
class NVMeOptimizer(torch.optim.Optimizer):
@@ -42,8 +43,9 @@ class NVMeOptimizer(torch.optim.Optimizer):
self.offloader = None
self.is_on_nvme: Dict[Parameter, bool] = {}
self.offloaded_numel: int = 0
self.total_numel: int = self._get_numel()
self.can_offload_numel = math.floor(self.total_numel * self.nvme_offload_fraction)
# As param may be not materialized here, these attributes are initalized when the first step
self.total_numel: Optional[int] = None
self.can_offload_numel: Optional[int] = None
self.prefetch_params: List[Parameter] = []
self.param_to_prefetch_idx: Dict[Parameter, int] = {}
@@ -77,6 +79,9 @@ class NVMeOptimizer(torch.optim.Optimizer):
self.prefetch_params.append(p)
def _pre_step(self, *state_keys: str) -> None:
if self.total_numel is None:
self.total_numel = self._get_numel()
self.can_offload_numel = math.floor(self.total_numel * self.nvme_offload_fraction)
self._setup_prefetch_params()
if self.offloader is None or len(self.prefetch_params) == 0:
return