[Feature] qlora support (#5586)

* [feature] qlora support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* qlora follow commit

* migrate qutization folder to colossalai/

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fixes

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
linsj20 2024-04-17 15:03:31 +08:00 committed by GitHub
parent cabc1286ca
commit 52a2dded36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
51 changed files with 1031 additions and 579 deletions

17
LICENSE
View File

@ -527,3 +527,20 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
---------------- LICENSE FOR Hugging Face accelerate ----------------
Copyright 2021 The HuggingFace Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -78,7 +78,9 @@ def main(args):
elif args.strategy == "colossalai_gemini": elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="static", initial_scale=2**5) strategy = GeminiStrategy(placement_policy="static", initial_scale=2**5)
elif args.strategy == "colossalai_gemini_cpu": elif args.strategy == "colossalai_gemini_cpu":
strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5) strategy = GeminiStrategy(
placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5
)
elif args.strategy == "colossalai_zero2": elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
elif args.strategy == "colossalai_zero2_cpu": elif args.strategy == "colossalai_zero2_cpu":

View File

@ -30,4 +30,3 @@ class Actor(LoRAModule):
"""Returns model output.""" """Returns model output."""
output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs) output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
return output return output

View File

@ -75,7 +75,9 @@ def get_strategy_from_args(strategy: str):
elif strategy == "colossalai_zero2": elif strategy == "colossalai_zero2":
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda") strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
elif strategy == "colossalai_gemini_cpu": elif strategy == "colossalai_gemini_cpu":
strategy_ = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5) strategy_ = GeminiStrategy(
placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5
)
elif strategy == "colossalai_zero2_cpu": elif strategy == "colossalai_zero2_cpu":
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu") strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else: else:

View File

@ -101,16 +101,17 @@ class DDPStrategy(Strategy):
model_path = os.path.join(path, "pytorch_model.bin") model_path = os.path.join(path, "pytorch_model.bin")
self.save_model(model, model_path, shard=shard) self.save_model(model, model_path, shard=shard)
def _replace_keys(model_path: str, replace_fn: Callable): def _replace_keys(model_path: str, replace_fn: Callable):
state_dict = torch.load(model_path, map_location="cpu") state_dict = torch.load(model_path, map_location="cpu")
state_dict = {replace_fn(k): v for k, v in state_dict.items()} state_dict = {replace_fn(k): v for k, v in state_dict.items()}
torch.save(state_dict, model_path) torch.save(state_dict, model_path)
# FIXME: save_model would add "model." prefix to keys of pytorch_model.bin # FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
# HACK: rename keys of pytorch_model.bin # HACK: rename keys of pytorch_model.bin
if dist.get_rank() == 0: if dist.get_rank() == 0:
_replace_keys(model_path, lambda k: k.replace("model.", "", 1)) _replace_keys(model_path, lambda k: k.replace("model.", "", 1))
def get_model_state_dict_shard(self, model: nn.Module, **config): def get_model_state_dict_shard(self, model: nn.Module, **config):
# TODO: implement sharding on naive strategy # TODO: implement sharding on naive strategy
model = self.unwrap_model(model) model = self.unwrap_model(model)

View File

@ -24,7 +24,9 @@ def main(args):
if args.strategy == "ddp": if args.strategy == "ddp":
strategy = DDPStrategy() strategy = DDPStrategy()
elif args.strategy == "colossalai_gemini": elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5) strategy = GeminiStrategy(
placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5
)
elif args.strategy == "colossalai_zero2": elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu") strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else: else:

View File

@ -1,20 +1,20 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import numpy as np
import os import os
import random import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable from typing import Callable, Dict, Iterator, List, Optional, Sequence, Union
import numpy as np
import torch import torch
from datasets import dataset_dict, load_from_disk import torch.nn.functional as F
from datasets import Dataset as HFDataset from datasets import Dataset as HFDataset
from datasets import dataset_dict, load_from_disk
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group from torch.distributed.distributed_c10d import _get_default_group
from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler from torch.utils.data import ConcatDataset, DataLoader, Dataset, DistributedSampler
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
import torch.nn.functional as F
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset] DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
PathType = Union[str, os.PathLike] PathType = Union[str, os.PathLike]

View File

@ -7,9 +7,9 @@ Splicing multiple pre-tokenized sequence data points
import random import random
import warnings import warnings
from copy import deepcopy from copy import deepcopy
from datasets import dataset_dict from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Union, Tuple
from datasets import dataset_dict
from torch.utils.data import ConcatDataset, Dataset, IterableDataset from torch.utils.data import ConcatDataset, Dataset, IterableDataset
from transformers.models.llama.tokenization_llama import LlamaTokenizer from transformers.models.llama.tokenization_llama import LlamaTokenizer
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
@ -169,12 +169,7 @@ class ClosedToConstantLengthSplicedDataset(IterableDataset):
spliced_labels.extend(seq_labels) spliced_labels.extend(seq_labels)
# For residual spliced data point at the end of the data set # For residual spliced data point at the end of the data set
if self.infinite is False and more_data_points is False and len(spliced_input_ids) > 0: if self.infinite is False and more_data_points is False and len(spliced_input_ids) > 0:
examples.append( examples.append({self.input_ids_field: spliced_input_ids, self.labels_field: spliced_labels})
{
self.input_ids_field: spliced_input_ids,
self.labels_field: spliced_labels
}
)
if self.shuffle: if self.shuffle:
random.shuffle(examples) random.shuffle(examples)
for spliced_data_point in examples: for spliced_data_point in examples:

View File

@ -8,11 +8,10 @@ import argparse
import numpy as np import numpy as np
import torch import torch
from transformers import LlamaTokenizer, LlamaForCausalLM from transformers import LlamaForCausalLM, LlamaTokenizer
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
logger = get_dist_logger() logger = get_dist_logger()

View File

@ -6,12 +6,12 @@ Initialize new tokenizer for continual pre-training
""" """
import argparse import argparse
import os
import json import json
import os
from typing import List, Union from typing import List, Union
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger

View File

@ -10,8 +10,8 @@ import os
from typing import Any, Dict, Tuple, Union from typing import Any, Dict, Tuple, Union
import torch import torch
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator

View File

@ -11,14 +11,14 @@ import os
import time import time
from multiprocessing import cpu_count from multiprocessing import cpu_count
from colossal_llama2.dataset.spliced_and_tokenized_dataset import (
ClosedToConstantLengthSplicedDataset,
supervised_tokenize,
)
from datasets import dataset_dict, load_dataset from datasets import dataset_dict, load_dataset
from transformers.models.llama.tokenization_llama import LlamaTokenizer from transformers.models.llama.tokenization_llama import LlamaTokenizer
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossal_llama2.dataset.spliced_and_tokenized_dataset import (
supervised_tokenize,
ClosedToConstantLengthSplicedDataset,
)
logger = get_dist_logger() logger = get_dist_logger()
@ -149,5 +149,5 @@ def main():
spliced_dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(spliced_dataset), cpu_count())) spliced_dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(spliced_dataset), cpu_count()))
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -12,4 +12,3 @@ flash-attn>=2.0.0,<=2.0.5
tqdm tqdm
sentencepiece==0.1.99 sentencepiece==0.1.99
protobuf<=3.20.0 protobuf<=3.20.0

View File

@ -4,42 +4,36 @@
Continual Pre-training of LLaMA-2 developed by Colossal-AI Team Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
""" """
import json
import argparse import argparse
import json
import os import os
import resource import resource
from contextlib import nullcontext from contextlib import nullcontext
from tqdm import tqdm
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossal_llama2.dataset.loader import (
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
load_tokenized_dataset,
setup_distributed_dataloader,
)
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig from tqdm import tqdm
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import ( from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
GeminiPlugin,
LowLevelZeroPlugin,
HybridParallelPlugin,
)
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossal_llama2.dataset.loader import (
load_tokenized_dataset,
setup_distributed_dataloader,
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
)
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
def get_model_numel(model: torch.nn.Module) -> int: def get_model_numel(model: torch.nn.Module) -> int:
return sum(p.numel() for p in model.parameters()) return sum(p.numel() for p in model.parameters())
@ -372,9 +366,7 @@ def main() -> None:
# Final save. # Final save.
coordinator.print_on_master("Start saving final model checkpoint") coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
coordinator.print_on_master( coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}"
)
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")

View File

