[Feature] qlora support (#5586)

* [feature] qlora support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* qlora follow commit

* migrate qutization folder to colossalai/

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fixes

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
linsj20 2024-04-17 15:03:31 +08:00 committed by GitHub
parent cabc1286ca
commit 52a2dded36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
51 changed files with 1031 additions and 579 deletions

17
LICENSE
View File

@ -527,3 +527,20 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
---------------- LICENSE FOR Hugging Face accelerate ----------------
Copyright 2021 The HuggingFace Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -76,9 +76,11 @@ def main(args):
if args.strategy == "ddp":
strategy = DDPStrategy()
elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="static",initial_scale=2**5)
strategy = GeminiStrategy(placement_policy="static", initial_scale=2**5)
elif args.strategy == "colossalai_gemini_cpu":
strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
strategy = GeminiStrategy(
placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5
)
elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
elif args.strategy == "colossalai_zero2_cpu":

View File

@ -30,4 +30,3 @@ class Actor(LoRAModule):
"""Returns model output."""
output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
return output

View File

@ -75,7 +75,9 @@ def get_strategy_from_args(strategy: str):
elif strategy == "colossalai_zero2":
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
elif strategy == "colossalai_gemini_cpu":
strategy_ = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
strategy_ = GeminiStrategy(
placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5
)
elif strategy == "colossalai_zero2_cpu":
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else:

View File

@ -101,16 +101,17 @@ class DDPStrategy(Strategy):
model_path = os.path.join(path, "pytorch_model.bin")
self.save_model(model, model_path, shard=shard)
def _replace_keys(model_path: str, replace_fn: Callable):
state_dict = torch.load(model_path, map_location="cpu")
state_dict = {replace_fn(k): v for k, v in state_dict.items()}
torch.save(state_dict, model_path)
# FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
# HACK: rename keys of pytorch_model.bin
if dist.get_rank() == 0:
_replace_keys(model_path, lambda k: k.replace("model.", "", 1))
def get_model_state_dict_shard(self, model: nn.Module, **config):
# TODO: implement sharding on naive strategy
model = self.unwrap_model(model)

View File

@ -24,7 +24,9 @@ def main(args):
if args.strategy == "ddp":
strategy = DDPStrategy()
elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
strategy = GeminiStrategy(
placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5
)
elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else:

View File

@ -130,8 +130,8 @@ from modelscope import AutoModelForCausalLM, AutoTokenizer, snapshot_download
model_dir = snapshot_download('colossalai/Colossal-LLaMA-2-7b-base', revision='v1.0.1')
tokenizer = AutoTokenizer.from_pretrained(model_dir, device_map="auto", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", trust_remote_code=True).eval()
generation_kwargs = {"max_new_tokens": 256,
"top_p": 0.95,
generation_kwargs = {"max_new_tokens": 256,
"top_p": 0.95,
"temperature": 0.3
}
input = '离离原上草,'

View File

@ -1,20 +1,20 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
import os
import random
from dataclasses import dataclass
from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable
from typing import Callable, Dict, Iterator, List, Optional, Sequence, Union
import numpy as np
import torch
from datasets import dataset_dict, load_from_disk
import torch.nn.functional as F
from datasets import Dataset as HFDataset
from datasets import dataset_dict, load_from_disk
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group
from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler
from torch.utils.data import ConcatDataset, DataLoader, Dataset, DistributedSampler
from transformers.tokenization_utils import PreTrainedTokenizer
import torch.nn.functional as F
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
PathType = Union[str, os.PathLike]

View File

@ -7,9 +7,9 @@ Splicing multiple pre-tokenized sequence data points
import random
import warnings
from copy import deepcopy
from datasets import dataset_dict
from typing import Any, Callable, Dict, Iterable, List, Union, Tuple
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
from datasets import dataset_dict
from torch.utils.data import ConcatDataset, Dataset, IterableDataset
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from transformers.tokenization_utils import PreTrainedTokenizer
@ -169,12 +169,7 @@ class ClosedToConstantLengthSplicedDataset(IterableDataset):
spliced_labels.extend(seq_labels)
# For residual spliced data point at the end of the data set
if self.infinite is False and more_data_points is False and len(spliced_input_ids) > 0:
examples.append(
{
self.input_ids_field: spliced_input_ids,
self.labels_field: spliced_labels
}
)
examples.append({self.input_ids_field: spliced_input_ids, self.labels_field: spliced_labels})
if self.shuffle:
random.shuffle(examples)
for spliced_data_point in examples:

View File

@ -8,11 +8,10 @@ import argparse
import numpy as np
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM
from transformers import LlamaForCausalLM, LlamaTokenizer
from colossalai.logging import get_dist_logger
logger = get_dist_logger()

View File

@ -6,12 +6,12 @@ Initialize new tokenizer for continual pre-training
"""
import argparse
import os
import json
import os
from typing import List, Union
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from colossalai.logging import get_dist_logger

View File

@ -10,8 +10,8 @@ import os
from typing import Any, Dict, Tuple, Union
import torch
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator

View File

@ -242,4 +242,4 @@ To comprehensively assess the performance of the Colossal-LLaMA-2-7B-base model,
## Conclusion
In general, the Colossal-LLaMA-2-7B-base model not only enhances its understanding of English but also exhibits significant improvements in its comprehension of Chinese. It boasts a broad spectrum of general knowledge, encompassing various fields such as food, sports, technology, literature, games, and more. Regarding text generation tasks, the Colossal-LLaMA-2-7B-base model excels in writing performance; however, its ability to generate specific formats like code, emails, tables, etc., needs enhancement due to the scarcity of relevant training data during our training phase. When compared to the Qwen-7b-base model, the Colossal-LLaMA-2-7B-base model outperforms it in answering most English questions and some Chinese questions, as demonstrated in the examples above.
Presently, the Colossal-LLaMA-2-7B-base model already exhibits some capabilities in sentiment analysis, logical reasoning, information extraction, role-play, classification, and rewriting. These capabilities are poised for further improvement in the future as part of our ongoing enhancements.
Presently, the Colossal-LLaMA-2-7B-base model already exhibits some capabilities in sentiment analysis, logical reasoning, information extraction, role-play, classification, and rewriting. These capabilities are poised for further improvement in the future as part of our ongoing enhancements.

View File

@ -1,2 +1,2 @@
hostname1
hostname2
hostname2

View File

@ -11,14 +11,14 @@ import os
import time
from multiprocessing import cpu_count
from colossal_llama2.dataset.spliced_and_tokenized_dataset import (
ClosedToConstantLengthSplicedDataset,
supervised_tokenize,
)
from datasets import dataset_dict, load_dataset
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from colossalai.logging import get_dist_logger
from colossal_llama2.dataset.spliced_and_tokenized_dataset import (
supervised_tokenize,
ClosedToConstantLengthSplicedDataset,
)
logger = get_dist_logger()
@ -149,5 +149,5 @@ def main():
spliced_dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(spliced_dataset), cpu_count()))
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -12,4 +12,3 @@ flash-attn>=2.0.0,<=2.0.5
tqdm
sentencepiece==0.1.99
protobuf<=3.20.0

View File

@ -1,45 +1,39 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
"""
import json
import argparse
import json
import os
import resource
from contextlib import nullcontext
from tqdm import tqdm
import torch
import torch.distributed as dist
from colossal_llama2.dataset.loader import (
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
load_tokenized_dataset,
setup_distributed_dataloader,
)
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
from torch.utils.tensorboard import SummaryWriter
from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig
from tqdm import tqdm
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import (
GeminiPlugin,
LowLevelZeroPlugin,
HybridParallelPlugin,
)
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from colossal_llama2.dataset.loader import (
load_tokenized_dataset,
setup_distributed_dataloader,
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
)
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
def get_model_numel(model: torch.nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
@ -372,9 +366,7 @@ def main() -> None:
# Final save.
coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
coordinator.print_on_master(
f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}"
)
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")

View File

@ -1 +1 @@
0.0.1
0.0.1

View File

@ -19,6 +19,7 @@ except ImportError:
import colossalai.interface.pretrained as pretrained_utils
from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.quantization import BnbQuantizationConfig
from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory
@ -230,7 +231,12 @@ class Booster:
return self.plugin.no_sync(model, optimizer)
def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: "peft.LoraConfig" = None
self,
model: nn.Module,
pretrained_dir: Optional[str] = None,
lora_config: "peft.LoraConfig" = None,
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
quantize=False,
) -> nn.Module:
"""
Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory.
@ -259,7 +265,20 @@ class Booster:
assert (
pretrained_dir is not None
), "Please provide pretrained directory path if not passing in lora configuration."
return self.plugin.enable_lora(model, pretrained_dir, lora_config)
if quantize is True:
if bnb_quantization_config is not None:
warnings.warn(
"User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk."
)
else:
bnb_quantization_config = BnbQuantizationConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config)
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
"""Load model from checkpoint.

View File

@ -1,11 +1,11 @@
import logging
import warnings
import enum
import logging
import os
import warnings
from functools import partial
from pathlib import Path
from types import MethodType
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Dict
from typing import Callable, Dict, Iterator, List, Optional, Tuple
import torch
import torch.nn as nn
@ -27,6 +27,7 @@ from colossalai.checkpoint_io.utils import (
sharded_optimizer_loading_epilogue,
)
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.utils import get_current_device
from colossalai.zero import LowLevelZeroOptimizer
@ -44,6 +45,7 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"]
class OptimizerParamCheckState(enum.Enum):
ORIGIN_PARAM_FINDED = 0
ORIGIN_PARAM_NOT_FIND = -1
@ -221,6 +223,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
from peft import PeftModel
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
peft_model = model.unwrap()
assert isinstance(
@ -331,38 +334,45 @@ class LowLevelZeroPlugin(DPPluginBase):
def supported_devices(self) -> List[str]:
return ["cuda"]
def support_lora(self) -> bool:
return True
def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
self,
model: nn.Module,
pretrained_dir: Optional[str] = None,
lora_config: Optional[Dict] = None,
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
) -> nn.Module:
from peft import PeftModel, get_peft_model
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
self.lora_enabled = True
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
if bnb_quantization_config is not None:
model = quantize_model(model, bnb_quantization_config)
if pretrained_dir is None:
peft_model = get_peft_model(model, lora_config)
else:
peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)
return peft_model
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter):
origin_param_id = id(origin_param)
for group_id, param_group in enumerate(optimizer.param_groups):
for p in param_group['params']:
for p in param_group["params"]:
if id(p) == origin_param_id:
return group_id
return -1
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter, lora_param: Parameter):
origin_param_id = id(origin_param)
lora_param_id = id(lora_param)
target_group_id = None
for group_id, param_group in enumerate(optimizer.param_groups):
for p in param_group['params']:
for p in param_group["params"]:
if id(p) == lora_param_id:
# check if the lora parameter exists.
return target_group_id, OptimizerParamCheckState.LORA_PARM_EXISTED
@ -372,25 +382,31 @@ class LowLevelZeroPlugin(DPPluginBase):
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_FINDED
else:
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND
def add_lora_params_to_optimizer(self, model, optimizer):
""" add lora parameters to optimizer """
name2param= {}
"""add lora parameters to optimizer"""
name2param = {}
for name, param in model.named_parameters():
name2param[name] = param
for name, param in name2param.items():
if 'lora_A' in name or 'lora_B' in name:
if "lora_A" in name or "lora_B" in name:
origin_key = name.replace("lora_A.", "")
origin_key = origin_key.replace("lora_B.", "")
origin_key = origin_key.replace(f"{model.active_adapter}", "base_layer")
origin_param = name2param[origin_key]
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
warnings.warn("Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups.")
elif check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED and group_id is not None and group_id >= 0:
optimizer.param_groups[group_id]['params'].append(param)
warnings.warn(
"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups."
)
elif (
check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED
and group_id is not None
and group_id >= 0
):
optimizer.param_groups[group_id]["params"].append(param)
def configure(
self,
model: nn.Module,
@ -401,11 +417,13 @@ class LowLevelZeroPlugin(DPPluginBase):
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if self.lora_enabled:
from peft import PeftModel
assert isinstance(model, PeftModel), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"
assert isinstance(
model, PeftModel
), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"
if optimizer is not None:
self.add_lora_params_to_optimizer(model, optimizer)
if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.precision)

