mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-27 15:57:16 +00:00
[Feature] qlora support (#5586)
* [feature] qlora support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * qlora follow commit * migrate qutization folder to colossalai/ * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
cabc1286ca
commit
52a2dded36
17
LICENSE
17
LICENSE
@ -527,3 +527,20 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
|
||||
---------------- LICENSE FOR Hugging Face accelerate ----------------
|
||||
|
||||
Copyright 2021 The HuggingFace Team
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
@ -76,9 +76,11 @@ def main(args):
|
||||
if args.strategy == "ddp":
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == "colossalai_gemini":
|
||||
strategy = GeminiStrategy(placement_policy="static",initial_scale=2**5)
|
||||
strategy = GeminiStrategy(placement_policy="static", initial_scale=2**5)
|
||||
elif args.strategy == "colossalai_gemini_cpu":
|
||||
strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
|
||||
strategy = GeminiStrategy(
|
||||
placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5
|
||||
)
|
||||
elif args.strategy == "colossalai_zero2":
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||
elif args.strategy == "colossalai_zero2_cpu":
|
||||
|
@ -30,4 +30,3 @@ class Actor(LoRAModule):
|
||||
"""Returns model output."""
|
||||
output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
|
||||
return output
|
||||
|
||||
|
@ -75,7 +75,9 @@ def get_strategy_from_args(strategy: str):
|
||||
elif strategy == "colossalai_zero2":
|
||||
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||
elif strategy == "colossalai_gemini_cpu":
|
||||
strategy_ = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
|
||||
strategy_ = GeminiStrategy(
|
||||
placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5
|
||||
)
|
||||
elif strategy == "colossalai_zero2_cpu":
|
||||
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
|
||||
else:
|
||||
|
@ -101,16 +101,17 @@ class DDPStrategy(Strategy):
|
||||
|
||||
model_path = os.path.join(path, "pytorch_model.bin")
|
||||
self.save_model(model, model_path, shard=shard)
|
||||
|
||||
def _replace_keys(model_path: str, replace_fn: Callable):
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
state_dict = {replace_fn(k): v for k, v in state_dict.items()}
|
||||
torch.save(state_dict, model_path)
|
||||
|
||||
# FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
|
||||
# HACK: rename keys of pytorch_model.bin
|
||||
if dist.get_rank() == 0:
|
||||
_replace_keys(model_path, lambda k: k.replace("model.", "", 1))
|
||||
|
||||
|
||||
def get_model_state_dict_shard(self, model: nn.Module, **config):
|
||||
# TODO: implement sharding on naive strategy
|
||||
model = self.unwrap_model(model)
|
||||
|
@ -24,7 +24,9 @@ def main(args):
|
||||
if args.strategy == "ddp":
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == "colossalai_gemini":
|
||||
strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
|
||||
strategy = GeminiStrategy(
|
||||
placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5
|
||||
)
|
||||
elif args.strategy == "colossalai_zero2":
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
|
||||
else:
|
||||
|
@ -130,8 +130,8 @@ from modelscope import AutoModelForCausalLM, AutoTokenizer, snapshot_download
|
||||
model_dir = snapshot_download('colossalai/Colossal-LLaMA-2-7b-base', revision='v1.0.1')
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, device_map="auto", trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", trust_remote_code=True).eval()
|
||||
generation_kwargs = {"max_new_tokens": 256,
|
||||
"top_p": 0.95,
|
||||
generation_kwargs = {"max_new_tokens": 256,
|
||||
"top_p": 0.95,
|
||||
"temperature": 0.3
|
||||
}
|
||||
input = '离离原上草,'
|
||||
|
@ -1,20 +1,20 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable
|
||||
from typing import Callable, Dict, Iterator, List, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import dataset_dict, load_from_disk
|
||||
import torch.nn.functional as F
|
||||
from datasets import Dataset as HFDataset
|
||||
from datasets import dataset_dict, load_from_disk
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler
|
||||
from torch.utils.data import ConcatDataset, DataLoader, Dataset, DistributedSampler
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
import torch.nn.functional as F
|
||||
|
||||
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
||||
PathType = Union[str, os.PathLike]
|
||||
|
@ -7,9 +7,9 @@ Splicing multiple pre-tokenized sequence data points
|
||||
import random
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from datasets import dataset_dict
|
||||
from typing import Any, Callable, Dict, Iterable, List, Union, Tuple
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
|
||||
|
||||
from datasets import dataset_dict
|
||||
from torch.utils.data import ConcatDataset, Dataset, IterableDataset
|
||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
@ -169,12 +169,7 @@ class ClosedToConstantLengthSplicedDataset(IterableDataset):
|
||||
spliced_labels.extend(seq_labels)
|
||||
# For residual spliced data point at the end of the data set
|
||||
if self.infinite is False and more_data_points is False and len(spliced_input_ids) > 0:
|
||||
examples.append(
|
||||
{
|
||||
self.input_ids_field: spliced_input_ids,
|
||||
self.labels_field: spliced_labels
|
||||
}
|
||||
)
|
||||
examples.append({self.input_ids_field: spliced_input_ids, self.labels_field: spliced_labels})
|
||||
if self.shuffle:
|
||||
random.shuffle(examples)
|
||||
for spliced_data_point in examples:
|
||||
|
@ -8,11 +8,10 @@ import argparse
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import LlamaTokenizer, LlamaForCausalLM
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
|
@ -6,12 +6,12 @@ Initialize new tokenizer for continual pre-training
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
from typing import List, Union
|
||||
|
||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
|
||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
@ -10,8 +10,8 @@ import os
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
@ -242,4 +242,4 @@ To comprehensively assess the performance of the Colossal-LLaMA-2-7B-base model,
|
||||
## Conclusion
|
||||
In general, the Colossal-LLaMA-2-7B-base model not only enhances its understanding of English but also exhibits significant improvements in its comprehension of Chinese. It boasts a broad spectrum of general knowledge, encompassing various fields such as food, sports, technology, literature, games, and more. Regarding text generation tasks, the Colossal-LLaMA-2-7B-base model excels in writing performance; however, its ability to generate specific formats like code, emails, tables, etc., needs enhancement due to the scarcity of relevant training data during our training phase. When compared to the Qwen-7b-base model, the Colossal-LLaMA-2-7B-base model outperforms it in answering most English questions and some Chinese questions, as demonstrated in the examples above.
|
||||
|
||||
Presently, the Colossal-LLaMA-2-7B-base model already exhibits some capabilities in sentiment analysis, logical reasoning, information extraction, role-play, classification, and rewriting. These capabilities are poised for further improvement in the future as part of our ongoing enhancements.
|
||||
Presently, the Colossal-LLaMA-2-7B-base model already exhibits some capabilities in sentiment analysis, logical reasoning, information extraction, role-play, classification, and rewriting. These capabilities are poised for further improvement in the future as part of our ongoing enhancements.
|
||||
|
@ -1,2 +1,2 @@
|
||||
hostname1
|
||||
hostname2
|
||||
hostname2
|
||||
|
@ -11,14 +11,14 @@ import os
|
||||
import time
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
from colossal_llama2.dataset.spliced_and_tokenized_dataset import (
|
||||
ClosedToConstantLengthSplicedDataset,
|
||||
supervised_tokenize,
|
||||
)
|
||||
from datasets import dataset_dict, load_dataset
|
||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossal_llama2.dataset.spliced_and_tokenized_dataset import (
|
||||
supervised_tokenize,
|
||||
ClosedToConstantLengthSplicedDataset,
|
||||
)
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
@ -149,5 +149,5 @@ def main():
|
||||
spliced_dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(spliced_dataset), cpu_count()))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -12,4 +12,3 @@ flash-attn>=2.0.0,<=2.0.5
|
||||
tqdm
|
||||
sentencepiece==0.1.99
|
||||
protobuf<=3.20.0
|
||||
|
||||
|
@ -1,45 +1,39 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
|
||||
Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
|
||||
"""
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_llama2.dataset.loader import (
|
||||
DataCollatorForSupervisedDataset,
|
||||
StatefulDistributedSampler,
|
||||
load_tokenized_dataset,
|
||||
setup_distributed_dataloader,
|
||||
)
|
||||
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
|
||||
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
|
||||
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import (
|
||||
GeminiPlugin,
|
||||
LowLevelZeroPlugin,
|
||||
HybridParallelPlugin,
|
||||
)
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from colossal_llama2.dataset.loader import (
|
||||
load_tokenized_dataset,
|
||||
setup_distributed_dataloader,
|
||||
DataCollatorForSupervisedDataset,
|
||||
StatefulDistributedSampler,
|
||||
)
|
||||
|
||||
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
|
||||
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
|
||||
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
|
||||
|
||||
|
||||
def get_model_numel(model: torch.nn.Module) -> int:
|
||||
return sum(p.numel() for p in model.parameters())
|
||||
@ -372,9 +366,7 @@ def main() -> None:
|
||||
# Final save.
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(
|
||||
f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}"
|
||||
)
|
||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
@ -1 +1 @@
|
||||
0.0.1
|
||||
0.0.1
|
||||
|
@ -19,6 +19,7 @@ except ImportError:
|
||||
import colossalai.interface.pretrained as pretrained_utils
|
||||
from colossalai.checkpoint_io import GeneralCheckpointIO
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.quantization import BnbQuantizationConfig
|
||||
|
||||
from .accelerator import Accelerator
|
||||
from .mixed_precision import MixedPrecision, mixed_precision_factory
|
||||
@ -230,7 +231,12 @@ class Booster:
|
||||
return self.plugin.no_sync(model, optimizer)
|
||||
|
||||
def enable_lora(
|
||||
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: "peft.LoraConfig" = None
|
||||
self,
|
||||
model: nn.Module,
|
||||
pretrained_dir: Optional[str] = None,
|
||||
lora_config: "peft.LoraConfig" = None,
|
||||
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
|
||||
quantize=False,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory.
|
||||
@ -259,7 +265,20 @@ class Booster:
|
||||
assert (
|
||||
pretrained_dir is not None
|
||||
), "Please provide pretrained directory path if not passing in lora configuration."
|
||||
return self.plugin.enable_lora(model, pretrained_dir, lora_config)
|
||||
if quantize is True:
|
||||
if bnb_quantization_config is not None:
|
||||
warnings.warn(
|
||||
"User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk."
|
||||
)
|
||||
else:
|
||||
bnb_quantization_config = BnbQuantizationConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
|
||||
return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config)
|
||||
|
||||
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
|
||||
"""Load model from checkpoint.
|
||||
|
@ -1,11 +1,11 @@
|
||||
import logging
|
||||
import warnings
|
||||
import enum
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from types import MethodType
|
||||
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Dict
|
||||
from typing import Callable, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -27,6 +27,7 @@ from colossalai.checkpoint_io.utils import (
|
||||
sharded_optimizer_loading_epilogue,
|
||||
)
|
||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
|
||||
@ -44,6 +45,7 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
|
||||
|
||||
SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"]
|
||||
|
||||
|
||||
class OptimizerParamCheckState(enum.Enum):
|
||||
ORIGIN_PARAM_FINDED = 0
|
||||
ORIGIN_PARAM_NOT_FIND = -1
|
||||
@ -221,6 +223,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
from peft import PeftModel
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
peft_model = model.unwrap()
|
||||
assert isinstance(
|
||||
@ -331,38 +334,45 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
def supported_devices(self) -> List[str]:
|
||||
return ["cuda"]
|
||||
|
||||
|
||||
def support_lora(self) -> bool:
|
||||
return True
|
||||
|
||||
def enable_lora(
|
||||
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||
self,
|
||||
model: nn.Module,
|
||||
pretrained_dir: Optional[str] = None,
|
||||
lora_config: Optional[Dict] = None,
|
||||
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
|
||||
) -> nn.Module:
|
||||
from peft import PeftModel, get_peft_model
|
||||
|
||||
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
|
||||
self.lora_enabled = True
|
||||
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
|
||||
|
||||
if bnb_quantization_config is not None:
|
||||
model = quantize_model(model, bnb_quantization_config)
|
||||
|
||||
if pretrained_dir is None:
|
||||
peft_model = get_peft_model(model, lora_config)
|
||||
else:
|
||||
peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)
|
||||
return peft_model
|
||||
|
||||
|
||||
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter):
|
||||
origin_param_id = id(origin_param)
|
||||
for group_id, param_group in enumerate(optimizer.param_groups):
|
||||
for p in param_group['params']:
|
||||
for p in param_group["params"]:
|
||||
if id(p) == origin_param_id:
|
||||
return group_id
|
||||
return -1
|
||||
|
||||
|
||||
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter, lora_param: Parameter):
|
||||
origin_param_id = id(origin_param)
|
||||
lora_param_id = id(lora_param)
|
||||
target_group_id = None
|
||||
for group_id, param_group in enumerate(optimizer.param_groups):
|
||||
for p in param_group['params']:
|
||||
for p in param_group["params"]:
|
||||
if id(p) == lora_param_id:
|
||||
# check if the lora parameter exists.
|
||||
return target_group_id, OptimizerParamCheckState.LORA_PARM_EXISTED
|
||||
@ -372,25 +382,31 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_FINDED
|
||||
else:
|
||||
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND
|
||||
|
||||
|
||||
def add_lora_params_to_optimizer(self, model, optimizer):
|
||||
""" add lora parameters to optimizer """
|
||||
name2param= {}
|
||||
"""add lora parameters to optimizer"""
|
||||
name2param = {}
|
||||
for name, param in model.named_parameters():
|
||||
name2param[name] = param
|
||||
|
||||
for name, param in name2param.items():
|
||||
if 'lora_A' in name or 'lora_B' in name:
|
||||
if "lora_A" in name or "lora_B" in name:
|
||||
origin_key = name.replace("lora_A.", "")
|
||||
origin_key = origin_key.replace("lora_B.", "")
|
||||
origin_key = origin_key.replace(f"{model.active_adapter}", "base_layer")
|
||||
origin_param = name2param[origin_key]
|
||||
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
|
||||
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
|
||||
warnings.warn("Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups.")
|
||||
elif check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED and group_id is not None and group_id >= 0:
|
||||
optimizer.param_groups[group_id]['params'].append(param)
|
||||
|
||||
warnings.warn(
|
||||
"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups."
|
||||
)
|
||||
elif (
|
||||
check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED
|
||||
and group_id is not None
|
||||
and group_id >= 0
|
||||
):
|
||||
optimizer.param_groups[group_id]["params"].append(param)
|
||||
|
||||
def configure(
|
||||
self,
|
||||
model: nn.Module,
|
||||
@ -401,11 +417,13 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
if self.lora_enabled:
|
||||
from peft import PeftModel
|
||||
assert isinstance(model, PeftModel), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"
|
||||
|
||||
assert isinstance(
|
||||
model, PeftModel
|
||||
), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"
|
||||
if optimizer is not None:
|
||||
self.add_lora_params_to_optimizer(model, optimizer)
|
||||
|
||||
|
||||
if not isinstance(model, ModelWrapper):
|
||||
model = LowLevelZeroModel(model, self.precision)
|
||||
|
||||
|
@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
|
||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||
|
||||
from .dp_plugin_base import DPPluginBase
|
||||
|
||||
@ -237,10 +238,17 @@ class TorchDDPPlugin(DPPluginBase):
|
||||
return model.module.no_sync()
|
||||
|
||||
def enable_lora(
|
||||
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||
self,
|
||||
model: nn.Module,
|
||||
pretrained_dir: Optional[str] = None,
|
||||
lora_config: Optional[Dict] = None,
|
||||
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
|
||||
) -> nn.Module:
|
||||
from peft import PeftModel, get_peft_model
|
||||
|
||||
if bnb_quantization_config is not None:
|
||||
model = quantize_model(model, bnb_quantization_config)
|
||||
|
||||
assert not isinstance(model, TorchDDPModel), "Lora should be enabled before boosting the model."
|
||||
if pretrained_dir is None:
|
||||
return get_peft_model(model, lora_config)
|
||||
|
@ -64,7 +64,7 @@ vllm
|
||||
flash-attention
|
||||
|
||||
# install lightllm since we depend on lightllm triton kernels
|
||||
git clone https://github.com/ModelTC/lightllm
|
||||
git clone https://github.com/ModelTC/lightllm
|
||||
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
|
||||
cd lightllm
|
||||
pip3 install -e .
|
||||
@ -84,7 +84,7 @@ cd /path/to/CollossalAI
|
||||
pip install -e .
|
||||
|
||||
# install lightllm
|
||||
git clone https://github.com/ModelTC/lightllm
|
||||
git clone https://github.com/ModelTC/lightllm
|
||||
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
|
||||
cd lightllm
|
||||
pip3 install -e .
|
||||
|
@ -18,15 +18,15 @@ from .gptq_op import CaiGPTQLinearOp
|
||||
HAS_GPTQ_CUDA = False
|
||||
try:
|
||||
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
||||
|
||||
gptq_cuda = GPTQBuilder().load()
|
||||
HAS_GPTQ_CUDA = True
|
||||
except ImportError:
|
||||
warnings.warn('CUDA gptq is not installed')
|
||||
warnings.warn("CUDA gptq is not installed")
|
||||
HAS_GPTQ_CUDA = False
|
||||
|
||||
|
||||
class CaiQuantLinear(nn.Module):
|
||||
|
||||
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
||||
super().__init__()
|
||||
if bits not in [2, 4, 8]:
|
||||
@ -37,23 +37,28 @@ class CaiQuantLinear(nn.Module):
|
||||
self.maxq = 2**self.bits - 1
|
||||
self.groupsize = groupsize if groupsize != -1 else infeatures
|
||||
|
||||
self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
|
||||
self.register_buffer("qweight", torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
|
||||
self.register_buffer(
|
||||
'qzeros',
|
||||
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32))
|
||||
self.register_buffer('scales',
|
||||
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
|
||||
"qzeros",
|
||||
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32),
|
||||
)
|
||||
self.register_buffer(
|
||||
"scales", torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)
|
||||
)
|
||||
if row_split:
|
||||
self.register_buffer(
|
||||
'g_idx',
|
||||
torch.tensor([(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)],
|
||||
dtype=torch.int32))
|
||||
"g_idx",
|
||||
torch.tensor(
|
||||
[(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], dtype=torch.int32
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.register_buffer('g_idx',
|
||||
torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))
|
||||
self.register_buffer(
|
||||
"g_idx", torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
||||
self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
@ -66,9 +71,11 @@ class CaiQuantLinear(nn.Module):
|
||||
self.row_split = row_split
|
||||
|
||||
def pack(self, linear, scales, zeros, g_idx=None):
|
||||
|
||||
g_idx = g_idx.clone() if g_idx is not None else torch.tensor(
|
||||
[i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)
|
||||
g_idx = (
|
||||
g_idx.clone()
|
||||
if g_idx is not None
|
||||
else torch.tensor([i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)
|
||||
)
|
||||
|
||||
scales = scales.t().contiguous()
|
||||
zeros = zeros.t().contiguous()
|
||||
@ -79,7 +86,6 @@ class CaiQuantLinear(nn.Module):
|
||||
if linear.bias is not None:
|
||||
self.bias = linear.bias.clone().half()
|
||||
|
||||
wn = 8
|
||||
pbits = 32
|
||||
ptype = torch.int32
|
||||
unsign_type = np.uint32
|
||||
@ -88,9 +94,10 @@ class CaiQuantLinear(nn.Module):
|
||||
intweight = []
|
||||
for idx in range(self.infeatures):
|
||||
intweight.append(
|
||||
torch.round(
|
||||
(linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:,
|
||||
None])
|
||||
torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[
|
||||
:, None
|
||||
]
|
||||
)
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.numpy().astype(unsign_type)
|
||||
@ -109,7 +116,7 @@ class CaiQuantLinear(nn.Module):
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
qweight = qweight.astype(sign_type)
|
||||
qweight1 = torch.from_numpy(qweight)
|
||||
qweight1 = qweight1.contiguous() #.to("cuda")
|
||||
qweight1 = qweight1.contiguous() # .to("cuda")
|
||||
self.qweight.data.copy_(qweight1)
|
||||
|
||||
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type)
|
||||
@ -140,17 +147,20 @@ class CaiQuantLinear(nn.Module):
|
||||
self.q4_width = self.qweight.shape[1]
|
||||
if self.g_idx is not None:
|
||||
if self.row_split and torch.equal(
|
||||
self.g_idx,
|
||||
torch.tensor(
|
||||
[(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)],
|
||||
dtype=torch.int32,
|
||||
device=self.g_idx.device)):
|
||||
self.g_idx,
|
||||
torch.tensor(
|
||||
[(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)],
|
||||
dtype=torch.int32,
|
||||
device=self.g_idx.device,
|
||||
),
|
||||
):
|
||||
self.g_idx = None
|
||||
elif torch.equal(
|
||||
self.g_idx,
|
||||
torch.tensor([i // self.groupsize for i in range(self.infeatures)],
|
||||
dtype=torch.int32,
|
||||
device=self.g_idx.device)):
|
||||
self.g_idx,
|
||||
torch.tensor(
|
||||
[i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device
|
||||
),
|
||||
):
|
||||
self.g_idx = None
|
||||
|
||||
if self.g_idx is not None:
|
||||
@ -165,7 +175,6 @@ class CaiQuantLinear(nn.Module):
|
||||
outshape = x.shape[:-1] + (self.outfeatures,)
|
||||
|
||||
if HAS_GPTQ_CUDA and self.bits == 4:
|
||||
|
||||
if self.q4 is None:
|
||||
self.init_q4()
|
||||
|
||||
@ -191,7 +200,6 @@ class CaiQuantLinear(nn.Module):
|
||||
|
||||
|
||||
def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1):
|
||||
|
||||
qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1)
|
||||
qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1)
|
||||
scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1)
|
||||
@ -203,24 +211,24 @@ def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1
|
||||
zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num
|
||||
|
||||
for i in range(split_num):
|
||||
cai_linear.qweight[:, i * cai_split_out_features:(i + 1) *
|
||||
cai_split_out_features] = qweights[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) *
|
||||
cai_split_out_features]
|
||||
cai_linear.qzeros[:, i * zero_split_block:(i + 1) *
|
||||
zero_split_block] = qzeros[i][:, tp_rank * zero_split_block:(tp_rank + 1) * zero_split_block]
|
||||
cai_linear.scales[:, i * cai_split_out_features:(i + 1) *
|
||||
cai_split_out_features] = scales[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) *
|
||||
cai_split_out_features]
|
||||
cai_linear.qweight[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = qweights[i][
|
||||
:, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
|
||||
]
|
||||
cai_linear.qzeros[:, i * zero_split_block : (i + 1) * zero_split_block] = qzeros[i][
|
||||
:, tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block
|
||||
]
|
||||
cai_linear.scales[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = scales[i][
|
||||
:, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
|
||||
]
|
||||
if cai_linear.bias is not None:
|
||||
cai_linear.bias[i * cai_split_out_features:(i + 1) *
|
||||
cai_split_out_features] = bias[i][tp_rank * cai_split_out_features:(tp_rank + 1) *
|
||||
cai_split_out_features]
|
||||
cai_linear.bias[i * cai_split_out_features : (i + 1) * cai_split_out_features] = bias[i][
|
||||
tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
|
||||
]
|
||||
|
||||
cai_linear.g_idx.copy_(g_idx)
|
||||
|
||||
|
||||
def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
|
||||
|
||||
qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0)
|
||||
qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0)
|
||||
scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0)
|
||||
@ -231,47 +239,40 @@ def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
|
||||
idx_split_features = cai_linear.infeatures // split_num
|
||||
|
||||
for i in range(split_num):
|
||||
cai_linear.qweight[i * cai_split_in_features:(i + 1) *
|
||||
cai_split_in_features, :] = qweights[i][tp_rank * cai_split_in_features:(tp_rank + 1) *
|
||||
cai_split_in_features, :]
|
||||
cai_linear.qzeros[i * zero_split_block:(i + 1) *
|
||||
zero_split_block, :] = qzeros[i][tp_rank * zero_split_block:(tp_rank + 1) *
|
||||
zero_split_block, :]
|
||||
cai_linear.scales[i * zero_split_block:(i + 1) *
|
||||
zero_split_block, :] = scales[i][tp_rank * zero_split_block:(tp_rank + 1) *
|
||||
zero_split_block, :]
|
||||
cai_linear.g_idx[i * idx_split_features:(i + 1) *
|
||||
idx_split_features] = g_idxs[i][tp_rank * idx_split_features:(tp_rank + 1) *
|
||||
idx_split_features]
|
||||
cai_linear.qweight[i * cai_split_in_features : (i + 1) * cai_split_in_features, :] = qweights[i][
|
||||
tp_rank * cai_split_in_features : (tp_rank + 1) * cai_split_in_features, :
|
||||
]
|
||||
cai_linear.qzeros[i * zero_split_block : (i + 1) * zero_split_block, :] = qzeros[i][
|
||||
tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, :
|
||||
]
|
||||
cai_linear.scales[i * zero_split_block : (i + 1) * zero_split_block, :] = scales[i][
|
||||
tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, :
|
||||
]
|
||||
cai_linear.g_idx[i * idx_split_features : (i + 1) * idx_split_features] = g_idxs[i][
|
||||
tp_rank * idx_split_features : (tp_rank + 1) * idx_split_features
|
||||
]
|
||||
if cai_linear.bias is not None:
|
||||
cai_linear.bias.copy_(gptq_linear.bias)
|
||||
|
||||
|
||||
class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
||||
|
||||
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
||||
|
||||
super().__init__(bits,
|
||||
groupsize,
|
||||
infeatures,
|
||||
outfeatures,
|
||||
bias,
|
||||
tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
row_split=row_split)
|
||||
super().__init__(
|
||||
bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split
|
||||
)
|
||||
self.process_group = None
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||
**kwargs) -> ParallelModule:
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
in_features = module.in_features
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, \
|
||||
f'Expected only one process group, got {len(process_group)}.'
|
||||
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||
process_group = process_group[0]
|
||||
|
||||
tp_size = dist.get_world_size(process_group)
|
||||
@ -282,15 +283,18 @@ class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
||||
|
||||
if in_features % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
|
||||
linear_1d = RowCaiQuantLinear(module.bits,
|
||||
module.group_size,
|
||||
module.in_features // tp_size,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
row_split=True)
|
||||
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||
)
|
||||
linear_1d = RowCaiQuantLinear(
|
||||
module.bits,
|
||||
module.group_size,
|
||||
module.in_features // tp_size,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
row_split=True,
|
||||
)
|
||||
linear_1d.process_group = process_group
|
||||
|
||||
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||
@ -306,30 +310,23 @@ class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
||||
|
||||
|
||||
class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
||||
|
||||
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
||||
|
||||
super().__init__(bits,
|
||||
groupsize,
|
||||
infeatures,
|
||||
outfeatures,
|
||||
bias,
|
||||
tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
row_split=row_split)
|
||||
super().__init__(
|
||||
bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split
|
||||
)
|
||||
self.process_group = None
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||
**kwargs) -> ParallelModule:
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
in_features = module.in_features
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, \
|
||||
f'Expected only one process group, got {len(process_group)}.'
|
||||
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||
process_group = process_group[0]
|
||||
|
||||
tp_size = dist.get_world_size(process_group)
|
||||
@ -340,14 +337,17 @@ class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
||||
|
||||
if in_features % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
|
||||
linear_1d = ColCaiQuantLinear(module.bits,
|
||||
module.group_size,
|
||||
module.in_features,
|
||||
module.out_features // tp_size,
|
||||
module.bias is not None,
|
||||
tp_size=tp_size,
|
||||
tp_rank=tp_rank)
|
||||
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||
)
|
||||
linear_1d = ColCaiQuantLinear(
|
||||
module.bits,
|
||||
module.group_size,
|
||||
module.in_features,
|
||||
module.out_features // tp_size,
|
||||
module.bias is not None,
|
||||
tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
)
|
||||
linear_1d.process_group = process_group
|
||||
|
||||
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||
|
@ -5,6 +5,7 @@ import torch
|
||||
|
||||
from .kvcache_manager import MemoryManager
|
||||
|
||||
|
||||
# adapted from: lightllm/server/router/model_infer/infer_batch.py
|
||||
@dataclass
|
||||
class BatchInferState:
|
||||
|
@ -19,8 +19,11 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
|
||||
from ._utils import copy_kv_to_mem_cache
|
||||
|
||||
try:
|
||||
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_llama2_context_attention_fwd
|
||||
from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd as chatglm2_rotary_emb_fwd
|
||||
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
|
||||
context_attention_fwd as lightllm_llama2_context_attention_fwd,
|
||||
)
|
||||
|
||||
HAS_LIGHTLLM_KERNEL = True
|
||||
except:
|
||||
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
|
||||
|
@ -4,7 +4,6 @@ import torch
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
from colossalai.shardformer.modeling.bloom import build_bloom_alibi_tensor_fn
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||
from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy
|
||||
|
||||
@ -40,33 +39,36 @@ class BloomModelInferPolicy(BloomForCausalLMPolicy):
|
||||
policy = super().module_policy()
|
||||
if self.shard_config.inference_gptq:
|
||||
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
|
||||
policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
|
||||
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.query_key_value",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={'split_num': 3}),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.dense",
|
||||
target_module=RowCaiQuantLinear,
|
||||
kwargs={'split_num': 1}),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.attention_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dense_h_to_4h",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={'split_num': 1}),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dense_4h_to_h",
|
||||
target_module=RowCaiQuantLinear,
|
||||
kwargs={'split_num': 1}),
|
||||
])
|
||||
|
||||
policy[BloomBlock] = ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
"self_attention.hidden_size": self.model.config.hidden_size
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
"self_attention.split_size": self.model.config.hidden_size
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.query_key_value",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={"split_num": 3},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.dense", target_module=RowCaiQuantLinear, kwargs={"split_num": 1}
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.attention_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dense_h_to_4h", target_module=ColCaiQuantLinear, kwargs={"split_num": 1}
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dense_4h_to_h", target_module=RowCaiQuantLinear, kwargs={"split_num": 1}
|
||||
),
|
||||
],
|
||||
)
|
||||
# NOTE set inference mode to shard config
|
||||
self.shard_config._infer()
|
||||
|
||||
|
@ -13,6 +13,7 @@ from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forw
|
||||
|
||||
try:
|
||||
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward
|
||||
|
||||
HAS_TRITON_RMSNORM = True
|
||||
except:
|
||||
print("you should install triton from https://github.com/openai/triton")
|
||||
@ -21,6 +22,7 @@ except:
|
||||
|
||||
def get_triton_rmsnorm_forward():
|
||||
if HAS_TRITON_RMSNORM:
|
||||
|
||||
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
|
||||
return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
|
||||
|
||||
|
@ -1,254 +1,202 @@
|
||||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include "util.cuh"
|
||||
#include "tuning.h"
|
||||
#include "cuda_buffers.cuh"
|
||||
#include "q4_matrix.cuh"
|
||||
#include "q4_matmul.cuh"
|
||||
|
||||
#include "column_remap.cuh"
|
||||
#include "cuda_buffers.cuh"
|
||||
#include "q4_matmul.cuh"
|
||||
#include "q4_matrix.cuh"
|
||||
#include "tuning.h"
|
||||
#include "util.cuh"
|
||||
|
||||
// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a
|
||||
// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of
|
||||
// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console.
|
||||
// Check CUDA return code. We don't want to include Torch headers in the .cu
|
||||
// files because parsing them adds almost a minute to the compile time on a
|
||||
// 12900K. Also passing exceptions back to Python is super tricky, so in place
|
||||
// of exceptions, CUDA functions return with a cudaError_t which we can parse
|
||||
// and dump to the console.
|
||||
|
||||
void check_cuda(cudaError_t ret)
|
||||
{
|
||||
switch (ret)
|
||||
{
|
||||
case cudaSuccess:
|
||||
break;
|
||||
void check_cuda(cudaError_t ret) {
|
||||
switch (ret) {
|
||||
case cudaSuccess:
|
||||
break;
|
||||
|
||||
case cudaUnspecified:
|
||||
printf(" **** Unspecified error\n");
|
||||
TORCH_CHECK(false, "CUDA error");
|
||||
break;
|
||||
case cudaUnspecified:
|
||||
printf(" **** Unspecified error\n");
|
||||
TORCH_CHECK(false, "CUDA error");
|
||||
break;
|
||||
|
||||
default:
|
||||
printf(" **** CUDA error\n"); \
|
||||
printf(" **** %s\n", cudaGetErrorString(ret)); \
|
||||
TORCH_CHECK(false, "CUDA error"); \
|
||||
break;
|
||||
}
|
||||
default:
|
||||
printf(" **** CUDA error\n");
|
||||
printf(" **** %s\n", cudaGetErrorString(ret));
|
||||
TORCH_CHECK(false, "CUDA error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Some decluttering macros
|
||||
|
||||
#define STRINGIFY_(__x) #__x
|
||||
#define STRINGIFY(__x) STRINGIFY_(__x)
|
||||
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||
#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod))
|
||||
#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
|
||||
#define TORCH_CHECK_DTYPE(__x, __dtype) \
|
||||
TORCH_CHECK((__x).dtype() == torch::__dtype, \
|
||||
#__x " is incorrect datatype, must be " #__dtype)
|
||||
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) \
|
||||
TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, \
|
||||
#__x " is incorrect datatype, must be " #__dtype)
|
||||
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) \
|
||||
TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, \
|
||||
#__x " and " #__y " have incompatible shapes")
|
||||
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) \
|
||||
TORCH_CHECK((__x).device().is_meta() || \
|
||||
(__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, \
|
||||
#__x " and " #__y " have incompatible shapes")
|
||||
#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) \
|
||||
TORCH_CHECK((__x).size(__dim_x) % __mod == 0, \
|
||||
#__x ".shape[" STRINGIFY( \
|
||||
__dim_x) "] must be a multiple of " STRINGIFY(__mod))
|
||||
#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) \
|
||||
TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
|
||||
|
||||
#define TORCH_CHECK_DEVICE_INDEX(__index) \
|
||||
do { \
|
||||
TORCH_CHECK(__index >= 0, "no device index"); \
|
||||
#define TORCH_CHECK_DEVICE_INDEX(__index) \
|
||||
do { \
|
||||
TORCH_CHECK(__index >= 0, "no device index"); \
|
||||
TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \
|
||||
} while(0)
|
||||
} while (0)
|
||||
|
||||
#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \
|
||||
do { \
|
||||
TORCH_CHECK_DTYPE(__w, kInt); \
|
||||
TORCH_CHECK_DTYPE(__w_scales, kHalf); \
|
||||
TORCH_CHECK_DTYPE(__w_zeros, kInt); \
|
||||
TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \
|
||||
TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \
|
||||
TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \
|
||||
TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
|
||||
} while(0)
|
||||
do { \
|
||||
TORCH_CHECK_DTYPE(__w, kInt); \
|
||||
TORCH_CHECK_DTYPE(__w_scales, kHalf); \
|
||||
TORCH_CHECK_DTYPE(__w_zeros, kInt); \
|
||||
TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \
|
||||
TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \
|
||||
TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \
|
||||
TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
|
||||
} while (0)
|
||||
|
||||
int get_groupsize(torch::Tensor w, torch::Tensor w_zeros)
|
||||
{
|
||||
int groupsize = w.size(0) * 8 / w_zeros.size(0);
|
||||
TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]")
|
||||
return groupsize;
|
||||
int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) {
|
||||
int groupsize = w.size(0) * 8 / w_zeros.size(0);
|
||||
TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8,
|
||||
"w.shape[-2] must be a multiple of zeros.shape[-2]")
|
||||
return groupsize;
|
||||
}
|
||||
|
||||
|
||||
// Tuning parameters
|
||||
|
||||
ExLlamaTuning tuningParams;
|
||||
|
||||
void set_tuning_params
|
||||
(
|
||||
int matmul_recons_thd,
|
||||
bool matmul_fused_remap,
|
||||
bool matmul_no_half2
|
||||
)
|
||||
{
|
||||
tuningParams.matmul_recons_thd = matmul_recons_thd;
|
||||
tuningParams.matmul_fused_remap = matmul_fused_remap;
|
||||
tuningParams.matmul_no_half2 = matmul_no_half2;
|
||||
void set_tuning_params(int matmul_recons_thd, bool matmul_fused_remap,
|
||||
bool matmul_no_half2) {
|
||||
tuningParams.matmul_recons_thd = matmul_recons_thd;
|
||||
tuningParams.matmul_fused_remap = matmul_fused_remap;
|
||||
tuningParams.matmul_no_half2 = matmul_no_half2;
|
||||
}
|
||||
|
||||
|
||||
// Release all unmanaged objects allocated by the extension
|
||||
|
||||
void cleanup()
|
||||
{
|
||||
cleanup_buffers_cuda();
|
||||
g_q4_free_matrices();
|
||||
void cleanup() {
|
||||
cleanup_buffers_cuda();
|
||||
g_q4_free_matrices();
|
||||
}
|
||||
|
||||
|
||||
// Prepare buffers for forward pass
|
||||
|
||||
void prepare_buffers
|
||||
(
|
||||
torch::Device device,
|
||||
torch::Tensor temp_state,
|
||||
torch::Tensor temp_dq
|
||||
)
|
||||
{
|
||||
int device_index = device.index();
|
||||
TORCH_CHECK_DEVICE_INDEX(device_index);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device);
|
||||
void prepare_buffers(torch::Device device, torch::Tensor temp_state,
|
||||
torch::Tensor temp_dq) {
|
||||
int device_index = device.index();
|
||||
TORCH_CHECK_DEVICE_INDEX(device_index);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device);
|
||||
|
||||
prepare_buffers_cuda
|
||||
(
|
||||
device_index,
|
||||
// buffer size used for sanity checks
|
||||
temp_state.numel(),
|
||||
(half*) temp_state.data_ptr(),
|
||||
(half*) temp_dq.data_ptr()
|
||||
);
|
||||
prepare_buffers_cuda(device_index,
|
||||
// buffer size used for sanity checks
|
||||
temp_state.numel(), (half*)temp_state.data_ptr(),
|
||||
(half*)temp_dq.data_ptr());
|
||||
}
|
||||
|
||||
|
||||
// Create Q4Matrix, return handle
|
||||
|
||||
uintptr_t make_q4
|
||||
(
|
||||
torch::Tensor qweight,
|
||||
torch::Tensor qzeros,
|
||||
torch::Tensor scales,
|
||||
torch::Tensor g_idx,
|
||||
int device
|
||||
)
|
||||
{
|
||||
TORCH_CHECK_DTYPE(qweight, kInt);
|
||||
TORCH_CHECK_DTYPE(qzeros, kInt);
|
||||
TORCH_CHECK_DTYPE(scales, kHalf);
|
||||
TORCH_CHECK_DTYPE_OPT(g_idx, kInt);
|
||||
TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);
|
||||
TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);
|
||||
TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);
|
||||
uintptr_t make_q4(torch::Tensor qweight, torch::Tensor qzeros,
|
||||
torch::Tensor scales, torch::Tensor g_idx, int device) {
|
||||
TORCH_CHECK_DTYPE(qweight, kInt);
|
||||
TORCH_CHECK_DTYPE(qzeros, kInt);
|
||||
TORCH_CHECK_DTYPE(scales, kHalf);
|
||||
TORCH_CHECK_DTYPE_OPT(g_idx, kInt);
|
||||
TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);
|
||||
TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);
|
||||
TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);
|
||||
|
||||
int width = qweight.size(1);
|
||||
int height = qweight.size(0) * 8;
|
||||
int groups = qzeros.size(0);
|
||||
int width = qweight.size(1);
|
||||
int height = qweight.size(0) * 8;
|
||||
int groups = qzeros.size(0);
|
||||
|
||||
Q4Matrix* m = new Q4Matrix
|
||||
(
|
||||
height,
|
||||
width,
|
||||
groups,
|
||||
Q4Matrix* m = new Q4Matrix(
|
||||
height, width, groups,
|
||||
|
||||
(uint32_t*) qweight.data_ptr(),
|
||||
(uint32_t*) qzeros.data_ptr(),
|
||||
(half*) scales.data_ptr(),
|
||||
g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(),
|
||||
(uint32_t*)qweight.data_ptr(), (uint32_t*)qzeros.data_ptr(),
|
||||
(half*)scales.data_ptr(),
|
||||
g_idx.device().is_meta() ? NULL : (uint32_t*)g_idx.data_ptr(),
|
||||
|
||||
device
|
||||
);
|
||||
device);
|
||||
|
||||
g_q4_keep_matrix(m);
|
||||
return reinterpret_cast<uintptr_t> (m);
|
||||
g_q4_keep_matrix(m);
|
||||
return reinterpret_cast<uintptr_t>(m);
|
||||
}
|
||||
|
||||
|
||||
// Matmul half @ quant -> half
|
||||
|
||||
void q4_matmul
|
||||
(
|
||||
torch::Tensor x,
|
||||
uintptr_t w,
|
||||
torch::Tensor out
|
||||
)
|
||||
{
|
||||
Q4Matrix* wm = reinterpret_cast<Q4Matrix*> (w);
|
||||
void q4_matmul(torch::Tensor x, uintptr_t w, torch::Tensor out) {
|
||||
Q4Matrix* wm = reinterpret_cast<Q4Matrix*>(w);
|
||||
|
||||
TORCH_CHECK_DTYPE(x, kHalf);
|
||||
TORCH_CHECK_DTYPE(out, kHalf);
|
||||
TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
|
||||
TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
|
||||
TORCH_CHECK_DTYPE(x, kHalf);
|
||||
TORCH_CHECK_DTYPE(out, kHalf);
|
||||
TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
|
||||
TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
|
||||
int x_height = x.size(0);
|
||||
int x_height = x.size(0);
|
||||
|
||||
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
|
||||
{
|
||||
q4_matmul_cuda
|
||||
(
|
||||
&tuningParams,
|
||||
(half*) x.data_ptr(),
|
||||
x_height,
|
||||
wm,
|
||||
(half*) out.data_ptr()
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
q4_matmul_recons_cuda
|
||||
(
|
||||
&tuningParams,
|
||||
(half*) x.data_ptr(),
|
||||
x_height,
|
||||
wm,
|
||||
(half*) out.data_ptr(),
|
||||
at::cuda::getCurrentCUDABlasHandle()
|
||||
);
|
||||
}
|
||||
if (tuningParams.matmul_recons_thd == 0 ||
|
||||
x_height < tuningParams.matmul_recons_thd) {
|
||||
q4_matmul_cuda(&tuningParams, (half*)x.data_ptr(), x_height, wm,
|
||||
(half*)out.data_ptr());
|
||||
} else {
|
||||
q4_matmul_recons_cuda(&tuningParams, (half*)x.data_ptr(), x_height, wm,
|
||||
(half*)out.data_ptr(),
|
||||
at::cuda::getCurrentCUDABlasHandle());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Remap columns in half tensor
|
||||
|
||||
void column_remap
|
||||
(
|
||||
torch::Tensor x,
|
||||
torch::Tensor x_new,
|
||||
torch::Tensor x_map
|
||||
)
|
||||
{
|
||||
TORCH_CHECK_DTYPE(x, kHalf);
|
||||
TORCH_CHECK_DTYPE(x_new, kHalf);
|
||||
TORCH_CHECK_DTYPE(x_map, kInt);
|
||||
TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);
|
||||
void column_remap(torch::Tensor x, torch::Tensor x_new, torch::Tensor x_map) {
|
||||
TORCH_CHECK_DTYPE(x, kHalf);
|
||||
TORCH_CHECK_DTYPE(x_new, kHalf);
|
||||
TORCH_CHECK_DTYPE(x_map, kInt);
|
||||
TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);
|
||||
|
||||
int height = x.size(0);
|
||||
int width = x.size(1);
|
||||
int height = x.size(0);
|
||||
int width = x.size(1);
|
||||
|
||||
TORCH_CHECK_BUFFER_SIZE(x_new, height * width);
|
||||
TORCH_CHECK_BUFFER_SIZE(x_new, height * width);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
|
||||
column_remap_cuda
|
||||
(
|
||||
(half*) x.data_ptr(),
|
||||
(half*) x_new.data_ptr(),
|
||||
height,
|
||||
width,
|
||||
(uint32_t*) x_map.data_ptr()
|
||||
);
|
||||
column_remap_cuda((half*)x.data_ptr(), (half*)x_new.data_ptr(), height, width,
|
||||
(uint32_t*)x_map.data_ptr());
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
|
||||
m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
|
||||
m.def("cleanup", &cleanup, "cleanup");
|
||||
m.def("make_q4", &make_q4, "make_q4");
|
||||
m.def("q4_matmul", &q4_matmul, "q4_matmul");
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
|
||||
m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
|
||||
m.def("cleanup", &cleanup, "cleanup");
|
||||
m.def("make_q4", &make_q4, "make_q4");
|
||||
m.def("q4_matmul", &q4_matmul, "q4_matmul");
|
||||
}
|
||||
|
@ -184,7 +184,7 @@ __global__ void reconstruct_kernel
|
||||
int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x;
|
||||
int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8;
|
||||
if (column >= width) return;
|
||||
|
||||
|
||||
// Views
|
||||
|
||||
MatrixView_q4_column w_(w, height, width);
|
||||
|
@ -50,4 +50,4 @@ private:
|
||||
void g_q4_keep_matrix(Q4Matrix* m);
|
||||
void g_q4_free_matrices();
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
@ -238,5 +238,5 @@ if HAS_TRITON:
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
return
|
||||
|
@ -10,7 +10,6 @@ except ImportError:
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
if HAS_TRITON:
|
||||
|
||||
# adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py
|
||||
@triton.jit
|
||||
def _fwd_copy_kv_cache_dest(
|
||||
|
@ -13,6 +13,9 @@ except ImportError:
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
try:
|
||||
from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import (
|
||||
token_att_fwd as lightllm_bloom_token_att_fwd,
|
||||
)
|
||||
from lightllm.models.llama2.triton_kernel.token_attention_nopad_att1 import (
|
||||
token_att_fwd as lightllm_llama2_token_att_fwd,
|
||||
)
|
||||
@ -22,11 +25,15 @@ try:
|
||||
from lightllm.models.llama2.triton_kernel.token_attention_nopad_softmax import (
|
||||
token_softmax_fwd as lightllm_llama2_token_softmax_fwd,
|
||||
)
|
||||
|
||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fw2
|
||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd
|
||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd
|
||||
from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd
|
||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import (
|
||||
token_att_fwd as lightllm_llama_token_att_fwd,
|
||||
)
|
||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import (
|
||||
token_att_fwd2 as lightllm_llama_token_att_fw2,
|
||||
)
|
||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import (
|
||||
token_softmax_fwd as lightllm_llama_token_softmax_fwd,
|
||||
)
|
||||
|
||||
HAS_TRITON_TOKEN_ATTENTION = True
|
||||
except ImportError:
|
||||
|
@ -44,8 +44,8 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
|
||||
|
||||
return unpickle
|
||||
|
||||
def check_for_nccl_backend(group):
|
||||
|
||||
def check_for_nccl_backend(group):
|
||||
pg = group or c10d._get_default_group()
|
||||
# Gate PG wrapper check on Gloo availability.
|
||||
if c10d._GLOO_AVAILABLE:
|
||||
@ -54,10 +54,8 @@ def check_for_nccl_backend(group):
|
||||
while isinstance(pg, c10d._ProcessGroupWrapper):
|
||||
pg = pg.wrapped_pg
|
||||
|
||||
return (
|
||||
c10d.is_nccl_available() and
|
||||
pg.name() == c10d.Backend.NCCL
|
||||
)
|
||||
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
|
||||
|
||||
|
||||
def _broadcast_object_list(
|
||||
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None
|
||||
|
7
colossalai/quantization/__init__.py
Normal file
7
colossalai/quantization/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
from .bnb import quantize_model
|
||||
from .bnb_config import BnbQuantizationConfig
|
||||
|
||||
__all__ = [
|
||||
"BnbQuantizationConfig",
|
||||
"quantize_model",
|
||||
]
|
321
colossalai/quantization/bnb.py
Normal file
321
colossalai/quantization/bnb.py
Normal file
@ -0,0 +1,321 @@
|
||||
# adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .bnb_config import BnbQuantizationConfig
|
||||
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
IS_4BIT_BNB_AVAILABLE = bnb.__version__ >= "0.39.0"
|
||||
IS_8BIT_BNB_AVAILABLE = bnb.__version__ >= "0.37.2"
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def quantize_model(
|
||||
model: torch.nn.Module,
|
||||
bnb_quantization_config: BnbQuantizationConfig,
|
||||
):
|
||||
"""
|
||||
This function will quantize the input loaded model with the associated config passed in `bnb_quantization_config`.
|
||||
We will quantize the model and put the model on the GPU.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`):
|
||||
Input model. The model already loaded
|
||||
bnb_quantization_config (`BnbQuantizationConfig`):
|
||||
The bitsandbytes quantization parameters
|
||||
|
||||
Returns:
|
||||
`torch.nn.Module`: The quantized model
|
||||
"""
|
||||
|
||||
load_in_4bit = bnb_quantization_config.load_in_4bit
|
||||
load_in_8bit = bnb_quantization_config.load_in_8bit
|
||||
|
||||
if load_in_8bit and not IS_8BIT_BNB_AVAILABLE:
|
||||
raise ImportError(
|
||||
"You have a version of `bitsandbytes` that is not compatible with 8bit quantization,"
|
||||
" make sure you have the latest version of `bitsandbytes` installed."
|
||||
)
|
||||
if load_in_4bit and not IS_4BIT_BNB_AVAILABLE:
|
||||
raise ValueError(
|
||||
"You have a version of `bitsandbytes` that is not compatible with 4bit quantization,"
|
||||
"make sure you have the latest version of `bitsandbytes` installed."
|
||||
)
|
||||
|
||||
# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
|
||||
if bnb_quantization_config.skip_modules is None:
|
||||
bnb_quantization_config.skip_modules = get_keys_to_not_convert(model)
|
||||
|
||||
modules_to_not_convert = bnb_quantization_config.skip_modules
|
||||
|
||||
# We add the modules we want to keep in full precision
|
||||
if bnb_quantization_config.keep_in_fp32_modules is None:
|
||||
bnb_quantization_config.keep_in_fp32_modules = []
|
||||
keep_in_fp32_modules = bnb_quantization_config.keep_in_fp32_modules
|
||||
|
||||
# compatibility with peft
|
||||
model.is_loaded_in_4bit = load_in_4bit
|
||||
model.is_loaded_in_8bit = load_in_8bit
|
||||
|
||||
# assert model_device is cuda
|
||||
model_device = next(model.parameters()).device
|
||||
|
||||
model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert)
|
||||
|
||||
# convert param to the right dtype
|
||||
dtype = bnb_quantization_config.torch_dtype
|
||||
for name, param in model.state_dict().items():
|
||||
if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules):
|
||||
param.to(torch.float32)
|
||||
if param.dtype != torch.float32:
|
||||
name = name.replace(".weight", "").replace(".bias", "")
|
||||
param = getattr(model, name, None)
|
||||
if param is not None:
|
||||
param.to(torch.float32)
|
||||
elif torch.is_floating_point(param):
|
||||
param.to(dtype)
|
||||
if model_device.type == "cuda":
|
||||
# move everything to cpu in the first place because we can't do quantization if the weights are already on cuda
|
||||
model.cuda(torch.cuda.current_device())
|
||||
torch.cuda.empty_cache()
|
||||
elif torch.cuda.is_available():
|
||||
model.to(torch.cuda.current_device())
|
||||
logger.info(
|
||||
f"The model device type is {model_device.type}. However, cuda is needed for quantization."
|
||||
"We move the model to cuda."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
|
||||
return model
|
||||
|
||||
|
||||
def replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=None, current_key_name=None):
|
||||
"""
|
||||
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules or by `bnb.nn.Linear4bit`
|
||||
modules from the `bitsandbytes`library. The function will be run recursively and replace `torch.nn.Linear` modules.
|
||||
|
||||
Parameters:
|
||||
model (`torch.nn.Module`):
|
||||
Input model or `torch.nn.Module` as the function is run recursively.
|
||||
modules_to_not_convert (`List[str]`):
|
||||
Names of the modules to not quantize convert. In practice we keep the `lm_head` in full precision for
|
||||
numerical stability reasons.
|
||||
current_key_name (`List[str]`, *optional*):
|
||||
An array to track the current key of the recursion. This is used to check whether the current key (part of
|
||||
it) is not in the list of modules to not convert.
|
||||
"""
|
||||
|
||||
if modules_to_not_convert is None:
|
||||
modules_to_not_convert = []
|
||||
|
||||
model, has_been_replaced = _replace_with_bnb_layers(
|
||||
model, bnb_quantization_config, modules_to_not_convert, current_key_name
|
||||
)
|
||||
if not has_been_replaced:
|
||||
logger.warning(
|
||||
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
|
||||
" this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers."
|
||||
" Please double check your model architecture, or submit an issue on github if you think this is"
|
||||
" a bug."
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def _replace_with_bnb_layers(
|
||||
model,
|
||||
bnb_quantization_config,
|
||||
modules_to_not_convert=None,
|
||||
current_key_name=None,
|
||||
):
|
||||
"""
|
||||
Private method that wraps the recursion for module replacement.
|
||||
|
||||
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
|
||||
"""
|
||||
# bitsandbytes will initialize CUDA on import, so it needs to be imported lazily
|
||||
|
||||
has_been_replaced = False
|
||||
for name, module in model.named_children():
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
current_key_name.append(name)
|
||||
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
|
||||
# Check if the current key is not in the `modules_to_not_convert`
|
||||
current_key_name_str = ".".join(current_key_name)
|
||||
proceed = True
|
||||
for key in modules_to_not_convert:
|
||||
if (
|
||||
(key in current_key_name_str) and (key + "." in current_key_name_str)
|
||||
) or key == current_key_name_str:
|
||||
proceed = False
|
||||
break
|
||||
if proceed:
|
||||
# Load bnb module with empty weight and replace ``nn.Linear` module
|
||||
if bnb_quantization_config.load_in_8bit:
|
||||
bnb_module = bnb.nn.Linear8bitLt(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
has_fp16_weights=False,
|
||||
threshold=bnb_quantization_config.llm_int8_threshold,
|
||||
)
|
||||
elif bnb_quantization_config.load_in_4bit:
|
||||
bnb_module = bnb.nn.Linear4bit(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
bnb_quantization_config.bnb_4bit_compute_dtype,
|
||||
compress_statistics=bnb_quantization_config.bnb_4bit_use_double_quant,
|
||||
quant_type=bnb_quantization_config.bnb_4bit_quant_type,
|
||||
)
|
||||
else:
|
||||
raise ValueError("load_in_8bit and load_in_4bit can't be both False")
|
||||
bnb_module.weight.data = module.weight.data
|
||||
bnb_module.weight.skip_zero_check = True
|
||||
if module.bias is not None:
|
||||
bnb_module.bias.data = module.bias.data
|
||||
bnb_module.bias.skip_zero_check = True
|
||||
bnb_module.requires_grad_(False)
|
||||
setattr(model, name, bnb_module)
|
||||
has_been_replaced = True
|
||||
if len(list(module.children())) > 0:
|
||||
_, _has_been_replaced = _replace_with_bnb_layers(
|
||||
module, bnb_quantization_config, modules_to_not_convert, current_key_name
|
||||
)
|
||||
has_been_replaced = has_been_replaced | _has_been_replaced
|
||||
# Remove the last key for recursion
|
||||
current_key_name.pop(-1)
|
||||
return model, has_been_replaced
|
||||
|
||||
|
||||
def get_keys_to_not_convert(model):
|
||||
r"""
|
||||
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
|
||||
we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
|
||||
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
|
||||
int8.
|
||||
|
||||
Parameters:
|
||||
model (`torch.nn.Module`):
|
||||
Input model
|
||||
"""
|
||||
# Create a copy of the model
|
||||
# with init_empty_weights():
|
||||
# tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
|
||||
tied_model = model
|
||||
|
||||
tied_params = find_tied_parameters(tied_model)
|
||||
# For compatibility with Accelerate < 0.18
|
||||
if isinstance(tied_params, dict):
|
||||
tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys())
|
||||
else:
|
||||
tied_keys = sum(tied_params, [])
|
||||
has_tied_params = len(tied_keys) > 0
|
||||
|
||||
# Check if it is a base model
|
||||
is_base_model = False
|
||||
if hasattr(model, "base_model_prefix"):
|
||||
is_base_model = not hasattr(model, model.base_model_prefix)
|
||||
|
||||
# Ignore this for base models (BertModel, GPT2Model, etc.)
|
||||
if (not has_tied_params) and is_base_model:
|
||||
return []
|
||||
|
||||
# otherwise they have an attached head
|
||||
list_modules = list(model.named_children())
|
||||
list_last_module = [list_modules[-1][0]]
|
||||
|
||||
# add last module together with tied weights
|
||||
intersection = set(list_last_module) - set(tied_keys)
|
||||
list_untouched = list(set(tied_keys)) + list(intersection)
|
||||
|
||||
# remove ".weight" from the keys
|
||||
names_to_remove = [".weight", ".bias"]
|
||||
filtered_module_names = []
|
||||
for name in list_untouched:
|
||||
for name_to_remove in names_to_remove:
|
||||
if name_to_remove in name:
|
||||
name = name.replace(name_to_remove, "")
|
||||
filtered_module_names.append(name)
|
||||
|
||||
return filtered_module_names
|
||||
|
||||
|
||||
def find_tied_parameters(model: nn.Module, **kwargs):
|
||||
"""
|
||||
Find the tied parameters in a given model.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore
|
||||
them.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`): The model to inspect.
|
||||
|
||||
Returns:
|
||||
List[List[str]]: A list of lists of parameter names being all tied together.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> from collections import OrderedDict
|
||||
>>> import torch.nn as nn
|
||||
|
||||
>>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))]))
|
||||
>>> model.linear2.weight = model.linear1.weight
|
||||
>>> find_tied_parameters(model)
|
||||
[['linear1.weight', 'linear2.weight']]
|
||||
```
|
||||
"""
|
||||
# Initialize result and named_parameters before recursing.
|
||||
named_parameters = kwargs.get("named_parameters", None)
|
||||
prefix = kwargs.get("prefix", "")
|
||||
result = kwargs.get("result", {})
|
||||
|
||||
if named_parameters is None:
|
||||
named_parameters = {n: p for n, p in model.named_parameters()}
|
||||
else:
|
||||
# A tied parameter will not be in the full `named_parameters` seen above but will be in the `named_parameters`
|
||||
# of the submodule it belongs to. So while recursing we track the names that are not in the initial
|
||||
# `named_parameters`.
|
||||
for name, parameter in model.named_parameters():
|
||||
full_name = name if prefix == "" else f"{prefix}.{name}"
|
||||
if full_name not in named_parameters:
|
||||
# When we find one, it has to be one of the existing parameters.
|
||||
for new_name, new_param in named_parameters.items():
|
||||
if new_param is parameter:
|
||||
if new_name not in result:
|
||||
result[new_name] = []
|
||||
result[new_name].append(full_name)
|
||||
|
||||
# Once we have treated direct parameters, we move to the child modules.
|
||||
for name, child in model.named_children():
|
||||
child_name = name if prefix == "" else f"{prefix}.{name}"
|
||||
find_tied_parameters(child, named_parameters=named_parameters, prefix=child_name, result=result)
|
||||
|
||||
return FindTiedParametersResult([sorted([weight] + list(set(tied))) for weight, tied in result.items()])
|
||||
|
||||
|
||||
class FindTiedParametersResult(list):
|
||||
"""
|
||||
This is a subclass of a list to handle backward compatibility for Transformers. Do not rely on the fact this is not
|
||||
a list or on the `values` method as in the future this will be removed.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def values(self):
|
||||
return sum([x[1:] for x in self], [])
|
113
colossalai/quantization/bnb_config.py
Normal file
113
colossalai/quantization/bnb_config.py
Normal file
@ -0,0 +1,113 @@
|
||||
# adapted from Hugging Face accelerate/utils/dataclasses.py
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class BnbQuantizationConfig:
|
||||
"""
|
||||
A plugin to enable BitsAndBytes 4bit and 8bit quantization
|
||||
"""
|
||||
|
||||
load_in_8bit: bool = field(default=False, metadata={"help": "enable 8bit quantization."})
|
||||
|
||||
llm_int8_threshold: float = field(
|
||||
default=6.0, metadata={"help": "value of the outliner threshold. only relevant when load_in_8bit=True"}
|
||||
)
|
||||
|
||||
load_in_4bit: bool = field(default=False, metadata={"help": "enable 4bit quantization."})
|
||||
|
||||
bnb_4bit_quant_type: str = field(
|
||||
default="fp4",
|
||||
metadata={
|
||||
"help": "set the quantization data type in the `bnb.nn.Linear4Bit` layers. Options are {'fp4','np4'}."
|
||||
},
|
||||
)
|
||||
|
||||
bnb_4bit_use_double_quant: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "enable nested quantization where the quantization constants from the first quantization are quantized again."
|
||||
},
|
||||
)
|
||||
|
||||
bnb_4bit_compute_dtype: bool = field(
|
||||
default="fp16",
|
||||
metadata={
|
||||
"help": "This sets the computational type which might be different than the input time. For example, inputs might be "
|
||||
"fp32, but computation can be set to bf16 for speedups. Options are {'fp32','fp16','bf16'}."
|
||||
},
|
||||
)
|
||||
|
||||
torch_dtype: torch.dtype = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "this sets the dtype of the remaining non quantized layers. `bitsandbytes` library suggests to set the value"
|
||||
"to `torch.float16` for 8 bit model and use the same dtype as the compute dtype for 4 bit model "
|
||||
},
|
||||
)
|
||||
|
||||
skip_modules: List[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "an explicit list of the modules that we don't quantize. The dtype of these modules will be `torch_dtype`."
|
||||
},
|
||||
)
|
||||
|
||||
keep_in_fp32_modules: List[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "an explicit list of the modules that we don't quantize. We keep them in `torch.float32`."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.bnb_4bit_compute_dtype, str):
|
||||
if self.bnb_4bit_compute_dtype == "fp32":
|
||||
self.bnb_4bit_compute_dtype = torch.float32
|
||||
elif self.bnb_4bit_compute_dtype == "fp16":
|
||||
self.bnb_4bit_compute_dtype = torch.float16
|
||||
elif self.bnb_4bit_compute_dtype == "bf16":
|
||||
self.bnb_4bit_compute_dtype = torch.bfloat16
|
||||
else:
|
||||
raise ValueError(
|
||||
f"bnb_4bit_compute_dtype must be in ['fp32','fp16','bf16'] but found {self.bnb_4bit_compute_dtype}"
|
||||
)
|
||||
elif not isinstance(self.bnb_4bit_compute_dtype, torch.dtype):
|
||||
raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype")
|
||||
|
||||
if self.skip_modules is not None and not isinstance(self.skip_modules, list):
|
||||
raise ValueError("skip_modules must be a list of strings")
|
||||
|
||||
if self.keep_in_fp32_modules is not None and not isinstance(self.keep_in_fp32_modules, list):
|
||||
raise ValueError("keep_in_fp_32_modules must be a list of strings")
|
||||
|
||||
if self.load_in_4bit:
|
||||
self.target_dtype = "int4"
|
||||
|
||||
if self.load_in_8bit:
|
||||
self.target_dtype = torch.int8
|
||||
|
||||
if self.load_in_4bit and self.llm_int8_threshold != 6.0:
|
||||
warnings.warn("llm_int8_threshold can only be used for model loaded in 8bit")
|
||||
|
||||
if isinstance(self.torch_dtype, str):
|
||||
if self.torch_dtype == "fp32":
|
||||
self.torch_dtype = torch.float32
|
||||
elif self.torch_dtype == "fp16":
|
||||
self.torch_dtype = torch.float16
|
||||
elif self.torch_dtype == "bf16":
|
||||
self.torch_dtype = torch.bfloat16
|
||||
else:
|
||||
raise ValueError(f"torch_dtype must be in ['fp32','fp16','bf16'] but found {self.torch_dtype}")
|
||||
|
||||
if self.load_in_8bit and self.torch_dtype is None:
|
||||
self.torch_dtype = torch.float16
|
||||
|
||||
if self.load_in_4bit and self.torch_dtype is None:
|
||||
self.torch_dtype = self.bnb_4bit_compute_dtype
|
||||
|
||||
if not isinstance(self.torch_dtype, torch.dtype):
|
||||
raise ValueError("torch_dtype must be a torch.dtype")
|
@ -190,6 +190,7 @@ def calculate_global_norm_from_list(norm_list):
|
||||
total_norm += norm**2.0
|
||||
return math.sqrt(total_norm)
|
||||
|
||||
|
||||
def sync_tensor(flat_tensor, tensor_list):
|
||||
"""
|
||||
Synchronize the flattened tensor and unflattened tensor list. When
|
||||
|
@ -187,9 +187,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
for param_group in self.optim.param_groups:
|
||||
group_params = param_group["params"]
|
||||
for param in group_params:
|
||||
assert (
|
||||
param.dtype == self._dtype
|
||||
), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
|
||||
if not hasattr(param, "skip_zero_check") or param.skip_zero_check is False:
|
||||
assert (
|
||||
param.dtype == self._dtype
|
||||
), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
|
||||
|
||||
def _create_master_param_current_rank(self, param_list):
|
||||
# split each param evenly by world size
|
||||
|
@ -149,7 +149,7 @@ model, optimizer, _criterion, _, lr_scheduler = booster.boost(
|
||||
|
||||
## Training GPT-2 using hybrid parallelism
|
||||
|
||||
In the previous tutorial, We've explained how to inject various parallelism features into the model and its training components using the Booster and `HybridParallelPlugin`. Now we can start model training.
|
||||
In the previous tutorial, We've explained how to inject various parallelism features into the model and its training components using the Booster and `HybridParallelPlugin`. Now we can start model training.
|
||||
Define a training function. When pipeline parallelism is used, you need to call `booster.execute_pipeline` to schedule the stages of model training.
|
||||
```python
|
||||
def train_epoch(
|
||||
@ -204,4 +204,4 @@ Training the gpt-2 model
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
|
||||
```
|
||||
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 train_gpt_using_hybrid_parallelism.py -->
|
||||
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 train_gpt_using_hybrid_parallelism.py -->
|
||||
|
@ -201,4 +201,4 @@ def train_epoch(
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
|
||||
```
|
||||
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 train_gpt_using_hybrid_parallelism.py -->
|
||||
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 train_gpt_using_hybrid_parallelism.py -->
|
||||
|
@ -220,7 +220,7 @@ model, optimizer, _criterion, train_dataloader, lr_scheduler = booster.boost(
|
||||
)
|
||||
```
|
||||
## 使用混合并行训练 ViT
|
||||
最后就可以使用混合并行策略来训练模型了。我们先定义一个训练函数,描述训练过程。需要注意的是,如果使用了管道并行策略,需要调用`booster.execute_pipeline`来执行模型的训练,它会调用`scheduler`管理模型的前后向操作。
|
||||
最后就可以使用混合并行策略来训练模型了。我们先定义一个训练函数,描述训练过程。需要注意的是,如果使用了管道并行策略,需要调用`booster.execute_pipeline`来执行模型的训练,它会调用`scheduler`管理模型的前后向操作。
|
||||
```python
|
||||
def run_forward_backward(
|
||||
model: nn.Module,
|
||||
|
@ -1,12 +1,10 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
||||
from auto_gptq.nn_modules.qlinear import GeneralQuantLinear
|
||||
from transformers import AutoTokenizer, BloomForCausalLM, BloomTokenizerFast, LlamaForCausalLM, LlamaTokenizer
|
||||
from auto_gptq import AutoGPTQForCausalLM
|
||||
from transformers import BloomTokenizerFast
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.tensor_parallel.engine import TPInferEngine
|
||||
@ -14,7 +12,7 @@ from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||
|
||||
|
||||
def print_perf_stats(latency_set, config, bs, warmup=3):
|
||||
@ -28,7 +26,7 @@ def print_perf_stats(latency_set, config, bs, warmup=3):
|
||||
avg = sum(latency_set) / count
|
||||
num_layers = getattr(config, "num_layers", config.num_hidden_layers)
|
||||
num_parameters = num_layers * config.hidden_size * config.hidden_size * 12
|
||||
num_bytes = 2 # float16
|
||||
num_bytes = 2 # float16
|
||||
|
||||
print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000))
|
||||
print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9))
|
||||
@ -37,7 +35,6 @@ def print_perf_stats(latency_set, config, bs, warmup=3):
|
||||
|
||||
|
||||
def bench_bloom(args):
|
||||
|
||||
pretrained_model_dir = args.path
|
||||
quantized_model_dir = args.quantized_path
|
||||
max_batch_size = args.batch_size
|
||||
@ -48,9 +45,9 @@ def bench_bloom(args):
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# load quantized model to the first GPU
|
||||
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir,
|
||||
device=torch.cuda.current_device(),
|
||||
inject_fused_attention=False)
|
||||
model = AutoGPTQForCausalLM.from_quantized(
|
||||
quantized_model_dir, device=torch.cuda.current_device(), inject_fused_attention=False
|
||||
)
|
||||
|
||||
model = model.half()
|
||||
|
||||
@ -60,22 +57,22 @@ def bench_bloom(args):
|
||||
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
|
||||
|
||||
input_tokens = {
|
||||
"input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'),
|
||||
"attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda')
|
||||
"input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"),
|
||||
"attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"),
|
||||
}
|
||||
|
||||
# init TPInferEngine and shard the original model
|
||||
# To benchmark torch original, comment out the line of optimizing model
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False,
|
||||
inference_only=True,
|
||||
inference_gptq=True)
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||
|
||||
# prepare data for generation
|
||||
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
|
||||
input_tokens = {
|
||||
"input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)),
|
||||
"attention_mask": torch.ones((max_batch_size, max_input_len))
|
||||
"attention_mask": torch.ones((max_batch_size, max_input_len)),
|
||||
}
|
||||
for t in input_tokens:
|
||||
if torch.is_tensor(input_tokens[t]):
|
||||
@ -99,7 +96,7 @@ def bench_bloom(args):
|
||||
|
||||
def check_bloom(rank, world_size, port, args):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
bench_bloom(args)
|
||||
|
||||
|
||||
@ -111,12 +108,12 @@ def test_bloom(args):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-p', '--path', type=str, help='Model path', required=True)
|
||||
parser.add_argument('-q', '--quantized_path', type=str, help='Model path', required=True)
|
||||
parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size')
|
||||
parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size')
|
||||
parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length')
|
||||
parser.add_argument('--output_len', type=int, default=128, help='Maximum output length')
|
||||
parser.add_argument("-p", "--path", type=str, help="Model path", required=True)
|
||||
parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True)
|
||||
parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size")
|
||||
parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length")
|
||||
parser.add_argument("--output_len", type=int, default=128, help="Maximum output length")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -14,4 +14,5 @@ einops
|
||||
sentencepiece
|
||||
google
|
||||
protobuf
|
||||
peft>=0.7.1
|
||||
peft>=0.7.1
|
||||
bitsandbytes>=0.39.0
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Callable, Iterator, List, Tuple, Union, Dict
|
||||
from typing import Callable, Dict, Iterator, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -51,7 +51,6 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config=None)
|
||||
# raise e
|
||||
|
||||
|
||||
|
||||
@parameterize("stage", [2])
|
||||
def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
|
||||
"""check low level zero plugin over model zoo
|
||||
@ -118,6 +117,7 @@ def check_low_level_zero_lora(stage, model_name, early_stop: bool = True):
|
||||
print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
|
||||
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, early_stop: bool = True):
|
||||
# init dist env
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
|
@ -1,10 +1,11 @@
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from peft import LoraConfig
|
||||
from torchvision.models import resnet18
|
||||
from utils import shared_tempdir
|
||||
from typing import Optional
|
||||
from peft import LoraConfig
|
||||
from copy import deepcopy
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
@ -131,12 +132,15 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo
|
||||
# return repr(e)
|
||||
raise e
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize("stage", [2])
|
||||
@parameterize("shard", [True, False])
|
||||
@parameterize("offload", [False, True])
|
||||
@parameterize("model_name", ["transformers_llama"])
|
||||
def check_low_level_zero_lora_checkpointIO(stage: int, shard: bool, offload: bool, model_name: str, early_stop: bool = True):
|
||||
def check_low_level_zero_lora_checkpointIO(
|
||||
stage: int, shard: bool, offload: bool, model_name: str, early_stop: bool = True
|
||||
):
|
||||
passed_models = []
|
||||
failed_info = {} # (model_name, error) pair
|
||||
|
||||
@ -166,6 +170,7 @@ def check_low_level_zero_lora_checkpointIO(stage: int, shard: bool, offload: boo
|
||||
print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
|
||||
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_low_level_zero_checkpointIO()
|
||||
|
@ -1,16 +1,8 @@
|
||||
import math
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
@ -22,6 +14,7 @@ try:
|
||||
from exllama_kernels import prepare_buffers, set_tuning_params
|
||||
|
||||
from colossalai.inference.quant.gptq import CaiQuantLinear
|
||||
|
||||
HAS_AUTO_GPTQ = True
|
||||
except:
|
||||
HAS_AUTO_GPTQ = False
|
||||
@ -32,13 +25,14 @@ import warnings
|
||||
HAS_GPTQ_CUDA = False
|
||||
try:
|
||||
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
||||
|
||||
gptq_cuda = GPTQBuilder().load()
|
||||
HAS_GPTQ_CUDA = True
|
||||
except ImportError:
|
||||
warnings.warn('CUDA gptq is not installed')
|
||||
warnings.warn("CUDA gptq is not installed")
|
||||
HAS_GPTQ_CUDA = False
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
max_inner_outer_dim = 1
|
||||
max_input_len = 1
|
||||
@ -64,9 +58,9 @@ def init_buffer(cai_linear, use_act_order=False):
|
||||
max_input_len = 4096
|
||||
# The temp_state buffer is required to reorder X in the act-order case.
|
||||
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
||||
gptq_temp_state_buffer = torch.zeros((max_input_len, max_inner_outer_dim),
|
||||
dtype=torch.float16,
|
||||
device=torch.cuda.current_device())
|
||||
gptq_temp_state_buffer = torch.zeros(
|
||||
(max_input_len, max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device()
|
||||
)
|
||||
gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device())
|
||||
|
||||
gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer)
|
||||
@ -77,10 +71,11 @@ def init_buffer(cai_linear, use_act_order=False):
|
||||
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ,
|
||||
reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq")
|
||||
@pytest.mark.skipif(
|
||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ,
|
||||
reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq",
|
||||
)
|
||||
def test_gptq_linear():
|
||||
|
||||
infeature = 1024
|
||||
outfeature = 1024
|
||||
group_size = 128
|
||||
@ -120,7 +115,7 @@ def test_gptq_linear():
|
||||
max_input_len = 2048
|
||||
buffers = {
|
||||
"temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device),
|
||||
"temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device)
|
||||
"temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device),
|
||||
}
|
||||
|
||||
prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"])
|
||||
@ -146,5 +141,4 @@ def test_gptq_linear():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
test_gptq_linear()
|
||||
|
@ -4,6 +4,7 @@ from packaging import version
|
||||
|
||||
try:
|
||||
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
|
106
tests/test_lora/test_lora.py
Normal file
106
tests/test_lora/test_lora.py
Normal file
@ -0,0 +1,106 @@
|
||||
import copy
|
||||
import os
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
from peft import LoraConfig
|
||||
from torch import distributed as dist
|
||||
from torch.optim import AdamW
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_checkpoint_io.utils import shared_tempdir
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
|
||||
model = model_fn()
|
||||
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
|
||||
|
||||
test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()]
|
||||
test_configs = [
|
||||
{
|
||||
"lora_config": lora_config,
|
||||
"quantize": False,
|
||||
},
|
||||
{
|
||||
"lora_config": lora_config,
|
||||
"quantize": True,
|
||||
},
|
||||
]
|
||||
for plugin, test_config in product(test_plugins, test_configs):
|
||||
# checkpoint loaded model
|
||||
model_save = model_fn()
|
||||
model_load = copy.deepcopy(model_save)
|
||||
|
||||
optimizer = AdamW(model.parameters(), lr=0.001)
|
||||
criterion = loss_fn
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
model_save = booster.enable_lora(model_save, **test_config)
|
||||
model_save, optimizer, criterion, _, _ = booster.boost(model_save, optimizer, criterion)
|
||||
|
||||
with shared_tempdir() as tempdir:
|
||||
lora_ckpt_path = os.path.join(tempdir, "ckpt")
|
||||
booster.save_lora_as_pretrained(model_save, lora_ckpt_path)
|
||||
dist.barrier()
|
||||
|
||||
# The Lora checkpoint should be small in size
|
||||
checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024)
|
||||
assert checkpoint_size_mb < 1
|
||||
|
||||
model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path, **test_config)
|
||||
model_load, _, _, _, _ = booster.boost(model_load)
|
||||
|
||||
check_state_dict_equal(model_save.state_dict(), model_load.state_dict())
|
||||
|
||||
# test fwd bwd correctness
|
||||
test_model = model_load
|
||||
model_copy = copy.deepcopy(model_load)
|
||||
|
||||
data = data_gen_fn()
|
||||
data = {
|
||||
k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
|
||||
}
|
||||
|
||||
output = test_model(**data)
|
||||
output = output_transform_fn(output)
|
||||
loss = criterion(output)
|
||||
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.clip_grad_by_norm(1.0)
|
||||
optimizer.step()
|
||||
|
||||
for (n1, p1), (n2, p2) in zip(test_model.named_parameters(), model_copy.named_parameters()):
|
||||
if "lora_" in n1:
|
||||
# lora modules require gradients, thus updated
|
||||
assert p1.requires_grad
|
||||
assert not torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3)
|
||||
else:
|
||||
if not p1.requires_grad:
|
||||
torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3)
|
||||
|
||||
|
||||
def run_lora_test():
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
task_type = None
|
||||
if name == "transformers_llama_for_casual_lm":
|
||||
task_type = "CAUSAL_LM"
|
||||
if name == "transformers_llama_for_sequence_classification":
|
||||
task_type = "SEQ_CLS"
|
||||
check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_lora_test()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_torch_ddp_lora():
|
||||
spawn(run_dist, 2)
|
@ -1,108 +0,0 @@
|
||||
import copy
|
||||
import os
|
||||
|
||||
import torch
|
||||
from peft import LoraConfig
|
||||
from torch import distributed as dist
|
||||
from torch.optim import AdamW
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import TorchDDPPlugin
|
||||
from colossalai.testing import (
|
||||
assert_equal,
|
||||
assert_not_equal,
|
||||
check_state_dict_equal,
|
||||
clear_cache_before_run,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_checkpoint_io.utils import shared_tempdir
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
|
||||
model = model_fn()
|
||||
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
|
||||
|
||||
plugin = TorchDDPPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
model = booster.enable_lora(model, lora_config=lora_config)
|
||||
model_copy = copy.deepcopy(model)
|
||||
|
||||
optimizer = AdamW(model.parameters(), lr=0.001)
|
||||
criterion = loss_fn
|
||||
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||
|
||||
data = data_gen_fn()
|
||||
data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()}
|
||||
|
||||
output = model(**data)
|
||||
output = output_transform_fn(output)
|
||||
loss = criterion(output)
|
||||
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.clip_grad_by_norm(1.0)
|
||||
optimizer.step()
|
||||
|
||||
for (n1, p1), (n2, p2) in zip(model.named_parameters(), model_copy.named_parameters()):
|
||||
if "lora_" in n1:
|
||||
# lora modules require gradients, thus updated
|
||||
assert p1.requires_grad
|
||||
assert_not_equal(p1.to(p2.device), p2)
|
||||
else:
|
||||
if not p1.requires_grad:
|
||||
assert_equal(p1.to(p2.device), p2)
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
|
||||
plugin = TorchDDPPlugin()
|
||||
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
|
||||
|
||||
model_save = model_fn()
|
||||
model_load = copy.deepcopy(model_save)
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
model_save = booster.enable_lora(model_save, lora_config=lora_config)
|
||||
model_save, _, _, _, _ = booster.boost(model_save)
|
||||
|
||||
with shared_tempdir() as tempdir:
|
||||
lora_ckpt_path = os.path.join(tempdir, "ckpt")
|
||||
booster.save_lora_as_pretrained(model_save, lora_ckpt_path)
|
||||
dist.barrier()
|
||||
|
||||
# The Lora checkpoint should be small in size
|
||||
checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024)
|
||||
assert checkpoint_size_mb < 1
|
||||
|
||||
model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path)
|
||||
model_load, _, _, _, _ = booster.boost(model_load)
|
||||
|
||||
check_state_dict_equal(model_save.state_dict(), model_load.state_dict())
|
||||
|
||||
|
||||
def run_lora_test():
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
task_type = None
|
||||
if name == "transformers_llama_for_casual_lm":
|
||||
task_type = "CAUSAL_LM"
|
||||
if name == "transformers_llama_for_sequence_classification":
|
||||
task_type = "SEQ_CLS"
|
||||
check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
|
||||
check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_lora_test()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_torch_ddp_lora():
|
||||
spawn(run_dist, 2)
|
Loading…
Reference in New Issue
Block a user