@ -19,6 +19,7 @@ except ImportError:
import colossalai.interface.pretrained as pretrained_utils import colossalai.interface.pretrained as pretrained_utils
from colossalai.checkpoint_io import GeneralCheckpointIO from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.quantization import BnbQuantizationConfig
from .accelerator import Accelerator from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory from .mixed_precision import MixedPrecision, mixed_precision_factory
@ -230,7 +231,12 @@ class Booster:
return self.plugin.no_sync(model, optimizer) return self.plugin.no_sync(model, optimizer)
def enable_lora( def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: "peft.LoraConfig" = None self,
model: nn.Module,
pretrained_dir: Optional[str] = None,
lora_config: "peft.LoraConfig" = None,
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
quantize=False,
) -> nn.Module: ) -> nn.Module:
""" """
Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory. Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory.
@ -259,7 +265,20 @@ class Booster:
assert ( assert (
pretrained_dir is not None pretrained_dir is not None
), "Please provide pretrained directory path if not passing in lora configuration." ), "Please provide pretrained directory path if not passing in lora configuration."
return self.plugin.enable_lora(model, pretrained_dir, lora_config) if quantize is True:
if bnb_quantization_config is not None:
warnings.warn(
"User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk."
)
else:
bnb_quantization_config = BnbQuantizationConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config)
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None: def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
"""Load model from checkpoint. """Load model from checkpoint.

View File

@ -1,11 +1,11 @@
import logging
import warnings
import enum import enum
import logging
import os import os
import warnings
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from types import MethodType from types import MethodType
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Dict from typing import Callable, Dict, Iterator, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -27,6 +27,7 @@ from colossalai.checkpoint_io.utils import (
sharded_optimizer_loading_epilogue, sharded_optimizer_loading_epilogue,
) )
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero import LowLevelZeroOptimizer from colossalai.zero import LowLevelZeroOptimizer
@ -44,6 +45,7 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"] SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"]
class OptimizerParamCheckState(enum.Enum): class OptimizerParamCheckState(enum.Enum):
ORIGIN_PARAM_FINDED = 0 ORIGIN_PARAM_FINDED = 0
ORIGIN_PARAM_NOT_FIND = -1 ORIGIN_PARAM_NOT_FIND = -1
@ -221,6 +223,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return return
from peft import PeftModel from peft import PeftModel
assert isinstance(model, ModelWrapper), "Please boost the model before saving!" assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
peft_model = model.unwrap() peft_model = model.unwrap()
assert isinstance( assert isinstance(
@ -331,18 +334,25 @@ class LowLevelZeroPlugin(DPPluginBase):
def supported_devices(self) -> List[str]: def supported_devices(self) -> List[str]:
return ["cuda"] return ["cuda"]
def support_lora(self) -> bool: def support_lora(self) -> bool:
return True return True
def enable_lora( def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None self,
model: nn.Module,
pretrained_dir: Optional[str] = None,
lora_config: Optional[Dict] = None,
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
) -> nn.Module: ) -> nn.Module:
from peft import PeftModel, get_peft_model from peft import PeftModel, get_peft_model
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model." assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
self.lora_enabled = True self.lora_enabled = True
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
if bnb_quantization_config is not None:
model = quantize_model(model, bnb_quantization_config)
if pretrained_dir is None: if pretrained_dir is None:
peft_model = get_peft_model(model, lora_config) peft_model = get_peft_model(model, lora_config)
else: else:
@ -352,7 +362,7 @@ class LowLevelZeroPlugin(DPPluginBase):
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter): def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter):
origin_param_id = id(origin_param) origin_param_id = id(origin_param)
for group_id, param_group in enumerate(optimizer.param_groups): for group_id, param_group in enumerate(optimizer.param_groups):
for p in param_group['params']: for p in param_group["params"]:
if id(p) == origin_param_id: if id(p) == origin_param_id:
return group_id return group_id
return -1 return -1
@ -362,7 +372,7 @@ class LowLevelZeroPlugin(DPPluginBase):
lora_param_id = id(lora_param) lora_param_id = id(lora_param)
target_group_id = None target_group_id = None
for group_id, param_group in enumerate(optimizer.param_groups): for group_id, param_group in enumerate(optimizer.param_groups):
for p in param_group['params']: for p in param_group["params"]:
if id(p) == lora_param_id: if id(p) == lora_param_id:
# check if the lora parameter exists. # check if the lora parameter exists.
return target_group_id, OptimizerParamCheckState.LORA_PARM_EXISTED return target_group_id, OptimizerParamCheckState.LORA_PARM_EXISTED
@ -380,16 +390,22 @@ class LowLevelZeroPlugin(DPPluginBase):
name2param[name] = param name2param[name] = param
for name, param in name2param.items(): for name, param in name2param.items():
if 'lora_A' in name or 'lora_B' in name: if "lora_A" in name or "lora_B" in name:
origin_key = name.replace("lora_A.", "") origin_key = name.replace("lora_A.", "")
origin_key = origin_key.replace("lora_B.", "") origin_key = origin_key.replace("lora_B.", "")
origin_key = origin_key.replace(f"{model.active_adapter}", "base_layer") origin_key = origin_key.replace(f"{model.active_adapter}", "base_layer")
origin_param = name2param[origin_key] origin_param = name2param[origin_key]
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param) group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND: if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
warnings.warn("Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups.") warnings.warn(
elif check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED and group_id is not None and group_id >= 0: "Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups."
optimizer.param_groups[group_id]['params'].append(param) )
elif (
check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED
and group_id is not None
and group_id >= 0
):
optimizer.param_groups[group_id]["params"].append(param)
def configure( def configure(
self, self,
@ -401,11 +417,13 @@ class LowLevelZeroPlugin(DPPluginBase):
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if self.lora_enabled: if self.lora_enabled:
from peft import PeftModel from peft import PeftModel
assert isinstance(model, PeftModel), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"
assert isinstance(
model, PeftModel
), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"
if optimizer is not None: if optimizer is not None:
self.add_lora_params_to_optimizer(model, optimizer) self.add_lora_params_to_optimizer(model, optimizer)
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.precision) model = LowLevelZeroModel(model, self.precision)

View File

@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from .dp_plugin_base import DPPluginBase from .dp_plugin_base import DPPluginBase
@ -237,10 +238,17 @@ class TorchDDPPlugin(DPPluginBase):
return model.module.no_sync() return model.module.no_sync()
def enable_lora( def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None self,
model: nn.Module,
pretrained_dir: Optional[str] = None,
lora_config: Optional[Dict] = None,
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
) -> nn.Module: ) -> nn.Module:
from peft import PeftModel, get_peft_model from peft import PeftModel, get_peft_model
if bnb_quantization_config is not None:
model = quantize_model(model, bnb_quantization_config)
assert not isinstance(model, TorchDDPModel), "Lora should be enabled before boosting the model." assert not isinstance(model, TorchDDPModel), "Lora should be enabled before boosting the model."
if pretrained_dir is None: if pretrained_dir is None:
return get_peft_model(model, lora_config) return get_peft_model(model, lora_config)

View File