View File

@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from .dp_plugin_base import DPPluginBase
@ -237,10 +238,17 @@ class TorchDDPPlugin(DPPluginBase):
return model.module.no_sync()
def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
self,
model: nn.Module,
pretrained_dir: Optional[str] = None,
lora_config: Optional[Dict] = None,
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
) -> nn.Module:
from peft import PeftModel, get_peft_model
if bnb_quantization_config is not None:
model = quantize_model(model, bnb_quantization_config)
assert not isinstance(model, TorchDDPModel), "Lora should be enabled before boosting the model."
if pretrained_dir is None:
return get_peft_model(model, lora_config)

View File

@ -64,7 +64,7 @@ vllm
flash-attention
# install lightllm since we depend on lightllm triton kernels
git clone https://github.com/ModelTC/lightllm
git clone https://github.com/ModelTC/lightllm
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
cd lightllm
pip3 install -e .
@ -84,7 +84,7 @@ cd /path/to/CollossalAI
pip install -e .
# install lightllm
git clone https://github.com/ModelTC/lightllm
git clone https://github.com/ModelTC/lightllm
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
cd lightllm
pip3 install -e .

View File

@ -18,15 +18,15 @@ from .gptq_op import CaiGPTQLinearOp
HAS_GPTQ_CUDA = False
try:
from colossalai.kernel.op_builder.gptq import GPTQBuilder
gptq_cuda = GPTQBuilder().load()
HAS_GPTQ_CUDA = True
except ImportError:
warnings.warn('CUDA gptq is not installed')
warnings.warn("CUDA gptq is not installed")
HAS_GPTQ_CUDA = False
class CaiQuantLinear(nn.Module):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
super().__init__()
if bits not in [2, 4, 8]:
@ -37,23 +37,28 @@ class CaiQuantLinear(nn.Module):
self.maxq = 2**self.bits - 1
self.groupsize = groupsize if groupsize != -1 else infeatures
self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
self.register_buffer("qweight", torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
self.register_buffer(
'qzeros',
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32))
self.register_buffer('scales',
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
"qzeros",
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32),
)
self.register_buffer(
"scales", torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)
)
if row_split:
self.register_buffer(
'g_idx',
torch.tensor([(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)],
dtype=torch.int32))
"g_idx",
torch.tensor(
[(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], dtype=torch.int32
),
)
else:
self.register_buffer('g_idx',
torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))
self.register_buffer(
"g_idx", torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)
)
if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16))
else:
self.bias = None
@ -66,9 +71,11 @@ class CaiQuantLinear(nn.Module):
self.row_split = row_split
def pack(self, linear, scales, zeros, g_idx=None):
g_idx = g_idx.clone() if g_idx is not None else torch.tensor(
[i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)
g_idx = (
g_idx.clone()
if g_idx is not None
else torch.tensor([i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)
)
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
@ -79,7 +86,6 @@ class CaiQuantLinear(nn.Module):
if linear.bias is not None:
self.bias = linear.bias.clone().half()
wn = 8
pbits = 32
ptype = torch.int32
unsign_type = np.uint32
@ -88,9 +94,10 @@ class CaiQuantLinear(nn.Module):
intweight = []
for idx in range(self.infeatures):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:,
None])
torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[
:, None
]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(unsign_type)
@ -109,7 +116,7 @@ class CaiQuantLinear(nn.Module):
raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = qweight.astype(sign_type)
qweight1 = torch.from_numpy(qweight)
qweight1 = qweight1.contiguous() #.to("cuda")
qweight1 = qweight1.contiguous() # .to("cuda")
self.qweight.data.copy_(qweight1)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type)
@ -140,17 +147,20 @@ class CaiQuantLinear(nn.Module):
self.q4_width = self.qweight.shape[1]
if self.g_idx is not None:
if self.row_split and torch.equal(
self.g_idx,
torch.tensor(
[(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)],
dtype=torch.int32,
device=self.g_idx.device)):
self.g_idx,
torch.tensor(
[(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)],
dtype=torch.int32,
device=self.g_idx.device,
),
):
self.g_idx = None
elif torch.equal(
self.g_idx,
torch.tensor([i // self.groupsize for i in range(self.infeatures)],
dtype=torch.int32,
device=self.g_idx.device)):
self.g_idx,
torch.tensor(
[i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device
),
):
self.g_idx = None
if self.g_idx is not None:
@ -165,7 +175,6 @@ class CaiQuantLinear(nn.Module):
outshape = x.shape[:-1] + (self.outfeatures,)
if HAS_GPTQ_CUDA and self.bits == 4:
if self.q4 is None:
self.init_q4()
@ -191,7 +200,6 @@ class CaiQuantLinear(nn.Module):
def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1):
qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1)
qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1)
scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1)
@ -203,24 +211,24 @@ def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1
zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num
for i in range(split_num):
cai_linear.qweight[:, i * cai_split_out_features:(i + 1) *
cai_split_out_features] = qweights[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) *
cai_split_out_features]
cai_linear.qzeros[:, i * zero_split_block:(i + 1) *
zero_split_block] = qzeros[i][:, tp_rank * zero_split_block:(tp_rank + 1) * zero_split_block]
cai_linear.scales[:, i * cai_split_out_features:(i + 1) *
cai_split_out_features] = scales[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) *
cai_split_out_features]
cai_linear.qweight[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = qweights[i][
:, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
]
cai_linear.qzeros[:, i * zero_split_block : (i + 1) * zero_split_block] = qzeros[i][
:, tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block
]
cai_linear.scales[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = scales[i][
:, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
]
if cai_linear.bias is not None:
cai_linear.bias[i * cai_split_out_features:(i + 1) *
cai_split_out_features] = bias[i][tp_rank * cai_split_out_features:(tp_rank + 1) *
cai_split_out_features]
cai_linear.bias[i * cai_split_out_features : (i + 1) * cai_split_out_features] = bias[i][
tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
]
cai_linear.g_idx.copy_(g_idx)
def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0)
qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0)
scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0)
@ -231,47 +239,40 @@ def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
idx_split_features = cai_linear.infeatures // split_num
for i in range(split_num):
cai_linear.qweight[i * cai_split_in_features:(i + 1) *
cai_split_in_features, :] = qweights[i][tp_rank * cai_split_in_features:(tp_rank + 1) *
cai_split_in_features, :]
cai_linear.qzeros[i * zero_split_block:(i + 1) *
zero_split_block, :] = qzeros[i][tp_rank * zero_split_block:(tp_rank + 1) *
zero_split_block, :]
cai_linear.scales[i * zero_split_block:(i + 1) *
zero_split_block, :] = scales[i][tp_rank * zero_split_block:(tp_rank + 1) *
zero_split_block, :]
cai_linear.g_idx[i * idx_split_features:(i + 1) *
idx_split_features] = g_idxs[i][tp_rank * idx_split_features:(tp_rank + 1) *
idx_split_features]
cai_linear.qweight[i * cai_split_in_features : (i + 1) * cai_split_in_features, :] = qweights[i][
tp_rank * cai_split_in_features : (tp_rank + 1) * cai_split_in_features, :
]
cai_linear.qzeros[i * zero_split_block : (i + 1) * zero_split_block, :] = qzeros[i][
tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, :
]
cai_linear.scales[i * zero_split_block : (i + 1) * zero_split_block, :] = scales[i][
tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, :
]
cai_linear.g_idx[i * idx_split_features : (i + 1) * idx_split_features] = g_idxs[i][
tp_rank * idx_split_features : (tp_rank + 1) * idx_split_features
]
if cai_linear.bias is not None:
cai_linear.bias.copy_(gptq_linear.bias)
class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
super().__init__(bits,
groupsize,
infeatures,
outfeatures,
bias,
tp_size=tp_size,
tp_rank=tp_rank,
row_split=row_split)
super().__init__(
bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split
)
self.process_group = None
@staticmethod
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
**kwargs) -> ParallelModule:
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \
f'Expected only one process group, got {len(process_group)}.'
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
@ -282,15 +283,18 @@ class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
if in_features % tp_size != 0:
raise ValueError(
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
linear_1d = RowCaiQuantLinear(module.bits,
module.group_size,
module.in_features // tp_size,
module.out_features,
module.bias is not None,
tp_size=tp_size,
tp_rank=tp_rank,
row_split=True)
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
)
linear_1d = RowCaiQuantLinear(
module.bits,
module.group_size,
module.in_features // tp_size,
module.out_features,
module.bias is not None,
tp_size=tp_size,
tp_rank=tp_rank,
row_split=True,
)
linear_1d.process_group = process_group
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
@ -306,30 +310,23 @@ class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
super().__init__(bits,
groupsize,
infeatures,
outfeatures,
bias,
tp_size=tp_size,
tp_rank=tp_rank,
row_split=row_split)
super().__init__(
bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split
)
self.process_group = None
@staticmethod
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
**kwargs) -> ParallelModule:
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \
f'Expected only one process group, got {len(process_group)}.'
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
@ -340,14 +337,17 @@ class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
if in_features % tp_size != 0:
raise ValueError(
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
linear_1d = ColCaiQuantLinear(module.bits,
module.group_size,
module.in_features,
module.out_features // tp_size,
module.bias is not None,
tp_size=tp_size,
tp_rank=tp_rank)
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
)
linear_1d = ColCaiQuantLinear(
module.bits,
module.group_size,
module.in_features,
module.out_features // tp_size,
module.bias is not None,
tp_size=tp_size,
tp_rank=tp_rank,
)
linear_1d.process_group = process_group
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)

