[LowLevelZero] low level zero support lora (#5153)

* low level zero support lora

low level zero support lora

* add checkpoint test

* add checkpoint test

* fix

* fix

* fix

* fix

fix

fix

fix

* fix

* fix

fix

fix

fix

fix

fix

fix

* fix

* fix

fix

fix

fix

fix

fix

fix

* fix

* test ci

* git # This is a combination of 3 commits.

Update low_level_zero_plugin.py

Update low_level_zero_plugin.py

fix

fix

fix

* fix naming

fix naming

fix naming

fix
This commit is contained in:
flybird11111
2023-12-21 17:01:01 +08:00
committed by Hongxin Liu
parent 14b0d4c7e5
commit 8954a0c2e2
8 changed files with 264 additions and 8 deletions

View File

@@ -1,4 +1,4 @@
from typing import Callable, Iterator, List, Tuple, Union
from typing import Callable, Dict, Iterator, List, Tuple, Union
import torch
import torch.distributed as dist
@@ -51,6 +51,12 @@ class DPPluginWrapper(DPPluginBase):
def no_sync(self, model: nn.Module) -> Iterator[None]:
pass
def enable_lora(self, model: nn.Module, pretrained_dir: str, lora_config: Dict) -> nn.Module:
pass
def support_lora(self) -> bool:
pass
def check_dataloader_sharding():
plugin = DPPluginWrapper()