@ -18,15 +18,15 @@ from .gptq_op import CaiGPTQLinearOp
HAS_GPTQ_CUDA = False HAS_GPTQ_CUDA = False
try: try:
from colossalai.kernel.op_builder.gptq import GPTQBuilder from colossalai.kernel.op_builder.gptq import GPTQBuilder
gptq_cuda = GPTQBuilder().load() gptq_cuda = GPTQBuilder().load()
HAS_GPTQ_CUDA = True HAS_GPTQ_CUDA = True
except ImportError: except ImportError:
warnings.warn('CUDA gptq is not installed') warnings.warn("CUDA gptq is not installed")
HAS_GPTQ_CUDA = False HAS_GPTQ_CUDA = False
class CaiQuantLinear(nn.Module): class CaiQuantLinear(nn.Module):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
super().__init__() super().__init__()
if bits not in [2, 4, 8]: if bits not in [2, 4, 8]:
@ -37,23 +37,28 @@ class CaiQuantLinear(nn.Module):
self.maxq = 2**self.bits - 1 self.maxq = 2**self.bits - 1
self.groupsize = groupsize if groupsize != -1 else infeatures self.groupsize = groupsize if groupsize != -1 else infeatures
self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) self.register_buffer("qweight", torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
self.register_buffer( self.register_buffer(
'qzeros', "qzeros",
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)) torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32),
self.register_buffer('scales', )
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) self.register_buffer(
"scales", torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)
)
if row_split: if row_split:
self.register_buffer( self.register_buffer(
'g_idx', "g_idx",
torch.tensor([(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], torch.tensor(
dtype=torch.int32)) [(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], dtype=torch.int32
),
)
else: else:
self.register_buffer('g_idx', self.register_buffer(
torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)) "g_idx", torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)
)
if bias: if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16))
else: else:
self.bias = None self.bias = None
@ -66,9 +71,11 @@ class CaiQuantLinear(nn.Module):
self.row_split = row_split self.row_split = row_split
def pack(self, linear, scales, zeros, g_idx=None): def pack(self, linear, scales, zeros, g_idx=None):
g_idx = (
g_idx = g_idx.clone() if g_idx is not None else torch.tensor( g_idx.clone()
[i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32) if g_idx is not None
else torch.tensor([i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)
)
scales = scales.t().contiguous() scales = scales.t().contiguous()
zeros = zeros.t().contiguous() zeros = zeros.t().contiguous()
@ -79,7 +86,6 @@ class CaiQuantLinear(nn.Module):
if linear.bias is not None: if linear.bias is not None:
self.bias = linear.bias.clone().half() self.bias = linear.bias.clone().half()
wn = 8
pbits = 32 pbits = 32
ptype = torch.int32 ptype = torch.int32
unsign_type = np.uint32 unsign_type = np.uint32
@ -88,9 +94,10 @@ class CaiQuantLinear(nn.Module):
intweight = [] intweight = []
for idx in range(self.infeatures): for idx in range(self.infeatures):
intweight.append( intweight.append(
torch.round( torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[
(linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:, :, None
None]) ]
)
intweight = torch.cat(intweight, dim=1) intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous() intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(unsign_type) intweight = intweight.numpy().astype(unsign_type)
@ -144,13 +151,16 @@ class CaiQuantLinear(nn.Module):
torch.tensor( torch.tensor(
[(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)], [(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)],
dtype=torch.int32, dtype=torch.int32,
device=self.g_idx.device)): device=self.g_idx.device,
),
):
self.g_idx = None self.g_idx = None
elif torch.equal( elif torch.equal(
self.g_idx, self.g_idx,
torch.tensor([i // self.groupsize for i in range(self.infeatures)], torch.tensor(
dtype=torch.int32, [i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device
device=self.g_idx.device)): ),
):
self.g_idx = None self.g_idx = None
if self.g_idx is not None: if self.g_idx is not None:
@ -165,7 +175,6 @@ class CaiQuantLinear(nn.Module):
outshape = x.shape[:-1] + (self.outfeatures,) outshape = x.shape[:-1] + (self.outfeatures,)
if HAS_GPTQ_CUDA and self.bits == 4: if HAS_GPTQ_CUDA and self.bits == 4:
if self.q4 is None: if self.q4 is None:
self.init_q4() self.init_q4()
@ -191,7 +200,6 @@ class CaiQuantLinear(nn.Module):
def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1): def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1):
qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1) qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1)
qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1) qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1)
scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1) scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1)
@ -203,24 +211,24 @@ def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1
zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num
for i in range(split_num): for i in range(split_num):
cai_linear.qweight[:, i * cai_split_out_features:(i + 1) * cai_linear.qweight[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = qweights[i][
cai_split_out_features] = qweights[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) * :, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
cai_split_out_features] ]
cai_linear.qzeros[:, i * zero_split_block:(i + 1) * cai_linear.qzeros[:, i * zero_split_block : (i + 1) * zero_split_block] = qzeros[i][
zero_split_block] = qzeros[i][:, tp_rank * zero_split_block:(tp_rank + 1) * zero_split_block] :, tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block
cai_linear.scales[:, i * cai_split_out_features:(i + 1) * ]
cai_split_out_features] = scales[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) * cai_linear.scales[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = scales[i][
cai_split_out_features] :, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
]
if cai_linear.bias is not None: if cai_linear.bias is not None:
cai_linear.bias[i * cai_split_out_features:(i + 1) * cai_linear.bias[i * cai_split_out_features : (i + 1) * cai_split_out_features] = bias[i][
cai_split_out_features] = bias[i][tp_rank * cai_split_out_features:(tp_rank + 1) * tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
cai_split_out_features] ]
cai_linear.g_idx.copy_(g_idx) cai_linear.g_idx.copy_(g_idx)
def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1): def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0) qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0)
qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0) qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0)
scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0) scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0)
@ -231,47 +239,40 @@ def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
idx_split_features = cai_linear.infeatures // split_num idx_split_features = cai_linear.infeatures // split_num
for i in range(split_num): for i in range(split_num):
cai_linear.qweight[i * cai_split_in_features:(i + 1) * cai_linear.qweight[i * cai_split_in_features : (i + 1) * cai_split_in_features, :] = qweights[i][
cai_split_in_features, :] = qweights[i][tp_rank * cai_split_in_features:(tp_rank + 1) * tp_rank * cai_split_in_features : (tp_rank + 1) * cai_split_in_features, :
cai_split_in_features, :] ]
cai_linear.qzeros[i * zero_split_block:(i + 1) * cai_linear.qzeros[i * zero_split_block : (i + 1) * zero_split_block, :] = qzeros[i][
zero_split_block, :] = qzeros[i][tp_rank * zero_split_block:(tp_rank + 1) * tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, :
zero_split_block, :] ]
cai_linear.scales[i * zero_split_block:(i + 1) * cai_linear.scales[i * zero_split_block : (i + 1) * zero_split_block, :] = scales[i][
zero_split_block, :] = scales[i][tp_rank * zero_split_block:(tp_rank + 1) * tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, :
zero_split_block, :] ]
cai_linear.g_idx[i * idx_split_features:(i + 1) * cai_linear.g_idx[i * idx_split_features : (i + 1) * idx_split_features] = g_idxs[i][
idx_split_features] = g_idxs[i][tp_rank * idx_split_features:(tp_rank + 1) * tp_rank * idx_split_features : (tp_rank + 1) * idx_split_features
idx_split_features] ]
if cai_linear.bias is not None: if cai_linear.bias is not None:
cai_linear.bias.copy_(gptq_linear.bias) cai_linear.bias.copy_(gptq_linear.bias)
class RowCaiQuantLinear(CaiQuantLinear, ParallelModule): class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
super().__init__(
super().__init__(bits, bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split
groupsize, )
infeatures,
outfeatures,
bias,
tp_size=tp_size,
tp_rank=tp_rank,
row_split=row_split)
self.process_group = None self.process_group = None
@staticmethod @staticmethod
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, def from_native_module(
**kwargs) -> ParallelModule: module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
# get the attributes # get the attributes
in_features = module.in_features in_features = module.in_features
# ensure only one process group is passed # ensure only one process group is passed
if isinstance(process_group, (list, tuple)): if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \ assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0] process_group = process_group[0]
tp_size = dist.get_world_size(process_group) tp_size = dist.get_world_size(process_group)
@ -282,15 +283,18 @@ class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
if in_features % tp_size != 0: if in_features % tp_size != 0:
raise ValueError( raise ValueError(
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
linear_1d = RowCaiQuantLinear(module.bits, )
linear_1d = RowCaiQuantLinear(
module.bits,
module.group_size, module.group_size,
module.in_features // tp_size, module.in_features // tp_size,
module.out_features, module.out_features,
module.bias is not None, module.bias is not None,
tp_size=tp_size, tp_size=tp_size,
tp_rank=tp_rank, tp_rank=tp_rank,
row_split=True) row_split=True,
)
linear_1d.process_group = process_group linear_1d.process_group = process_group
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
@ -306,30 +310,23 @@ class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
class ColCaiQuantLinear(CaiQuantLinear, ParallelModule): class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
super().__init__(
super().__init__(bits, bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split
groupsize, )
infeatures,
outfeatures,
bias,
tp_size=tp_size,
tp_rank=tp_rank,
row_split=row_split)
self.process_group = None self.process_group = None
@staticmethod @staticmethod
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, def from_native_module(
**kwargs) -> ParallelModule: module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
# get the attributes # get the attributes
in_features = module.in_features in_features = module.in_features
# ensure only one process group is passed # ensure only one process group is passed
if isinstance(process_group, (list, tuple)): if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \ assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0] process_group = process_group[0]
tp_size = dist.get_world_size(process_group) tp_size = dist.get_world_size(process_group)
@ -340,14 +337,17 @@ class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
if in_features % tp_size != 0: if in_features % tp_size != 0:
raise ValueError( raise ValueError(
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
linear_1d = ColCaiQuantLinear(module.bits, )
linear_1d = ColCaiQuantLinear(
module.bits,
module.group_size, module.group_size,
module.in_features, module.in_features,
module.out_features // tp_size, module.out_features // tp_size,
module.bias is not None, module.bias is not None,
tp_size=tp_size, tp_size=tp_size,
tp_rank=tp_rank) tp_rank=tp_rank,
)
linear_1d.process_group = process_group linear_1d.process_group = process_group
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)