View File

@ -5,6 +5,7 @@ import torch
from .kvcache_manager import MemoryManager
# adapted from: lightllm/server/router/model_infer/infer_batch.py
@dataclass
class BatchInferState:

View File

@ -19,8 +19,11 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
from ._utils import copy_kv_to_mem_cache
try:
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_llama2_context_attention_fwd
from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd as chatglm2_rotary_emb_fwd
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_llama2_context_attention_fwd,
)
HAS_LIGHTLLM_KERNEL = True
except:
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")

View File

@ -4,7 +4,6 @@ import torch
from torch.nn import LayerNorm
import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.modeling.bloom import build_bloom_alibi_tensor_fn
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy
@ -40,33 +39,36 @@ class BloomModelInferPolicy(BloomForCausalLMPolicy):
policy = super().module_policy()
if self.shard_config.inference_gptq:
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=ColCaiQuantLinear,
kwargs={'split_num': 3}),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=RowCaiQuantLinear,
kwargs={'split_num': 1}),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h",
target_module=ColCaiQuantLinear,
kwargs={'split_num': 1}),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h",
target_module=RowCaiQuantLinear,
kwargs={'split_num': 1}),
])
policy[BloomBlock] = ModulePolicyDescription(
attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size
// self.shard_config.tensor_parallel_size,
"self_attention.split_size": self.model.config.hidden_size
// self.shard_config.tensor_parallel_size,
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 3},
),
SubModuleReplacementDescription(
suffix="self_attention.dense", target_module=RowCaiQuantLinear, kwargs={"split_num": 1}
),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h", target_module=ColCaiQuantLinear, kwargs={"split_num": 1}
),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h", target_module=RowCaiQuantLinear, kwargs={"split_num": 1}
),
],
)
# NOTE set inference mode to shard config
self.shard_config._infer()

