[Feature] qlora support (#5586)

* [feature] qlora support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* qlora follow commit

* migrate qutization folder to colossalai/

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fixes

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
linsj20
2024-04-17 15:03:31 +08:00
committed by GitHub
parent cabc1286ca
commit 52a2dded36
51 changed files with 1031 additions and 579 deletions

View File

@@ -1,10 +1,11 @@
from copy import deepcopy
from typing import Optional
import torch
import torch.distributed as dist
from peft import LoraConfig
from torchvision.models import resnet18
from utils import shared_tempdir
from typing import Optional
from peft import LoraConfig
from copy import deepcopy
import colossalai
from colossalai.booster import Booster
@@ -131,12 +132,15 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo
# return repr(e)
raise e
@clear_cache_before_run()
@parameterize("stage", [2])
@parameterize("shard", [True, False])
@parameterize("offload", [False, True])
@parameterize("model_name", ["transformers_llama"])
def check_low_level_zero_lora_checkpointIO(stage: int, shard: bool, offload: bool, model_name: str, early_stop: bool = True):
def check_low_level_zero_lora_checkpointIO(
stage: int, shard: bool, offload: bool, model_name: str, early_stop: bool = True
):
passed_models = []
failed_info = {} # (model_name, error) pair
@@ -166,6 +170,7 @@ def check_low_level_zero_lora_checkpointIO(stage: int, shard: bool, offload: boo
print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
def run_dist(rank, world_size, port):
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost")
check_low_level_zero_checkpointIO()