View File

@ -5,6 +5,7 @@ import torch
from .kvcache_manager import MemoryManager from .kvcache_manager import MemoryManager
# adapted from: lightllm/server/router/model_infer/infer_batch.py # adapted from: lightllm/server/router/model_infer/infer_batch.py
@dataclass @dataclass
class BatchInferState: class BatchInferState:

View File

@ -19,8 +19,11 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
from ._utils import copy_kv_to_mem_cache from ._utils import copy_kv_to_mem_cache
try: try:
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_llama2_context_attention_fwd
from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd as chatglm2_rotary_emb_fwd from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd as chatglm2_rotary_emb_fwd
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_llama2_context_attention_fwd,
)
HAS_LIGHTLLM_KERNEL = True HAS_LIGHTLLM_KERNEL = True
except: except:
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")

View File

@ -4,7 +4,6 @@ import torch
from torch.nn import LayerNorm from torch.nn import LayerNorm
import colossalai.shardformer.layer as col_nn import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.modeling.bloom import build_bloom_alibi_tensor_fn
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy
@ -40,33 +39,36 @@ class BloomModelInferPolicy(BloomForCausalLMPolicy):
policy = super().module_policy() policy = super().module_policy()
if self.shard_config.inference_gptq: if self.shard_config.inference_gptq:
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, policy[BloomBlock] = ModulePolicyDescription(
"self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size
// self.shard_config.tensor_parallel_size,
"self_attention.split_size": self.model.config.hidden_size
// self.shard_config.tensor_parallel_size,
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
}, },
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attention.query_key_value", suffix="self_attention.query_key_value",
target_module=ColCaiQuantLinear, target_module=ColCaiQuantLinear,
kwargs={'split_num': 3}), kwargs={"split_num": 3},
),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attention.dense", suffix="self_attention.dense", target_module=RowCaiQuantLinear, kwargs={"split_num": 1}
target_module=RowCaiQuantLinear, ),
kwargs={'split_num': 1}),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attention.attention_dropout", suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput, target_module=col_nn.DropoutForParallelInput,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h", suffix="mlp.dense_h_to_4h", target_module=ColCaiQuantLinear, kwargs={"split_num": 1}
target_module=ColCaiQuantLinear, ),
kwargs={'split_num': 1}),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h", suffix="mlp.dense_4h_to_h", target_module=RowCaiQuantLinear, kwargs={"split_num": 1}
target_module=RowCaiQuantLinear, ),
kwargs={'split_num': 1}), ],
]) )
# NOTE set inference mode to shard config # NOTE set inference mode to shard config
self.shard_config._infer() self.shard_config._infer()

View File

@ -13,6 +13,7 @@ from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forw
try: try:
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward
HAS_TRITON_RMSNORM = True HAS_TRITON_RMSNORM = True
except: except:
print("you should install triton from https://github.com/openai/triton") print("you should install triton from https://github.com/openai/triton")
@ -21,6 +22,7 @@ except:
def get_triton_rmsnorm_forward(): def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM: if HAS_TRITON_RMSNORM:
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)

View File