View File

@ -13,6 +13,7 @@ from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forw
try:
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
@ -21,6 +22,7 @@ except:
def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)

View File

@ -1,254 +1,202 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <cstdint>
#include <cstdio>
#include "util.cuh"
#include "tuning.h"
#include "cuda_buffers.cuh"
#include "q4_matrix.cuh"
#include "q4_matmul.cuh"
#include "column_remap.cuh"
#include "cuda_buffers.cuh"
#include "q4_matmul.cuh"
#include "q4_matrix.cuh"
#include "tuning.h"
#include "util.cuh"
// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a
// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of
// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console.
// Check CUDA return code. We don't want to include Torch headers in the .cu
// files because parsing them adds almost a minute to the compile time on a
// 12900K. Also passing exceptions back to Python is super tricky, so in place
// of exceptions, CUDA functions return with a cudaError_t which we can parse
// and dump to the console.
void check_cuda(cudaError_t ret)
{
switch (ret)
{
case cudaSuccess:
break;
void check_cuda(cudaError_t ret) {
switch (ret) {
case cudaSuccess:
break;
case cudaUnspecified:
printf(" **** Unspecified error\n");
TORCH_CHECK(false, "CUDA error");
break;
case cudaUnspecified:
printf(" **** Unspecified error\n");
TORCH_CHECK(false, "CUDA error");
break;
default:
printf(" **** CUDA error\n"); \
printf(" **** %s\n", cudaGetErrorString(ret)); \
TORCH_CHECK(false, "CUDA error"); \
break;
}
default:
printf(" **** CUDA error\n");
printf(" **** %s\n", cudaGetErrorString(ret));
TORCH_CHECK(false, "CUDA error");
break;
}
}
// Some decluttering macros
#define STRINGIFY_(__x) #__x
#define STRINGIFY(__x) STRINGIFY_(__x)
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod))
#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
#define TORCH_CHECK_DTYPE(__x, __dtype) \
TORCH_CHECK((__x).dtype() == torch::__dtype, \
#__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) \
TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, \
#__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) \
TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, \
#__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) \
TORCH_CHECK((__x).device().is_meta() || \
(__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, \
#__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) \
TORCH_CHECK((__x).size(__dim_x) % __mod == 0, \
#__x ".shape[" STRINGIFY( \
__dim_x) "] must be a multiple of " STRINGIFY(__mod))
#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) \
TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
#define TORCH_CHECK_DEVICE_INDEX(__index) \
do { \
TORCH_CHECK(__index >= 0, "no device index"); \
#define TORCH_CHECK_DEVICE_INDEX(__index) \
do { \
TORCH_CHECK(__index >= 0, "no device index"); \
TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \
} while(0)
} while (0)
#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \
do { \
TORCH_CHECK_DTYPE(__w, kInt); \
TORCH_CHECK_DTYPE(__w_scales, kHalf); \
TORCH_CHECK_DTYPE(__w_zeros, kInt); \
TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \
TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \
TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \
TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
} while(0)
do { \
TORCH_CHECK_DTYPE(__w, kInt); \
TORCH_CHECK_DTYPE(__w_scales, kHalf); \
TORCH_CHECK_DTYPE(__w_zeros, kInt); \
TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \
TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \
TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \
TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
} while (0)
int get_groupsize(torch::Tensor w, torch::Tensor w_zeros)
{
int groupsize = w.size(0) * 8 / w_zeros.size(0);
TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]")
return groupsize;
int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) {
int groupsize = w.size(0) * 8 / w_zeros.size(0);
TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8,
"w.shape[-2] must be a multiple of zeros.shape[-2]")
return groupsize;
}
// Tuning parameters
ExLlamaTuning tuningParams;
void set_tuning_params
(
int matmul_recons_thd,
bool matmul_fused_remap,
bool matmul_no_half2
)
{
tuningParams.matmul_recons_thd = matmul_recons_thd;
tuningParams.matmul_fused_remap = matmul_fused_remap;
tuningParams.matmul_no_half2 = matmul_no_half2;
void set_tuning_params(int matmul_recons_thd, bool matmul_fused_remap,
bool matmul_no_half2) {
tuningParams.matmul_recons_thd = matmul_recons_thd;
tuningParams.matmul_fused_remap = matmul_fused_remap;
tuningParams.matmul_no_half2 = matmul_no_half2;
}
// Release all unmanaged objects allocated by the extension
void cleanup()
{
cleanup_buffers_cuda();
g_q4_free_matrices();
void cleanup() {
cleanup_buffers_cuda();
g_q4_free_matrices();
}
// Prepare buffers for forward pass
void prepare_buffers
(
torch::Device device,
torch::Tensor temp_state,
torch::Tensor temp_dq
)
{
int device_index = device.index();
TORCH_CHECK_DEVICE_INDEX(device_index);
const at::cuda::OptionalCUDAGuard device_guard(device);
void prepare_buffers(torch::Device device, torch::Tensor temp_state,
torch::Tensor temp_dq) {
int device_index = device.index();
TORCH_CHECK_DEVICE_INDEX(device_index);
const at::cuda::OptionalCUDAGuard device_guard(device);
prepare_buffers_cuda
(
device_index,
// buffer size used for sanity checks
temp_state.numel(),
(half*) temp_state.data_ptr(),
(half*) temp_dq.data_ptr()
);
prepare_buffers_cuda(device_index,
// buffer size used for sanity checks
temp_state.numel(), (half*)temp_state.data_ptr(),
(half*)temp_dq.data_ptr());
}
// Create Q4Matrix, return handle
uintptr_t make_q4
(
torch::Tensor qweight,
torch::Tensor qzeros,
torch::Tensor scales,
torch::Tensor g_idx,
int device
)
{
TORCH_CHECK_DTYPE(qweight, kInt);
TORCH_CHECK_DTYPE(qzeros, kInt);
TORCH_CHECK_DTYPE(scales, kHalf);
TORCH_CHECK_DTYPE_OPT(g_idx, kInt);
TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);
TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);
TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);
uintptr_t make_q4(torch::Tensor qweight, torch::Tensor qzeros,
torch::Tensor scales, torch::Tensor g_idx, int device) {
TORCH_CHECK_DTYPE(qweight, kInt);
TORCH_CHECK_DTYPE(qzeros, kInt);
TORCH_CHECK_DTYPE(scales, kHalf);
TORCH_CHECK_DTYPE_OPT(g_idx, kInt);
TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);
TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);
TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);
int width = qweight.size(1);
int height = qweight.size(0) * 8;
int groups = qzeros.size(0);
int width = qweight.size(1);
int height = qweight.size(0) * 8;
int groups = qzeros.size(0);
Q4Matrix* m = new Q4Matrix
(
height,
width,
groups,
Q4Matrix* m = new Q4Matrix(
height, width, groups,
(uint32_t*) qweight.data_ptr(),
(uint32_t*) qzeros.data_ptr(),
(half*) scales.data_ptr(),
g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(),
(uint32_t*)qweight.data_ptr(), (uint32_t*)qzeros.data_ptr(),
(half*)scales.data_ptr(),
g_idx.device().is_meta() ? NULL : (uint32_t*)g_idx.data_ptr(),
device
);
device);
g_q4_keep_matrix(m);
return reinterpret_cast<uintptr_t> (m);
g_q4_keep_matrix(m);
return reinterpret_cast<uintptr_t>(m);
}
// Matmul half @ quant -> half
void q4_matmul
(
torch::Tensor x,
uintptr_t w,
torch::Tensor out
)
{
Q4Matrix* wm = reinterpret_cast<Q4Matrix*> (w);
void q4_matmul(torch::Tensor x, uintptr_t w, torch::Tensor out) {
Q4Matrix* wm = reinterpret_cast<Q4Matrix*>(w);
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(out, kHalf);
TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(out, kHalf);
TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
int x_height = x.size(0);
int x_height = x.size(0);
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
{
q4_matmul_cuda
(
&tuningParams,
(half*) x.data_ptr(),
x_height,
wm,
(half*) out.data_ptr()
);
}
else
{
q4_matmul_recons_cuda
(
&tuningParams,
(half*) x.data_ptr(),
x_height,
wm,
(half*) out.data_ptr(),
at::cuda::getCurrentCUDABlasHandle()
);
}
if (tuningParams.matmul_recons_thd == 0 ||
x_height < tuningParams.matmul_recons_thd) {
q4_matmul_cuda(&tuningParams, (half*)x.data_ptr(), x_height, wm,
(half*)out.data_ptr());
} else {
q4_matmul_recons_cuda(&tuningParams, (half*)x.data_ptr(), x_height, wm,
(half*)out.data_ptr(),
at::cuda::getCurrentCUDABlasHandle());
}
}
// Remap columns in half tensor
void column_remap
(
torch::Tensor x,
torch::Tensor x_new,
torch::Tensor x_map
)
{
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(x_new, kHalf);
TORCH_CHECK_DTYPE(x_map, kInt);
TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);
void column_remap(torch::Tensor x, torch::Tensor x_new, torch::Tensor x_map) {
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(x_new, kHalf);
TORCH_CHECK_DTYPE(x_map, kInt);
TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);
int height = x.size(0);
int width = x.size(1);
int height = x.size(0);
int width = x.size(1);
TORCH_CHECK_BUFFER_SIZE(x_new, height * width);
TORCH_CHECK_BUFFER_SIZE(x_new, height * width);
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
column_remap_cuda
(
(half*) x.data_ptr(),
(half*) x_new.data_ptr(),
height,
width,
(uint32_t*) x_map.data_ptr()
);
column_remap_cuda((half*)x.data_ptr(), (half*)x_new.data_ptr(), height, width,
(uint32_t*)x_map.data_ptr());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
m.def("cleanup", &cleanup, "cleanup");
m.def("make_q4", &make_q4, "make_q4");
m.def("q4_matmul", &q4_matmul, "q4_matmul");
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
m.def("cleanup", &cleanup, "cleanup");
m.def("make_q4", &make_q4, "make_q4");
m.def("q4_matmul", &q4_matmul, "q4_matmul");
}

