mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 01:48:07 +00:00
[Feature] qlora support (#5586)
* [feature] qlora support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * qlora follow commit * migrate qutization folder to colossalai/ * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Callable, Iterator, List, Tuple, Union, Dict
|
||||
from typing import Callable, Dict, Iterator, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@@ -51,7 +51,6 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config=None)
|
||||
# raise e
|
||||
|
||||
|
||||
|
||||
@parameterize("stage", [2])
|
||||
def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
|
||||
"""check low level zero plugin over model zoo
|
||||
@@ -118,6 +117,7 @@ def check_low_level_zero_lora(stage, model_name, early_stop: bool = True):
|
||||
print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
|
||||
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, early_stop: bool = True):
|
||||
# init dist env
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
|
@@ -1,10 +1,11 @@
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from peft import LoraConfig
|
||||
from torchvision.models import resnet18
|
||||
from utils import shared_tempdir
|
||||
from typing import Optional
|
||||
from peft import LoraConfig
|
||||
from copy import deepcopy
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
@@ -131,12 +132,15 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo
|
||||
# return repr(e)
|
||||
raise e
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize("stage", [2])
|
||||
@parameterize("shard", [True, False])
|
||||
@parameterize("offload", [False, True])
|
||||
@parameterize("model_name", ["transformers_llama"])
|
||||
def check_low_level_zero_lora_checkpointIO(stage: int, shard: bool, offload: bool, model_name: str, early_stop: bool = True):
|
||||
def check_low_level_zero_lora_checkpointIO(
|
||||
stage: int, shard: bool, offload: bool, model_name: str, early_stop: bool = True
|
||||
):
|
||||
passed_models = []
|
||||
failed_info = {} # (model_name, error) pair
|
||||
|
||||
@@ -166,6 +170,7 @@ def check_low_level_zero_lora_checkpointIO(stage: int, shard: bool, offload: boo
|
||||
print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
|
||||
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_low_level_zero_checkpointIO()
|
||||
|
@@ -1,16 +1,8 @@
|
||||
import math
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
@@ -22,6 +14,7 @@ try:
|
||||
from exllama_kernels import prepare_buffers, set_tuning_params
|
||||
|
||||
from colossalai.inference.quant.gptq import CaiQuantLinear
|
||||
|
||||
HAS_AUTO_GPTQ = True
|
||||
except:
|
||||
HAS_AUTO_GPTQ = False
|
||||
@@ -32,13 +25,14 @@ import warnings
|
||||
HAS_GPTQ_CUDA = False
|
||||
try:
|
||||
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
||||
|
||||
gptq_cuda = GPTQBuilder().load()
|
||||
HAS_GPTQ_CUDA = True
|
||||
except ImportError:
|
||||
warnings.warn('CUDA gptq is not installed')
|
||||
warnings.warn("CUDA gptq is not installed")
|
||||
HAS_GPTQ_CUDA = False
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
max_inner_outer_dim = 1
|
||||
max_input_len = 1
|
||||
@@ -64,9 +58,9 @@ def init_buffer(cai_linear, use_act_order=False):
|
||||
max_input_len = 4096
|
||||
# The temp_state buffer is required to reorder X in the act-order case.
|
||||
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
||||
gptq_temp_state_buffer = torch.zeros((max_input_len, max_inner_outer_dim),
|
||||
dtype=torch.float16,
|
||||
device=torch.cuda.current_device())
|
||||
gptq_temp_state_buffer = torch.zeros(
|
||||
(max_input_len, max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device()
|
||||
)
|
||||
gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device())
|
||||
|
||||
gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer)
|
||||
@@ -77,10 +71,11 @@ def init_buffer(cai_linear, use_act_order=False):
|
||||
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ,
|
||||
reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq")
|
||||
@pytest.mark.skipif(
|
||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ,
|
||||
reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq",
|
||||
)
|
||||
def test_gptq_linear():
|
||||
|
||||
infeature = 1024
|
||||
outfeature = 1024
|
||||
group_size = 128
|
||||
@@ -120,7 +115,7 @@ def test_gptq_linear():
|
||||
max_input_len = 2048
|
||||
buffers = {
|
||||
"temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device),
|
||||
"temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device)
|
||||
"temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device),
|
||||
}
|
||||
|
||||
prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"])
|
||||
@@ -146,5 +141,4 @@ def test_gptq_linear():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
test_gptq_linear()
|
||||
|
@@ -4,6 +4,7 @@ from packaging import version
|
||||
|
||||
try:
|
||||
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
|
106
tests/test_lora/test_lora.py
Normal file
106
tests/test_lora/test_lora.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import copy
|
||||
import os
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
from peft import LoraConfig
|
||||
from torch import distributed as dist
|
||||
from torch.optim import AdamW
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_checkpoint_io.utils import shared_tempdir
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
|
||||
model = model_fn()
|
||||
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
|
||||
|
||||
test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()]
|
||||
test_configs = [
|
||||
{
|
||||
"lora_config": lora_config,
|
||||
"quantize": False,
|
||||
},
|
||||
{
|
||||
"lora_config": lora_config,
|
||||
"quantize": True,
|
||||
},
|
||||
]
|
||||
for plugin, test_config in product(test_plugins, test_configs):
|
||||
# checkpoint loaded model
|
||||
model_save = model_fn()
|
||||
model_load = copy.deepcopy(model_save)
|
||||
|
||||
optimizer = AdamW(model.parameters(), lr=0.001)
|
||||
criterion = loss_fn
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
model_save = booster.enable_lora(model_save, **test_config)
|
||||
model_save, optimizer, criterion, _, _ = booster.boost(model_save, optimizer, criterion)
|
||||
|
||||
with shared_tempdir() as tempdir:
|
||||
lora_ckpt_path = os.path.join(tempdir, "ckpt")
|
||||
booster.save_lora_as_pretrained(model_save, lora_ckpt_path)
|
||||
dist.barrier()
|
||||
|
||||
# The Lora checkpoint should be small in size
|
||||
checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024)
|
||||
assert checkpoint_size_mb < 1
|
||||
|
||||
model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path, **test_config)
|
||||
model_load, _, _, _, _ = booster.boost(model_load)
|
||||
|
||||
check_state_dict_equal(model_save.state_dict(), model_load.state_dict())
|
||||
|
||||
# test fwd bwd correctness
|
||||
test_model = model_load
|
||||
model_copy = copy.deepcopy(model_load)
|
||||
|
||||
data = data_gen_fn()
|
||||
data = {
|
||||
k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
|
||||
}
|
||||
|
||||
output = test_model(**data)
|
||||
output = output_transform_fn(output)
|
||||
loss = criterion(output)
|
||||
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.clip_grad_by_norm(1.0)
|
||||
optimizer.step()
|
||||
|
||||
for (n1, p1), (n2, p2) in zip(test_model.named_parameters(), model_copy.named_parameters()):
|
||||
if "lora_" in n1:
|
||||
# lora modules require gradients, thus updated
|
||||
assert p1.requires_grad
|
||||
assert not torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3)
|
||||
else:
|
||||
if not p1.requires_grad:
|
||||
torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3)
|
||||
|
||||
|
||||
def run_lora_test():
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
task_type = None
|
||||
if name == "transformers_llama_for_casual_lm":
|
||||
task_type = "CAUSAL_LM"
|
||||
if name == "transformers_llama_for_sequence_classification":
|
||||
task_type = "SEQ_CLS"
|
||||
check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_lora_test()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_torch_ddp_lora():
|
||||
spawn(run_dist, 2)
|
@@ -1,108 +0,0 @@
|
||||
import copy
|
||||
import os
|
||||
|
||||
import torch
|
||||
from peft import LoraConfig
|
||||
from torch import distributed as dist
|
||||
from torch.optim import AdamW
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import TorchDDPPlugin
|
||||
from colossalai.testing import (
|
||||
assert_equal,
|
||||
assert_not_equal,
|
||||
check_state_dict_equal,
|
||||
clear_cache_before_run,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_checkpoint_io.utils import shared_tempdir
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
|
||||
model = model_fn()
|
||||
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
|
||||
|
||||
plugin = TorchDDPPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
model = booster.enable_lora(model, lora_config=lora_config)
|
||||
model_copy = copy.deepcopy(model)
|
||||
|
||||
optimizer = AdamW(model.parameters(), lr=0.001)
|
||||
criterion = loss_fn
|
||||
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||
|
||||
data = data_gen_fn()
|
||||
data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()}
|
||||
|
||||
output = model(**data)
|
||||
output = output_transform_fn(output)
|
||||
loss = criterion(output)
|
||||
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.clip_grad_by_norm(1.0)
|
||||
optimizer.step()
|
||||
|
||||
for (n1, p1), (n2, p2) in zip(model.named_parameters(), model_copy.named_parameters()):
|
||||
if "lora_" in n1:
|
||||
# lora modules require gradients, thus updated
|
||||
assert p1.requires_grad
|
||||
assert_not_equal(p1.to(p2.device), p2)
|
||||
else:
|
||||
if not p1.requires_grad:
|
||||
assert_equal(p1.to(p2.device), p2)
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
|
||||
plugin = TorchDDPPlugin()
|
||||
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
|
||||
|
||||
model_save = model_fn()
|
||||
model_load = copy.deepcopy(model_save)
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
model_save = booster.enable_lora(model_save, lora_config=lora_config)
|
||||
model_save, _, _, _, _ = booster.boost(model_save)
|
||||
|
||||
with shared_tempdir() as tempdir:
|
||||
lora_ckpt_path = os.path.join(tempdir, "ckpt")
|
||||
booster.save_lora_as_pretrained(model_save, lora_ckpt_path)
|
||||
dist.barrier()
|
||||
|
||||
# The Lora checkpoint should be small in size
|
||||
checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024)
|
||||
assert checkpoint_size_mb < 1
|
||||
|
||||
model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path)
|
||||
model_load, _, _, _, _ = booster.boost(model_load)
|
||||
|
||||
check_state_dict_equal(model_save.state_dict(), model_load.state_dict())
|
||||
|
||||
|
||||
def run_lora_test():
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
task_type = None
|
||||
if name == "transformers_llama_for_casual_lm":
|
||||
task_type = "CAUSAL_LM"
|
||||
if name == "transformers_llama_for_sequence_classification":
|
||||
task_type = "SEQ_CLS"
|
||||
check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
|
||||
check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_lora_test()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_torch_ddp_lora():
|
||||
spawn(run_dist, 2)
|
Reference in New Issue
Block a user