@ -1,27 +1,29 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama // Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h> #include <c10/cuda/CUDAGuard.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <cstdint> #include <cstdint>
#include <cstdio> #include <cstdio>
#include "util.cuh"
#include "tuning.h"
#include "cuda_buffers.cuh"
#include "q4_matrix.cuh"
#include "q4_matmul.cuh"
#include "column_remap.cuh" #include "column_remap.cuh"
#include "cuda_buffers.cuh"
#include "q4_matmul.cuh"
#include "q4_matrix.cuh"
#include "tuning.h"
#include "util.cuh"
// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a // Check CUDA return code. We don't want to include Torch headers in the .cu
// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of // files because parsing them adds almost a minute to the compile time on a
// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console. // 12900K. Also passing exceptions back to Python is super tricky, so in place
// of exceptions, CUDA functions return with a cudaError_t which we can parse
// and dump to the console.
void check_cuda(cudaError_t ret) void check_cuda(cudaError_t ret) {
{ switch (ret) {
switch (ret)
{
case cudaSuccess: case cudaSuccess:
break; break;
@ -31,9 +33,9 @@ void check_cuda(cudaError_t ret)
break; break;
default: default:
printf(" **** CUDA error\n"); \ printf(" **** CUDA error\n");
printf(" **** %s\n", cudaGetErrorString(ret)); \ printf(" **** %s\n", cudaGetErrorString(ret));
TORCH_CHECK(false, "CUDA error"); \ TORCH_CHECK(false, "CUDA error");
break; break;
} }
} }
@ -42,12 +44,25 @@ void check_cuda(cudaError_t ret)
#define STRINGIFY_(__x) #__x #define STRINGIFY_(__x) #__x
#define STRINGIFY(__x) STRINGIFY_(__x) #define STRINGIFY(__x) STRINGIFY_(__x)
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) #define TORCH_CHECK_DTYPE(__x, __dtype) \
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, \
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") #define TORCH_CHECK_DTYPE_OPT(__x, __dtype) \
#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod)) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, \
#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small") #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) \
TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, \
#__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) \
TORCH_CHECK((__x).device().is_meta() || \
(__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, \
#__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) \
TORCH_CHECK((__x).size(__dim_x) % __mod == 0, \
#__x ".shape[" STRINGIFY( \
__dim_x) "] must be a multiple of " STRINGIFY(__mod))
#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) \
TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
#define TORCH_CHECK_DEVICE_INDEX(__index) \ #define TORCH_CHECK_DEVICE_INDEX(__index) \
do { \ do { \
@ -66,75 +81,49 @@ do { \
TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \ TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
} while (0) } while (0)
int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) {
{
int groupsize = w.size(0) * 8 / w_zeros.size(0); int groupsize = w.size(0) * 8 / w_zeros.size(0);
TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]") TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8,
"w.shape[-2] must be a multiple of zeros.shape[-2]")
return groupsize; return groupsize;
} }
// Tuning parameters // Tuning parameters
ExLlamaTuning tuningParams; ExLlamaTuning tuningParams;
void set_tuning_params void set_tuning_params(int matmul_recons_thd, bool matmul_fused_remap,
( bool matmul_no_half2) {
int matmul_recons_thd,
bool matmul_fused_remap,
bool matmul_no_half2
)
{
tuningParams.matmul_recons_thd = matmul_recons_thd; tuningParams.matmul_recons_thd = matmul_recons_thd;
tuningParams.matmul_fused_remap = matmul_fused_remap; tuningParams.matmul_fused_remap = matmul_fused_remap;
tuningParams.matmul_no_half2 = matmul_no_half2; tuningParams.matmul_no_half2 = matmul_no_half2;
} }
// Release all unmanaged objects allocated by the extension // Release all unmanaged objects allocated by the extension
void cleanup() void cleanup() {
{
cleanup_buffers_cuda(); cleanup_buffers_cuda();
g_q4_free_matrices(); g_q4_free_matrices();
} }
// Prepare buffers for forward pass // Prepare buffers for forward pass
void prepare_buffers void prepare_buffers(torch::Device device, torch::Tensor temp_state,
( torch::Tensor temp_dq) {
torch::Device device,
torch::Tensor temp_state,
torch::Tensor temp_dq
)
{
int device_index = device.index(); int device_index = device.index();
TORCH_CHECK_DEVICE_INDEX(device_index); TORCH_CHECK_DEVICE_INDEX(device_index);
const at::cuda::OptionalCUDAGuard device_guard(device); const at::cuda::OptionalCUDAGuard device_guard(device);
prepare_buffers_cuda prepare_buffers_cuda(device_index,
(
device_index,
// buffer size used for sanity checks // buffer size used for sanity checks
temp_state.numel(), temp_state.numel(), (half*)temp_state.data_ptr(),
(half*) temp_state.data_ptr(), (half*)temp_dq.data_ptr());
(half*) temp_dq.data_ptr()
);
} }
// Create Q4Matrix, return handle // Create Q4Matrix, return handle
uintptr_t make_q4 uintptr_t make_q4(torch::Tensor qweight, torch::Tensor qzeros,
( torch::Tensor scales, torch::Tensor g_idx, int device) {
torch::Tensor qweight,
torch::Tensor qzeros,
torch::Tensor scales,
torch::Tensor g_idx,
int device
)
{
TORCH_CHECK_DTYPE(qweight, kInt); TORCH_CHECK_DTYPE(qweight, kInt);
TORCH_CHECK_DTYPE(qzeros, kInt); TORCH_CHECK_DTYPE(qzeros, kInt);
TORCH_CHECK_DTYPE(scales, kHalf); TORCH_CHECK_DTYPE(scales, kHalf);
@ -147,34 +136,22 @@ uintptr_t make_q4
int height = qweight.size(0) * 8; int height = qweight.size(0) * 8;
int groups = qzeros.size(0); int groups = qzeros.size(0);
Q4Matrix* m = new Q4Matrix Q4Matrix* m = new Q4Matrix(
( height, width, groups,
height,
width,
groups,
(uint32_t*) qweight.data_ptr(), (uint32_t*)qweight.data_ptr(), (uint32_t*)qzeros.data_ptr(),
(uint32_t*) qzeros.data_ptr(),
(half*)scales.data_ptr(), (half*)scales.data_ptr(),
g_idx.device().is_meta() ? NULL : (uint32_t*)g_idx.data_ptr(), g_idx.device().is_meta() ? NULL : (uint32_t*)g_idx.data_ptr(),
device device);
);
g_q4_keep_matrix(m); g_q4_keep_matrix(m);
return reinterpret_cast<uintptr_t>(m); return reinterpret_cast<uintptr_t>(m);
} }
// Matmul half @ quant -> half // Matmul half @ quant -> half
void q4_matmul void q4_matmul(torch::Tensor x, uintptr_t w, torch::Tensor out) {
(
torch::Tensor x,
uintptr_t w,
torch::Tensor out
)
{
Q4Matrix* wm = reinterpret_cast<Q4Matrix*>(w); Q4Matrix* wm = reinterpret_cast<Q4Matrix*>(w);
TORCH_CHECK_DTYPE(x, kHalf); TORCH_CHECK_DTYPE(x, kHalf);
@ -186,41 +163,20 @@ void q4_matmul
int x_height = x.size(0); int x_height = x.size(0);
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd) if (tuningParams.matmul_recons_thd == 0 ||
{ x_height < tuningParams.matmul_recons_thd) {
q4_matmul_cuda q4_matmul_cuda(&tuningParams, (half*)x.data_ptr(), x_height, wm,
( (half*)out.data_ptr());
&tuningParams, } else {
(half*) x.data_ptr(), q4_matmul_recons_cuda(&tuningParams, (half*)x.data_ptr(), x_height, wm,
x_height,
wm,
(half*) out.data_ptr()
);
}
else
{
q4_matmul_recons_cuda
(
&tuningParams,
(half*) x.data_ptr(),
x_height,
wm,
(half*)out.data_ptr(), (half*)out.data_ptr(),
at::cuda::getCurrentCUDABlasHandle() at::cuda::getCurrentCUDABlasHandle());
);
} }
} }
// Remap columns in half tensor // Remap columns in half tensor
void column_remap void column_remap(torch::Tensor x, torch::Tensor x_new, torch::Tensor x_map) {
(
torch::Tensor x,
torch::Tensor x_new,
torch::Tensor x_map
)
{
TORCH_CHECK_DTYPE(x, kHalf); TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(x_new, kHalf); TORCH_CHECK_DTYPE(x_new, kHalf);
TORCH_CHECK_DTYPE(x_map, kInt); TORCH_CHECK_DTYPE(x_map, kInt);
@ -233,19 +189,11 @@ void column_remap
const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
column_remap_cuda column_remap_cuda((half*)x.data_ptr(), (half*)x_new.data_ptr(), height, width,
( (uint32_t*)x_map.data_ptr());
(half*) x.data_ptr(),
(half*) x_new.data_ptr(),
height,
width,
(uint32_t*) x_map.data_ptr()
);
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("set_tuning_params", &set_tuning_params, "set_tuning_params"); m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
m.def("prepare_buffers", &prepare_buffers, "prepare_buffers"); m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
m.def("cleanup", &cleanup, "cleanup"); m.def("cleanup", &cleanup, "cleanup");

View File

@ -10,7 +10,6 @@ except ImportError:
print("please install triton from https://github.com/openai/triton") print("please install triton from https://github.com/openai/triton")
if HAS_TRITON: if HAS_TRITON:
# adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py # adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py
@triton.jit @triton.jit
def _fwd_copy_kv_cache_dest( def _fwd_copy_kv_cache_dest(

View File

@ -13,6 +13,9 @@ except ImportError:
print("please install triton from https://github.com/openai/triton") print("please install triton from https://github.com/openai/triton")
try: try:
from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import (
token_att_fwd as lightllm_bloom_token_att_fwd,
)
from lightllm.models.llama2.triton_kernel.token_attention_nopad_att1 import ( from lightllm.models.llama2.triton_kernel.token_attention_nopad_att1 import (
token_att_fwd as lightllm_llama2_token_att_fwd, token_att_fwd as lightllm_llama2_token_att_fwd,
) )
@ -22,11 +25,15 @@ try:
from lightllm.models.llama2.triton_kernel.token_attention_nopad_softmax import ( from lightllm.models.llama2.triton_kernel.token_attention_nopad_softmax import (
token_softmax_fwd as lightllm_llama2_token_softmax_fwd, token_softmax_fwd as lightllm_llama2_token_softmax_fwd,
) )
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import (
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fw2 token_att_fwd as lightllm_llama_token_att_fwd,
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd )
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import (
from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd token_att_fwd2 as lightllm_llama_token_att_fw2,
)
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import (
token_softmax_fwd as lightllm_llama_token_softmax_fwd,
)
HAS_TRITON_TOKEN_ATTENTION = True HAS_TRITON_TOKEN_ATTENTION = True
except ImportError: except ImportError:

View File

@ -44,8 +44,8 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
return unpickle return unpickle
def check_for_nccl_backend(group):
def check_for_nccl_backend(group):
pg = group or c10d._get_default_group() pg = group or c10d._get_default_group()
# Gate PG wrapper check on Gloo availability. # Gate PG wrapper check on Gloo availability.
if c10d._GLOO_AVAILABLE: if c10d._GLOO_AVAILABLE:
@ -54,10 +54,8 @@ def check_for_nccl_backend(group):
while isinstance(pg, c10d._ProcessGroupWrapper): while isinstance(pg, c10d._ProcessGroupWrapper):
pg = pg.wrapped_pg pg = pg.wrapped_pg
return ( return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
c10d.is_nccl_available() and
pg.name() == c10d.Backend.NCCL
)
def _broadcast_object_list( def _broadcast_object_list(
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None

View File

@ -0,0 +1,7 @@
from .bnb import quantize_model
from .bnb_config import BnbQuantizationConfig
__all__ = [
"BnbQuantizationConfig",
"quantize_model",
]

View File

@ -0,0 +1,321 @@
# adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py
import logging
import torch
import torch.nn as nn
from .bnb_config import BnbQuantizationConfig
try:
import bitsandbytes as bnb
IS_4BIT_BNB_AVAILABLE = bnb.__version__ >= "0.39.0"
IS_8BIT_BNB_AVAILABLE = bnb.__version__ >= "0.37.2"
except ImportError:
pass
logger = logging.getLogger(__name__)
def quantize_model(
model: torch.nn.Module,
bnb_quantization_config: BnbQuantizationConfig,
):
"""
This function will quantize the input loaded model with the associated config passed in `bnb_quantization_config`.
We will quantize the model and put the model on the GPU.
Args:
model (`torch.nn.Module`):
Input model. The model already loaded
bnb_quantization_config (`BnbQuantizationConfig`):
The bitsandbytes quantization parameters
Returns:
`torch.nn.Module`: The quantized model
"""
load_in_4bit = bnb_quantization_config.load_in_4bit
load_in_8bit = bnb_quantization_config.load_in_8bit
if load_in_8bit and not IS_8BIT_BNB_AVAILABLE:
raise ImportError(
"You have a version of `bitsandbytes` that is not compatible with 8bit quantization,"
" make sure you have the latest version of `bitsandbytes` installed."
)
if load_in_4bit and not IS_4BIT_BNB_AVAILABLE:
raise ValueError(
"You have a version of `bitsandbytes` that is not compatible with 4bit quantization,"
"make sure you have the latest version of `bitsandbytes` installed."
)
# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
if bnb_quantization_config.skip_modules is None:
bnb_quantization_config.skip_modules = get_keys_to_not_convert(model)
modules_to_not_convert = bnb_quantization_config.skip_modules
# We add the modules we want to keep in full precision
if bnb_quantization_config.keep_in_fp32_modules is None:
bnb_quantization_config.keep_in_fp32_modules = []
keep_in_fp32_modules = bnb_quantization_config.keep_in_fp32_modules
# compatibility with peft
model.is_loaded_in_4bit = load_in_4bit
model.is_loaded_in_8bit = load_in_8bit
# assert model_device is cuda
model_device = next(model.parameters()).device
model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert)
# convert param to the right dtype
dtype = bnb_quantization_config.torch_dtype
for name, param in model.state_dict().items():
if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules):
param.to(torch.float32)
if param.dtype != torch.float32:
name = name.replace(".weight", "").replace(".bias", "")
param = getattr(model, name, None)
if param is not None:
param.to(torch.float32)
elif torch.is_floating_point(param):
param.to(dtype)
if model_device.type == "cuda":
# move everything to cpu in the first place because we can't do quantization if the weights are already on cuda
model.cuda(torch.cuda.current_device())
torch.cuda.empty_cache()
elif torch.cuda.is_available():
model.to(torch.cuda.current_device())
logger.info(
f"The model device type is {model_device.type}. However, cuda is needed for quantization."
"We move the model to cuda."
)
else:
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
return model
def replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=None, current_key_name=None):
"""
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules or by `bnb.nn.Linear4bit`
modules from the `bitsandbytes`library. The function will be run recursively and replace `torch.nn.Linear` modules.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
modules_to_not_convert (`List[str]`):
Names of the modules to not quantize convert. In practice we keep the `lm_head` in full precision for
numerical stability reasons.
current_key_name (`List[str]`, *optional*):
An array to track the current key of the recursion. This is used to check whether the current key (part of
it) is not in the list of modules to not convert.
"""
if modules_to_not_convert is None:
modules_to_not_convert = []
model, has_been_replaced = _replace_with_bnb_layers(
model, bnb_quantization_config, modules_to_not_convert, current_key_name
)
if not has_been_replaced:
logger.warning(
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
" this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers."
" Please double check your model architecture, or submit an issue on github if you think this is"
" a bug."
)
return model
def _replace_with_bnb_layers(
model,
bnb_quantization_config,
modules_to_not_convert=None,
current_key_name=None,
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
# bitsandbytes will initialize CUDA on import, so it needs to be imported lazily
has_been_replaced = False
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
current_key_name_str = ".".join(current_key_name)
proceed = True
for key in modules_to_not_convert:
if (
(key in current_key_name_str) and (key + "." in current_key_name_str)
) or key == current_key_name_str:
proceed = False
break
if proceed:
# Load bnb module with empty weight and replace ``nn.Linear` module
if bnb_quantization_config.load_in_8bit:
bnb_module = bnb.nn.Linear8bitLt(
module.in_features,
module.out_features,
module.bias is not None,
has_fp16_weights=False,
threshold=bnb_quantization_config.llm_int8_threshold,
)
elif bnb_quantization_config.load_in_4bit:
bnb_module = bnb.nn.Linear4bit(
module.in_features,
module.out_features,
module.bias is not None,
bnb_quantization_config.bnb_4bit_compute_dtype,
compress_statistics=bnb_quantization_config.bnb_4bit_use_double_quant,
quant_type=bnb_quantization_config.bnb_4bit_quant_type,
)
else:
raise ValueError("load_in_8bit and load_in_4bit can't be both False")
bnb_module.weight.data = module.weight.data
bnb_module.weight.skip_zero_check = True
if module.bias is not None:
bnb_module.bias.data = module.bias.data
bnb_module.bias.skip_zero_check = True
bnb_module.requires_grad_(False)
setattr(model, name, bnb_module)
has_been_replaced = True
if len(list(module.children())) > 0:
_, _has_been_replaced = _replace_with_bnb_layers(
module, bnb_quantization_config, modules_to_not_convert, current_key_name
)
has_been_replaced = has_been_replaced | _has_been_replaced
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced
def get_keys_to_not_convert(model):
r"""
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
int8.
Parameters:
model (`torch.nn.Module`):
Input model
"""
# Create a copy of the model
# with init_empty_weights():
# tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
tied_model = model
tied_params = find_tied_parameters(tied_model)
# For compatibility with Accelerate < 0.18
if isinstance(tied_params, dict):
tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys())
else:
tied_keys = sum(tied_params, [])
has_tied_params = len(tied_keys) > 0
# Check if it is a base model
is_base_model = False
if hasattr(model, "base_model_prefix"):
is_base_model = not hasattr(model, model.base_model_prefix)
# Ignore this for base models (BertModel, GPT2Model, etc.)
if (not has_tied_params) and is_base_model:
return []
# otherwise they have an attached head
list_modules = list(model.named_children())
list_last_module = [list_modules[-1][0]]
# add last module together with tied weights
intersection = set(list_last_module) - set(tied_keys)
list_untouched = list(set(tied_keys)) + list(intersection)
# remove ".weight" from the keys
names_to_remove = [".weight", ".bias"]
filtered_module_names = []
for name in list_untouched:
for name_to_remove in names_to_remove:
if name_to_remove in name:
name = name.replace(name_to_remove, "")
filtered_module_names.append(name)
return filtered_module_names
def find_tied_parameters(model: nn.Module, **kwargs):
"""
Find the tied parameters in a given model.
<Tip warning={true}>
The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore
them.
</Tip>
Args:
model (`torch.nn.Module`): The model to inspect.
Returns:
List[List[str]]: A list of lists of parameter names being all tied together.
Example:
```py
>>> from collections import OrderedDict
>>> import torch.nn as nn
>>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))]))
>>> model.linear2.weight = model.linear1.weight
>>> find_tied_parameters(model)
[['linear1.weight', 'linear2.weight']]
```
"""
# Initialize result and named_parameters before recursing.
named_parameters = kwargs.get("named_parameters", None)
prefix = kwargs.get("prefix", "")
result = kwargs.get("result", {})
if named_parameters is None:
named_parameters = {n: p for n, p in model.named_parameters()}
else:
# A tied parameter will not be in the full `named_parameters` seen above but will be in the `named_parameters`
# of the submodule it belongs to. So while recursing we track the names that are not in the initial
# `named_parameters`.
for name, parameter in model.named_parameters():
full_name = name if prefix == "" else f"{prefix}.{name}"
if full_name not in named_parameters:
# When we find one, it has to be one of the existing parameters.
for new_name, new_param in named_parameters.items():
if new_param is parameter:
if new_name not in result:
result[new_name] = []
result[new_name].append(full_name)
# Once we have treated direct parameters, we move to the child modules.
for name, child in model.named_children():
child_name = name if prefix == "" else f"{prefix}.{name}"
find_tied_parameters(child, named_parameters=named_parameters, prefix=child_name, result=result)
return FindTiedParametersResult([sorted([weight] + list(set(tied))) for weight, tied in result.items()])
class FindTiedParametersResult(list):
"""
This is a subclass of a list to handle backward compatibility for Transformers. Do not rely on the fact this is not
a list or on the `values` method as in the future this will be removed.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def values(self):
return sum([x[1:] for x in self], [])

View File

@ -0,0 +1,113 @@
# adapted from Hugging Face accelerate/utils/dataclasses.py
import warnings
from dataclasses import dataclass, field
from typing import List
import torch
@dataclass
class BnbQuantizationConfig:
"""
A plugin to enable BitsAndBytes 4bit and 8bit quantization
"""
load_in_8bit: bool = field(default=False, metadata={"help": "enable 8bit quantization."})
llm_int8_threshold: float = field(
default=6.0, metadata={"help": "value of the outliner threshold. only relevant when load_in_8bit=True"}
)
load_in_4bit: bool = field(default=False, metadata={"help": "enable 4bit quantization."})
bnb_4bit_quant_type: str = field(
default="fp4",
metadata={
"help": "set the quantization data type in the `bnb.nn.Linear4Bit` layers. Options are {'fp4','np4'}."
},
)
bnb_4bit_use_double_quant: bool = field(
default=False,
metadata={
"help": "enable nested quantization where the quantization constants from the first quantization are quantized again."
},
)
bnb_4bit_compute_dtype: bool = field(
default="fp16",
metadata={
"help": "This sets the computational type which might be different than the input time. For example, inputs might be "
"fp32, but computation can be set to bf16 for speedups. Options are {'fp32','fp16','bf16'}."
},
)
torch_dtype: torch.dtype = field(
default=None,
metadata={
"help": "this sets the dtype of the remaining non quantized layers. `bitsandbytes` library suggests to set the value"
"to `torch.float16` for 8 bit model and use the same dtype as the compute dtype for 4 bit model "
},
)
skip_modules: List[str] = field(
default=None,
metadata={
"help": "an explicit list of the modules that we don't quantize. The dtype of these modules will be `torch_dtype`."
},
)
keep_in_fp32_modules: List[str] = field(
default=None,
metadata={"help": "an explicit list of the modules that we don't quantize. We keep them in `torch.float32`."},
)
def __post_init__(self):
if isinstance(self.bnb_4bit_compute_dtype, str):
if self.bnb_4bit_compute_dtype == "fp32":
self.bnb_4bit_compute_dtype = torch.float32
elif self.bnb_4bit_compute_dtype == "fp16":
self.bnb_4bit_compute_dtype = torch.float16
elif self.bnb_4bit_compute_dtype == "bf16":
self.bnb_4bit_compute_dtype = torch.bfloat16
else:
raise ValueError(
f"bnb_4bit_compute_dtype must be in ['fp32','fp16','bf16'] but found {self.bnb_4bit_compute_dtype}"
)
elif not isinstance(self.bnb_4bit_compute_dtype, torch.dtype):
raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype")
if self.skip_modules is not None and not isinstance(self.skip_modules, list):
raise ValueError("skip_modules must be a list of strings")
if self.keep_in_fp32_modules is not None and not isinstance(self.keep_in_fp32_modules, list):
raise ValueError("keep_in_fp_32_modules must be a list of strings")
if self.load_in_4bit:
self.target_dtype = "int4"
if self.load_in_8bit:
self.target_dtype = torch.int8
if self.load_in_4bit and self.llm_int8_threshold != 6.0:
warnings.warn("llm_int8_threshold can only be used for model loaded in 8bit")
if isinstance(self.torch_dtype, str):
if self.torch_dtype == "fp32":
self.torch_dtype = torch.float32
elif self.torch_dtype == "fp16":
self.torch_dtype = torch.float16
elif self.torch_dtype == "bf16":
self.torch_dtype = torch.bfloat16
else:
raise ValueError(f"torch_dtype must be in ['fp32','fp16','bf16'] but found {self.torch_dtype}")
if self.load_in_8bit and self.torch_dtype is None:
self.torch_dtype = torch.float16
if self.load_in_4bit and self.torch_dtype is None:
self.torch_dtype = self.bnb_4bit_compute_dtype
if not isinstance(self.torch_dtype, torch.dtype):
raise ValueError("torch_dtype must be a torch.dtype")

View File

@ -190,6 +190,7 @@ def calculate_global_norm_from_list(norm_list):
total_norm += norm**2.0 total_norm += norm**2.0
return math.sqrt(total_norm) return math.sqrt(total_norm)
def sync_tensor(flat_tensor, tensor_list): def sync_tensor(flat_tensor, tensor_list):
""" """
Synchronize the flattened tensor and unflattened tensor list. When Synchronize the flattened tensor and unflattened tensor list. When

View File

@ -187,6 +187,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for param_group in self.optim.param_groups: for param_group in self.optim.param_groups:
group_params = param_group["params"] group_params = param_group["params"]
for param in group_params: for param in group_params:
if not hasattr(param, "skip_zero_check") or param.skip_zero_check is False:
assert ( assert (
param.dtype == self._dtype param.dtype == self._dtype
), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`" ), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"

View File

@ -1,12 +1,10 @@
import argparse import argparse
import logging
import os import os
import time import time
import torch import torch
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig from auto_gptq import AutoGPTQForCausalLM
from auto_gptq.nn_modules.qlinear import GeneralQuantLinear from transformers import BloomTokenizerFast
from transformers import AutoTokenizer, BloomForCausalLM, BloomTokenizerFast, LlamaForCausalLM, LlamaTokenizer
import colossalai import colossalai
from colossalai.inference.tensor_parallel.engine import TPInferEngine from colossalai.inference.tensor_parallel.engine import TPInferEngine
@ -14,7 +12,7 @@ from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
def print_perf_stats(latency_set, config, bs, warmup=3): def print_perf_stats(latency_set, config, bs, warmup=3):
@ -37,7 +35,6 @@ def print_perf_stats(latency_set, config, bs, warmup=3):
def bench_bloom(args): def bench_bloom(args):
pretrained_model_dir = args.path pretrained_model_dir = args.path
quantized_model_dir = args.quantized_path quantized_model_dir = args.quantized_path
max_batch_size = args.batch_size max_batch_size = args.batch_size
@ -48,9 +45,9 @@ def bench_bloom(args):
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
# load quantized model to the first GPU # load quantized model to the first GPU
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, model = AutoGPTQForCausalLM.from_quantized(
device=torch.cuda.current_device(), quantized_model_dir, device=torch.cuda.current_device(), inject_fused_attention=False
inject_fused_attention=False) )
model = model.half() model = model.half()
@ -60,22 +57,22 @@ def bench_bloom(args):
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
input_tokens = { input_tokens = {
"input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'), "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"),
"attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda') "attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"),
} }
# init TPInferEngine and shard the original model # init TPInferEngine and shard the original model
# To benchmark torch original, comment out the line of optimizing model # To benchmark torch original, comment out the line of optimizing model
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, shard_config = ShardConfig(
inference_only=True, enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
inference_gptq=True) )
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
# prepare data for generation # prepare data for generation
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
input_tokens = { input_tokens = {
"input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)), "input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)),
"attention_mask": torch.ones((max_batch_size, max_input_len)) "attention_mask": torch.ones((max_batch_size, max_input_len)),
} }
for t in input_tokens: for t in input_tokens:
if torch.is_tensor(input_tokens[t]): if torch.is_tensor(input_tokens[t]):
@ -99,7 +96,7 @@ def bench_bloom(args):
def check_bloom(rank, world_size, port, args): def check_bloom(rank, world_size, port, args):
disable_existing_loggers() disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
bench_bloom(args) bench_bloom(args)
@ -111,12 +108,12 @@ def test_bloom(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-p', '--path', type=str, help='Model path', required=True) parser.add_argument("-p", "--path", type=str, help="Model path", required=True)
parser.add_argument('-q', '--quantized_path', type=str, help='Model path', required=True) parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True)
parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size') parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size') parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size")
parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length') parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length")
parser.add_argument('--output_len', type=int, default=128, help='Maximum output length') parser.add_argument("--output_len", type=int, default=128, help="Maximum output length")
args = parser.parse_args() args = parser.parse_args()

View File

@ -15,3 +15,4 @@ sentencepiece
google google
protobuf protobuf
peft>=0.7.1 peft>=0.7.1
bitsandbytes>=0.39.0

View File

@ -1,4 +1,4 @@
from typing import Callable, Iterator, List, Tuple, Union, Dict from typing import Callable, Dict, Iterator, List, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist

View File

@ -51,7 +51,6 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config=None)
# raise e # raise e
@parameterize("stage", [2]) @parameterize("stage", [2])
def check_low_level_zero_plugin(stage: int, early_stop: bool = True): def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
"""check low level zero plugin over model zoo """check low level zero plugin over model zoo
@ -118,6 +117,7 @@ def check_low_level_zero_lora(stage, model_name, early_stop: bool = True):
print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n") print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()]) assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
def run_dist(rank, world_size, port, early_stop: bool = True): def run_dist(rank, world_size, port, early_stop: bool = True):
# init dist env # init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")

View File

@ -1,10 +1,11 @@
from copy import deepcopy
from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from peft import LoraConfig
from torchvision.models import resnet18 from torchvision.models import resnet18
from utils import shared_tempdir from utils import shared_tempdir
from typing import Optional
from peft import LoraConfig
from copy import deepcopy
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
@ -131,12 +132,15 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo
# return repr(e) # return repr(e)
raise e raise e
@clear_cache_before_run() @clear_cache_before_run()
@parameterize("stage", [2]) @parameterize("stage", [2])
@parameterize("shard", [True, False]) @parameterize("shard", [True, False])
@parameterize("offload", [False, True]) @parameterize("offload", [False, True])
@parameterize("model_name", ["transformers_llama"]) @parameterize("model_name", ["transformers_llama"])
def check_low_level_zero_lora_checkpointIO(stage: int, shard: bool, offload: bool, model_name: str, early_stop: bool = True): def check_low_level_zero_lora_checkpointIO(
stage: int, shard: bool, offload: bool, model_name: str, early_stop: bool = True
):
passed_models = [] passed_models = []
failed_info = {} # (model_name, error) pair failed_info = {} # (model_name, error) pair
@ -166,6 +170,7 @@ def check_low_level_zero_lora_checkpointIO(stage: int, shard: bool, offload: boo
print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n") print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()]) assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost") colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost")
check_low_level_zero_checkpointIO() check_low_level_zero_checkpointIO()

View File

@ -1,16 +1,8 @@
import math
import time
import numpy as np
import pytest import pytest
import torch import torch
import torch.nn as nn
import transformers
from packaging import version from packaging import version
try: try:
import triton
import triton.language as tl
HAS_TRITON = True HAS_TRITON = True
except ImportError: except ImportError:
HAS_TRITON = False HAS_TRITON = False
@ -22,6 +14,7 @@ try:
from exllama_kernels import prepare_buffers, set_tuning_params from exllama_kernels import prepare_buffers, set_tuning_params
from colossalai.inference.quant.gptq import CaiQuantLinear from colossalai.inference.quant.gptq import CaiQuantLinear
HAS_AUTO_GPTQ = True HAS_AUTO_GPTQ = True
except: except:
HAS_AUTO_GPTQ = False HAS_AUTO_GPTQ = False
@ -32,13 +25,14 @@ import warnings
HAS_GPTQ_CUDA = False HAS_GPTQ_CUDA = False
try: try:
from colossalai.kernel.op_builder.gptq import GPTQBuilder from colossalai.kernel.op_builder.gptq import GPTQBuilder
gptq_cuda = GPTQBuilder().load() gptq_cuda = GPTQBuilder().load()
HAS_GPTQ_CUDA = True HAS_GPTQ_CUDA = True
except ImportError: except ImportError:
warnings.warn('CUDA gptq is not installed') warnings.warn("CUDA gptq is not installed")
HAS_GPTQ_CUDA = False HAS_GPTQ_CUDA = False
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
max_inner_outer_dim = 1 max_inner_outer_dim = 1
max_input_len = 1 max_input_len = 1
@ -64,9 +58,9 @@ def init_buffer(cai_linear, use_act_order=False):
max_input_len = 4096 max_input_len = 4096
# The temp_state buffer is required to reorder X in the act-order case. # The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
gptq_temp_state_buffer = torch.zeros((max_input_len, max_inner_outer_dim), gptq_temp_state_buffer = torch.zeros(
dtype=torch.float16, (max_input_len, max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device()
device=torch.cuda.current_device()) )
gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device()) gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device())
gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer) gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer)
@ -77,10 +71,11 @@ def init_buffer(cai_linear, use_act_order=False):
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ, @pytest.mark.skipif(
reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq") not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ,
reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq",
)
def test_gptq_linear(): def test_gptq_linear():
infeature = 1024 infeature = 1024
outfeature = 1024 outfeature = 1024
group_size = 128 group_size = 128
@ -120,7 +115,7 @@ def test_gptq_linear():
max_input_len = 2048 max_input_len = 2048
buffers = { buffers = {
"temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device), "temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device),
"temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device) "temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device),
} }
prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"]) prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"])
@ -146,5 +141,4 @@ def test_gptq_linear():
if __name__ == "__main__": if __name__ == "__main__":
test_gptq_linear() test_gptq_linear()

View File

@ -4,6 +4,7 @@ from packaging import version
try: try:
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
HAS_TRITON = True HAS_TRITON = True
except ImportError: except ImportError:
HAS_TRITON = False HAS_TRITON = False

View File

@ -0,0 +1,106 @@
import copy
import os
from itertools import product
import torch
from peft import LoraConfig
from torch import distributed as dist
from torch.optim import AdamW
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_checkpoint_io.utils import shared_tempdir
@clear_cache_before_run()
def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
model = model_fn()
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()]
test_configs = [
{
"lora_config": lora_config,
"quantize": False,
},
{
"lora_config": lora_config,
"quantize": True,
},
]
for plugin, test_config in product(test_plugins, test_configs):
# checkpoint loaded model
model_save = model_fn()
model_load = copy.deepcopy(model_save)
optimizer = AdamW(model.parameters(), lr=0.001)
criterion = loss_fn
booster = Booster(plugin=plugin)
model_save = booster.enable_lora(model_save, **test_config)
model_save, optimizer, criterion, _, _ = booster.boost(model_save, optimizer, criterion)
with shared_tempdir() as tempdir:
lora_ckpt_path = os.path.join(tempdir, "ckpt")
booster.save_lora_as_pretrained(model_save, lora_ckpt_path)
dist.barrier()
# The Lora checkpoint should be small in size
checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024)
assert checkpoint_size_mb < 1
model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path, **test_config)
model_load, _, _, _, _ = booster.boost(model_load)
check_state_dict_equal(model_save.state_dict(), model_load.state_dict())
# test fwd bwd correctness
test_model = model_load
model_copy = copy.deepcopy(model_load)
data = data_gen_fn()
data = {
k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
}
output = test_model(**data)
output = output_transform_fn(output)
loss = criterion(output)
booster.backward(loss, optimizer)
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
for (n1, p1), (n2, p2) in zip(test_model.named_parameters(), model_copy.named_parameters()):
if "lora_" in n1:
# lora modules require gradients, thus updated
assert p1.requires_grad
assert not torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3)
else:
if not p1.requires_grad:
torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3)
def run_lora_test():
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
task_type = None
if name == "transformers_llama_for_casual_lm":
task_type = "CAUSAL_LM"
if name == "transformers_llama_for_sequence_classification":
task_type = "SEQ_CLS"
check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_lora_test()
@rerun_if_address_is_in_use()
def test_torch_ddp_lora():
spawn(run_dist, 2)

View File

@ -1,108 +0,0 @@
import copy
import os
import torch
from peft import LoraConfig
from torch import distributed as dist
from torch.optim import AdamW
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.testing import (
assert_equal,
assert_not_equal,
check_state_dict_equal,
clear_cache_before_run,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_checkpoint_io.utils import shared_tempdir
@clear_cache_before_run()
def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
model = model_fn()
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = booster.enable_lora(model, lora_config=lora_config)
model_copy = copy.deepcopy(model)
optimizer = AdamW(model.parameters(), lr=0.001)
criterion = loss_fn
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
data = data_gen_fn()
data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()}
output = model(**data)
output = output_transform_fn(output)
loss = criterion(output)
booster.backward(loss, optimizer)
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
for (n1, p1), (n2, p2) in zip(model.named_parameters(), model_copy.named_parameters()):
if "lora_" in n1:
# lora modules require gradients, thus updated
assert p1.requires_grad
assert_not_equal(p1.to(p2.device), p2)
else:
if not p1.requires_grad:
assert_equal(p1.to(p2.device), p2)
@clear_cache_before_run()
def check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
plugin = TorchDDPPlugin()
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
model_save = model_fn()
model_load = copy.deepcopy(model_save)
booster = Booster(plugin=plugin)
model_save = booster.enable_lora(model_save, lora_config=lora_config)
model_save, _, _, _, _ = booster.boost(model_save)
with shared_tempdir() as tempdir:
lora_ckpt_path = os.path.join(tempdir, "ckpt")
booster.save_lora_as_pretrained(model_save, lora_ckpt_path)
dist.barrier()
# The Lora checkpoint should be small in size
checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024)
assert checkpoint_size_mb < 1
model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path)
model_load, _, _, _, _ = booster.boost(model_load)
check_state_dict_equal(model_save.state_dict(), model_load.state_dict())
def run_lora_test():
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
task_type = None
if name == "transformers_llama_for_casual_lm":
task_type = "CAUSAL_LM"
if name == "transformers_llama_for_sequence_classification":
task_type = "SEQ_CLS"
check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_lora_test()
@rerun_if_address_is_in_use()
def test_torch_ddp_lora():
spawn(run_dist, 2)