diff --git a/LICENSE b/LICENSE index b3eb43520..e32e96773 100644 --- a/LICENSE +++ b/LICENSE @@ -527,3 +527,20 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + + ---------------- LICENSE FOR Hugging Face accelerate ---------------- + + Copyright 2021 The HuggingFace Team + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py index 0d0e2a7d3..47ad29f46 100644 --- a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py +++ b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py @@ -76,9 +76,11 @@ def main(args): if args.strategy == "ddp": strategy = DDPStrategy() elif args.strategy == "colossalai_gemini": - strategy = GeminiStrategy(placement_policy="static",initial_scale=2**5) + strategy = GeminiStrategy(placement_policy="static", initial_scale=2**5) elif args.strategy == "colossalai_gemini_cpu": - strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5) + strategy = GeminiStrategy( + placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5 + ) elif args.strategy == "colossalai_zero2": strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") elif args.strategy == "colossalai_zero2_cpu": diff --git a/applications/Chat/coati/models/base/actor.py b/applications/Chat/coati/models/base/actor.py index 8b2b81ed0..0634631df 100644 --- a/applications/Chat/coati/models/base/actor.py +++ b/applications/Chat/coati/models/base/actor.py @@ -30,4 +30,3 @@ class Actor(LoRAModule): """Returns model output.""" output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs) return output - diff --git a/applications/Chat/coati/ray/utils.py b/applications/Chat/coati/ray/utils.py index b88140c0e..4882f00b7 100644 --- a/applications/Chat/coati/ray/utils.py +++ b/applications/Chat/coati/ray/utils.py @@ -75,7 +75,9 @@ def get_strategy_from_args(strategy: str): elif strategy == "colossalai_zero2": strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda") elif strategy == "colossalai_gemini_cpu": - strategy_ = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5) + strategy_ = GeminiStrategy( + placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5 + ) elif strategy == "colossalai_zero2_cpu": strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu") else: diff --git a/applications/Chat/coati/trainer/strategies/ddp.py b/applications/Chat/coati/trainer/strategies/ddp.py index f2a44aeb0..3d3d44ca5 100644 --- a/applications/Chat/coati/trainer/strategies/ddp.py +++ b/applications/Chat/coati/trainer/strategies/ddp.py @@ -101,16 +101,17 @@ class DDPStrategy(Strategy): model_path = os.path.join(path, "pytorch_model.bin") self.save_model(model, model_path, shard=shard) + def _replace_keys(model_path: str, replace_fn: Callable): state_dict = torch.load(model_path, map_location="cpu") state_dict = {replace_fn(k): v for k, v in state_dict.items()} torch.save(state_dict, model_path) + # FIXME: save_model would add "model." prefix to keys of pytorch_model.bin # HACK: rename keys of pytorch_model.bin if dist.get_rank() == 0: _replace_keys(model_path, lambda k: k.replace("model.", "", 1)) - def get_model_state_dict_shard(self, model: nn.Module, **config): # TODO: implement sharding on naive strategy model = self.unwrap_model(model) diff --git a/applications/Chat/examples/community/peft/train_peft_prompts.py b/applications/Chat/examples/community/peft/train_peft_prompts.py index 99a024f14..0b174297a 100644 --- a/applications/Chat/examples/community/peft/train_peft_prompts.py +++ b/applications/Chat/examples/community/peft/train_peft_prompts.py @@ -24,7 +24,9 @@ def main(args): if args.strategy == "ddp": strategy = DDPStrategy() elif args.strategy == "colossalai_gemini": - strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5) + strategy = GeminiStrategy( + placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5 + ) elif args.strategy == "colossalai_zero2": strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu") else: diff --git a/applications/Colossal-LLaMA-2/README.md b/applications/Colossal-LLaMA-2/README.md index ae2e0c6bb..ac7593d98 100644 --- a/applications/Colossal-LLaMA-2/README.md +++ b/applications/Colossal-LLaMA-2/README.md @@ -130,8 +130,8 @@ from modelscope import AutoModelForCausalLM, AutoTokenizer, snapshot_download model_dir = snapshot_download('colossalai/Colossal-LLaMA-2-7b-base', revision='v1.0.1') tokenizer = AutoTokenizer.from_pretrained(model_dir, device_map="auto", trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", trust_remote_code=True).eval() -generation_kwargs = {"max_new_tokens": 256, - "top_p": 0.95, +generation_kwargs = {"max_new_tokens": 256, + "top_p": 0.95, "temperature": 0.3 } input = '离离原上草,' diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py index a2cfb2ef6..bdf346203 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py @@ -1,20 +1,20 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import numpy as np import os import random from dataclasses import dataclass -from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable +from typing import Callable, Dict, Iterator, List, Optional, Sequence, Union +import numpy as np import torch -from datasets import dataset_dict, load_from_disk +import torch.nn.functional as F from datasets import Dataset as HFDataset +from datasets import dataset_dict, load_from_disk from torch.distributed import ProcessGroup from torch.distributed.distributed_c10d import _get_default_group -from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler +from torch.utils.data import ConcatDataset, DataLoader, Dataset, DistributedSampler from transformers.tokenization_utils import PreTrainedTokenizer -import torch.nn.functional as F DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset] PathType = Union[str, os.PathLike] diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py index 0c21f325a..99f97aeed 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py @@ -7,9 +7,9 @@ Splicing multiple pre-tokenized sequence data points import random import warnings from copy import deepcopy -from datasets import dataset_dict -from typing import Any, Callable, Dict, Iterable, List, Union, Tuple +from typing import Any, Callable, Dict, Iterable, List, Tuple, Union +from datasets import dataset_dict from torch.utils.data import ConcatDataset, Dataset, IterableDataset from transformers.models.llama.tokenization_llama import LlamaTokenizer from transformers.tokenization_utils import PreTrainedTokenizer @@ -169,12 +169,7 @@ class ClosedToConstantLengthSplicedDataset(IterableDataset): spliced_labels.extend(seq_labels) # For residual spliced data point at the end of the data set if self.infinite is False and more_data_points is False and len(spliced_input_ids) > 0: - examples.append( - { - self.input_ids_field: spliced_input_ids, - self.labels_field: spliced_labels - } - ) + examples.append({self.input_ids_field: spliced_input_ids, self.labels_field: spliced_labels}) if self.shuffle: random.shuffle(examples) for spliced_data_point in examples: diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py b/applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py index 67e487f43..f61291f35 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py @@ -8,11 +8,10 @@ import argparse import numpy as np import torch -from transformers import LlamaTokenizer, LlamaForCausalLM +from transformers import LlamaForCausalLM, LlamaTokenizer from colossalai.logging import get_dist_logger - logger = get_dist_logger() diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py b/applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py index 43297633d..439135503 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py @@ -6,12 +6,12 @@ Initialize new tokenizer for continual pre-training """ import argparse -import os import json +import os from typing import List, Union -from transformers.models.llama.tokenization_llama import LlamaTokenizer from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model +from transformers.models.llama.tokenization_llama import LlamaTokenizer from colossalai.logging import get_dist_logger diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py index 85decf37d..05342ce41 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py @@ -10,8 +10,8 @@ import os from typing import Any, Dict, Tuple, Union import torch -from torch.optim.optimizer import Optimizer from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.optimizer import Optimizer from colossalai.booster import Booster from colossalai.cluster import DistCoordinator diff --git a/applications/Colossal-LLaMA-2/docs/example.md b/applications/Colossal-LLaMA-2/docs/example.md index d889ab416..d833d2821 100644 --- a/applications/Colossal-LLaMA-2/docs/example.md +++ b/applications/Colossal-LLaMA-2/docs/example.md @@ -242,4 +242,4 @@ To comprehensively assess the performance of the Colossal-LLaMA-2-7B-base model, ## Conclusion In general, the Colossal-LLaMA-2-7B-base model not only enhances its understanding of English but also exhibits significant improvements in its comprehension of Chinese. It boasts a broad spectrum of general knowledge, encompassing various fields such as food, sports, technology, literature, games, and more. Regarding text generation tasks, the Colossal-LLaMA-2-7B-base model excels in writing performance; however, its ability to generate specific formats like code, emails, tables, etc., needs enhancement due to the scarcity of relevant training data during our training phase. When compared to the Qwen-7b-base model, the Colossal-LLaMA-2-7B-base model outperforms it in answering most English questions and some Chinese questions, as demonstrated in the examples above. -Presently, the Colossal-LLaMA-2-7B-base model already exhibits some capabilities in sentiment analysis, logical reasoning, information extraction, role-play, classification, and rewriting. These capabilities are poised for further improvement in the future as part of our ongoing enhancements. \ No newline at end of file +Presently, the Colossal-LLaMA-2-7B-base model already exhibits some capabilities in sentiment analysis, logical reasoning, information extraction, role-play, classification, and rewriting. These capabilities are poised for further improvement in the future as part of our ongoing enhancements. diff --git a/applications/Colossal-LLaMA-2/hostfile.example b/applications/Colossal-LLaMA-2/hostfile.example index 82948648c..cfaaa0ef5 100644 --- a/applications/Colossal-LLaMA-2/hostfile.example +++ b/applications/Colossal-LLaMA-2/hostfile.example @@ -1,2 +1,2 @@ hostname1 -hostname2 \ No newline at end of file +hostname2 diff --git a/applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py b/applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py index a519232f6..69bc8b5dd 100644 --- a/applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py +++ b/applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py @@ -11,14 +11,14 @@ import os import time from multiprocessing import cpu_count +from colossal_llama2.dataset.spliced_and_tokenized_dataset import ( + ClosedToConstantLengthSplicedDataset, + supervised_tokenize, +) from datasets import dataset_dict, load_dataset from transformers.models.llama.tokenization_llama import LlamaTokenizer from colossalai.logging import get_dist_logger -from colossal_llama2.dataset.spliced_and_tokenized_dataset import ( - supervised_tokenize, - ClosedToConstantLengthSplicedDataset, -) logger = get_dist_logger() @@ -149,5 +149,5 @@ def main(): spliced_dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(spliced_dataset), cpu_count())) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/applications/Colossal-LLaMA-2/requirements.txt b/applications/Colossal-LLaMA-2/requirements.txt index d8afee768..b75eba994 100644 --- a/applications/Colossal-LLaMA-2/requirements.txt +++ b/applications/Colossal-LLaMA-2/requirements.txt @@ -12,4 +12,3 @@ flash-attn>=2.0.0,<=2.0.5 tqdm sentencepiece==0.1.99 protobuf<=3.20.0 - diff --git a/applications/Colossal-LLaMA-2/train.py b/applications/Colossal-LLaMA-2/train.py index 41b4ef031..8b2e6898d 100644 --- a/applications/Colossal-LLaMA-2/train.py +++ b/applications/Colossal-LLaMA-2/train.py @@ -1,45 +1,39 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -Continual Pre-training of LLaMA-2 developed by Colossal-AI Team +Continual Pre-training of LLaMA-2 developed by Colossal-AI Team """ -import json import argparse +import json import os import resource from contextlib import nullcontext -from tqdm import tqdm import torch import torch.distributed as dist +from colossal_llama2.dataset.loader import ( + DataCollatorForSupervisedDataset, + StatefulDistributedSampler, + load_tokenized_dataset, + setup_distributed_dataloader, +) +from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint +from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention +from colossal_llama2.utils.froze import freeze_non_embeds_parameters from torch.utils.tensorboard import SummaryWriter -from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig +from tqdm import tqdm +from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import ( - GeminiPlugin, - LowLevelZeroPlugin, - HybridParallelPlugin, -) +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device -from colossal_llama2.dataset.loader import ( - load_tokenized_dataset, - setup_distributed_dataloader, - DataCollatorForSupervisedDataset, - StatefulDistributedSampler, -) - -from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention -from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint -from colossal_llama2.utils.froze import freeze_non_embeds_parameters - def get_model_numel(model: torch.nn.Module) -> int: return sum(p.numel() for p in model.parameters()) @@ -372,9 +366,7 @@ def main() -> None: # Final save. coordinator.print_on_master("Start saving final model checkpoint") booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) - coordinator.print_on_master( - f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}" - ) + coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}") coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") diff --git a/applications/Colossal-LLaMA-2/version.txt b/applications/Colossal-LLaMA-2/version.txt index 8a9ecc2ea..8acdd82b7 100644 --- a/applications/Colossal-LLaMA-2/version.txt +++ b/applications/Colossal-LLaMA-2/version.txt @@ -1 +1 @@ -0.0.1 \ No newline at end of file +0.0.1 diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index c2a724084..56d8a0935 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -19,6 +19,7 @@ except ImportError: import colossalai.interface.pretrained as pretrained_utils from colossalai.checkpoint_io import GeneralCheckpointIO from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.quantization import BnbQuantizationConfig from .accelerator import Accelerator from .mixed_precision import MixedPrecision, mixed_precision_factory @@ -230,7 +231,12 @@ class Booster: return self.plugin.no_sync(model, optimizer) def enable_lora( - self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: "peft.LoraConfig" = None + self, + model: nn.Module, + pretrained_dir: Optional[str] = None, + lora_config: "peft.LoraConfig" = None, + bnb_quantization_config: Optional[BnbQuantizationConfig] = None, + quantize=False, ) -> nn.Module: """ Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory. @@ -259,7 +265,20 @@ class Booster: assert ( pretrained_dir is not None ), "Please provide pretrained directory path if not passing in lora configuration." - return self.plugin.enable_lora(model, pretrained_dir, lora_config) + if quantize is True: + if bnb_quantization_config is not None: + warnings.warn( + "User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk." + ) + else: + bnb_quantization_config = BnbQuantizationConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + ) + + return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config) def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None: """Load model from checkpoint. diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index b2087af68..795dc3973 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -1,11 +1,11 @@ -import logging -import warnings import enum +import logging import os +import warnings from functools import partial from pathlib import Path from types import MethodType -from typing import Callable, Dict, Iterator, List, Optional, Tuple, Dict +from typing import Callable, Dict, Iterator, List, Optional, Tuple import torch import torch.nn as nn @@ -27,6 +27,7 @@ from colossalai.checkpoint_io.utils import ( sharded_optimizer_loading_epilogue, ) from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper +from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.utils import get_current_device from colossalai.zero import LowLevelZeroOptimizer @@ -44,6 +45,7 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"] + class OptimizerParamCheckState(enum.Enum): ORIGIN_PARAM_FINDED = 0 ORIGIN_PARAM_NOT_FIND = -1 @@ -221,6 +223,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") return from peft import PeftModel + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" peft_model = model.unwrap() assert isinstance( @@ -331,38 +334,45 @@ class LowLevelZeroPlugin(DPPluginBase): def supported_devices(self) -> List[str]: return ["cuda"] - def support_lora(self) -> bool: return True def enable_lora( - self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None + self, + model: nn.Module, + pretrained_dir: Optional[str] = None, + lora_config: Optional[Dict] = None, + bnb_quantization_config: Optional[BnbQuantizationConfig] = None, ) -> nn.Module: from peft import PeftModel, get_peft_model + assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model." self.lora_enabled = True warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") + if bnb_quantization_config is not None: + model = quantize_model(model, bnb_quantization_config) + if pretrained_dir is None: peft_model = get_peft_model(model, lora_config) else: peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True) return peft_model - + def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter): origin_param_id = id(origin_param) for group_id, param_group in enumerate(optimizer.param_groups): - for p in param_group['params']: + for p in param_group["params"]: if id(p) == origin_param_id: return group_id return -1 - + def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter, lora_param: Parameter): origin_param_id = id(origin_param) lora_param_id = id(lora_param) target_group_id = None for group_id, param_group in enumerate(optimizer.param_groups): - for p in param_group['params']: + for p in param_group["params"]: if id(p) == lora_param_id: # check if the lora parameter exists. return target_group_id, OptimizerParamCheckState.LORA_PARM_EXISTED @@ -372,25 +382,31 @@ class LowLevelZeroPlugin(DPPluginBase): return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_FINDED else: return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND - + def add_lora_params_to_optimizer(self, model, optimizer): - """ add lora parameters to optimizer """ - name2param= {} + """add lora parameters to optimizer""" + name2param = {} for name, param in model.named_parameters(): name2param[name] = param for name, param in name2param.items(): - if 'lora_A' in name or 'lora_B' in name: + if "lora_A" in name or "lora_B" in name: origin_key = name.replace("lora_A.", "") origin_key = origin_key.replace("lora_B.", "") origin_key = origin_key.replace(f"{model.active_adapter}", "base_layer") origin_param = name2param[origin_key] group_id, check_state = self.get_param_group_id(optimizer, origin_param, param) if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND: - warnings.warn("Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups.") - elif check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED and group_id is not None and group_id >= 0: - optimizer.param_groups[group_id]['params'].append(param) - + warnings.warn( + "Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups." + ) + elif ( + check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED + and group_id is not None + and group_id >= 0 + ): + optimizer.param_groups[group_id]["params"].append(param) + def configure( self, model: nn.Module, @@ -401,11 +417,13 @@ class LowLevelZeroPlugin(DPPluginBase): ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: if self.lora_enabled: from peft import PeftModel - assert isinstance(model, PeftModel), "The model should have been wrapped as a PeftModel when self.lora_enabled is True" + + assert isinstance( + model, PeftModel + ), "The model should have been wrapped as a PeftModel when self.lora_enabled is True" if optimizer is not None: self.add_lora_params_to_optimizer(model, optimizer) - if not isinstance(model, ModelWrapper): model = LowLevelZeroModel(model, self.precision) diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 9ba520de2..482cc4e98 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -9,6 +9,7 @@ from torch.utils.data import DataLoader from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.quantization import BnbQuantizationConfig, quantize_model from .dp_plugin_base import DPPluginBase @@ -237,10 +238,17 @@ class TorchDDPPlugin(DPPluginBase): return model.module.no_sync() def enable_lora( - self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None + self, + model: nn.Module, + pretrained_dir: Optional[str] = None, + lora_config: Optional[Dict] = None, + bnb_quantization_config: Optional[BnbQuantizationConfig] = None, ) -> nn.Module: from peft import PeftModel, get_peft_model + if bnb_quantization_config is not None: + model = quantize_model(model, bnb_quantization_config) + assert not isinstance(model, TorchDDPModel), "Lora should be enabled before boosting the model." if pretrained_dir is None: return get_peft_model(model, lora_config) diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index d0c281e05..9cddc8318 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -64,7 +64,7 @@ vllm flash-attention # install lightllm since we depend on lightllm triton kernels -git clone https://github.com/ModelTC/lightllm +git clone https://github.com/ModelTC/lightllm git checkout 28c1267cfca536b7b4f28e921e03de735b003039 cd lightllm pip3 install -e . @@ -84,7 +84,7 @@ cd /path/to/CollossalAI pip install -e . # install lightllm -git clone https://github.com/ModelTC/lightllm +git clone https://github.com/ModelTC/lightllm git checkout 28c1267cfca536b7b4f28e921e03de735b003039 cd lightllm pip3 install -e . diff --git a/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py b/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py index ca12c34ed..36339ac88 100644 --- a/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py +++ b/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py @@ -18,15 +18,15 @@ from .gptq_op import CaiGPTQLinearOp HAS_GPTQ_CUDA = False try: from colossalai.kernel.op_builder.gptq import GPTQBuilder + gptq_cuda = GPTQBuilder().load() HAS_GPTQ_CUDA = True except ImportError: - warnings.warn('CUDA gptq is not installed') + warnings.warn("CUDA gptq is not installed") HAS_GPTQ_CUDA = False class CaiQuantLinear(nn.Module): - def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): super().__init__() if bits not in [2, 4, 8]: @@ -37,23 +37,28 @@ class CaiQuantLinear(nn.Module): self.maxq = 2**self.bits - 1 self.groupsize = groupsize if groupsize != -1 else infeatures - self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) + self.register_buffer("qweight", torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) self.register_buffer( - 'qzeros', - torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)) - self.register_buffer('scales', - torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) + "qzeros", + torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32), + ) + self.register_buffer( + "scales", torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16) + ) if row_split: self.register_buffer( - 'g_idx', - torch.tensor([(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], - dtype=torch.int32)) + "g_idx", + torch.tensor( + [(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], dtype=torch.int32 + ), + ) else: - self.register_buffer('g_idx', - torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)) + self.register_buffer( + "g_idx", torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32) + ) if bias: - self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) + self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16)) else: self.bias = None @@ -66,9 +71,11 @@ class CaiQuantLinear(nn.Module): self.row_split = row_split def pack(self, linear, scales, zeros, g_idx=None): - - g_idx = g_idx.clone() if g_idx is not None else torch.tensor( - [i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32) + g_idx = ( + g_idx.clone() + if g_idx is not None + else torch.tensor([i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32) + ) scales = scales.t().contiguous() zeros = zeros.t().contiguous() @@ -79,7 +86,6 @@ class CaiQuantLinear(nn.Module): if linear.bias is not None: self.bias = linear.bias.clone().half() - wn = 8 pbits = 32 ptype = torch.int32 unsign_type = np.uint32 @@ -88,9 +94,10 @@ class CaiQuantLinear(nn.Module): intweight = [] for idx in range(self.infeatures): intweight.append( - torch.round( - (linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:, - None]) + torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[ + :, None + ] + ) intweight = torch.cat(intweight, dim=1) intweight = intweight.t().contiguous() intweight = intweight.numpy().astype(unsign_type) @@ -109,7 +116,7 @@ class CaiQuantLinear(nn.Module): raise NotImplementedError("Only 2,4,8 bits are supported.") qweight = qweight.astype(sign_type) qweight1 = torch.from_numpy(qweight) - qweight1 = qweight1.contiguous() #.to("cuda") + qweight1 = qweight1.contiguous() # .to("cuda") self.qweight.data.copy_(qweight1) qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type) @@ -140,17 +147,20 @@ class CaiQuantLinear(nn.Module): self.q4_width = self.qweight.shape[1] if self.g_idx is not None: if self.row_split and torch.equal( - self.g_idx, - torch.tensor( - [(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)], - dtype=torch.int32, - device=self.g_idx.device)): + self.g_idx, + torch.tensor( + [(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)], + dtype=torch.int32, + device=self.g_idx.device, + ), + ): self.g_idx = None elif torch.equal( - self.g_idx, - torch.tensor([i // self.groupsize for i in range(self.infeatures)], - dtype=torch.int32, - device=self.g_idx.device)): + self.g_idx, + torch.tensor( + [i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device + ), + ): self.g_idx = None if self.g_idx is not None: @@ -165,7 +175,6 @@ class CaiQuantLinear(nn.Module): outshape = x.shape[:-1] + (self.outfeatures,) if HAS_GPTQ_CUDA and self.bits == 4: - if self.q4 is None: self.init_q4() @@ -191,7 +200,6 @@ class CaiQuantLinear(nn.Module): def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1): - qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1) qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1) scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1) @@ -203,24 +211,24 @@ def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1 zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num for i in range(split_num): - cai_linear.qweight[:, i * cai_split_out_features:(i + 1) * - cai_split_out_features] = qweights[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) * - cai_split_out_features] - cai_linear.qzeros[:, i * zero_split_block:(i + 1) * - zero_split_block] = qzeros[i][:, tp_rank * zero_split_block:(tp_rank + 1) * zero_split_block] - cai_linear.scales[:, i * cai_split_out_features:(i + 1) * - cai_split_out_features] = scales[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) * - cai_split_out_features] + cai_linear.qweight[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = qweights[i][ + :, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features + ] + cai_linear.qzeros[:, i * zero_split_block : (i + 1) * zero_split_block] = qzeros[i][ + :, tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block + ] + cai_linear.scales[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = scales[i][ + :, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features + ] if cai_linear.bias is not None: - cai_linear.bias[i * cai_split_out_features:(i + 1) * - cai_split_out_features] = bias[i][tp_rank * cai_split_out_features:(tp_rank + 1) * - cai_split_out_features] + cai_linear.bias[i * cai_split_out_features : (i + 1) * cai_split_out_features] = bias[i][ + tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features + ] cai_linear.g_idx.copy_(g_idx) def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1): - qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0) qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0) scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0) @@ -231,47 +239,40 @@ def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1): idx_split_features = cai_linear.infeatures // split_num for i in range(split_num): - cai_linear.qweight[i * cai_split_in_features:(i + 1) * - cai_split_in_features, :] = qweights[i][tp_rank * cai_split_in_features:(tp_rank + 1) * - cai_split_in_features, :] - cai_linear.qzeros[i * zero_split_block:(i + 1) * - zero_split_block, :] = qzeros[i][tp_rank * zero_split_block:(tp_rank + 1) * - zero_split_block, :] - cai_linear.scales[i * zero_split_block:(i + 1) * - zero_split_block, :] = scales[i][tp_rank * zero_split_block:(tp_rank + 1) * - zero_split_block, :] - cai_linear.g_idx[i * idx_split_features:(i + 1) * - idx_split_features] = g_idxs[i][tp_rank * idx_split_features:(tp_rank + 1) * - idx_split_features] + cai_linear.qweight[i * cai_split_in_features : (i + 1) * cai_split_in_features, :] = qweights[i][ + tp_rank * cai_split_in_features : (tp_rank + 1) * cai_split_in_features, : + ] + cai_linear.qzeros[i * zero_split_block : (i + 1) * zero_split_block, :] = qzeros[i][ + tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, : + ] + cai_linear.scales[i * zero_split_block : (i + 1) * zero_split_block, :] = scales[i][ + tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, : + ] + cai_linear.g_idx[i * idx_split_features : (i + 1) * idx_split_features] = g_idxs[i][ + tp_rank * idx_split_features : (tp_rank + 1) * idx_split_features + ] if cai_linear.bias is not None: cai_linear.bias.copy_(gptq_linear.bias) class RowCaiQuantLinear(CaiQuantLinear, ParallelModule): - def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): - - super().__init__(bits, - groupsize, - infeatures, - outfeatures, - bias, - tp_size=tp_size, - tp_rank=tp_rank, - row_split=row_split) + super().__init__( + bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split + ) self.process_group = None @staticmethod - def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: LazyInitContext.materialize(module) # get the attributes in_features = module.in_features # ensure only one process group is passed if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." process_group = process_group[0] tp_size = dist.get_world_size(process_group) @@ -282,15 +283,18 @@ class RowCaiQuantLinear(CaiQuantLinear, ParallelModule): if in_features % tp_size != 0: raise ValueError( - f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") - linear_1d = RowCaiQuantLinear(module.bits, - module.group_size, - module.in_features // tp_size, - module.out_features, - module.bias is not None, - tp_size=tp_size, - tp_rank=tp_rank, - row_split=True) + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + linear_1d = RowCaiQuantLinear( + module.bits, + module.group_size, + module.in_features // tp_size, + module.out_features, + module.bias is not None, + tp_size=tp_size, + tp_rank=tp_rank, + row_split=True, + ) linear_1d.process_group = process_group split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) @@ -306,30 +310,23 @@ class RowCaiQuantLinear(CaiQuantLinear, ParallelModule): class ColCaiQuantLinear(CaiQuantLinear, ParallelModule): - def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): - - super().__init__(bits, - groupsize, - infeatures, - outfeatures, - bias, - tp_size=tp_size, - tp_rank=tp_rank, - row_split=row_split) + super().__init__( + bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split + ) self.process_group = None @staticmethod - def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: LazyInitContext.materialize(module) # get the attributes in_features = module.in_features # ensure only one process group is passed if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." process_group = process_group[0] tp_size = dist.get_world_size(process_group) @@ -340,14 +337,17 @@ class ColCaiQuantLinear(CaiQuantLinear, ParallelModule): if in_features % tp_size != 0: raise ValueError( - f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") - linear_1d = ColCaiQuantLinear(module.bits, - module.group_size, - module.in_features, - module.out_features // tp_size, - module.bias is not None, - tp_size=tp_size, - tp_rank=tp_rank) + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + linear_1d = ColCaiQuantLinear( + module.bits, + module.group_size, + module.in_features, + module.out_features // tp_size, + module.bias is not None, + tp_size=tp_size, + tp_rank=tp_rank, + ) linear_1d.process_group = process_group split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) diff --git a/colossalai/inference/tensor_parallel/batch_infer_state.py b/colossalai/inference/tensor_parallel/batch_infer_state.py index de150311c..33e79ec04 100644 --- a/colossalai/inference/tensor_parallel/batch_infer_state.py +++ b/colossalai/inference/tensor_parallel/batch_infer_state.py @@ -5,6 +5,7 @@ import torch from .kvcache_manager import MemoryManager + # adapted from: lightllm/server/router/model_infer/infer_batch.py @dataclass class BatchInferState: diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py index b8274d3c6..b21ec9b8e 100644 --- a/colossalai/inference/tensor_parallel/modeling/chatglm2.py +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -19,8 +19,11 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( from ._utils import copy_kv_to_mem_cache try: - from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_llama2_context_attention_fwd from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd as chatglm2_rotary_emb_fwd + from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd as lightllm_llama2_context_attention_fwd, + ) + HAS_LIGHTLLM_KERNEL = True except: print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") diff --git a/colossalai/inference/tensor_parallel/policies/bloom.py b/colossalai/inference/tensor_parallel/policies/bloom.py index 3d6df2097..fba83a081 100644 --- a/colossalai/inference/tensor_parallel/policies/bloom.py +++ b/colossalai/inference/tensor_parallel/policies/bloom.py @@ -4,7 +4,6 @@ import torch from torch.nn import LayerNorm import colossalai.shardformer.layer as col_nn -from colossalai.shardformer.modeling.bloom import build_bloom_alibi_tensor_fn from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy @@ -40,33 +39,36 @@ class BloomModelInferPolicy(BloomForCausalLMPolicy): policy = super().module_policy() if self.shard_config.inference_gptq: from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear - policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={ - "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attention.query_key_value", - target_module=ColCaiQuantLinear, - kwargs={'split_num': 3}), - SubModuleReplacementDescription( - suffix="self_attention.dense", - target_module=RowCaiQuantLinear, - kwargs={'split_num': 1}), - SubModuleReplacementDescription( - suffix="self_attention.attention_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="mlp.dense_h_to_4h", - target_module=ColCaiQuantLinear, - kwargs={'split_num': 1}), - SubModuleReplacementDescription( - suffix="mlp.dense_4h_to_h", - target_module=RowCaiQuantLinear, - kwargs={'split_num': 1}), - ]) + + policy[BloomBlock] = ModulePolicyDescription( + attribute_replacement={ + "self_attention.hidden_size": self.model.config.hidden_size + // self.shard_config.tensor_parallel_size, + "self_attention.split_size": self.model.config.hidden_size + // self.shard_config.tensor_parallel_size, + "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 3}, + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", target_module=RowCaiQuantLinear, kwargs={"split_num": 1} + ), + SubModuleReplacementDescription( + suffix="self_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_h_to_4h", target_module=ColCaiQuantLinear, kwargs={"split_num": 1} + ), + SubModuleReplacementDescription( + suffix="mlp.dense_4h_to_h", target_module=RowCaiQuantLinear, kwargs={"split_num": 1} + ), + ], + ) # NOTE set inference mode to shard config self.shard_config._infer() diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index 7e163efe0..da666fbe3 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -13,6 +13,7 @@ from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forw try: from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward + HAS_TRITON_RMSNORM = True except: print("you should install triton from https://github.com/openai/triton") @@ -21,6 +22,7 @@ except: def get_triton_rmsnorm_forward(): if HAS_TRITON_RMSNORM: + def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) diff --git a/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp b/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp index bcc0e4390..8f17723cb 100644 --- a/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp +++ b/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp @@ -1,254 +1,202 @@ // Adapted from turboderp exllama: https://github.com/turboderp/exllama -#include -#include #include -#include +#include #include +#include +#include + #include #include -#include "util.cuh" -#include "tuning.h" -#include "cuda_buffers.cuh" -#include "q4_matrix.cuh" -#include "q4_matmul.cuh" + #include "column_remap.cuh" +#include "cuda_buffers.cuh" +#include "q4_matmul.cuh" +#include "q4_matrix.cuh" +#include "tuning.h" +#include "util.cuh" -// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a -// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of -// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console. +// Check CUDA return code. We don't want to include Torch headers in the .cu +// files because parsing them adds almost a minute to the compile time on a +// 12900K. Also passing exceptions back to Python is super tricky, so in place +// of exceptions, CUDA functions return with a cudaError_t which we can parse +// and dump to the console. -void check_cuda(cudaError_t ret) -{ - switch (ret) - { - case cudaSuccess: - break; +void check_cuda(cudaError_t ret) { + switch (ret) { + case cudaSuccess: + break; - case cudaUnspecified: - printf(" **** Unspecified error\n"); - TORCH_CHECK(false, "CUDA error"); - break; + case cudaUnspecified: + printf(" **** Unspecified error\n"); + TORCH_CHECK(false, "CUDA error"); + break; - default: - printf(" **** CUDA error\n"); \ - printf(" **** %s\n", cudaGetErrorString(ret)); \ - TORCH_CHECK(false, "CUDA error"); \ - break; - } + default: + printf(" **** CUDA error\n"); + printf(" **** %s\n", cudaGetErrorString(ret)); + TORCH_CHECK(false, "CUDA error"); + break; + } } // Some decluttering macros #define STRINGIFY_(__x) #__x #define STRINGIFY(__x) STRINGIFY_(__x) -#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) -#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) -#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") -#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") -#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod)) -#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small") +#define TORCH_CHECK_DTYPE(__x, __dtype) \ + TORCH_CHECK((__x).dtype() == torch::__dtype, \ + #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) \ + TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, \ + #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) \ + TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, \ + #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) \ + TORCH_CHECK((__x).device().is_meta() || \ + (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, \ + #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) \ + TORCH_CHECK((__x).size(__dim_x) % __mod == 0, \ + #__x ".shape[" STRINGIFY( \ + __dim_x) "] must be a multiple of " STRINGIFY(__mod)) +#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) \ + TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small") -#define TORCH_CHECK_DEVICE_INDEX(__index) \ -do { \ - TORCH_CHECK(__index >= 0, "no device index"); \ +#define TORCH_CHECK_DEVICE_INDEX(__index) \ + do { \ + TORCH_CHECK(__index >= 0, "no device index"); \ TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \ -} while(0) + } while (0) #define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \ -do { \ - TORCH_CHECK_DTYPE(__w, kInt); \ - TORCH_CHECK_DTYPE(__w_scales, kHalf); \ - TORCH_CHECK_DTYPE(__w_zeros, kInt); \ - TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \ - TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \ - TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \ - TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \ -} while(0) + do { \ + TORCH_CHECK_DTYPE(__w, kInt); \ + TORCH_CHECK_DTYPE(__w_scales, kHalf); \ + TORCH_CHECK_DTYPE(__w_zeros, kInt); \ + TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \ + TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \ + TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \ + TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \ + } while (0) -int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) -{ - int groupsize = w.size(0) * 8 / w_zeros.size(0); - TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]") - return groupsize; +int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) { + int groupsize = w.size(0) * 8 / w_zeros.size(0); + TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, + "w.shape[-2] must be a multiple of zeros.shape[-2]") + return groupsize; } - // Tuning parameters ExLlamaTuning tuningParams; -void set_tuning_params -( - int matmul_recons_thd, - bool matmul_fused_remap, - bool matmul_no_half2 -) -{ - tuningParams.matmul_recons_thd = matmul_recons_thd; - tuningParams.matmul_fused_remap = matmul_fused_remap; - tuningParams.matmul_no_half2 = matmul_no_half2; +void set_tuning_params(int matmul_recons_thd, bool matmul_fused_remap, + bool matmul_no_half2) { + tuningParams.matmul_recons_thd = matmul_recons_thd; + tuningParams.matmul_fused_remap = matmul_fused_remap; + tuningParams.matmul_no_half2 = matmul_no_half2; } - // Release all unmanaged objects allocated by the extension -void cleanup() -{ - cleanup_buffers_cuda(); - g_q4_free_matrices(); +void cleanup() { + cleanup_buffers_cuda(); + g_q4_free_matrices(); } - // Prepare buffers for forward pass -void prepare_buffers -( - torch::Device device, - torch::Tensor temp_state, - torch::Tensor temp_dq -) -{ - int device_index = device.index(); - TORCH_CHECK_DEVICE_INDEX(device_index); - const at::cuda::OptionalCUDAGuard device_guard(device); +void prepare_buffers(torch::Device device, torch::Tensor temp_state, + torch::Tensor temp_dq) { + int device_index = device.index(); + TORCH_CHECK_DEVICE_INDEX(device_index); + const at::cuda::OptionalCUDAGuard device_guard(device); - prepare_buffers_cuda - ( - device_index, - // buffer size used for sanity checks - temp_state.numel(), - (half*) temp_state.data_ptr(), - (half*) temp_dq.data_ptr() - ); + prepare_buffers_cuda(device_index, + // buffer size used for sanity checks + temp_state.numel(), (half*)temp_state.data_ptr(), + (half*)temp_dq.data_ptr()); } - // Create Q4Matrix, return handle -uintptr_t make_q4 -( - torch::Tensor qweight, - torch::Tensor qzeros, - torch::Tensor scales, - torch::Tensor g_idx, - int device -) -{ - TORCH_CHECK_DTYPE(qweight, kInt); - TORCH_CHECK_DTYPE(qzeros, kInt); - TORCH_CHECK_DTYPE(scales, kHalf); - TORCH_CHECK_DTYPE_OPT(g_idx, kInt); - TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8); - TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1); - TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1); +uintptr_t make_q4(torch::Tensor qweight, torch::Tensor qzeros, + torch::Tensor scales, torch::Tensor g_idx, int device) { + TORCH_CHECK_DTYPE(qweight, kInt); + TORCH_CHECK_DTYPE(qzeros, kInt); + TORCH_CHECK_DTYPE(scales, kHalf); + TORCH_CHECK_DTYPE_OPT(g_idx, kInt); + TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8); + TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1); + TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1); - int width = qweight.size(1); - int height = qweight.size(0) * 8; - int groups = qzeros.size(0); + int width = qweight.size(1); + int height = qweight.size(0) * 8; + int groups = qzeros.size(0); - Q4Matrix* m = new Q4Matrix - ( - height, - width, - groups, + Q4Matrix* m = new Q4Matrix( + height, width, groups, - (uint32_t*) qweight.data_ptr(), - (uint32_t*) qzeros.data_ptr(), - (half*) scales.data_ptr(), - g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(), + (uint32_t*)qweight.data_ptr(), (uint32_t*)qzeros.data_ptr(), + (half*)scales.data_ptr(), + g_idx.device().is_meta() ? NULL : (uint32_t*)g_idx.data_ptr(), - device - ); + device); - g_q4_keep_matrix(m); - return reinterpret_cast (m); + g_q4_keep_matrix(m); + return reinterpret_cast(m); } - // Matmul half @ quant -> half -void q4_matmul -( - torch::Tensor x, - uintptr_t w, - torch::Tensor out -) -{ - Q4Matrix* wm = reinterpret_cast (w); +void q4_matmul(torch::Tensor x, uintptr_t w, torch::Tensor out) { + Q4Matrix* wm = reinterpret_cast(w); - TORCH_CHECK_DTYPE(x, kHalf); - TORCH_CHECK_DTYPE(out, kHalf); - TORCH_CHECK_SHAPES(x, 0, out, 0, 1); - TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes") + TORCH_CHECK_DTYPE(x, kHalf); + TORCH_CHECK_DTYPE(out, kHalf); + TORCH_CHECK_SHAPES(x, 0, out, 0, 1); + TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes") - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - int x_height = x.size(0); + int x_height = x.size(0); - if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd) - { - q4_matmul_cuda - ( - &tuningParams, - (half*) x.data_ptr(), - x_height, - wm, - (half*) out.data_ptr() - ); - } - else - { - q4_matmul_recons_cuda - ( - &tuningParams, - (half*) x.data_ptr(), - x_height, - wm, - (half*) out.data_ptr(), - at::cuda::getCurrentCUDABlasHandle() - ); - } + if (tuningParams.matmul_recons_thd == 0 || + x_height < tuningParams.matmul_recons_thd) { + q4_matmul_cuda(&tuningParams, (half*)x.data_ptr(), x_height, wm, + (half*)out.data_ptr()); + } else { + q4_matmul_recons_cuda(&tuningParams, (half*)x.data_ptr(), x_height, wm, + (half*)out.data_ptr(), + at::cuda::getCurrentCUDABlasHandle()); + } } - // Remap columns in half tensor -void column_remap -( - torch::Tensor x, - torch::Tensor x_new, - torch::Tensor x_map -) -{ - TORCH_CHECK_DTYPE(x, kHalf); - TORCH_CHECK_DTYPE(x_new, kHalf); - TORCH_CHECK_DTYPE(x_map, kInt); - TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1); +void column_remap(torch::Tensor x, torch::Tensor x_new, torch::Tensor x_map) { + TORCH_CHECK_DTYPE(x, kHalf); + TORCH_CHECK_DTYPE(x_new, kHalf); + TORCH_CHECK_DTYPE(x_map, kInt); + TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1); - int height = x.size(0); - int width = x.size(1); + int height = x.size(0); + int width = x.size(1); - TORCH_CHECK_BUFFER_SIZE(x_new, height * width); + TORCH_CHECK_BUFFER_SIZE(x_new, height * width); - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - column_remap_cuda - ( - (half*) x.data_ptr(), - (half*) x_new.data_ptr(), - height, - width, - (uint32_t*) x_map.data_ptr() - ); + column_remap_cuda((half*)x.data_ptr(), (half*)x_new.data_ptr(), height, width, + (uint32_t*)x_map.data_ptr()); } - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("set_tuning_params", &set_tuning_params, "set_tuning_params"); - m.def("prepare_buffers", &prepare_buffers, "prepare_buffers"); - m.def("cleanup", &cleanup, "cleanup"); - m.def("make_q4", &make_q4, "make_q4"); - m.def("q4_matmul", &q4_matmul, "q4_matmul"); +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("set_tuning_params", &set_tuning_params, "set_tuning_params"); + m.def("prepare_buffers", &prepare_buffers, "prepare_buffers"); + m.def("cleanup", &cleanup, "cleanup"); + m.def("make_q4", &make_q4, "make_q4"); + m.def("q4_matmul", &q4_matmul, "q4_matmul"); } diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu index 9c61143f5..bd595ee6f 100644 --- a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu +++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu @@ -184,7 +184,7 @@ __global__ void reconstruct_kernel int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x; int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8; if (column >= width) return; - + // Views MatrixView_q4_column w_(w, height, width); diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh index 50cb72a41..49431dc95 100644 --- a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh +++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh @@ -50,4 +50,4 @@ private: void g_q4_keep_matrix(Q4Matrix* m); void g_q4_free_matrices(); -#endif \ No newline at end of file +#endif diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py index 1b4f6e44b..f5847af8b 100644 --- a/colossalai/kernel/triton/context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -238,5 +238,5 @@ if HAS_TRITON: num_warps=num_warps, num_stages=1, ) - - return \ No newline at end of file + + return diff --git a/colossalai/kernel/triton/copy_kv_cache_dest.py b/colossalai/kernel/triton/copy_kv_cache_dest.py index 0ce6b09e5..823d21573 100644 --- a/colossalai/kernel/triton/copy_kv_cache_dest.py +++ b/colossalai/kernel/triton/copy_kv_cache_dest.py @@ -10,7 +10,6 @@ except ImportError: print("please install triton from https://github.com/openai/triton") if HAS_TRITON: - # adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py @triton.jit def _fwd_copy_kv_cache_dest( diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py index 8dc919bad..fce90e453 100644 --- a/colossalai/kernel/triton/token_attention_kernel.py +++ b/colossalai/kernel/triton/token_attention_kernel.py @@ -13,6 +13,9 @@ except ImportError: print("please install triton from https://github.com/openai/triton") try: + from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import ( + token_att_fwd as lightllm_bloom_token_att_fwd, + ) from lightllm.models.llama2.triton_kernel.token_attention_nopad_att1 import ( token_att_fwd as lightllm_llama2_token_att_fwd, ) @@ -22,11 +25,15 @@ try: from lightllm.models.llama2.triton_kernel.token_attention_nopad_softmax import ( token_softmax_fwd as lightllm_llama2_token_softmax_fwd, ) - - from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fw2 - from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd - from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd - from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd + from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import ( + token_att_fwd as lightllm_llama_token_att_fwd, + ) + from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import ( + token_att_fwd2 as lightllm_llama_token_att_fw2, + ) + from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import ( + token_softmax_fwd as lightllm_llama_token_softmax_fwd, + ) HAS_TRITON_TOKEN_ATTENTION = True except ImportError: diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 29a102be0..560ae7b05 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -44,8 +44,8 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) - return unpickle -def check_for_nccl_backend(group): +def check_for_nccl_backend(group): pg = group or c10d._get_default_group() # Gate PG wrapper check on Gloo availability. if c10d._GLOO_AVAILABLE: @@ -54,10 +54,8 @@ def check_for_nccl_backend(group): while isinstance(pg, c10d._ProcessGroupWrapper): pg = pg.wrapped_pg - return ( - c10d.is_nccl_available() and - pg.name() == c10d.Backend.NCCL - ) + return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL + def _broadcast_object_list( object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None diff --git a/colossalai/quantization/__init__.py b/colossalai/quantization/__init__.py new file mode 100644 index 000000000..e9707b479 --- /dev/null +++ b/colossalai/quantization/__init__.py @@ -0,0 +1,7 @@ +from .bnb import quantize_model +from .bnb_config import BnbQuantizationConfig + +__all__ = [ + "BnbQuantizationConfig", + "quantize_model", +] diff --git a/colossalai/quantization/bnb.py b/colossalai/quantization/bnb.py new file mode 100644 index 000000000..fa214116a --- /dev/null +++ b/colossalai/quantization/bnb.py @@ -0,0 +1,321 @@ +# adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py + +import logging + +import torch +import torch.nn as nn + +from .bnb_config import BnbQuantizationConfig + +try: + import bitsandbytes as bnb + + IS_4BIT_BNB_AVAILABLE = bnb.__version__ >= "0.39.0" + IS_8BIT_BNB_AVAILABLE = bnb.__version__ >= "0.37.2" +except ImportError: + pass + + +logger = logging.getLogger(__name__) + + +def quantize_model( + model: torch.nn.Module, + bnb_quantization_config: BnbQuantizationConfig, +): + """ + This function will quantize the input loaded model with the associated config passed in `bnb_quantization_config`. + We will quantize the model and put the model on the GPU. + + Args: + model (`torch.nn.Module`): + Input model. The model already loaded + bnb_quantization_config (`BnbQuantizationConfig`): + The bitsandbytes quantization parameters + + Returns: + `torch.nn.Module`: The quantized model + """ + + load_in_4bit = bnb_quantization_config.load_in_4bit + load_in_8bit = bnb_quantization_config.load_in_8bit + + if load_in_8bit and not IS_8BIT_BNB_AVAILABLE: + raise ImportError( + "You have a version of `bitsandbytes` that is not compatible with 8bit quantization," + " make sure you have the latest version of `bitsandbytes` installed." + ) + if load_in_4bit and not IS_4BIT_BNB_AVAILABLE: + raise ValueError( + "You have a version of `bitsandbytes` that is not compatible with 4bit quantization," + "make sure you have the latest version of `bitsandbytes` installed." + ) + + # We keep some modules such as the lm_head in their original dtype for numerical stability reasons + if bnb_quantization_config.skip_modules is None: + bnb_quantization_config.skip_modules = get_keys_to_not_convert(model) + + modules_to_not_convert = bnb_quantization_config.skip_modules + + # We add the modules we want to keep in full precision + if bnb_quantization_config.keep_in_fp32_modules is None: + bnb_quantization_config.keep_in_fp32_modules = [] + keep_in_fp32_modules = bnb_quantization_config.keep_in_fp32_modules + + # compatibility with peft + model.is_loaded_in_4bit = load_in_4bit + model.is_loaded_in_8bit = load_in_8bit + + # assert model_device is cuda + model_device = next(model.parameters()).device + + model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert) + + # convert param to the right dtype + dtype = bnb_quantization_config.torch_dtype + for name, param in model.state_dict().items(): + if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules): + param.to(torch.float32) + if param.dtype != torch.float32: + name = name.replace(".weight", "").replace(".bias", "") + param = getattr(model, name, None) + if param is not None: + param.to(torch.float32) + elif torch.is_floating_point(param): + param.to(dtype) + if model_device.type == "cuda": + # move everything to cpu in the first place because we can't do quantization if the weights are already on cuda + model.cuda(torch.cuda.current_device()) + torch.cuda.empty_cache() + elif torch.cuda.is_available(): + model.to(torch.cuda.current_device()) + logger.info( + f"The model device type is {model_device.type}. However, cuda is needed for quantization." + "We move the model to cuda." + ) + else: + raise RuntimeError("No GPU found. A GPU is needed for quantization.") + return model + + +def replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=None, current_key_name=None): + """ + A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules or by `bnb.nn.Linear4bit` + modules from the `bitsandbytes`library. The function will be run recursively and replace `torch.nn.Linear` modules. + + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + modules_to_not_convert (`List[str]`): + Names of the modules to not quantize convert. In practice we keep the `lm_head` in full precision for + numerical stability reasons. + current_key_name (`List[str]`, *optional*): + An array to track the current key of the recursion. This is used to check whether the current key (part of + it) is not in the list of modules to not convert. + """ + + if modules_to_not_convert is None: + modules_to_not_convert = [] + + model, has_been_replaced = _replace_with_bnb_layers( + model, bnb_quantization_config, modules_to_not_convert, current_key_name + ) + if not has_been_replaced: + logger.warning( + "You are loading your model in 8bit or 4bit but no linear modules were found in your model." + " this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers." + " Please double check your model architecture, or submit an issue on github if you think this is" + " a bug." + ) + return model + + +def _replace_with_bnb_layers( + model, + bnb_quantization_config, + modules_to_not_convert=None, + current_key_name=None, +): + """ + Private method that wraps the recursion for module replacement. + + Returns the converted model and a boolean that indicates if the conversion has been successfull or not. + """ + # bitsandbytes will initialize CUDA on import, so it needs to be imported lazily + + has_been_replaced = False + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + if isinstance(module, nn.Linear) and name not in modules_to_not_convert: + # Check if the current key is not in the `modules_to_not_convert` + current_key_name_str = ".".join(current_key_name) + proceed = True + for key in modules_to_not_convert: + if ( + (key in current_key_name_str) and (key + "." in current_key_name_str) + ) or key == current_key_name_str: + proceed = False + break + if proceed: + # Load bnb module with empty weight and replace ``nn.Linear` module + if bnb_quantization_config.load_in_8bit: + bnb_module = bnb.nn.Linear8bitLt( + module.in_features, + module.out_features, + module.bias is not None, + has_fp16_weights=False, + threshold=bnb_quantization_config.llm_int8_threshold, + ) + elif bnb_quantization_config.load_in_4bit: + bnb_module = bnb.nn.Linear4bit( + module.in_features, + module.out_features, + module.bias is not None, + bnb_quantization_config.bnb_4bit_compute_dtype, + compress_statistics=bnb_quantization_config.bnb_4bit_use_double_quant, + quant_type=bnb_quantization_config.bnb_4bit_quant_type, + ) + else: + raise ValueError("load_in_8bit and load_in_4bit can't be both False") + bnb_module.weight.data = module.weight.data + bnb_module.weight.skip_zero_check = True + if module.bias is not None: + bnb_module.bias.data = module.bias.data + bnb_module.bias.skip_zero_check = True + bnb_module.requires_grad_(False) + setattr(model, name, bnb_module) + has_been_replaced = True + if len(list(module.children())) > 0: + _, _has_been_replaced = _replace_with_bnb_layers( + module, bnb_quantization_config, modules_to_not_convert, current_key_name + ) + has_been_replaced = has_been_replaced | _has_been_replaced + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced + + +def get_keys_to_not_convert(model): + r""" + An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules + we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want + to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in + int8. + + Parameters: + model (`torch.nn.Module`): + Input model + """ + # Create a copy of the model + # with init_empty_weights(): + # tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager` + tied_model = model + + tied_params = find_tied_parameters(tied_model) + # For compatibility with Accelerate < 0.18 + if isinstance(tied_params, dict): + tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys()) + else: + tied_keys = sum(tied_params, []) + has_tied_params = len(tied_keys) > 0 + + # Check if it is a base model + is_base_model = False + if hasattr(model, "base_model_prefix"): + is_base_model = not hasattr(model, model.base_model_prefix) + + # Ignore this for base models (BertModel, GPT2Model, etc.) + if (not has_tied_params) and is_base_model: + return [] + + # otherwise they have an attached head + list_modules = list(model.named_children()) + list_last_module = [list_modules[-1][0]] + + # add last module together with tied weights + intersection = set(list_last_module) - set(tied_keys) + list_untouched = list(set(tied_keys)) + list(intersection) + + # remove ".weight" from the keys + names_to_remove = [".weight", ".bias"] + filtered_module_names = [] + for name in list_untouched: + for name_to_remove in names_to_remove: + if name_to_remove in name: + name = name.replace(name_to_remove, "") + filtered_module_names.append(name) + + return filtered_module_names + + +def find_tied_parameters(model: nn.Module, **kwargs): + """ + Find the tied parameters in a given model. + + + + The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore + them. + + + + Args: + model (`torch.nn.Module`): The model to inspect. + + Returns: + List[List[str]]: A list of lists of parameter names being all tied together. + + Example: + + ```py + >>> from collections import OrderedDict + >>> import torch.nn as nn + + >>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))])) + >>> model.linear2.weight = model.linear1.weight + >>> find_tied_parameters(model) + [['linear1.weight', 'linear2.weight']] + ``` + """ + # Initialize result and named_parameters before recursing. + named_parameters = kwargs.get("named_parameters", None) + prefix = kwargs.get("prefix", "") + result = kwargs.get("result", {}) + + if named_parameters is None: + named_parameters = {n: p for n, p in model.named_parameters()} + else: + # A tied parameter will not be in the full `named_parameters` seen above but will be in the `named_parameters` + # of the submodule it belongs to. So while recursing we track the names that are not in the initial + # `named_parameters`. + for name, parameter in model.named_parameters(): + full_name = name if prefix == "" else f"{prefix}.{name}" + if full_name not in named_parameters: + # When we find one, it has to be one of the existing parameters. + for new_name, new_param in named_parameters.items(): + if new_param is parameter: + if new_name not in result: + result[new_name] = [] + result[new_name].append(full_name) + + # Once we have treated direct parameters, we move to the child modules. + for name, child in model.named_children(): + child_name = name if prefix == "" else f"{prefix}.{name}" + find_tied_parameters(child, named_parameters=named_parameters, prefix=child_name, result=result) + + return FindTiedParametersResult([sorted([weight] + list(set(tied))) for weight, tied in result.items()]) + + +class FindTiedParametersResult(list): + """ + This is a subclass of a list to handle backward compatibility for Transformers. Do not rely on the fact this is not + a list or on the `values` method as in the future this will be removed. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def values(self): + return sum([x[1:] for x in self], []) diff --git a/colossalai/quantization/bnb_config.py b/colossalai/quantization/bnb_config.py new file mode 100644 index 000000000..98a30211b --- /dev/null +++ b/colossalai/quantization/bnb_config.py @@ -0,0 +1,113 @@ +# adapted from Hugging Face accelerate/utils/dataclasses.py + +import warnings +from dataclasses import dataclass, field +from typing import List + +import torch + + +@dataclass +class BnbQuantizationConfig: + """ + A plugin to enable BitsAndBytes 4bit and 8bit quantization + """ + + load_in_8bit: bool = field(default=False, metadata={"help": "enable 8bit quantization."}) + + llm_int8_threshold: float = field( + default=6.0, metadata={"help": "value of the outliner threshold. only relevant when load_in_8bit=True"} + ) + + load_in_4bit: bool = field(default=False, metadata={"help": "enable 4bit quantization."}) + + bnb_4bit_quant_type: str = field( + default="fp4", + metadata={ + "help": "set the quantization data type in the `bnb.nn.Linear4Bit` layers. Options are {'fp4','np4'}." + }, + ) + + bnb_4bit_use_double_quant: bool = field( + default=False, + metadata={ + "help": "enable nested quantization where the quantization constants from the first quantization are quantized again." + }, + ) + + bnb_4bit_compute_dtype: bool = field( + default="fp16", + metadata={ + "help": "This sets the computational type which might be different than the input time. For example, inputs might be " + "fp32, but computation can be set to bf16 for speedups. Options are {'fp32','fp16','bf16'}." + }, + ) + + torch_dtype: torch.dtype = field( + default=None, + metadata={ + "help": "this sets the dtype of the remaining non quantized layers. `bitsandbytes` library suggests to set the value" + "to `torch.float16` for 8 bit model and use the same dtype as the compute dtype for 4 bit model " + }, + ) + + skip_modules: List[str] = field( + default=None, + metadata={ + "help": "an explicit list of the modules that we don't quantize. The dtype of these modules will be `torch_dtype`." + }, + ) + + keep_in_fp32_modules: List[str] = field( + default=None, + metadata={"help": "an explicit list of the modules that we don't quantize. We keep them in `torch.float32`."}, + ) + + def __post_init__(self): + if isinstance(self.bnb_4bit_compute_dtype, str): + if self.bnb_4bit_compute_dtype == "fp32": + self.bnb_4bit_compute_dtype = torch.float32 + elif self.bnb_4bit_compute_dtype == "fp16": + self.bnb_4bit_compute_dtype = torch.float16 + elif self.bnb_4bit_compute_dtype == "bf16": + self.bnb_4bit_compute_dtype = torch.bfloat16 + else: + raise ValueError( + f"bnb_4bit_compute_dtype must be in ['fp32','fp16','bf16'] but found {self.bnb_4bit_compute_dtype}" + ) + elif not isinstance(self.bnb_4bit_compute_dtype, torch.dtype): + raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype") + + if self.skip_modules is not None and not isinstance(self.skip_modules, list): + raise ValueError("skip_modules must be a list of strings") + + if self.keep_in_fp32_modules is not None and not isinstance(self.keep_in_fp32_modules, list): + raise ValueError("keep_in_fp_32_modules must be a list of strings") + + if self.load_in_4bit: + self.target_dtype = "int4" + + if self.load_in_8bit: + self.target_dtype = torch.int8 + + if self.load_in_4bit and self.llm_int8_threshold != 6.0: + warnings.warn("llm_int8_threshold can only be used for model loaded in 8bit") + + if isinstance(self.torch_dtype, str): + if self.torch_dtype == "fp32": + self.torch_dtype = torch.float32 + elif self.torch_dtype == "fp16": + self.torch_dtype = torch.float16 + elif self.torch_dtype == "bf16": + self.torch_dtype = torch.bfloat16 + else: + raise ValueError(f"torch_dtype must be in ['fp32','fp16','bf16'] but found {self.torch_dtype}") + + if self.load_in_8bit and self.torch_dtype is None: + self.torch_dtype = torch.float16 + + if self.load_in_4bit and self.torch_dtype is None: + self.torch_dtype = self.bnb_4bit_compute_dtype + + if not isinstance(self.torch_dtype, torch.dtype): + raise ValueError("torch_dtype must be a torch.dtype") diff --git a/colossalai/zero/low_level/_utils.py b/colossalai/zero/low_level/_utils.py index de08ecf3d..5ab703f09 100644 --- a/colossalai/zero/low_level/_utils.py +++ b/colossalai/zero/low_level/_utils.py @@ -190,6 +190,7 @@ def calculate_global_norm_from_list(norm_list): total_norm += norm**2.0 return math.sqrt(total_norm) + def sync_tensor(flat_tensor, tensor_list): """ Synchronize the flattened tensor and unflattened tensor list. When diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index e6974a676..7557d914d 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -187,9 +187,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for param_group in self.optim.param_groups: group_params = param_group["params"] for param in group_params: - assert ( - param.dtype == self._dtype - ), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`" + if not hasattr(param, "skip_zero_check") or param.skip_zero_check is False: + assert ( + param.dtype == self._dtype + ), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`" def _create_master_param_current_rank(self, param_list): # split each param evenly by world size diff --git a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md index 7a0e3b1a0..b043bfdef 100644 --- a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md +++ b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md @@ -149,7 +149,7 @@ model, optimizer, _criterion, _, lr_scheduler = booster.boost( ## Training GPT-2 using hybrid parallelism -In the previous tutorial, We've explained how to inject various parallelism features into the model and its training components using the Booster and `HybridParallelPlugin`. Now we can start model training. +In the previous tutorial, We've explained how to inject various parallelism features into the model and its training components using the Booster and `HybridParallelPlugin`. Now we can start model training. Define a training function. When pipeline parallelism is used, you need to call `booster.execute_pipeline` to schedule the stages of model training. ```python def train_epoch( @@ -204,4 +204,4 @@ Training the gpt-2 model for epoch in range(NUM_EPOCHS): train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) ``` - \ No newline at end of file + diff --git a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md index 117406980..ae9cbbabd 100644 --- a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md +++ b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md @@ -201,4 +201,4 @@ def train_epoch( for epoch in range(NUM_EPOCHS): train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) ``` - \ No newline at end of file + diff --git a/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md index 3de41601a..c25b109e7 100644 --- a/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md +++ b/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md @@ -220,7 +220,7 @@ model, optimizer, _criterion, train_dataloader, lr_scheduler = booster.boost( ) ``` ## 使用混合并行训练 ViT -最后就可以使用混合并行策略来训练模型了。我们先定义一个训练函数,描述训练过程。需要注意的是,如果使用了管道并行策略,需要调用`booster.execute_pipeline`来执行模型的训练,它会调用`scheduler`管理模型的前后向操作。 +最后就可以使用混合并行策略来训练模型了。我们先定义一个训练函数,描述训练过程。需要注意的是,如果使用了管道并行策略,需要调用`booster.execute_pipeline`来执行模型的训练,它会调用`scheduler`管理模型的前后向操作。 ```python def run_forward_backward( model: nn.Module, diff --git a/examples/inference/gptq_bloom.py b/examples/inference/gptq_bloom.py index 43e118cc0..f5413e316 100644 --- a/examples/inference/gptq_bloom.py +++ b/examples/inference/gptq_bloom.py @@ -1,12 +1,10 @@ import argparse -import logging import os import time import torch -from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig -from auto_gptq.nn_modules.qlinear import GeneralQuantLinear -from transformers import AutoTokenizer, BloomForCausalLM, BloomTokenizerFast, LlamaForCausalLM, LlamaTokenizer +from auto_gptq import AutoGPTQForCausalLM +from transformers import BloomTokenizerFast import colossalai from colossalai.inference.tensor_parallel.engine import TPInferEngine @@ -14,7 +12,7 @@ from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" def print_perf_stats(latency_set, config, bs, warmup=3): @@ -28,7 +26,7 @@ def print_perf_stats(latency_set, config, bs, warmup=3): avg = sum(latency_set) / count num_layers = getattr(config, "num_layers", config.num_hidden_layers) num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 - num_bytes = 2 # float16 + num_bytes = 2 # float16 print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) @@ -37,7 +35,6 @@ def print_perf_stats(latency_set, config, bs, warmup=3): def bench_bloom(args): - pretrained_model_dir = args.path quantized_model_dir = args.quantized_path max_batch_size = args.batch_size @@ -48,9 +45,9 @@ def bench_bloom(args): tokenizer.pad_token = tokenizer.eos_token # load quantized model to the first GPU - model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, - device=torch.cuda.current_device(), - inject_fused_attention=False) + model = AutoGPTQForCausalLM.from_quantized( + quantized_model_dir, device=torch.cuda.current_device(), inject_fused_attention=False + ) model = model.half() @@ -60,22 +57,22 @@ def bench_bloom(args): generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) input_tokens = { - "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'), - "attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda') + "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"), + "attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"), } # init TPInferEngine and shard the original model # To benchmark torch original, comment out the line of optimizing model - shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, - inference_only=True, - inference_gptq=True) + shard_config = ShardConfig( + enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True + ) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) # prepare data for generation generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) input_tokens = { "input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)), - "attention_mask": torch.ones((max_batch_size, max_input_len)) + "attention_mask": torch.ones((max_batch_size, max_input_len)), } for t in input_tokens: if torch.is_tensor(input_tokens[t]): @@ -99,7 +96,7 @@ def bench_bloom(args): def check_bloom(rank, world_size, port, args): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") bench_bloom(args) @@ -111,12 +108,12 @@ def test_bloom(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('-p', '--path', type=str, help='Model path', required=True) - parser.add_argument('-q', '--quantized_path', type=str, help='Model path', required=True) - parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size') - parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size') - parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length') - parser.add_argument('--output_len', type=int, default=128, help='Maximum output length') + parser.add_argument("-p", "--path", type=str, help="Model path", required=True) + parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True) + parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size") + parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length") + parser.add_argument("--output_len", type=int, default=128, help="Maximum output length") args = parser.parse_args() diff --git a/requirements/requirements.txt b/requirements/requirements.txt index db9c9908c..e3dfaf6e5 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -14,4 +14,5 @@ einops sentencepiece google protobuf -peft>=0.7.1 \ No newline at end of file +peft>=0.7.1 +bitsandbytes>=0.39.0 diff --git a/tests/test_booster/test_plugin/test_dp_plugin_base.py b/tests/test_booster/test_plugin/test_dp_plugin_base.py index eabe69ed3..fceb623fe 100644 --- a/tests/test_booster/test_plugin/test_dp_plugin_base.py +++ b/tests/test_booster/test_plugin/test_dp_plugin_base.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterator, List, Tuple, Union, Dict +from typing import Callable, Dict, Iterator, List, Tuple, Union import torch import torch.distributed as dist diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index 9ad39d089..302069209 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -51,7 +51,6 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config=None) # raise e - @parameterize("stage", [2]) def check_low_level_zero_plugin(stage: int, early_stop: bool = True): """check low level zero plugin over model zoo @@ -118,6 +117,7 @@ def check_low_level_zero_lora(stage, model_name, early_stop: bool = True): print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n") assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()]) + def run_dist(rank, world_size, port, early_stop: bool = True): # init dist env colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index ed5aa7dbd..4073cae0c 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -1,10 +1,11 @@ +from copy import deepcopy +from typing import Optional + import torch import torch.distributed as dist +from peft import LoraConfig from torchvision.models import resnet18 from utils import shared_tempdir -from typing import Optional -from peft import LoraConfig -from copy import deepcopy import colossalai from colossalai.booster import Booster @@ -131,12 +132,15 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo # return repr(e) raise e + @clear_cache_before_run() @parameterize("stage", [2]) @parameterize("shard", [True, False]) @parameterize("offload", [False, True]) @parameterize("model_name", ["transformers_llama"]) -def check_low_level_zero_lora_checkpointIO(stage: int, shard: bool, offload: bool, model_name: str, early_stop: bool = True): +def check_low_level_zero_lora_checkpointIO( + stage: int, shard: bool, offload: bool, model_name: str, early_stop: bool = True +): passed_models = [] failed_info = {} # (model_name, error) pair @@ -166,6 +170,7 @@ def check_low_level_zero_lora_checkpointIO(stage: int, shard: bool, offload: boo print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n") assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()]) + def run_dist(rank, world_size, port): colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost") check_low_level_zero_checkpointIO() diff --git a/tests/test_gptq/test_gptq_linear.py b/tests/test_gptq/test_gptq_linear.py index 9b650aa78..ded70fa43 100644 --- a/tests/test_gptq/test_gptq_linear.py +++ b/tests/test_gptq/test_gptq_linear.py @@ -1,16 +1,8 @@ -import math -import time - -import numpy as np import pytest import torch -import torch.nn as nn -import transformers from packaging import version try: - import triton - import triton.language as tl HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -22,6 +14,7 @@ try: from exllama_kernels import prepare_buffers, set_tuning_params from colossalai.inference.quant.gptq import CaiQuantLinear + HAS_AUTO_GPTQ = True except: HAS_AUTO_GPTQ = False @@ -32,13 +25,14 @@ import warnings HAS_GPTQ_CUDA = False try: from colossalai.kernel.op_builder.gptq import GPTQBuilder + gptq_cuda = GPTQBuilder().load() HAS_GPTQ_CUDA = True except ImportError: - warnings.warn('CUDA gptq is not installed') + warnings.warn("CUDA gptq is not installed") HAS_GPTQ_CUDA = False -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") max_inner_outer_dim = 1 max_input_len = 1 @@ -64,9 +58,9 @@ def init_buffer(cai_linear, use_act_order=False): max_input_len = 4096 # The temp_state buffer is required to reorder X in the act-order case. # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. - gptq_temp_state_buffer = torch.zeros((max_input_len, max_inner_outer_dim), - dtype=torch.float16, - device=torch.cuda.current_device()) + gptq_temp_state_buffer = torch.zeros( + (max_input_len, max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device() + ) gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device()) gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer) @@ -77,10 +71,11 @@ def init_buffer(cai_linear, use_act_order=False): gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ, - reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq") +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ, + reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq", +) def test_gptq_linear(): - infeature = 1024 outfeature = 1024 group_size = 128 @@ -120,7 +115,7 @@ def test_gptq_linear(): max_input_len = 2048 buffers = { "temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device), - "temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device) + "temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device), } prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"]) @@ -146,5 +141,4 @@ def test_gptq_linear(): if __name__ == "__main__": - test_gptq_linear() diff --git a/tests/test_infer_ops/triton/test_token_attn_fwd.py b/tests/test_infer_ops/triton/test_token_attn_fwd.py index a7fc3d29b..d3d4d0839 100644 --- a/tests/test_infer_ops/triton/test_token_attn_fwd.py +++ b/tests/test_infer_ops/triton/test_token_attn_fwd.py @@ -4,6 +4,7 @@ from packaging import version try: from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd + HAS_TRITON = True except ImportError: HAS_TRITON = False diff --git a/tests/test_lora/test_lora.py b/tests/test_lora/test_lora.py new file mode 100644 index 000000000..69febff38 --- /dev/null +++ b/tests/test_lora/test_lora.py @@ -0,0 +1,106 @@ +import copy +import os +from itertools import product + +import torch +from peft import LoraConfig +from torch import distributed as dist +from torch.optim import AdamW + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_checkpoint_io.utils import shared_tempdir + + +@clear_cache_before_run() +def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type): + model = model_fn() + lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1) + + test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()] + test_configs = [ + { + "lora_config": lora_config, + "quantize": False, + }, + { + "lora_config": lora_config, + "quantize": True, + }, + ] + for plugin, test_config in product(test_plugins, test_configs): + # checkpoint loaded model + model_save = model_fn() + model_load = copy.deepcopy(model_save) + + optimizer = AdamW(model.parameters(), lr=0.001) + criterion = loss_fn + + booster = Booster(plugin=plugin) + model_save = booster.enable_lora(model_save, **test_config) + model_save, optimizer, criterion, _, _ = booster.boost(model_save, optimizer, criterion) + + with shared_tempdir() as tempdir: + lora_ckpt_path = os.path.join(tempdir, "ckpt") + booster.save_lora_as_pretrained(model_save, lora_ckpt_path) + dist.barrier() + + # The Lora checkpoint should be small in size + checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024) + assert checkpoint_size_mb < 1 + + model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path, **test_config) + model_load, _, _, _, _ = booster.boost(model_load) + + check_state_dict_equal(model_save.state_dict(), model_load.state_dict()) + + # test fwd bwd correctness + test_model = model_load + model_copy = copy.deepcopy(model_load) + + data = data_gen_fn() + data = { + k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() + } + + output = test_model(**data) + output = output_transform_fn(output) + loss = criterion(output) + + booster.backward(loss, optimizer) + optimizer.clip_grad_by_norm(1.0) + optimizer.step() + + for (n1, p1), (n2, p2) in zip(test_model.named_parameters(), model_copy.named_parameters()): + if "lora_" in n1: + # lora modules require gradients, thus updated + assert p1.requires_grad + assert not torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3) + else: + if not p1.requires_grad: + torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3) + + +def run_lora_test(): + sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + task_type = None + if name == "transformers_llama_for_casual_lm": + task_type = "CAUSAL_LM" + if name == "transformers_llama_for_sequence_classification": + task_type = "SEQ_CLS" + check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_lora_test() + + +@rerun_if_address_is_in_use() +def test_torch_ddp_lora(): + spawn(run_dist, 2) diff --git a/tests/test_lora/test_torch_ddp_lora.py b/tests/test_lora/test_torch_ddp_lora.py deleted file mode 100644 index b3169bf86..000000000 --- a/tests/test_lora/test_torch_ddp_lora.py +++ /dev/null @@ -1,108 +0,0 @@ -import copy -import os - -import torch -from peft import LoraConfig -from torch import distributed as dist -from torch.optim import AdamW - -import colossalai -from colossalai.booster import Booster -from colossalai.booster.plugin import TorchDDPPlugin -from colossalai.testing import ( - assert_equal, - assert_not_equal, - check_state_dict_equal, - clear_cache_before_run, - rerun_if_address_is_in_use, - spawn, -) -from tests.kit.model_zoo import model_zoo -from tests.test_checkpoint_io.utils import shared_tempdir - - -@clear_cache_before_run() -def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type): - model = model_fn() - lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1) - - plugin = TorchDDPPlugin() - booster = Booster(plugin=plugin) - - model = booster.enable_lora(model, lora_config=lora_config) - model_copy = copy.deepcopy(model) - - optimizer = AdamW(model.parameters(), lr=0.001) - criterion = loss_fn - - model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) - - data = data_gen_fn() - data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} - - output = model(**data) - output = output_transform_fn(output) - loss = criterion(output) - - booster.backward(loss, optimizer) - optimizer.clip_grad_by_norm(1.0) - optimizer.step() - - for (n1, p1), (n2, p2) in zip(model.named_parameters(), model_copy.named_parameters()): - if "lora_" in n1: - # lora modules require gradients, thus updated - assert p1.requires_grad - assert_not_equal(p1.to(p2.device), p2) - else: - if not p1.requires_grad: - assert_equal(p1.to(p2.device), p2) - - -@clear_cache_before_run() -def check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type): - plugin = TorchDDPPlugin() - lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1) - - model_save = model_fn() - model_load = copy.deepcopy(model_save) - - booster = Booster(plugin=plugin) - model_save = booster.enable_lora(model_save, lora_config=lora_config) - model_save, _, _, _, _ = booster.boost(model_save) - - with shared_tempdir() as tempdir: - lora_ckpt_path = os.path.join(tempdir, "ckpt") - booster.save_lora_as_pretrained(model_save, lora_ckpt_path) - dist.barrier() - - # The Lora checkpoint should be small in size - checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024) - assert checkpoint_size_mb < 1 - - model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path) - model_load, _, _, _, _ = booster.boost(model_load) - - check_state_dict_equal(model_save.state_dict(), model_load.state_dict()) - - -def run_lora_test(): - sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - task_type = None - if name == "transformers_llama_for_casual_lm": - task_type = "CAUSAL_LM" - if name == "transformers_llama_for_sequence_classification": - task_type = "SEQ_CLS" - check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type) - check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type) - - -def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_lora_test() - - -@rerun_if_address_is_in_use() -def test_torch_ddp_lora(): - spawn(run_dist, 2)