mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-26 04:03:58 +00:00
[Feature] qlora support (#5586)
* [feature] qlora support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * qlora follow commit * migrate qutization folder to colossalai/ * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from peft import LoraConfig
|
||||
from torchvision.models import resnet18
|
||||
from utils import shared_tempdir
|
||||
from typing import Optional
|
||||
from peft import LoraConfig
|
||||
from copy import deepcopy
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
@@ -131,12 +132,15 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo
|
||||
# return repr(e)
|
||||
raise e
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize("stage", [2])
|
||||
@parameterize("shard", [True, False])
|
||||
@parameterize("offload", [False, True])
|
||||
@parameterize("model_name", ["transformers_llama"])
|
||||
def check_low_level_zero_lora_checkpointIO(stage: int, shard: bool, offload: bool, model_name: str, early_stop: bool = True):
|
||||
def check_low_level_zero_lora_checkpointIO(
|
||||
stage: int, shard: bool, offload: bool, model_name: str, early_stop: bool = True
|
||||
):
|
||||
passed_models = []
|
||||
failed_info = {} # (model_name, error) pair
|
||||
|
||||
@@ -166,6 +170,7 @@ def check_low_level_zero_lora_checkpointIO(stage: int, shard: bool, offload: boo
|
||||
print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
|
||||
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_low_level_zero_checkpointIO()
|
||||
|
Reference in New Issue
Block a user