View File

@ -184,7 +184,7 @@ __global__ void reconstruct_kernel
int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x;
int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8;
if (column >= width) return;
// Views
MatrixView_q4_column w_(w, height, width);

View File

@ -50,4 +50,4 @@ private:
void g_q4_keep_matrix(Q4Matrix* m);
void g_q4_free_matrices();
#endif
#endif

View File

@ -238,5 +238,5 @@ if HAS_TRITON:
num_warps=num_warps,
num_stages=1,
)
return
return

View File

@ -10,7 +10,6 @@ except ImportError:
print("please install triton from https://github.com/openai/triton")
if HAS_TRITON:
# adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py
@triton.jit
def _fwd_copy_kv_cache_dest(

View File

@ -13,6 +13,9 @@ except ImportError:
print("please install triton from https://github.com/openai/triton")
try:
from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import (
token_att_fwd as lightllm_bloom_token_att_fwd,
)
from lightllm.models.llama2.triton_kernel.token_attention_nopad_att1 import (
token_att_fwd as lightllm_llama2_token_att_fwd,
)
@ -22,11 +25,15 @@ try:
from lightllm.models.llama2.triton_kernel.token_attention_nopad_softmax import (
token_softmax_fwd as lightllm_llama2_token_softmax_fwd,
)
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fw2
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd
from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import (
token_att_fwd as lightllm_llama_token_att_fwd,
)
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import (
token_att_fwd2 as lightllm_llama_token_att_fw2,
)
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import (
token_softmax_fwd as lightllm_llama_token_softmax_fwd,
)
HAS_TRITON_TOKEN_ATTENTION = True
except ImportError:

View File

@ -44,8 +44,8 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
return unpickle
def check_for_nccl_backend(group):
def check_for_nccl_backend(group):
pg = group or c10d._get_default_group()
# Gate PG wrapper check on Gloo availability.
if c10d._GLOO_AVAILABLE:
@ -54,10 +54,8 @@ def check_for_nccl_backend(group):
while isinstance(pg, c10d._ProcessGroupWrapper):
pg = pg.wrapped_pg
return (
c10d.is_nccl_available() and
pg.name() == c10d.Backend.NCCL
)
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
def _broadcast_object_list(
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None

View File

@ -0,0 +1,7 @@
from .bnb import quantize_model
from .bnb_config import BnbQuantizationConfig
__all__ = [
"BnbQuantizationConfig",
"quantize_model",
]

View File

@ -0,0 +1,321 @@
# adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py
import logging
import torch
import torch.nn as nn
from .bnb_config import BnbQuantizationConfig
try:
import bitsandbytes as bnb
IS_4BIT_BNB_AVAILABLE = bnb.__version__ >= "0.39.0"
IS_8BIT_BNB_AVAILABLE = bnb.__version__ >= "0.37.2"
except ImportError:
pass
logger = logging.getLogger(__name__)
def quantize_model(
model: torch.nn.Module,
bnb_quantization_config: BnbQuantizationConfig,
):
"""
This function will quantize the input loaded model with the associated config passed in `bnb_quantization_config`.
We will quantize the model and put the model on the GPU.
Args:
model (`torch.nn.Module`):
Input model. The model already loaded
bnb_quantization_config (`BnbQuantizationConfig`):
The bitsandbytes quantization parameters
Returns:
`torch.nn.Module`: The quantized model
"""
load_in_4bit = bnb_quantization_config.load_in_4bit
load_in_8bit = bnb_quantization_config.load_in_8bit
if load_in_8bit and not IS_8BIT_BNB_AVAILABLE:
raise ImportError(
"You have a version of `bitsandbytes` that is not compatible with 8bit quantization,"
" make sure you have the latest version of `bitsandbytes` installed."
)
if load_in_4bit and not IS_4BIT_BNB_AVAILABLE:
raise ValueError(
"You have a version of `bitsandbytes` that is not compatible with 4bit quantization,"
"make sure you have the latest version of `bitsandbytes` installed."
)
# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
if bnb_quantization_config.skip_modules is None:
bnb_quantization_config.skip_modules = get_keys_to_not_convert(model)
modules_to_not_convert = bnb_quantization_config.skip_modules
# We add the modules we want to keep in full precision
if bnb_quantization_config.keep_in_fp32_modules is None:
bnb_quantization_config.keep_in_fp32_modules = []
keep_in_fp32_modules = bnb_quantization_config.keep_in_fp32_modules
# compatibility with peft
model.is_loaded_in_4bit = load_in_4bit
model.is_loaded_in_8bit = load_in_8bit
# assert model_device is cuda
model_device = next(model.parameters()).device
model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert)
# convert param to the right dtype
dtype = bnb_quantization_config.torch_dtype
for name, param in model.state_dict().items():
if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules):
param.to(torch.float32)
if param.dtype != torch.float32:
name = name.replace(".weight", "").replace(".bias", "")
param = getattr(model, name, None)
if param is not None:
param.to(torch.float32)
elif torch.is_floating_point(param):
param.to(dtype)
if model_device.type == "cuda":
# move everything to cpu in the first place because we can't do quantization if the weights are already on cuda
model.cuda(torch.cuda.current_device())
torch.cuda.empty_cache()
elif torch.cuda.is_available():
model.to(torch.cuda.current_device())
logger.info(
f"The model device type is {model_device.type}. However, cuda is needed for quantization."
"We move the model to cuda."
)
else:
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
return model
def replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=None, current_key_name=None):
"""
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules or by `bnb.nn.Linear4bit`
modules from the `bitsandbytes`library. The function will be run recursively and replace `torch.nn.Linear` modules.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
modules_to_not_convert (`List[str]`):
Names of the modules to not quantize convert. In practice we keep the `lm_head` in full precision for
numerical stability reasons.
current_key_name (`List[str]`, *optional*):
An array to track the current key of the recursion. This is used to check whether the current key (part of
it) is not in the list of modules to not convert.
"""
if modules_to_not_convert is None:
modules_to_not_convert = []
model, has_been_replaced = _replace_with_bnb_layers(
model, bnb_quantization_config, modules_to_not_convert, current_key_name
)
if not has_been_replaced:
logger.warning(
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
" this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers."
" Please double check your model architecture, or submit an issue on github if you think this is"
" a bug."
)
return model
def _replace_with_bnb_layers(
model,
bnb_quantization_config,
modules_to_not_convert=None,
current_key_name=None,
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
# bitsandbytes will initialize CUDA on import, so it needs to be imported lazily
has_been_replaced = False
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
current_key_name_str = ".".join(current_key_name)
proceed = True
for key in modules_to_not_convert:
if (
(key in current_key_name_str) and (key + "." in current_key_name_str)
) or key == current_key_name_str:
proceed = False
break
if proceed:
# Load bnb module with empty weight and replace ``nn.Linear` module
if bnb_quantization_config.load_in_8bit:
bnb_module = bnb.nn.Linear8bitLt(
module.in_features,
module.out_features,
module.bias is not None,
has_fp16_weights=False,
threshold=bnb_quantization_config.llm_int8_threshold,
)
elif bnb_quantization_config.load_in_4bit:
bnb_module = bnb.nn.Linear4bit(
module.in_features,
module.out_features,
module.bias is not None,
bnb_quantization_config.bnb_4bit_compute_dtype,
compress_statistics=bnb_quantization_config.bnb_4bit_use_double_quant,
quant_type=bnb_quantization_config.bnb_4bit_quant_type,
)
else:
raise ValueError("load_in_8bit and load_in_4bit can't be both False")
bnb_module.weight.data = module.weight.data
bnb_module.weight.skip_zero_check = True
if module.bias is not None:
bnb_module.bias.data = module.bias.data
bnb_module.bias.skip_zero_check = True
bnb_module.requires_grad_(False)
setattr(model, name, bnb_module)
has_been_replaced = True
if len(list(module.children())) > 0:
_, _has_been_replaced = _replace_with_bnb_layers(
module, bnb_quantization_config, modules_to_not_convert, current_key_name
)
has_been_replaced = has_been_replaced | _has_been_replaced
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced
def get_keys_to_not_convert(model):
r"""
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
int8.
Parameters:
model (`torch.nn.Module`):
Input model
"""
# Create a copy of the model
# with init_empty_weights():
# tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
tied_model = model
tied_params = find_tied_parameters(tied_model)
# For compatibility with Accelerate < 0.18
if isinstance(tied_params, dict):
tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys())
else:
tied_keys = sum(tied_params, [])
has_tied_params = len(tied_keys) > 0
# Check if it is a base model
is_base_model = False
if hasattr(model, "base_model_prefix"):
is_base_model = not hasattr(model, model.base_model_prefix)
# Ignore this for base models (BertModel, GPT2Model, etc.)
if (not has_tied_params) and is_base_model:
return []
# otherwise they have an attached head
list_modules = list(model.named_children())
list_last_module = [list_modules[-1][0]]
# add last module together with tied weights
intersection = set(list_last_module) - set(tied_keys)
list_untouched = list(set(tied_keys)) + list(intersection)
# remove ".weight" from the keys
names_to_remove = [".weight", ".bias"]
filtered_module_names = []
for name in list_untouched:
for name_to_remove in names_to_remove:
if name_to_remove in name:
name = name.replace(name_to_remove, "")
filtered_module_names.append(name)
return filtered_module_names
def find_tied_parameters(model: nn.Module, **kwargs):
"""
Find the tied parameters in a given model.
<Tip warning={true}>
The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore
them.
</Tip>
Args:
model (`torch.nn.Module`): The model to inspect.
Returns:
List[List[str]]: A list of lists of parameter names being all tied together.
Example:
```py
>>> from collections import OrderedDict
>>> import torch.nn as nn
>>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))]))
>>> model.linear2.weight = model.linear1.weight
>>> find_tied_parameters(model)
[['linear1.weight', 'linear2.weight']]
```
"""
# Initialize result and named_parameters before recursing.
named_parameters = kwargs.get("named_parameters", None)
prefix = kwargs.get("prefix", "")
result = kwargs.get("result", {})
if named_parameters is None:
named_parameters = {n: p for n, p in model.named_parameters()}
else:
# A tied parameter will not be in the full `named_parameters` seen above but will be in the `named_parameters`
# of the submodule it belongs to. So while recursing we track the names that are not in the initial
# `named_parameters`.
for name, parameter in model.named_parameters():
full_name = name if prefix == "" else f"{prefix}.{name}"
if full_name not in named_parameters:
# When we find one, it has to be one of the existing parameters.
for new_name, new_param in named_parameters.items():
if new_param is parameter:
if new_name not in result:
result[new_name] = []
result[new_name].append(full_name)
# Once we have treated direct parameters, we move to the child modules.
for name, child in model.named_children():
child_name = name if prefix == "" else f"{prefix}.{name}"
find_tied_parameters(child, named_parameters=named_parameters, prefix=child_name, result=result)
return FindTiedParametersResult([sorted([weight] + list(set(tied))) for weight, tied in result.items()])
class FindTiedParametersResult(list):
"""
This is a subclass of a list to handle backward compatibility for Transformers. Do not rely on the fact this is not
a list or on the `values` method as in the future this will be removed.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def values(self):
return sum([x[1:] for x in self], [])

View File

@ -0,0 +1,113 @@
# adapted from Hugging Face accelerate/utils/dataclasses.py
import warnings
from dataclasses import dataclass, field
from typing import List
import torch
@dataclass
class BnbQuantizationConfig:
"""
A plugin to enable BitsAndBytes 4bit and 8bit quantization
"""
load_in_8bit: bool = field(default=False, metadata={"help": "enable 8bit quantization."})
llm_int8_threshold: float = field(
default=6.0, metadata={"help": "value of the outliner threshold. only relevant when load_in_8bit=True"}
)
load_in_4bit: bool = field(default=False, metadata={"help": "enable 4bit quantization."})
bnb_4bit_quant_type: str = field(
default="fp4",
metadata={
"help": "set the quantization data type in the `bnb.nn.Linear4Bit` layers. Options are {'fp4','np4'}."
},
)
bnb_4bit_use_double_quant: bool = field(
default=False,
metadata={
"help": "enable nested quantization where the quantization constants from the first quantization are quantized again."
},
)
bnb_4bit_compute_dtype: bool = field(
default="fp16",
metadata={
"help": "This sets the computational type which might be different than the input time. For example, inputs might be "
"fp32, but computation can be set to bf16 for speedups. Options are {'fp32','fp16','bf16'}."
},
)
torch_dtype: torch.dtype = field(
default=None,
metadata={
"help": "this sets the dtype of the remaining non quantized layers. `bitsandbytes` library suggests to set the value"
"to `torch.float16` for 8 bit model and use the same dtype as the compute dtype for 4 bit model "
},
)
skip_modules: List[str] = field(
default=None,
metadata={
"help": "an explicit list of the modules that we don't quantize. The dtype of these modules will be `torch_dtype`."
},
)
keep_in_fp32_modules: List[str] = field(
default=None,
metadata={"help": "an explicit list of the modules that we don't quantize. We keep them in `torch.float32`."},
)
def __post_init__(self):
if isinstance(self.bnb_4bit_compute_dtype, str):
if self.bnb_4bit_compute_dtype == "fp32":
self.bnb_4bit_compute_dtype = torch.float32
elif self.bnb_4bit_compute_dtype == "fp16":
self.bnb_4bit_compute_dtype = torch.float16
elif self.bnb_4bit_compute_dtype == "bf16":
self.bnb_4bit_compute_dtype = torch.bfloat16
else:
raise ValueError(
f"bnb_4bit_compute_dtype must be in ['fp32','fp16','bf16'] but found {self.bnb_4bit_compute_dtype}"
)
elif not isinstance(self.bnb_4bit_compute_dtype, torch.dtype):
raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype")
if self.skip_modules is not None and not isinstance(self.skip_modules, list):
raise ValueError("skip_modules must be a list of strings")
if self.keep_in_fp32_modules is not None and not isinstance(self.keep_in_fp32_modules, list):
raise ValueError("keep_in_fp_32_modules must be a list of strings")
if self.load_in_4bit:
self.target_dtype = "int4"
if self.load_in_8bit:
self.target_dtype = torch.int8
if self.load_in_4bit and self.llm_int8_threshold != 6.0:
warnings.warn("llm_int8_threshold can only be used for model loaded in 8bit")
if isinstance(self.torch_dtype, str):
if self.torch_dtype == "fp32":
self.torch_dtype = torch.float32
elif self.torch_dtype == "fp16":
self.torch_dtype = torch.float16
elif self.torch_dtype == "bf16":
self.torch_dtype = torch.bfloat16
else:
raise ValueError(f"torch_dtype must be in ['fp32','fp16','bf16'] but found {self.torch_dtype}")
if self.load_in_8bit and self.torch_dtype is None:
self.torch_dtype = torch.float16
if self.load_in_4bit and self.torch_dtype is None:
self.torch_dtype = self.bnb_4bit_compute_dtype
if not isinstance(self.torch_dtype, torch.dtype):
raise ValueError("torch_dtype must be a torch.dtype")

View File

@ -190,6 +190,7 @@ def calculate_global_norm_from_list(norm_list):
total_norm += norm**2.0
return math.sqrt(total_norm)
def sync_tensor(flat_tensor, tensor_list):
"""
Synchronize the flattened tensor and unflattened tensor list. When

View File

@ -187,9 +187,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for param_group in self.optim.param_groups:
group_params = param_group["params"]
for param in group_params:
assert (
param.dtype == self._dtype
), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
if not hasattr(param, "skip_zero_check") or param.skip_zero_check is False:
assert (
param.dtype == self._dtype
), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
def _create_master_param_current_rank(self, param_list):
# split each param evenly by world size

View File

@ -149,7 +149,7 @@ model, optimizer, _criterion, _, lr_scheduler = booster.boost(
## Training GPT-2 using hybrid parallelism
In the previous tutorial, We've explained how to inject various parallelism features into the model and its training components using the Booster and `HybridParallelPlugin`. Now we can start model training.
In the previous tutorial, We've explained how to inject various parallelism features into the model and its training components using the Booster and `HybridParallelPlugin`. Now we can start model training.
Define a training function. When pipeline parallelism is used, you need to call `booster.execute_pipeline` to schedule the stages of model training.
```python
def train_epoch(
@ -204,4 +204,4 @@ Training the gpt-2 model
for epoch in range(NUM_EPOCHS):
train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
```
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 train_gpt_using_hybrid_parallelism.py -->
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 train_gpt_using_hybrid_parallelism.py -->

View File

@ -201,4 +201,4 @@ def train_epoch(
for epoch in range(NUM_EPOCHS):
train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
```
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 train_gpt_using_hybrid_parallelism.py -->
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 train_gpt_using_hybrid_parallelism.py -->

View File

@ -220,7 +220,7 @@ model, optimizer, _criterion, train_dataloader, lr_scheduler = booster.boost(
)
```
## 使用混合并行训练 ViT
最后就可以使用混合并行策略来训练模型了。我们先定义一个训练函数,描述训练过程。需要注意的是,如果使用了管道并行策略,需要调用`booster.execute_pipeline`来执行模型的训练,它会调用`scheduler`管理模型的前后向操作。
最后就可以使用混合并行策略来训练模型了。我们先定义一个训练函数,描述训练过程。需要注意的是,如果使用了管道并行策略,需要调用`booster.execute_pipeline`来执行模型的训练,它会调用`scheduler`管理模型的前后向操作。
```python
def run_forward_backward(
model: nn.Module,

View File

@ -1,12 +1,10 @@
import argparse
import logging
import os
import time
import torch
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from auto_gptq.nn_modules.qlinear import GeneralQuantLinear
from transformers import AutoTokenizer, BloomForCausalLM, BloomTokenizerFast, LlamaForCausalLM, LlamaTokenizer
from auto_gptq import AutoGPTQForCausalLM
from transformers import BloomTokenizerFast
import colossalai
from colossalai.inference.tensor_parallel.engine import TPInferEngine
@ -14,7 +12,7 @@ from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
def print_perf_stats(latency_set, config, bs, warmup=3):
@ -28,7 +26,7 @@ def print_perf_stats(latency_set, config, bs, warmup=3):
avg = sum(latency_set) / count
num_layers = getattr(config, "num_layers", config.num_hidden_layers)
num_parameters = num_layers * config.hidden_size * config.hidden_size * 12
num_bytes = 2 # float16
num_bytes = 2 # float16
print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000))
print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9))
@ -37,7 +35,6 @@ def print_perf_stats(latency_set, config, bs, warmup=3):
def bench_bloom(args):
pretrained_model_dir = args.path
quantized_model_dir = args.quantized_path
max_batch_size = args.batch_size
@ -48,9 +45,9 @@ def bench_bloom(args):
tokenizer.pad_token = tokenizer.eos_token
# load quantized model to the first GPU
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir,
device=torch.cuda.current_device(),
inject_fused_attention=False)
model = AutoGPTQForCausalLM.from_quantized(
quantized_model_dir, device=torch.cuda.current_device(), inject_fused_attention=False
)
model = model.half()
@ -60,22 +57,22 @@ def bench_bloom(args):
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
input_tokens = {
"input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'),
"attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda')
"input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"),
"attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"),
}
# init TPInferEngine and shard the original model
# To benchmark torch original, comment out the line of optimizing model
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False,
inference_only=True,
inference_gptq=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
# prepare data for generation
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
input_tokens = {
"input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)),
"attention_mask": torch.ones((max_batch_size, max_input_len))
"attention_mask": torch.ones((max_batch_size, max_input_len)),
}
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
@ -99,7 +96,7 @@ def bench_bloom(args):
def check_bloom(rank, world_size, port, args):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
bench_bloom(args)
@ -111,12 +108,12 @@ def test_bloom(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--path', type=str, help='Model path', required=True)
parser.add_argument('-q', '--quantized_path', type=str, help='Model path', required=True)
parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size')
parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size')
parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length')
parser.add_argument('--output_len', type=int, default=128, help='Maximum output length')
parser.add_argument("-p", "--path", type=str, help="Model path", required=True)
parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True)
parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size")
parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length")
parser.add_argument("--output_len", type=int, default=128, help="Maximum output length")
args = parser.parse_args()

View File

@ -14,4 +14,5 @@ einops
sentencepiece
google
protobuf
peft>=0.7.1
peft>=0.7.1
bitsandbytes>=0.39.0

View File

@ -1,4 +1,4 @@
from typing import Callable, Iterator, List, Tuple, Union, Dict
from typing import Callable, Dict, Iterator, List, Tuple, Union
import torch
import torch.distributed as dist

View File

@ -51,7 +51,6 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config=None)
# raise e
@parameterize("stage", [2])
def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
"""check low level zero plugin over model zoo
@ -118,6 +117,7 @@ def check_low_level_zero_lora(stage, model_name, early_stop: bool = True):
print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
def run_dist(rank, world_size, port, early_stop: bool = True):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")

View File

@ -1,10 +1,11 @@
from copy import deepcopy
from typing import Optional
import torch
import torch.distributed as dist
from peft import LoraConfig
from torchvision.models import resnet18
from utils import shared_tempdir
from typing import Optional
from peft import LoraConfig
from copy import deepcopy
import colossalai
from colossalai.booster import Booster
@ -131,12 +132,15 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo
# return repr(e)
raise e
@clear_cache_before_run()
@parameterize("stage", [2])
@parameterize("shard", [True, False])
@parameterize("offload", [False, True])
@parameterize("model_name", ["transformers_llama"])
def check_low_level_zero_lora_checkpointIO(stage: int, shard: bool, offload: bool, model_name: str, early_stop: bool = True):
def check_low_level_zero_lora_checkpointIO(
stage: int, shard: bool, offload: bool, model_name: str, early_stop: bool = True
):
passed_models = []
failed_info = {} # (model_name, error) pair
@ -166,6 +170,7 @@ def check_low_level_zero_lora_checkpointIO(stage: int, shard: bool, offload: boo
print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
def run_dist(rank, world_size, port):
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost")
check_low_level_zero_checkpointIO()

View File

@ -1,16 +1,8 @@
import math
import time
import numpy as np
import pytest
import torch
import torch.nn as nn
import transformers
from packaging import version
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
@ -22,6 +14,7 @@ try:
from exllama_kernels import prepare_buffers, set_tuning_params
from colossalai.inference.quant.gptq import CaiQuantLinear
HAS_AUTO_GPTQ = True
except:
HAS_AUTO_GPTQ = False
@ -32,13 +25,14 @@ import warnings
HAS_GPTQ_CUDA = False
try:
from colossalai.kernel.op_builder.gptq import GPTQBuilder
gptq_cuda = GPTQBuilder().load()
HAS_GPTQ_CUDA = True
except ImportError:
warnings.warn('CUDA gptq is not installed')
warnings.warn("CUDA gptq is not installed")
HAS_GPTQ_CUDA = False
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
max_inner_outer_dim = 1
max_input_len = 1
@ -64,9 +58,9 @@ def init_buffer(cai_linear, use_act_order=False):
max_input_len = 4096
# The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
gptq_temp_state_buffer = torch.zeros((max_input_len, max_inner_outer_dim),
dtype=torch.float16,
device=torch.cuda.current_device())
gptq_temp_state_buffer = torch.zeros(
(max_input_len, max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device()
)
gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device())
gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer)
@ -77,10 +71,11 @@ def init_buffer(cai_linear, use_act_order=False):
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ,
reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq")
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ,
reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq",
)
def test_gptq_linear():
infeature = 1024
outfeature = 1024
group_size = 128
@ -120,7 +115,7 @@ def test_gptq_linear():
max_input_len = 2048
buffers = {
"temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device),
"temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device)
"temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device),
}
prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"])
@ -146,5 +141,4 @@ def test_gptq_linear():
if __name__ == "__main__":
test_gptq_linear()

View File

@ -4,6 +4,7 @@ from packaging import version
try:
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
HAS_TRITON = True
except ImportError:
HAS_TRITON = False

View File

@ -0,0 +1,106 @@
import copy
import os
from itertools import product
import torch
from peft import LoraConfig
from torch import distributed as dist
from torch.optim import AdamW
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_checkpoint_io.utils import shared_tempdir
@clear_cache_before_run()
def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
model = model_fn()
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()]
test_configs = [
{
"lora_config": lora_config,
"quantize": False,
},
{
"lora_config": lora_config,
"quantize": True,
},
]
for plugin, test_config in product(test_plugins, test_configs):
# checkpoint loaded model
model_save = model_fn()
model_load = copy.deepcopy(model_save)
optimizer = AdamW(model.parameters(), lr=0.001)
criterion = loss_fn
booster = Booster(plugin=plugin)
model_save = booster.enable_lora(model_save, **test_config)
model_save, optimizer, criterion, _, _ = booster.boost(model_save, optimizer, criterion)
with shared_tempdir() as tempdir:
lora_ckpt_path = os.path.join(tempdir, "ckpt")
booster.save_lora_as_pretrained(model_save, lora_ckpt_path)
dist.barrier()
# The Lora checkpoint should be small in size
checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024)
assert checkpoint_size_mb < 1
model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path, **test_config)
model_load, _, _, _, _ = booster.boost(model_load)
check_state_dict_equal(model_save.state_dict(), model_load.state_dict())
# test fwd bwd correctness
test_model = model_load
model_copy = copy.deepcopy(model_load)
data = data_gen_fn()
data = {
k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
}
output = test_model(**data)
output = output_transform_fn(output)
loss = criterion(output)
booster.backward(loss, optimizer)
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
for (n1, p1), (n2, p2) in zip(test_model.named_parameters(), model_copy.named_parameters()):
if "lora_" in n1:
# lora modules require gradients, thus updated
assert p1.requires_grad
assert not torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3)
else:
if not p1.requires_grad:
torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3)
def run_lora_test():
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
task_type = None
if name == "transformers_llama_for_casual_lm":
task_type = "CAUSAL_LM"
if name == "transformers_llama_for_sequence_classification":
task_type = "SEQ_CLS"
check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_lora_test()
@rerun_if_address_is_in_use()
def test_torch_ddp_lora():
spawn(run_dist, 2)

View File

@ -1,108 +0,0 @@
import copy
import os
import torch
from peft import LoraConfig
from torch import distributed as dist
from torch.optim import AdamW
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.testing import (
assert_equal,
assert_not_equal,
check_state_dict_equal,
clear_cache_before_run,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_checkpoint_io.utils import shared_tempdir
@clear_cache_before_run()
def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
model = model_fn()
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = booster.enable_lora(model, lora_config=lora_config)
model_copy = copy.deepcopy(model)
optimizer = AdamW(model.parameters(), lr=0.001)
criterion = loss_fn
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
data = data_gen_fn()
data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()}
output = model(**data)
output = output_transform_fn(output)
loss = criterion(output)
booster.backward(loss, optimizer)
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
for (n1, p1), (n2, p2) in zip(model.named_parameters(), model_copy.named_parameters()):
if "lora_" in n1:
# lora modules require gradients, thus updated
assert p1.requires_grad
assert_not_equal(p1.to(p2.device), p2)
else:
if not p1.requires_grad:
assert_equal(p1.to(p2.device), p2)
@clear_cache_before_run()
def check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
plugin = TorchDDPPlugin()
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
model_save = model_fn()
model_load = copy.deepcopy(model_save)
booster = Booster(plugin=plugin)
model_save = booster.enable_lora(model_save, lora_config=lora_config)
model_save, _, _, _, _ = booster.boost(model_save)
with shared_tempdir() as tempdir:
lora_ckpt_path = os.path.join(tempdir, "ckpt")
booster.save_lora_as_pretrained(model_save, lora_ckpt_path)
dist.barrier()
# The Lora checkpoint should be small in size
checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024)
assert checkpoint_size_mb < 1
model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path)
model_load, _, _, _, _ = booster.boost(model_load)
check_state_dict_equal(model_save.state_dict(), model_load.state_dict())
def run_lora_test():
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
task_type = None
if name == "transformers_llama_for_casual_lm":
task_type = "CAUSAL_LM"
if name == "transformers_llama_for_sequence_classification":
task_type = "SEQ_CLS"
check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_lora_test()
@rerun_if_address_is_in_use()
def test_torch_ddp_lora():
spawn(run_dist, 2)