mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-29 08:47:17 +00:00
[Feature] qlora support (#5586)
* [feature] qlora support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * qlora follow commit * migrate qutization folder to colossalai/ * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
cabc1286ca
commit
52a2dded36
17
LICENSE
17
LICENSE
@ -527,3 +527,20 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
|
|||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
SOFTWARE.
|
SOFTWARE.
|
||||||
|
|
||||||
|
|
||||||
|
---------------- LICENSE FOR Hugging Face accelerate ----------------
|
||||||
|
|
||||||
|
Copyright 2021 The HuggingFace Team
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
@ -78,7 +78,9 @@ def main(args):
|
|||||||
elif args.strategy == "colossalai_gemini":
|
elif args.strategy == "colossalai_gemini":
|
||||||
strategy = GeminiStrategy(placement_policy="static", initial_scale=2**5)
|
strategy = GeminiStrategy(placement_policy="static", initial_scale=2**5)
|
||||||
elif args.strategy == "colossalai_gemini_cpu":
|
elif args.strategy == "colossalai_gemini_cpu":
|
||||||
strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
|
strategy = GeminiStrategy(
|
||||||
|
placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5
|
||||||
|
)
|
||||||
elif args.strategy == "colossalai_zero2":
|
elif args.strategy == "colossalai_zero2":
|
||||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||||
elif args.strategy == "colossalai_zero2_cpu":
|
elif args.strategy == "colossalai_zero2_cpu":
|
||||||
|
@ -30,4 +30,3 @@ class Actor(LoRAModule):
|
|||||||
"""Returns model output."""
|
"""Returns model output."""
|
||||||
output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
|
output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -75,7 +75,9 @@ def get_strategy_from_args(strategy: str):
|
|||||||
elif strategy == "colossalai_zero2":
|
elif strategy == "colossalai_zero2":
|
||||||
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||||
elif strategy == "colossalai_gemini_cpu":
|
elif strategy == "colossalai_gemini_cpu":
|
||||||
strategy_ = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
|
strategy_ = GeminiStrategy(
|
||||||
|
placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5
|
||||||
|
)
|
||||||
elif strategy == "colossalai_zero2_cpu":
|
elif strategy == "colossalai_zero2_cpu":
|
||||||
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
|
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
|
||||||
else:
|
else:
|
||||||
|
@ -101,16 +101,17 @@ class DDPStrategy(Strategy):
|
|||||||
|
|
||||||
model_path = os.path.join(path, "pytorch_model.bin")
|
model_path = os.path.join(path, "pytorch_model.bin")
|
||||||
self.save_model(model, model_path, shard=shard)
|
self.save_model(model, model_path, shard=shard)
|
||||||
|
|
||||||
def _replace_keys(model_path: str, replace_fn: Callable):
|
def _replace_keys(model_path: str, replace_fn: Callable):
|
||||||
state_dict = torch.load(model_path, map_location="cpu")
|
state_dict = torch.load(model_path, map_location="cpu")
|
||||||
state_dict = {replace_fn(k): v for k, v in state_dict.items()}
|
state_dict = {replace_fn(k): v for k, v in state_dict.items()}
|
||||||
torch.save(state_dict, model_path)
|
torch.save(state_dict, model_path)
|
||||||
|
|
||||||
# FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
|
# FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
|
||||||
# HACK: rename keys of pytorch_model.bin
|
# HACK: rename keys of pytorch_model.bin
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
_replace_keys(model_path, lambda k: k.replace("model.", "", 1))
|
_replace_keys(model_path, lambda k: k.replace("model.", "", 1))
|
||||||
|
|
||||||
|
|
||||||
def get_model_state_dict_shard(self, model: nn.Module, **config):
|
def get_model_state_dict_shard(self, model: nn.Module, **config):
|
||||||
# TODO: implement sharding on naive strategy
|
# TODO: implement sharding on naive strategy
|
||||||
model = self.unwrap_model(model)
|
model = self.unwrap_model(model)
|
||||||
|
@ -24,7 +24,9 @@ def main(args):
|
|||||||
if args.strategy == "ddp":
|
if args.strategy == "ddp":
|
||||||
strategy = DDPStrategy()
|
strategy = DDPStrategy()
|
||||||
elif args.strategy == "colossalai_gemini":
|
elif args.strategy == "colossalai_gemini":
|
||||||
strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
|
strategy = GeminiStrategy(
|
||||||
|
placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5
|
||||||
|
)
|
||||||
elif args.strategy == "colossalai_zero2":
|
elif args.strategy == "colossalai_zero2":
|
||||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
|
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
|
||||||
else:
|
else:
|
||||||
|
@ -1,20 +1,20 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable
|
from typing import Callable, Dict, Iterator, List, Optional, Sequence, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from datasets import dataset_dict, load_from_disk
|
import torch.nn.functional as F
|
||||||
from datasets import Dataset as HFDataset
|
from datasets import Dataset as HFDataset
|
||||||
|
from datasets import dataset_dict, load_from_disk
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.distributed.distributed_c10d import _get_default_group
|
from torch.distributed.distributed_c10d import _get_default_group
|
||||||
from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler
|
from torch.utils.data import ConcatDataset, DataLoader, Dataset, DistributedSampler
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
||||||
PathType = Union[str, os.PathLike]
|
PathType = Union[str, os.PathLike]
|
||||||
|
@ -7,9 +7,9 @@ Splicing multiple pre-tokenized sequence data points
|
|||||||
import random
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from datasets import dataset_dict
|
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Union, Tuple
|
|
||||||
|
|
||||||
|
from datasets import dataset_dict
|
||||||
from torch.utils.data import ConcatDataset, Dataset, IterableDataset
|
from torch.utils.data import ConcatDataset, Dataset, IterableDataset
|
||||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
@ -169,12 +169,7 @@ class ClosedToConstantLengthSplicedDataset(IterableDataset):
|
|||||||
spliced_labels.extend(seq_labels)
|
spliced_labels.extend(seq_labels)
|
||||||
# For residual spliced data point at the end of the data set
|
# For residual spliced data point at the end of the data set
|
||||||
if self.infinite is False and more_data_points is False and len(spliced_input_ids) > 0:
|
if self.infinite is False and more_data_points is False and len(spliced_input_ids) > 0:
|
||||||
examples.append(
|
examples.append({self.input_ids_field: spliced_input_ids, self.labels_field: spliced_labels})
|
||||||
{
|
|
||||||
self.input_ids_field: spliced_input_ids,
|
|
||||||
self.labels_field: spliced_labels
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if self.shuffle:
|
if self.shuffle:
|
||||||
random.shuffle(examples)
|
random.shuffle(examples)
|
||||||
for spliced_data_point in examples:
|
for spliced_data_point in examples:
|
||||||
|
@ -8,11 +8,10 @@ import argparse
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from transformers import LlamaTokenizer, LlamaForCausalLM
|
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||||
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
|
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,12 +6,12 @@ Initialize new tokenizer for continual pre-training
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
|
||||||
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
|
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
|
||||||
|
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||||
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
|
@ -10,8 +10,8 @@ import os
|
|||||||
from typing import Any, Dict, Tuple, Union
|
from typing import Any, Dict, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.optim.optimizer import Optimizer
|
|
||||||
from torch.optim.lr_scheduler import _LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
|
from torch.optim.optimizer import Optimizer
|
||||||
|
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
|
@ -11,14 +11,14 @@ import os
|
|||||||
import time
|
import time
|
||||||
from multiprocessing import cpu_count
|
from multiprocessing import cpu_count
|
||||||
|
|
||||||
|
from colossal_llama2.dataset.spliced_and_tokenized_dataset import (
|
||||||
|
ClosedToConstantLengthSplicedDataset,
|
||||||
|
supervised_tokenize,
|
||||||
|
)
|
||||||
from datasets import dataset_dict, load_dataset
|
from datasets import dataset_dict, load_dataset
|
||||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||||
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossal_llama2.dataset.spliced_and_tokenized_dataset import (
|
|
||||||
supervised_tokenize,
|
|
||||||
ClosedToConstantLengthSplicedDataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
|
||||||
@ -149,5 +149,5 @@ def main():
|
|||||||
spliced_dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(spliced_dataset), cpu_count()))
|
spliced_dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(spliced_dataset), cpu_count()))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -12,4 +12,3 @@ flash-attn>=2.0.0,<=2.0.5
|
|||||||
tqdm
|
tqdm
|
||||||
sentencepiece==0.1.99
|
sentencepiece==0.1.99
|
||||||
protobuf<=3.20.0
|
protobuf<=3.20.0
|
||||||
|
|
||||||
|
@ -4,42 +4,36 @@
|
|||||||
Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
|
Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import resource
|
import resource
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from colossal_llama2.dataset.loader import (
|
||||||
|
DataCollatorForSupervisedDataset,
|
||||||
|
StatefulDistributedSampler,
|
||||||
|
load_tokenized_dataset,
|
||||||
|
setup_distributed_dataloader,
|
||||||
|
)
|
||||||
|
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
|
||||||
|
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
|
||||||
|
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig
|
from tqdm import tqdm
|
||||||
|
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin import (
|
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||||
GeminiPlugin,
|
|
||||||
LowLevelZeroPlugin,
|
|
||||||
HybridParallelPlugin,
|
|
||||||
)
|
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.lazy import LazyInitContext
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
from colossal_llama2.dataset.loader import (
|
|
||||||
load_tokenized_dataset,
|
|
||||||
setup_distributed_dataloader,
|
|
||||||
DataCollatorForSupervisedDataset,
|
|
||||||
StatefulDistributedSampler,
|
|
||||||
)
|
|
||||||
|
|
||||||
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
|
|
||||||
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
|
|
||||||
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_numel(model: torch.nn.Module) -> int:
|
def get_model_numel(model: torch.nn.Module) -> int:
|
||||||
return sum(p.numel() for p in model.parameters())
|
return sum(p.numel() for p in model.parameters())
|
||||||
@ -372,9 +366,7 @@ def main() -> None:
|
|||||||
# Final save.
|
# Final save.
|
||||||
coordinator.print_on_master("Start saving final model checkpoint")
|
coordinator.print_on_master("Start saving final model checkpoint")
|
||||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||||
coordinator.print_on_master(
|
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
|
||||||
f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}"
|
|
||||||
)
|
|
||||||
|
|
||||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ except ImportError:
|
|||||||
import colossalai.interface.pretrained as pretrained_utils
|
import colossalai.interface.pretrained as pretrained_utils
|
||||||
from colossalai.checkpoint_io import GeneralCheckpointIO
|
from colossalai.checkpoint_io import GeneralCheckpointIO
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
|
from colossalai.quantization import BnbQuantizationConfig
|
||||||
|
|
||||||
from .accelerator import Accelerator
|
from .accelerator import Accelerator
|
||||||
from .mixed_precision import MixedPrecision, mixed_precision_factory
|
from .mixed_precision import MixedPrecision, mixed_precision_factory
|
||||||
@ -230,7 +231,12 @@ class Booster:
|
|||||||
return self.plugin.no_sync(model, optimizer)
|
return self.plugin.no_sync(model, optimizer)
|
||||||
|
|
||||||
def enable_lora(
|
def enable_lora(
|
||||||
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: "peft.LoraConfig" = None
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
pretrained_dir: Optional[str] = None,
|
||||||
|
lora_config: "peft.LoraConfig" = None,
|
||||||
|
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
|
||||||
|
quantize=False,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
"""
|
"""
|
||||||
Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory.
|
Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory.
|
||||||
@ -259,7 +265,20 @@ class Booster:
|
|||||||
assert (
|
assert (
|
||||||
pretrained_dir is not None
|
pretrained_dir is not None
|
||||||
), "Please provide pretrained directory path if not passing in lora configuration."
|
), "Please provide pretrained directory path if not passing in lora configuration."
|
||||||
return self.plugin.enable_lora(model, pretrained_dir, lora_config)
|
if quantize is True:
|
||||||
|
if bnb_quantization_config is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
bnb_quantization_config = BnbQuantizationConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||||
|
bnb_4bit_use_double_quant=True,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config)
|
||||||
|
|
||||||
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
|
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
|
||||||
"""Load model from checkpoint.
|
"""Load model from checkpoint.
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
import logging
|
|
||||||
import warnings
|
|
||||||
import enum
|
import enum
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Dict
|
from typing import Callable, Dict, Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -27,6 +27,7 @@ from colossalai.checkpoint_io.utils import (
|
|||||||
sharded_optimizer_loading_epilogue,
|
sharded_optimizer_loading_epilogue,
|
||||||
)
|
)
|
||||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||||
|
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.zero import LowLevelZeroOptimizer
|
from colossalai.zero import LowLevelZeroOptimizer
|
||||||
|
|
||||||
@ -44,6 +45,7 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
|
|||||||
|
|
||||||
SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"]
|
SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"]
|
||||||
|
|
||||||
|
|
||||||
class OptimizerParamCheckState(enum.Enum):
|
class OptimizerParamCheckState(enum.Enum):
|
||||||
ORIGIN_PARAM_FINDED = 0
|
ORIGIN_PARAM_FINDED = 0
|
||||||
ORIGIN_PARAM_NOT_FIND = -1
|
ORIGIN_PARAM_NOT_FIND = -1
|
||||||
@ -221,6 +223,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||||
return
|
return
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
|
|
||||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||||
peft_model = model.unwrap()
|
peft_model = model.unwrap()
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
@ -331,18 +334,25 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||||||
def supported_devices(self) -> List[str]:
|
def supported_devices(self) -> List[str]:
|
||||||
return ["cuda"]
|
return ["cuda"]
|
||||||
|
|
||||||
|
|
||||||
def support_lora(self) -> bool:
|
def support_lora(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def enable_lora(
|
def enable_lora(
|
||||||
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
pretrained_dir: Optional[str] = None,
|
||||||
|
lora_config: Optional[Dict] = None,
|
||||||
|
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
from peft import PeftModel, get_peft_model
|
from peft import PeftModel, get_peft_model
|
||||||
|
|
||||||
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
|
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
|
||||||
self.lora_enabled = True
|
self.lora_enabled = True
|
||||||
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
|
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
|
||||||
|
|
||||||
|
if bnb_quantization_config is not None:
|
||||||
|
model = quantize_model(model, bnb_quantization_config)
|
||||||
|
|
||||||
if pretrained_dir is None:
|
if pretrained_dir is None:
|
||||||
peft_model = get_peft_model(model, lora_config)
|
peft_model = get_peft_model(model, lora_config)
|
||||||
else:
|
else:
|
||||||
@ -352,7 +362,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||||||
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter):
|
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter):
|
||||||
origin_param_id = id(origin_param)
|
origin_param_id = id(origin_param)
|
||||||
for group_id, param_group in enumerate(optimizer.param_groups):
|
for group_id, param_group in enumerate(optimizer.param_groups):
|
||||||
for p in param_group['params']:
|
for p in param_group["params"]:
|
||||||
if id(p) == origin_param_id:
|
if id(p) == origin_param_id:
|
||||||
return group_id
|
return group_id
|
||||||
return -1
|
return -1
|
||||||
@ -362,7 +372,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||||||
lora_param_id = id(lora_param)
|
lora_param_id = id(lora_param)
|
||||||
target_group_id = None
|
target_group_id = None
|
||||||
for group_id, param_group in enumerate(optimizer.param_groups):
|
for group_id, param_group in enumerate(optimizer.param_groups):
|
||||||
for p in param_group['params']:
|
for p in param_group["params"]:
|
||||||
if id(p) == lora_param_id:
|
if id(p) == lora_param_id:
|
||||||
# check if the lora parameter exists.
|
# check if the lora parameter exists.
|
||||||
return target_group_id, OptimizerParamCheckState.LORA_PARM_EXISTED
|
return target_group_id, OptimizerParamCheckState.LORA_PARM_EXISTED
|
||||||
@ -380,16 +390,22 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||||||
name2param[name] = param
|
name2param[name] = param
|
||||||
|
|
||||||
for name, param in name2param.items():
|
for name, param in name2param.items():
|
||||||
if 'lora_A' in name or 'lora_B' in name:
|
if "lora_A" in name or "lora_B" in name:
|
||||||
origin_key = name.replace("lora_A.", "")
|
origin_key = name.replace("lora_A.", "")
|
||||||
origin_key = origin_key.replace("lora_B.", "")
|
origin_key = origin_key.replace("lora_B.", "")
|
||||||
origin_key = origin_key.replace(f"{model.active_adapter}", "base_layer")
|
origin_key = origin_key.replace(f"{model.active_adapter}", "base_layer")
|
||||||
origin_param = name2param[origin_key]
|
origin_param = name2param[origin_key]
|
||||||
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
|
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
|
||||||
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
|
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
|
||||||
warnings.warn("Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups.")
|
warnings.warn(
|
||||||
elif check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED and group_id is not None and group_id >= 0:
|
"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups."
|
||||||
optimizer.param_groups[group_id]['params'].append(param)
|
)
|
||||||
|
elif (
|
||||||
|
check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED
|
||||||
|
and group_id is not None
|
||||||
|
and group_id >= 0
|
||||||
|
):
|
||||||
|
optimizer.param_groups[group_id]["params"].append(param)
|
||||||
|
|
||||||
def configure(
|
def configure(
|
||||||
self,
|
self,
|
||||||
@ -401,11 +417,13 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||||
if self.lora_enabled:
|
if self.lora_enabled:
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
assert isinstance(model, PeftModel), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"
|
|
||||||
|
assert isinstance(
|
||||||
|
model, PeftModel
|
||||||
|
), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"
|
||||||
if optimizer is not None:
|
if optimizer is not None:
|
||||||
self.add_lora_params_to_optimizer(model, optimizer)
|
self.add_lora_params_to_optimizer(model, optimizer)
|
||||||
|
|
||||||
|
|
||||||
if not isinstance(model, ModelWrapper):
|
if not isinstance(model, ModelWrapper):
|
||||||
model = LowLevelZeroModel(model, self.precision)
|
model = LowLevelZeroModel(model, self.precision)
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
|
|||||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
|
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||||
|
|
||||||
from .dp_plugin_base import DPPluginBase
|
from .dp_plugin_base import DPPluginBase
|
||||||
|
|
||||||
@ -237,10 +238,17 @@ class TorchDDPPlugin(DPPluginBase):
|
|||||||
return model.module.no_sync()
|
return model.module.no_sync()
|
||||||
|
|
||||||
def enable_lora(
|
def enable_lora(
|
||||||
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
pretrained_dir: Optional[str] = None,
|
||||||
|
lora_config: Optional[Dict] = None,
|
||||||
|
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
from peft import PeftModel, get_peft_model
|
from peft import PeftModel, get_peft_model
|
||||||
|
|
||||||
|
if bnb_quantization_config is not None:
|
||||||
|
model = quantize_model(model, bnb_quantization_config)
|
||||||
|
|
||||||
assert not isinstance(model, TorchDDPModel), "Lora should be enabled before boosting the model."
|
assert not isinstance(model, TorchDDPModel), "Lora should be enabled before boosting the model."
|
||||||
if pretrained_dir is None:
|
if pretrained_dir is None:
|
||||||
return get_peft_model(model, lora_config)
|
return get_peft_model(model, lora_config)
|
||||||
|
@ -18,15 +18,15 @@ from .gptq_op import CaiGPTQLinearOp
|
|||||||
HAS_GPTQ_CUDA = False
|
HAS_GPTQ_CUDA = False
|
||||||
try:
|
try:
|
||||||
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
||||||
|
|
||||||
gptq_cuda = GPTQBuilder().load()
|
gptq_cuda = GPTQBuilder().load()
|
||||||
HAS_GPTQ_CUDA = True
|
HAS_GPTQ_CUDA = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
warnings.warn('CUDA gptq is not installed')
|
warnings.warn("CUDA gptq is not installed")
|
||||||
HAS_GPTQ_CUDA = False
|
HAS_GPTQ_CUDA = False
|
||||||
|
|
||||||
|
|
||||||
class CaiQuantLinear(nn.Module):
|
class CaiQuantLinear(nn.Module):
|
||||||
|
|
||||||
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if bits not in [2, 4, 8]:
|
if bits not in [2, 4, 8]:
|
||||||
@ -37,23 +37,28 @@ class CaiQuantLinear(nn.Module):
|
|||||||
self.maxq = 2**self.bits - 1
|
self.maxq = 2**self.bits - 1
|
||||||
self.groupsize = groupsize if groupsize != -1 else infeatures
|
self.groupsize = groupsize if groupsize != -1 else infeatures
|
||||||
|
|
||||||
self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
|
self.register_buffer("qweight", torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
'qzeros',
|
"qzeros",
|
||||||
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32))
|
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32),
|
||||||
self.register_buffer('scales',
|
)
|
||||||
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
|
self.register_buffer(
|
||||||
|
"scales", torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)
|
||||||
|
)
|
||||||
if row_split:
|
if row_split:
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
'g_idx',
|
"g_idx",
|
||||||
torch.tensor([(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)],
|
torch.tensor(
|
||||||
dtype=torch.int32))
|
[(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], dtype=torch.int32
|
||||||
|
),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.register_buffer('g_idx',
|
self.register_buffer(
|
||||||
torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))
|
"g_idx", torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)
|
||||||
|
)
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16))
|
||||||
else:
|
else:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
@ -66,9 +71,11 @@ class CaiQuantLinear(nn.Module):
|
|||||||
self.row_split = row_split
|
self.row_split = row_split
|
||||||
|
|
||||||
def pack(self, linear, scales, zeros, g_idx=None):
|
def pack(self, linear, scales, zeros, g_idx=None):
|
||||||
|
g_idx = (
|
||||||
g_idx = g_idx.clone() if g_idx is not None else torch.tensor(
|
g_idx.clone()
|
||||||
[i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)
|
if g_idx is not None
|
||||||
|
else torch.tensor([i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)
|
||||||
|
)
|
||||||
|
|
||||||
scales = scales.t().contiguous()
|
scales = scales.t().contiguous()
|
||||||
zeros = zeros.t().contiguous()
|
zeros = zeros.t().contiguous()
|
||||||
@ -79,7 +86,6 @@ class CaiQuantLinear(nn.Module):
|
|||||||
if linear.bias is not None:
|
if linear.bias is not None:
|
||||||
self.bias = linear.bias.clone().half()
|
self.bias = linear.bias.clone().half()
|
||||||
|
|
||||||
wn = 8
|
|
||||||
pbits = 32
|
pbits = 32
|
||||||
ptype = torch.int32
|
ptype = torch.int32
|
||||||
unsign_type = np.uint32
|
unsign_type = np.uint32
|
||||||
@ -88,9 +94,10 @@ class CaiQuantLinear(nn.Module):
|
|||||||
intweight = []
|
intweight = []
|
||||||
for idx in range(self.infeatures):
|
for idx in range(self.infeatures):
|
||||||
intweight.append(
|
intweight.append(
|
||||||
torch.round(
|
torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[
|
||||||
(linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:,
|
:, None
|
||||||
None])
|
]
|
||||||
|
)
|
||||||
intweight = torch.cat(intweight, dim=1)
|
intweight = torch.cat(intweight, dim=1)
|
||||||
intweight = intweight.t().contiguous()
|
intweight = intweight.t().contiguous()
|
||||||
intweight = intweight.numpy().astype(unsign_type)
|
intweight = intweight.numpy().astype(unsign_type)
|
||||||
@ -144,13 +151,16 @@ class CaiQuantLinear(nn.Module):
|
|||||||
torch.tensor(
|
torch.tensor(
|
||||||
[(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)],
|
[(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)],
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.g_idx.device)):
|
device=self.g_idx.device,
|
||||||
|
),
|
||||||
|
):
|
||||||
self.g_idx = None
|
self.g_idx = None
|
||||||
elif torch.equal(
|
elif torch.equal(
|
||||||
self.g_idx,
|
self.g_idx,
|
||||||
torch.tensor([i // self.groupsize for i in range(self.infeatures)],
|
torch.tensor(
|
||||||
dtype=torch.int32,
|
[i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device
|
||||||
device=self.g_idx.device)):
|
),
|
||||||
|
):
|
||||||
self.g_idx = None
|
self.g_idx = None
|
||||||
|
|
||||||
if self.g_idx is not None:
|
if self.g_idx is not None:
|
||||||
@ -165,7 +175,6 @@ class CaiQuantLinear(nn.Module):
|
|||||||
outshape = x.shape[:-1] + (self.outfeatures,)
|
outshape = x.shape[:-1] + (self.outfeatures,)
|
||||||
|
|
||||||
if HAS_GPTQ_CUDA and self.bits == 4:
|
if HAS_GPTQ_CUDA and self.bits == 4:
|
||||||
|
|
||||||
if self.q4 is None:
|
if self.q4 is None:
|
||||||
self.init_q4()
|
self.init_q4()
|
||||||
|
|
||||||
@ -191,7 +200,6 @@ class CaiQuantLinear(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1):
|
def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1):
|
||||||
|
|
||||||
qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1)
|
qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1)
|
||||||
qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1)
|
qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1)
|
||||||
scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1)
|
scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1)
|
||||||
@ -203,24 +211,24 @@ def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1
|
|||||||
zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num
|
zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num
|
||||||
|
|
||||||
for i in range(split_num):
|
for i in range(split_num):
|
||||||
cai_linear.qweight[:, i * cai_split_out_features:(i + 1) *
|
cai_linear.qweight[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = qweights[i][
|
||||||
cai_split_out_features] = qweights[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) *
|
:, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
|
||||||
cai_split_out_features]
|
]
|
||||||
cai_linear.qzeros[:, i * zero_split_block:(i + 1) *
|
cai_linear.qzeros[:, i * zero_split_block : (i + 1) * zero_split_block] = qzeros[i][
|
||||||
zero_split_block] = qzeros[i][:, tp_rank * zero_split_block:(tp_rank + 1) * zero_split_block]
|
:, tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block
|
||||||
cai_linear.scales[:, i * cai_split_out_features:(i + 1) *
|
]
|
||||||
cai_split_out_features] = scales[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) *
|
cai_linear.scales[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = scales[i][
|
||||||
cai_split_out_features]
|
:, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
|
||||||
|
]
|
||||||
if cai_linear.bias is not None:
|
if cai_linear.bias is not None:
|
||||||
cai_linear.bias[i * cai_split_out_features:(i + 1) *
|
cai_linear.bias[i * cai_split_out_features : (i + 1) * cai_split_out_features] = bias[i][
|
||||||
cai_split_out_features] = bias[i][tp_rank * cai_split_out_features:(tp_rank + 1) *
|
tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
|
||||||
cai_split_out_features]
|
]
|
||||||
|
|
||||||
cai_linear.g_idx.copy_(g_idx)
|
cai_linear.g_idx.copy_(g_idx)
|
||||||
|
|
||||||
|
|
||||||
def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
|
def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
|
||||||
|
|
||||||
qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0)
|
qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0)
|
||||||
qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0)
|
qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0)
|
||||||
scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0)
|
scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0)
|
||||||
@ -231,47 +239,40 @@ def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
|
|||||||
idx_split_features = cai_linear.infeatures // split_num
|
idx_split_features = cai_linear.infeatures // split_num
|
||||||
|
|
||||||
for i in range(split_num):
|
for i in range(split_num):
|
||||||
cai_linear.qweight[i * cai_split_in_features:(i + 1) *
|
cai_linear.qweight[i * cai_split_in_features : (i + 1) * cai_split_in_features, :] = qweights[i][
|
||||||
cai_split_in_features, :] = qweights[i][tp_rank * cai_split_in_features:(tp_rank + 1) *
|
tp_rank * cai_split_in_features : (tp_rank + 1) * cai_split_in_features, :
|
||||||
cai_split_in_features, :]
|
]
|
||||||
cai_linear.qzeros[i * zero_split_block:(i + 1) *
|
cai_linear.qzeros[i * zero_split_block : (i + 1) * zero_split_block, :] = qzeros[i][
|
||||||
zero_split_block, :] = qzeros[i][tp_rank * zero_split_block:(tp_rank + 1) *
|
tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, :
|
||||||
zero_split_block, :]
|
]
|
||||||
cai_linear.scales[i * zero_split_block:(i + 1) *
|
cai_linear.scales[i * zero_split_block : (i + 1) * zero_split_block, :] = scales[i][
|
||||||
zero_split_block, :] = scales[i][tp_rank * zero_split_block:(tp_rank + 1) *
|
tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, :
|
||||||
zero_split_block, :]
|
]
|
||||||
cai_linear.g_idx[i * idx_split_features:(i + 1) *
|
cai_linear.g_idx[i * idx_split_features : (i + 1) * idx_split_features] = g_idxs[i][
|
||||||
idx_split_features] = g_idxs[i][tp_rank * idx_split_features:(tp_rank + 1) *
|
tp_rank * idx_split_features : (tp_rank + 1) * idx_split_features
|
||||||
idx_split_features]
|
]
|
||||||
if cai_linear.bias is not None:
|
if cai_linear.bias is not None:
|
||||||
cai_linear.bias.copy_(gptq_linear.bias)
|
cai_linear.bias.copy_(gptq_linear.bias)
|
||||||
|
|
||||||
|
|
||||||
class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
||||||
|
|
||||||
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
||||||
|
super().__init__(
|
||||||
super().__init__(bits,
|
bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split
|
||||||
groupsize,
|
)
|
||||||
infeatures,
|
|
||||||
outfeatures,
|
|
||||||
bias,
|
|
||||||
tp_size=tp_size,
|
|
||||||
tp_rank=tp_rank,
|
|
||||||
row_split=row_split)
|
|
||||||
self.process_group = None
|
self.process_group = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
def from_native_module(
|
||||||
**kwargs) -> ParallelModule:
|
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||||
|
) -> ParallelModule:
|
||||||
LazyInitContext.materialize(module)
|
LazyInitContext.materialize(module)
|
||||||
# get the attributes
|
# get the attributes
|
||||||
in_features = module.in_features
|
in_features = module.in_features
|
||||||
|
|
||||||
# ensure only one process group is passed
|
# ensure only one process group is passed
|
||||||
if isinstance(process_group, (list, tuple)):
|
if isinstance(process_group, (list, tuple)):
|
||||||
assert len(process_group) == 1, \
|
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||||
f'Expected only one process group, got {len(process_group)}.'
|
|
||||||
process_group = process_group[0]
|
process_group = process_group[0]
|
||||||
|
|
||||||
tp_size = dist.get_world_size(process_group)
|
tp_size = dist.get_world_size(process_group)
|
||||||
@ -282,15 +283,18 @@ class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
|||||||
|
|
||||||
if in_features % tp_size != 0:
|
if in_features % tp_size != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
|
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||||
linear_1d = RowCaiQuantLinear(module.bits,
|
)
|
||||||
|
linear_1d = RowCaiQuantLinear(
|
||||||
|
module.bits,
|
||||||
module.group_size,
|
module.group_size,
|
||||||
module.in_features // tp_size,
|
module.in_features // tp_size,
|
||||||
module.out_features,
|
module.out_features,
|
||||||
module.bias is not None,
|
module.bias is not None,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
row_split=True)
|
row_split=True,
|
||||||
|
)
|
||||||
linear_1d.process_group = process_group
|
linear_1d.process_group = process_group
|
||||||
|
|
||||||
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||||
@ -306,30 +310,23 @@ class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
|||||||
|
|
||||||
|
|
||||||
class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
||||||
|
|
||||||
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
||||||
|
super().__init__(
|
||||||
super().__init__(bits,
|
bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split
|
||||||
groupsize,
|
)
|
||||||
infeatures,
|
|
||||||
outfeatures,
|
|
||||||
bias,
|
|
||||||
tp_size=tp_size,
|
|
||||||
tp_rank=tp_rank,
|
|
||||||
row_split=row_split)
|
|
||||||
self.process_group = None
|
self.process_group = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
def from_native_module(
|
||||||
**kwargs) -> ParallelModule:
|
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||||
|
) -> ParallelModule:
|
||||||
LazyInitContext.materialize(module)
|
LazyInitContext.materialize(module)
|
||||||
# get the attributes
|
# get the attributes
|
||||||
in_features = module.in_features
|
in_features = module.in_features
|
||||||
|
|
||||||
# ensure only one process group is passed
|
# ensure only one process group is passed
|
||||||
if isinstance(process_group, (list, tuple)):
|
if isinstance(process_group, (list, tuple)):
|
||||||
assert len(process_group) == 1, \
|
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||||
f'Expected only one process group, got {len(process_group)}.'
|
|
||||||
process_group = process_group[0]
|
process_group = process_group[0]
|
||||||
|
|
||||||
tp_size = dist.get_world_size(process_group)
|
tp_size = dist.get_world_size(process_group)
|
||||||
@ -340,14 +337,17 @@ class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
|||||||
|
|
||||||
if in_features % tp_size != 0:
|
if in_features % tp_size != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
|
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||||
linear_1d = ColCaiQuantLinear(module.bits,
|
)
|
||||||
|
linear_1d = ColCaiQuantLinear(
|
||||||
|
module.bits,
|
||||||
module.group_size,
|
module.group_size,
|
||||||
module.in_features,
|
module.in_features,
|
||||||
module.out_features // tp_size,
|
module.out_features // tp_size,
|
||||||
module.bias is not None,
|
module.bias is not None,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
tp_rank=tp_rank)
|
tp_rank=tp_rank,
|
||||||
|
)
|
||||||
linear_1d.process_group = process_group
|
linear_1d.process_group = process_group
|
||||||
|
|
||||||
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||||
|
@ -5,6 +5,7 @@ import torch
|
|||||||
|
|
||||||
from .kvcache_manager import MemoryManager
|
from .kvcache_manager import MemoryManager
|
||||||
|
|
||||||
|
|
||||||
# adapted from: lightllm/server/router/model_infer/infer_batch.py
|
# adapted from: lightllm/server/router/model_infer/infer_batch.py
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchInferState:
|
class BatchInferState:
|
||||||
|
@ -19,8 +19,11 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
|
|||||||
from ._utils import copy_kv_to_mem_cache
|
from ._utils import copy_kv_to_mem_cache
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_llama2_context_attention_fwd
|
|
||||||
from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd as chatglm2_rotary_emb_fwd
|
from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd as chatglm2_rotary_emb_fwd
|
||||||
|
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
|
||||||
|
context_attention_fwd as lightllm_llama2_context_attention_fwd,
|
||||||
|
)
|
||||||
|
|
||||||
HAS_LIGHTLLM_KERNEL = True
|
HAS_LIGHTLLM_KERNEL = True
|
||||||
except:
|
except:
|
||||||
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
|
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
|
||||||
|
@ -4,7 +4,6 @@ import torch
|
|||||||
from torch.nn import LayerNorm
|
from torch.nn import LayerNorm
|
||||||
|
|
||||||
import colossalai.shardformer.layer as col_nn
|
import colossalai.shardformer.layer as col_nn
|
||||||
from colossalai.shardformer.modeling.bloom import build_bloom_alibi_tensor_fn
|
|
||||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||||
from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy
|
from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy
|
||||||
|
|
||||||
@ -40,33 +39,36 @@ class BloomModelInferPolicy(BloomForCausalLMPolicy):
|
|||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
if self.shard_config.inference_gptq:
|
if self.shard_config.inference_gptq:
|
||||||
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
|
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
|
||||||
policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
|
|
||||||
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
policy[BloomBlock] = ModulePolicyDescription(
|
||||||
"self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
attribute_replacement={
|
||||||
|
"self_attention.hidden_size": self.model.config.hidden_size
|
||||||
|
// self.shard_config.tensor_parallel_size,
|
||||||
|
"self_attention.split_size": self.model.config.hidden_size
|
||||||
|
// self.shard_config.tensor_parallel_size,
|
||||||
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
|
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
|
||||||
},
|
},
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attention.query_key_value",
|
suffix="self_attention.query_key_value",
|
||||||
target_module=ColCaiQuantLinear,
|
target_module=ColCaiQuantLinear,
|
||||||
kwargs={'split_num': 3}),
|
kwargs={"split_num": 3},
|
||||||
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attention.dense",
|
suffix="self_attention.dense", target_module=RowCaiQuantLinear, kwargs={"split_num": 1}
|
||||||
target_module=RowCaiQuantLinear,
|
),
|
||||||
kwargs={'split_num': 1}),
|
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attention.attention_dropout",
|
suffix="self_attention.attention_dropout",
|
||||||
target_module=col_nn.DropoutForParallelInput,
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.dense_h_to_4h",
|
suffix="mlp.dense_h_to_4h", target_module=ColCaiQuantLinear, kwargs={"split_num": 1}
|
||||||
target_module=ColCaiQuantLinear,
|
),
|
||||||
kwargs={'split_num': 1}),
|
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.dense_4h_to_h",
|
suffix="mlp.dense_4h_to_h", target_module=RowCaiQuantLinear, kwargs={"split_num": 1}
|
||||||
target_module=RowCaiQuantLinear,
|
),
|
||||||
kwargs={'split_num': 1}),
|
],
|
||||||
])
|
)
|
||||||
# NOTE set inference mode to shard config
|
# NOTE set inference mode to shard config
|
||||||
self.shard_config._infer()
|
self.shard_config._infer()
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forw
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward
|
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward
|
||||||
|
|
||||||
HAS_TRITON_RMSNORM = True
|
HAS_TRITON_RMSNORM = True
|
||||||
except:
|
except:
|
||||||
print("you should install triton from https://github.com/openai/triton")
|
print("you should install triton from https://github.com/openai/triton")
|
||||||
@ -21,6 +22,7 @@ except:
|
|||||||
|
|
||||||
def get_triton_rmsnorm_forward():
|
def get_triton_rmsnorm_forward():
|
||||||
if HAS_TRITON_RMSNORM:
|
if HAS_TRITON_RMSNORM:
|
||||||
|
|
||||||
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
|
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
|
||||||
return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
|
return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
|
||||||
|
|
||||||
|
@ -1,27 +1,29 @@
|
|||||||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||||
|
|
||||||
#include <torch/extension.h>
|
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <cuda_runtime.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include "util.cuh"
|
|
||||||
#include "tuning.h"
|
|
||||||
#include "cuda_buffers.cuh"
|
|
||||||
#include "q4_matrix.cuh"
|
|
||||||
#include "q4_matmul.cuh"
|
|
||||||
#include "column_remap.cuh"
|
#include "column_remap.cuh"
|
||||||
|
#include "cuda_buffers.cuh"
|
||||||
|
#include "q4_matmul.cuh"
|
||||||
|
#include "q4_matrix.cuh"
|
||||||
|
#include "tuning.h"
|
||||||
|
#include "util.cuh"
|
||||||
|
|
||||||
// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a
|
// Check CUDA return code. We don't want to include Torch headers in the .cu
|
||||||
// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of
|
// files because parsing them adds almost a minute to the compile time on a
|
||||||
// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console.
|
// 12900K. Also passing exceptions back to Python is super tricky, so in place
|
||||||
|
// of exceptions, CUDA functions return with a cudaError_t which we can parse
|
||||||
|
// and dump to the console.
|
||||||
|
|
||||||
void check_cuda(cudaError_t ret)
|
void check_cuda(cudaError_t ret) {
|
||||||
{
|
switch (ret) {
|
||||||
switch (ret)
|
|
||||||
{
|
|
||||||
case cudaSuccess:
|
case cudaSuccess:
|
||||||
break;
|
break;
|
||||||
|
|
||||||
@ -31,9 +33,9 @@ void check_cuda(cudaError_t ret)
|
|||||||
break;
|
break;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
printf(" **** CUDA error\n"); \
|
printf(" **** CUDA error\n");
|
||||||
printf(" **** %s\n", cudaGetErrorString(ret)); \
|
printf(" **** %s\n", cudaGetErrorString(ret));
|
||||||
TORCH_CHECK(false, "CUDA error"); \
|
TORCH_CHECK(false, "CUDA error");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -42,12 +44,25 @@ void check_cuda(cudaError_t ret)
|
|||||||
|
|
||||||
#define STRINGIFY_(__x) #__x
|
#define STRINGIFY_(__x) #__x
|
||||||
#define STRINGIFY(__x) STRINGIFY_(__x)
|
#define STRINGIFY(__x) STRINGIFY_(__x)
|
||||||
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
#define TORCH_CHECK_DTYPE(__x, __dtype) \
|
||||||
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
TORCH_CHECK((__x).dtype() == torch::__dtype, \
|
||||||
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
#__x " is incorrect datatype, must be " #__dtype)
|
||||||
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) \
|
||||||
#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod))
|
TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, \
|
||||||
#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
|
#__x " is incorrect datatype, must be " #__dtype)
|
||||||
|
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) \
|
||||||
|
TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, \
|
||||||
|
#__x " and " #__y " have incompatible shapes")
|
||||||
|
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) \
|
||||||
|
TORCH_CHECK((__x).device().is_meta() || \
|
||||||
|
(__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, \
|
||||||
|
#__x " and " #__y " have incompatible shapes")
|
||||||
|
#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) \
|
||||||
|
TORCH_CHECK((__x).size(__dim_x) % __mod == 0, \
|
||||||
|
#__x ".shape[" STRINGIFY( \
|
||||||
|
__dim_x) "] must be a multiple of " STRINGIFY(__mod))
|
||||||
|
#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) \
|
||||||
|
TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
|
||||||
|
|
||||||
#define TORCH_CHECK_DEVICE_INDEX(__index) \
|
#define TORCH_CHECK_DEVICE_INDEX(__index) \
|
||||||
do { \
|
do { \
|
||||||
@ -66,75 +81,49 @@ do { \
|
|||||||
TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
|
TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
int get_groupsize(torch::Tensor w, torch::Tensor w_zeros)
|
int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) {
|
||||||
{
|
|
||||||
int groupsize = w.size(0) * 8 / w_zeros.size(0);
|
int groupsize = w.size(0) * 8 / w_zeros.size(0);
|
||||||
TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]")
|
TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8,
|
||||||
|
"w.shape[-2] must be a multiple of zeros.shape[-2]")
|
||||||
return groupsize;
|
return groupsize;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Tuning parameters
|
// Tuning parameters
|
||||||
|
|
||||||
ExLlamaTuning tuningParams;
|
ExLlamaTuning tuningParams;
|
||||||
|
|
||||||
void set_tuning_params
|
void set_tuning_params(int matmul_recons_thd, bool matmul_fused_remap,
|
||||||
(
|
bool matmul_no_half2) {
|
||||||
int matmul_recons_thd,
|
|
||||||
bool matmul_fused_remap,
|
|
||||||
bool matmul_no_half2
|
|
||||||
)
|
|
||||||
{
|
|
||||||
tuningParams.matmul_recons_thd = matmul_recons_thd;
|
tuningParams.matmul_recons_thd = matmul_recons_thd;
|
||||||
tuningParams.matmul_fused_remap = matmul_fused_remap;
|
tuningParams.matmul_fused_remap = matmul_fused_remap;
|
||||||
tuningParams.matmul_no_half2 = matmul_no_half2;
|
tuningParams.matmul_no_half2 = matmul_no_half2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Release all unmanaged objects allocated by the extension
|
// Release all unmanaged objects allocated by the extension
|
||||||
|
|
||||||
void cleanup()
|
void cleanup() {
|
||||||
{
|
|
||||||
cleanup_buffers_cuda();
|
cleanup_buffers_cuda();
|
||||||
g_q4_free_matrices();
|
g_q4_free_matrices();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Prepare buffers for forward pass
|
// Prepare buffers for forward pass
|
||||||
|
|
||||||
void prepare_buffers
|
void prepare_buffers(torch::Device device, torch::Tensor temp_state,
|
||||||
(
|
torch::Tensor temp_dq) {
|
||||||
torch::Device device,
|
|
||||||
torch::Tensor temp_state,
|
|
||||||
torch::Tensor temp_dq
|
|
||||||
)
|
|
||||||
{
|
|
||||||
int device_index = device.index();
|
int device_index = device.index();
|
||||||
TORCH_CHECK_DEVICE_INDEX(device_index);
|
TORCH_CHECK_DEVICE_INDEX(device_index);
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device);
|
const at::cuda::OptionalCUDAGuard device_guard(device);
|
||||||
|
|
||||||
prepare_buffers_cuda
|
prepare_buffers_cuda(device_index,
|
||||||
(
|
|
||||||
device_index,
|
|
||||||
// buffer size used for sanity checks
|
// buffer size used for sanity checks
|
||||||
temp_state.numel(),
|
temp_state.numel(), (half*)temp_state.data_ptr(),
|
||||||
(half*) temp_state.data_ptr(),
|
(half*)temp_dq.data_ptr());
|
||||||
(half*) temp_dq.data_ptr()
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Create Q4Matrix, return handle
|
// Create Q4Matrix, return handle
|
||||||
|
|
||||||
uintptr_t make_q4
|
uintptr_t make_q4(torch::Tensor qweight, torch::Tensor qzeros,
|
||||||
(
|
torch::Tensor scales, torch::Tensor g_idx, int device) {
|
||||||
torch::Tensor qweight,
|
|
||||||
torch::Tensor qzeros,
|
|
||||||
torch::Tensor scales,
|
|
||||||
torch::Tensor g_idx,
|
|
||||||
int device
|
|
||||||
)
|
|
||||||
{
|
|
||||||
TORCH_CHECK_DTYPE(qweight, kInt);
|
TORCH_CHECK_DTYPE(qweight, kInt);
|
||||||
TORCH_CHECK_DTYPE(qzeros, kInt);
|
TORCH_CHECK_DTYPE(qzeros, kInt);
|
||||||
TORCH_CHECK_DTYPE(scales, kHalf);
|
TORCH_CHECK_DTYPE(scales, kHalf);
|
||||||
@ -147,34 +136,22 @@ uintptr_t make_q4
|
|||||||
int height = qweight.size(0) * 8;
|
int height = qweight.size(0) * 8;
|
||||||
int groups = qzeros.size(0);
|
int groups = qzeros.size(0);
|
||||||
|
|
||||||
Q4Matrix* m = new Q4Matrix
|
Q4Matrix* m = new Q4Matrix(
|
||||||
(
|
height, width, groups,
|
||||||
height,
|
|
||||||
width,
|
|
||||||
groups,
|
|
||||||
|
|
||||||
(uint32_t*) qweight.data_ptr(),
|
(uint32_t*)qweight.data_ptr(), (uint32_t*)qzeros.data_ptr(),
|
||||||
(uint32_t*) qzeros.data_ptr(),
|
|
||||||
(half*)scales.data_ptr(),
|
(half*)scales.data_ptr(),
|
||||||
g_idx.device().is_meta() ? NULL : (uint32_t*)g_idx.data_ptr(),
|
g_idx.device().is_meta() ? NULL : (uint32_t*)g_idx.data_ptr(),
|
||||||
|
|
||||||
device
|
device);
|
||||||
);
|
|
||||||
|
|
||||||
g_q4_keep_matrix(m);
|
g_q4_keep_matrix(m);
|
||||||
return reinterpret_cast<uintptr_t>(m);
|
return reinterpret_cast<uintptr_t>(m);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Matmul half @ quant -> half
|
// Matmul half @ quant -> half
|
||||||
|
|
||||||
void q4_matmul
|
void q4_matmul(torch::Tensor x, uintptr_t w, torch::Tensor out) {
|
||||||
(
|
|
||||||
torch::Tensor x,
|
|
||||||
uintptr_t w,
|
|
||||||
torch::Tensor out
|
|
||||||
)
|
|
||||||
{
|
|
||||||
Q4Matrix* wm = reinterpret_cast<Q4Matrix*>(w);
|
Q4Matrix* wm = reinterpret_cast<Q4Matrix*>(w);
|
||||||
|
|
||||||
TORCH_CHECK_DTYPE(x, kHalf);
|
TORCH_CHECK_DTYPE(x, kHalf);
|
||||||
@ -186,41 +163,20 @@ void q4_matmul
|
|||||||
|
|
||||||
int x_height = x.size(0);
|
int x_height = x.size(0);
|
||||||
|
|
||||||
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
|
if (tuningParams.matmul_recons_thd == 0 ||
|
||||||
{
|
x_height < tuningParams.matmul_recons_thd) {
|
||||||
q4_matmul_cuda
|
q4_matmul_cuda(&tuningParams, (half*)x.data_ptr(), x_height, wm,
|
||||||
(
|
(half*)out.data_ptr());
|
||||||
&tuningParams,
|
} else {
|
||||||
(half*) x.data_ptr(),
|
q4_matmul_recons_cuda(&tuningParams, (half*)x.data_ptr(), x_height, wm,
|
||||||
x_height,
|
|
||||||
wm,
|
|
||||||
(half*) out.data_ptr()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
q4_matmul_recons_cuda
|
|
||||||
(
|
|
||||||
&tuningParams,
|
|
||||||
(half*) x.data_ptr(),
|
|
||||||
x_height,
|
|
||||||
wm,
|
|
||||||
(half*)out.data_ptr(),
|
(half*)out.data_ptr(),
|
||||||
at::cuda::getCurrentCUDABlasHandle()
|
at::cuda::getCurrentCUDABlasHandle());
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Remap columns in half tensor
|
// Remap columns in half tensor
|
||||||
|
|
||||||
void column_remap
|
void column_remap(torch::Tensor x, torch::Tensor x_new, torch::Tensor x_map) {
|
||||||
(
|
|
||||||
torch::Tensor x,
|
|
||||||
torch::Tensor x_new,
|
|
||||||
torch::Tensor x_map
|
|
||||||
)
|
|
||||||
{
|
|
||||||
TORCH_CHECK_DTYPE(x, kHalf);
|
TORCH_CHECK_DTYPE(x, kHalf);
|
||||||
TORCH_CHECK_DTYPE(x_new, kHalf);
|
TORCH_CHECK_DTYPE(x_new, kHalf);
|
||||||
TORCH_CHECK_DTYPE(x_map, kInt);
|
TORCH_CHECK_DTYPE(x_map, kInt);
|
||||||
@ -233,19 +189,11 @@ void column_remap
|
|||||||
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||||
|
|
||||||
column_remap_cuda
|
column_remap_cuda((half*)x.data_ptr(), (half*)x_new.data_ptr(), height, width,
|
||||||
(
|
(uint32_t*)x_map.data_ptr());
|
||||||
(half*) x.data_ptr(),
|
|
||||||
(half*) x_new.data_ptr(),
|
|
||||||
height,
|
|
||||||
width,
|
|
||||||
(uint32_t*) x_map.data_ptr()
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
|
||||||
{
|
|
||||||
m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
|
m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
|
||||||
m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
|
m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
|
||||||
m.def("cleanup", &cleanup, "cleanup");
|
m.def("cleanup", &cleanup, "cleanup");
|
||||||
|
@ -10,7 +10,6 @@ except ImportError:
|
|||||||
print("please install triton from https://github.com/openai/triton")
|
print("please install triton from https://github.com/openai/triton")
|
||||||
|
|
||||||
if HAS_TRITON:
|
if HAS_TRITON:
|
||||||
|
|
||||||
# adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py
|
# adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _fwd_copy_kv_cache_dest(
|
def _fwd_copy_kv_cache_dest(
|
||||||
|
@ -13,6 +13,9 @@ except ImportError:
|
|||||||
print("please install triton from https://github.com/openai/triton")
|
print("please install triton from https://github.com/openai/triton")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import (
|
||||||
|
token_att_fwd as lightllm_bloom_token_att_fwd,
|
||||||
|
)
|
||||||
from lightllm.models.llama2.triton_kernel.token_attention_nopad_att1 import (
|
from lightllm.models.llama2.triton_kernel.token_attention_nopad_att1 import (
|
||||||
token_att_fwd as lightllm_llama2_token_att_fwd,
|
token_att_fwd as lightllm_llama2_token_att_fwd,
|
||||||
)
|
)
|
||||||
@ -22,11 +25,15 @@ try:
|
|||||||
from lightllm.models.llama2.triton_kernel.token_attention_nopad_softmax import (
|
from lightllm.models.llama2.triton_kernel.token_attention_nopad_softmax import (
|
||||||
token_softmax_fwd as lightllm_llama2_token_softmax_fwd,
|
token_softmax_fwd as lightllm_llama2_token_softmax_fwd,
|
||||||
)
|
)
|
||||||
|
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import (
|
||||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fw2
|
token_att_fwd as lightllm_llama_token_att_fwd,
|
||||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd
|
)
|
||||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd
|
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import (
|
||||||
from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd
|
token_att_fwd2 as lightllm_llama_token_att_fw2,
|
||||||
|
)
|
||||||
|
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import (
|
||||||
|
token_softmax_fwd as lightllm_llama_token_softmax_fwd,
|
||||||
|
)
|
||||||
|
|
||||||
HAS_TRITON_TOKEN_ATTENTION = True
|
HAS_TRITON_TOKEN_ATTENTION = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -44,8 +44,8 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
|
|||||||
|
|
||||||
return unpickle
|
return unpickle
|
||||||
|
|
||||||
def check_for_nccl_backend(group):
|
|
||||||
|
|
||||||
|
def check_for_nccl_backend(group):
|
||||||
pg = group or c10d._get_default_group()
|
pg = group or c10d._get_default_group()
|
||||||
# Gate PG wrapper check on Gloo availability.
|
# Gate PG wrapper check on Gloo availability.
|
||||||
if c10d._GLOO_AVAILABLE:
|
if c10d._GLOO_AVAILABLE:
|
||||||
@ -54,10 +54,8 @@ def check_for_nccl_backend(group):
|
|||||||
while isinstance(pg, c10d._ProcessGroupWrapper):
|
while isinstance(pg, c10d._ProcessGroupWrapper):
|
||||||
pg = pg.wrapped_pg
|
pg = pg.wrapped_pg
|
||||||
|
|
||||||
return (
|
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
|
||||||
c10d.is_nccl_available() and
|
|
||||||
pg.name() == c10d.Backend.NCCL
|
|
||||||
)
|
|
||||||
|
|
||||||
def _broadcast_object_list(
|
def _broadcast_object_list(
|
||||||
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None
|
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None
|
||||||
|
7
colossalai/quantization/__init__.py
Normal file
7
colossalai/quantization/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from .bnb import quantize_model
|
||||||
|
from .bnb_config import BnbQuantizationConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BnbQuantizationConfig",
|
||||||
|
"quantize_model",
|
||||||
|
]
|
321
colossalai/quantization/bnb.py
Normal file
321
colossalai/quantization/bnb.py
Normal file
@ -0,0 +1,321 @@
|
|||||||
|
# adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .bnb_config import BnbQuantizationConfig
|
||||||
|
|
||||||
|
try:
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
|
||||||
|
IS_4BIT_BNB_AVAILABLE = bnb.__version__ >= "0.39.0"
|
||||||
|
IS_8BIT_BNB_AVAILABLE = bnb.__version__ >= "0.37.2"
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_model(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
bnb_quantization_config: BnbQuantizationConfig,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This function will quantize the input loaded model with the associated config passed in `bnb_quantization_config`.
|
||||||
|
We will quantize the model and put the model on the GPU.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (`torch.nn.Module`):
|
||||||
|
Input model. The model already loaded
|
||||||
|
bnb_quantization_config (`BnbQuantizationConfig`):
|
||||||
|
The bitsandbytes quantization parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.nn.Module`: The quantized model
|
||||||
|
"""
|
||||||
|
|
||||||
|
load_in_4bit = bnb_quantization_config.load_in_4bit
|
||||||
|
load_in_8bit = bnb_quantization_config.load_in_8bit
|
||||||
|
|
||||||
|
if load_in_8bit and not IS_8BIT_BNB_AVAILABLE:
|
||||||
|
raise ImportError(
|
||||||
|
"You have a version of `bitsandbytes` that is not compatible with 8bit quantization,"
|
||||||
|
" make sure you have the latest version of `bitsandbytes` installed."
|
||||||
|
)
|
||||||
|
if load_in_4bit and not IS_4BIT_BNB_AVAILABLE:
|
||||||
|
raise ValueError(
|
||||||
|
"You have a version of `bitsandbytes` that is not compatible with 4bit quantization,"
|
||||||
|
"make sure you have the latest version of `bitsandbytes` installed."
|
||||||
|
)
|
||||||
|
|
||||||
|
# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
|
||||||
|
if bnb_quantization_config.skip_modules is None:
|
||||||
|
bnb_quantization_config.skip_modules = get_keys_to_not_convert(model)
|
||||||
|
|
||||||
|
modules_to_not_convert = bnb_quantization_config.skip_modules
|
||||||
|
|
||||||
|
# We add the modules we want to keep in full precision
|
||||||
|
if bnb_quantization_config.keep_in_fp32_modules is None:
|
||||||
|
bnb_quantization_config.keep_in_fp32_modules = []
|
||||||
|
keep_in_fp32_modules = bnb_quantization_config.keep_in_fp32_modules
|
||||||
|
|
||||||
|
# compatibility with peft
|
||||||
|
model.is_loaded_in_4bit = load_in_4bit
|
||||||
|
model.is_loaded_in_8bit = load_in_8bit
|
||||||
|
|
||||||
|
# assert model_device is cuda
|
||||||
|
model_device = next(model.parameters()).device
|
||||||
|
|
||||||
|
model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert)
|
||||||
|
|
||||||
|
# convert param to the right dtype
|
||||||
|
dtype = bnb_quantization_config.torch_dtype
|
||||||
|
for name, param in model.state_dict().items():
|
||||||
|
if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules):
|
||||||
|
param.to(torch.float32)
|
||||||
|
if param.dtype != torch.float32:
|
||||||
|
name = name.replace(".weight", "").replace(".bias", "")
|
||||||
|
param = getattr(model, name, None)
|
||||||
|
if param is not None:
|
||||||
|
param.to(torch.float32)
|
||||||
|
elif torch.is_floating_point(param):
|
||||||
|
param.to(dtype)
|
||||||
|
if model_device.type == "cuda":
|
||||||
|
# move everything to cpu in the first place because we can't do quantization if the weights are already on cuda
|
||||||
|
model.cuda(torch.cuda.current_device())
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
model.to(torch.cuda.current_device())
|
||||||
|
logger.info(
|
||||||
|
f"The model device type is {model_device.type}. However, cuda is needed for quantization."
|
||||||
|
"We move the model to cuda."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=None, current_key_name=None):
|
||||||
|
"""
|
||||||
|
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules or by `bnb.nn.Linear4bit`
|
||||||
|
modules from the `bitsandbytes`library. The function will be run recursively and replace `torch.nn.Linear` modules.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
model (`torch.nn.Module`):
|
||||||
|
Input model or `torch.nn.Module` as the function is run recursively.
|
||||||
|
modules_to_not_convert (`List[str]`):
|
||||||
|
Names of the modules to not quantize convert. In practice we keep the `lm_head` in full precision for
|
||||||
|
numerical stability reasons.
|
||||||
|
current_key_name (`List[str]`, *optional*):
|
||||||
|
An array to track the current key of the recursion. This is used to check whether the current key (part of
|
||||||
|
it) is not in the list of modules to not convert.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if modules_to_not_convert is None:
|
||||||
|
modules_to_not_convert = []
|
||||||
|
|
||||||
|
model, has_been_replaced = _replace_with_bnb_layers(
|
||||||
|
model, bnb_quantization_config, modules_to_not_convert, current_key_name
|
||||||
|
)
|
||||||
|
if not has_been_replaced:
|
||||||
|
logger.warning(
|
||||||
|
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
|
||||||
|
" this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers."
|
||||||
|
" Please double check your model architecture, or submit an issue on github if you think this is"
|
||||||
|
" a bug."
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_with_bnb_layers(
|
||||||
|
model,
|
||||||
|
bnb_quantization_config,
|
||||||
|
modules_to_not_convert=None,
|
||||||
|
current_key_name=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Private method that wraps the recursion for module replacement.
|
||||||
|
|
||||||
|
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
|
||||||
|
"""
|
||||||
|
# bitsandbytes will initialize CUDA on import, so it needs to be imported lazily
|
||||||
|
|
||||||
|
has_been_replaced = False
|
||||||
|
for name, module in model.named_children():
|
||||||
|
if current_key_name is None:
|
||||||
|
current_key_name = []
|
||||||
|
current_key_name.append(name)
|
||||||
|
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
|
||||||
|
# Check if the current key is not in the `modules_to_not_convert`
|
||||||
|
current_key_name_str = ".".join(current_key_name)
|
||||||
|
proceed = True
|
||||||
|
for key in modules_to_not_convert:
|
||||||
|
if (
|
||||||
|
(key in current_key_name_str) and (key + "." in current_key_name_str)
|
||||||
|
) or key == current_key_name_str:
|
||||||
|
proceed = False
|
||||||
|
break
|
||||||
|
if proceed:
|
||||||
|
# Load bnb module with empty weight and replace ``nn.Linear` module
|
||||||
|
if bnb_quantization_config.load_in_8bit:
|
||||||
|
bnb_module = bnb.nn.Linear8bitLt(
|
||||||
|
module.in_features,
|
||||||
|
module.out_features,
|
||||||
|
module.bias is not None,
|
||||||
|
has_fp16_weights=False,
|
||||||
|
threshold=bnb_quantization_config.llm_int8_threshold,
|
||||||
|
)
|
||||||
|
elif bnb_quantization_config.load_in_4bit:
|
||||||
|
bnb_module = bnb.nn.Linear4bit(
|
||||||
|
module.in_features,
|
||||||
|
module.out_features,
|
||||||
|
module.bias is not None,
|
||||||
|
bnb_quantization_config.bnb_4bit_compute_dtype,
|
||||||
|
compress_statistics=bnb_quantization_config.bnb_4bit_use_double_quant,
|
||||||
|
quant_type=bnb_quantization_config.bnb_4bit_quant_type,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("load_in_8bit and load_in_4bit can't be both False")
|
||||||
|
bnb_module.weight.data = module.weight.data
|
||||||
|
bnb_module.weight.skip_zero_check = True
|
||||||
|
if module.bias is not None:
|
||||||
|
bnb_module.bias.data = module.bias.data
|
||||||
|
bnb_module.bias.skip_zero_check = True
|
||||||
|
bnb_module.requires_grad_(False)
|
||||||
|
setattr(model, name, bnb_module)
|
||||||
|
has_been_replaced = True
|
||||||
|
if len(list(module.children())) > 0:
|
||||||
|
_, _has_been_replaced = _replace_with_bnb_layers(
|
||||||
|
module, bnb_quantization_config, modules_to_not_convert, current_key_name
|
||||||
|
)
|
||||||
|
has_been_replaced = has_been_replaced | _has_been_replaced
|
||||||
|
# Remove the last key for recursion
|
||||||
|
current_key_name.pop(-1)
|
||||||
|
return model, has_been_replaced
|
||||||
|
|
||||||
|
|
||||||
|
def get_keys_to_not_convert(model):
|
||||||
|
r"""
|
||||||
|
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
|
||||||
|
we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
|
||||||
|
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
|
||||||
|
int8.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
model (`torch.nn.Module`):
|
||||||
|
Input model
|
||||||
|
"""
|
||||||
|
# Create a copy of the model
|
||||||
|
# with init_empty_weights():
|
||||||
|
# tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
|
||||||
|
tied_model = model
|
||||||
|
|
||||||
|
tied_params = find_tied_parameters(tied_model)
|
||||||
|
# For compatibility with Accelerate < 0.18
|
||||||
|
if isinstance(tied_params, dict):
|
||||||
|
tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys())
|
||||||
|
else:
|
||||||
|
tied_keys = sum(tied_params, [])
|
||||||
|
has_tied_params = len(tied_keys) > 0
|
||||||
|
|
||||||
|
# Check if it is a base model
|
||||||
|
is_base_model = False
|
||||||
|
if hasattr(model, "base_model_prefix"):
|
||||||
|
is_base_model = not hasattr(model, model.base_model_prefix)
|
||||||
|
|
||||||
|
# Ignore this for base models (BertModel, GPT2Model, etc.)
|
||||||
|
if (not has_tied_params) and is_base_model:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# otherwise they have an attached head
|
||||||
|
list_modules = list(model.named_children())
|
||||||
|
list_last_module = [list_modules[-1][0]]
|
||||||
|
|
||||||
|
# add last module together with tied weights
|
||||||
|
intersection = set(list_last_module) - set(tied_keys)
|
||||||
|
list_untouched = list(set(tied_keys)) + list(intersection)
|
||||||
|
|
||||||
|
# remove ".weight" from the keys
|
||||||
|
names_to_remove = [".weight", ".bias"]
|
||||||
|
filtered_module_names = []
|
||||||
|
for name in list_untouched:
|
||||||
|
for name_to_remove in names_to_remove:
|
||||||
|
if name_to_remove in name:
|
||||||
|
name = name.replace(name_to_remove, "")
|
||||||
|
filtered_module_names.append(name)
|
||||||
|
|
||||||
|
return filtered_module_names
|
||||||
|
|
||||||
|
|
||||||
|
def find_tied_parameters(model: nn.Module, **kwargs):
|
||||||
|
"""
|
||||||
|
Find the tied parameters in a given model.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore
|
||||||
|
them.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (`torch.nn.Module`): The model to inspect.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[List[str]]: A list of lists of parameter names being all tied together.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```py
|
||||||
|
>>> from collections import OrderedDict
|
||||||
|
>>> import torch.nn as nn
|
||||||
|
|
||||||
|
>>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))]))
|
||||||
|
>>> model.linear2.weight = model.linear1.weight
|
||||||
|
>>> find_tied_parameters(model)
|
||||||
|
[['linear1.weight', 'linear2.weight']]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
# Initialize result and named_parameters before recursing.
|
||||||
|
named_parameters = kwargs.get("named_parameters", None)
|
||||||
|
prefix = kwargs.get("prefix", "")
|
||||||
|
result = kwargs.get("result", {})
|
||||||
|
|
||||||
|
if named_parameters is None:
|
||||||
|
named_parameters = {n: p for n, p in model.named_parameters()}
|
||||||
|
else:
|
||||||
|
# A tied parameter will not be in the full `named_parameters` seen above but will be in the `named_parameters`
|
||||||
|
# of the submodule it belongs to. So while recursing we track the names that are not in the initial
|
||||||
|
# `named_parameters`.
|
||||||
|
for name, parameter in model.named_parameters():
|
||||||
|
full_name = name if prefix == "" else f"{prefix}.{name}"
|
||||||
|
if full_name not in named_parameters:
|
||||||
|
# When we find one, it has to be one of the existing parameters.
|
||||||
|
for new_name, new_param in named_parameters.items():
|
||||||
|
if new_param is parameter:
|
||||||
|
if new_name not in result:
|
||||||
|
result[new_name] = []
|
||||||
|
result[new_name].append(full_name)
|
||||||
|
|
||||||
|
# Once we have treated direct parameters, we move to the child modules.
|
||||||
|
for name, child in model.named_children():
|
||||||
|
child_name = name if prefix == "" else f"{prefix}.{name}"
|
||||||
|
find_tied_parameters(child, named_parameters=named_parameters, prefix=child_name, result=result)
|
||||||
|
|
||||||
|
return FindTiedParametersResult([sorted([weight] + list(set(tied))) for weight, tied in result.items()])
|
||||||
|
|
||||||
|
|
||||||
|
class FindTiedParametersResult(list):
|
||||||
|
"""
|
||||||
|
This is a subclass of a list to handle backward compatibility for Transformers. Do not rely on the fact this is not
|
||||||
|
a list or on the `values` method as in the future this will be removed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def values(self):
|
||||||
|
return sum([x[1:] for x in self], [])
|
113
colossalai/quantization/bnb_config.py
Normal file
113
colossalai/quantization/bnb_config.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
# adapted from Hugging Face accelerate/utils/dataclasses.py
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BnbQuantizationConfig:
|
||||||
|
"""
|
||||||
|
A plugin to enable BitsAndBytes 4bit and 8bit quantization
|
||||||
|
"""
|
||||||
|
|
||||||
|
load_in_8bit: bool = field(default=False, metadata={"help": "enable 8bit quantization."})
|
||||||
|
|
||||||
|
llm_int8_threshold: float = field(
|
||||||
|
default=6.0, metadata={"help": "value of the outliner threshold. only relevant when load_in_8bit=True"}
|
||||||
|
)
|
||||||
|
|
||||||
|
load_in_4bit: bool = field(default=False, metadata={"help": "enable 4bit quantization."})
|
||||||
|
|
||||||
|
bnb_4bit_quant_type: str = field(
|
||||||
|
default="fp4",
|
||||||
|
metadata={
|
||||||
|
"help": "set the quantization data type in the `bnb.nn.Linear4Bit` layers. Options are {'fp4','np4'}."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
bnb_4bit_use_double_quant: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "enable nested quantization where the quantization constants from the first quantization are quantized again."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
bnb_4bit_compute_dtype: bool = field(
|
||||||
|
default="fp16",
|
||||||
|
metadata={
|
||||||
|
"help": "This sets the computational type which might be different than the input time. For example, inputs might be "
|
||||||
|
"fp32, but computation can be set to bf16 for speedups. Options are {'fp32','fp16','bf16'}."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
torch_dtype: torch.dtype = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "this sets the dtype of the remaining non quantized layers. `bitsandbytes` library suggests to set the value"
|
||||||
|
"to `torch.float16` for 8 bit model and use the same dtype as the compute dtype for 4 bit model "
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
skip_modules: List[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "an explicit list of the modules that we don't quantize. The dtype of these modules will be `torch_dtype`."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
keep_in_fp32_modules: List[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "an explicit list of the modules that we don't quantize. We keep them in `torch.float32`."},
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if isinstance(self.bnb_4bit_compute_dtype, str):
|
||||||
|
if self.bnb_4bit_compute_dtype == "fp32":
|
||||||
|
self.bnb_4bit_compute_dtype = torch.float32
|
||||||
|
elif self.bnb_4bit_compute_dtype == "fp16":
|
||||||
|
self.bnb_4bit_compute_dtype = torch.float16
|
||||||
|
elif self.bnb_4bit_compute_dtype == "bf16":
|
||||||
|
self.bnb_4bit_compute_dtype = torch.bfloat16
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"bnb_4bit_compute_dtype must be in ['fp32','fp16','bf16'] but found {self.bnb_4bit_compute_dtype}"
|
||||||
|
)
|
||||||
|
elif not isinstance(self.bnb_4bit_compute_dtype, torch.dtype):
|
||||||
|
raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype")
|
||||||
|
|
||||||
|
if self.skip_modules is not None and not isinstance(self.skip_modules, list):
|
||||||
|
raise ValueError("skip_modules must be a list of strings")
|
||||||
|
|
||||||
|
if self.keep_in_fp32_modules is not None and not isinstance(self.keep_in_fp32_modules, list):
|
||||||
|
raise ValueError("keep_in_fp_32_modules must be a list of strings")
|
||||||
|
|
||||||
|
if self.load_in_4bit:
|
||||||
|
self.target_dtype = "int4"
|
||||||
|
|
||||||
|
if self.load_in_8bit:
|
||||||
|
self.target_dtype = torch.int8
|
||||||
|
|
||||||
|
if self.load_in_4bit and self.llm_int8_threshold != 6.0:
|
||||||
|
warnings.warn("llm_int8_threshold can only be used for model loaded in 8bit")
|
||||||
|
|
||||||
|
if isinstance(self.torch_dtype, str):
|
||||||
|
if self.torch_dtype == "fp32":
|
||||||
|
self.torch_dtype = torch.float32
|
||||||
|
elif self.torch_dtype == "fp16":
|
||||||
|
self.torch_dtype = torch.float16
|
||||||
|
elif self.torch_dtype == "bf16":
|
||||||
|
self.torch_dtype = torch.bfloat16
|
||||||
|
else:
|
||||||
|
raise ValueError(f"torch_dtype must be in ['fp32','fp16','bf16'] but found {self.torch_dtype}")
|
||||||
|
|
||||||
|
if self.load_in_8bit and self.torch_dtype is None:
|
||||||
|
self.torch_dtype = torch.float16
|
||||||
|
|
||||||
|
if self.load_in_4bit and self.torch_dtype is None:
|
||||||
|
self.torch_dtype = self.bnb_4bit_compute_dtype
|
||||||
|
|
||||||
|
if not isinstance(self.torch_dtype, torch.dtype):
|
||||||
|
raise ValueError("torch_dtype must be a torch.dtype")
|
@ -190,6 +190,7 @@ def calculate_global_norm_from_list(norm_list):
|
|||||||
total_norm += norm**2.0
|
total_norm += norm**2.0
|
||||||
return math.sqrt(total_norm)
|
return math.sqrt(total_norm)
|
||||||
|
|
||||||
|
|
||||||
def sync_tensor(flat_tensor, tensor_list):
|
def sync_tensor(flat_tensor, tensor_list):
|
||||||
"""
|
"""
|
||||||
Synchronize the flattened tensor and unflattened tensor list. When
|
Synchronize the flattened tensor and unflattened tensor list. When
|
||||||
|
@ -187,6 +187,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
for param_group in self.optim.param_groups:
|
for param_group in self.optim.param_groups:
|
||||||
group_params = param_group["params"]
|
group_params = param_group["params"]
|
||||||
for param in group_params:
|
for param in group_params:
|
||||||
|
if not hasattr(param, "skip_zero_check") or param.skip_zero_check is False:
|
||||||
assert (
|
assert (
|
||||||
param.dtype == self._dtype
|
param.dtype == self._dtype
|
||||||
), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
|
), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
|
||||||
|
@ -1,12 +1,10 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
from auto_gptq import AutoGPTQForCausalLM
|
||||||
from auto_gptq.nn_modules.qlinear import GeneralQuantLinear
|
from transformers import BloomTokenizerFast
|
||||||
from transformers import AutoTokenizer, BloomForCausalLM, BloomTokenizerFast, LlamaForCausalLM, LlamaTokenizer
|
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.inference.tensor_parallel.engine import TPInferEngine
|
from colossalai.inference.tensor_parallel.engine import TPInferEngine
|
||||||
@ -14,7 +12,7 @@ from colossalai.logging import disable_existing_loggers
|
|||||||
from colossalai.shardformer import ShardConfig
|
from colossalai.shardformer import ShardConfig
|
||||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||||
|
|
||||||
|
|
||||||
def print_perf_stats(latency_set, config, bs, warmup=3):
|
def print_perf_stats(latency_set, config, bs, warmup=3):
|
||||||
@ -37,7 +35,6 @@ def print_perf_stats(latency_set, config, bs, warmup=3):
|
|||||||
|
|
||||||
|
|
||||||
def bench_bloom(args):
|
def bench_bloom(args):
|
||||||
|
|
||||||
pretrained_model_dir = args.path
|
pretrained_model_dir = args.path
|
||||||
quantized_model_dir = args.quantized_path
|
quantized_model_dir = args.quantized_path
|
||||||
max_batch_size = args.batch_size
|
max_batch_size = args.batch_size
|
||||||
@ -48,9 +45,9 @@ def bench_bloom(args):
|
|||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
# load quantized model to the first GPU
|
# load quantized model to the first GPU
|
||||||
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir,
|
model = AutoGPTQForCausalLM.from_quantized(
|
||||||
device=torch.cuda.current_device(),
|
quantized_model_dir, device=torch.cuda.current_device(), inject_fused_attention=False
|
||||||
inject_fused_attention=False)
|
)
|
||||||
|
|
||||||
model = model.half()
|
model = model.half()
|
||||||
|
|
||||||
@ -60,22 +57,22 @@ def bench_bloom(args):
|
|||||||
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
|
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
|
||||||
|
|
||||||
input_tokens = {
|
input_tokens = {
|
||||||
"input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'),
|
"input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"),
|
||||||
"attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda')
|
"attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"),
|
||||||
}
|
}
|
||||||
|
|
||||||
# init TPInferEngine and shard the original model
|
# init TPInferEngine and shard the original model
|
||||||
# To benchmark torch original, comment out the line of optimizing model
|
# To benchmark torch original, comment out the line of optimizing model
|
||||||
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False,
|
shard_config = ShardConfig(
|
||||||
inference_only=True,
|
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
|
||||||
inference_gptq=True)
|
)
|
||||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||||
|
|
||||||
# prepare data for generation
|
# prepare data for generation
|
||||||
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
|
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
|
||||||
input_tokens = {
|
input_tokens = {
|
||||||
"input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)),
|
"input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)),
|
||||||
"attention_mask": torch.ones((max_batch_size, max_input_len))
|
"attention_mask": torch.ones((max_batch_size, max_input_len)),
|
||||||
}
|
}
|
||||||
for t in input_tokens:
|
for t in input_tokens:
|
||||||
if torch.is_tensor(input_tokens[t]):
|
if torch.is_tensor(input_tokens[t]):
|
||||||
@ -99,7 +96,7 @@ def bench_bloom(args):
|
|||||||
|
|
||||||
def check_bloom(rank, world_size, port, args):
|
def check_bloom(rank, world_size, port, args):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
bench_bloom(args)
|
bench_bloom(args)
|
||||||
|
|
||||||
|
|
||||||
@ -111,12 +108,12 @@ def test_bloom(args):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-p', '--path', type=str, help='Model path', required=True)
|
parser.add_argument("-p", "--path", type=str, help="Model path", required=True)
|
||||||
parser.add_argument('-q', '--quantized_path', type=str, help='Model path', required=True)
|
parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True)
|
||||||
parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size')
|
parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
|
||||||
parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size')
|
parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size")
|
||||||
parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length')
|
parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length")
|
||||||
parser.add_argument('--output_len', type=int, default=128, help='Maximum output length')
|
parser.add_argument("--output_len", type=int, default=128, help="Maximum output length")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -15,3 +15,4 @@ sentencepiece
|
|||||||
google
|
google
|
||||||
protobuf
|
protobuf
|
||||||
peft>=0.7.1
|
peft>=0.7.1
|
||||||
|
bitsandbytes>=0.39.0
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Callable, Iterator, List, Tuple, Union, Dict
|
from typing import Callable, Dict, Iterator, List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -51,7 +51,6 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config=None)
|
|||||||
# raise e
|
# raise e
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@parameterize("stage", [2])
|
@parameterize("stage", [2])
|
||||||
def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
|
def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
|
||||||
"""check low level zero plugin over model zoo
|
"""check low level zero plugin over model zoo
|
||||||
@ -118,6 +117,7 @@ def check_low_level_zero_lora(stage, model_name, early_stop: bool = True):
|
|||||||
print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
|
print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
|
||||||
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
|
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port, early_stop: bool = True):
|
def run_dist(rank, world_size, port, early_stop: bool = True):
|
||||||
# init dist env
|
# init dist env
|
||||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
|
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
|
from copy import deepcopy
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from peft import LoraConfig
|
||||||
from torchvision.models import resnet18
|
from torchvision.models import resnet18
|
||||||
from utils import shared_tempdir
|
from utils import shared_tempdir
|
||||||
from typing import Optional
|
|
||||||
from peft import LoraConfig
|
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
@ -131,12 +132,15 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo
|
|||||||
# return repr(e)
|
# return repr(e)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
@parameterize("stage", [2])
|
@parameterize("stage", [2])
|
||||||
@parameterize("shard", [True, False])
|
@parameterize("shard", [True, False])
|
||||||
@parameterize("offload", [False, True])
|
@parameterize("offload", [False, True])
|
||||||
@parameterize("model_name", ["transformers_llama"])
|
@parameterize("model_name", ["transformers_llama"])
|
||||||
def check_low_level_zero_lora_checkpointIO(stage: int, shard: bool, offload: bool, model_name: str, early_stop: bool = True):
|
def check_low_level_zero_lora_checkpointIO(
|
||||||
|
stage: int, shard: bool, offload: bool, model_name: str, early_stop: bool = True
|
||||||
|
):
|
||||||
passed_models = []
|
passed_models = []
|
||||||
failed_info = {} # (model_name, error) pair
|
failed_info = {} # (model_name, error) pair
|
||||||
|
|
||||||
@ -166,6 +170,7 @@ def check_low_level_zero_lora_checkpointIO(stage: int, shard: bool, offload: boo
|
|||||||
print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
|
print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
|
||||||
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
|
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost")
|
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost")
|
||||||
check_low_level_zero_checkpointIO()
|
check_low_level_zero_checkpointIO()
|
||||||
|
@ -1,16 +1,8 @@
|
|||||||
import math
|
|
||||||
import time
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import transformers
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
HAS_TRITON = True
|
HAS_TRITON = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
HAS_TRITON = False
|
HAS_TRITON = False
|
||||||
@ -22,6 +14,7 @@ try:
|
|||||||
from exllama_kernels import prepare_buffers, set_tuning_params
|
from exllama_kernels import prepare_buffers, set_tuning_params
|
||||||
|
|
||||||
from colossalai.inference.quant.gptq import CaiQuantLinear
|
from colossalai.inference.quant.gptq import CaiQuantLinear
|
||||||
|
|
||||||
HAS_AUTO_GPTQ = True
|
HAS_AUTO_GPTQ = True
|
||||||
except:
|
except:
|
||||||
HAS_AUTO_GPTQ = False
|
HAS_AUTO_GPTQ = False
|
||||||
@ -32,13 +25,14 @@ import warnings
|
|||||||
HAS_GPTQ_CUDA = False
|
HAS_GPTQ_CUDA = False
|
||||||
try:
|
try:
|
||||||
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
||||||
|
|
||||||
gptq_cuda = GPTQBuilder().load()
|
gptq_cuda = GPTQBuilder().load()
|
||||||
HAS_GPTQ_CUDA = True
|
HAS_GPTQ_CUDA = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
warnings.warn('CUDA gptq is not installed')
|
warnings.warn("CUDA gptq is not installed")
|
||||||
HAS_GPTQ_CUDA = False
|
HAS_GPTQ_CUDA = False
|
||||||
|
|
||||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||||
|
|
||||||
max_inner_outer_dim = 1
|
max_inner_outer_dim = 1
|
||||||
max_input_len = 1
|
max_input_len = 1
|
||||||
@ -64,9 +58,9 @@ def init_buffer(cai_linear, use_act_order=False):
|
|||||||
max_input_len = 4096
|
max_input_len = 4096
|
||||||
# The temp_state buffer is required to reorder X in the act-order case.
|
# The temp_state buffer is required to reorder X in the act-order case.
|
||||||
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
||||||
gptq_temp_state_buffer = torch.zeros((max_input_len, max_inner_outer_dim),
|
gptq_temp_state_buffer = torch.zeros(
|
||||||
dtype=torch.float16,
|
(max_input_len, max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device()
|
||||||
device=torch.cuda.current_device())
|
)
|
||||||
gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device())
|
gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device())
|
||||||
|
|
||||||
gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer)
|
gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer)
|
||||||
@ -77,10 +71,11 @@ def init_buffer(cai_linear, use_act_order=False):
|
|||||||
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ,
|
@pytest.mark.skipif(
|
||||||
reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq")
|
not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ,
|
||||||
|
reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq",
|
||||||
|
)
|
||||||
def test_gptq_linear():
|
def test_gptq_linear():
|
||||||
|
|
||||||
infeature = 1024
|
infeature = 1024
|
||||||
outfeature = 1024
|
outfeature = 1024
|
||||||
group_size = 128
|
group_size = 128
|
||||||
@ -120,7 +115,7 @@ def test_gptq_linear():
|
|||||||
max_input_len = 2048
|
max_input_len = 2048
|
||||||
buffers = {
|
buffers = {
|
||||||
"temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device),
|
"temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device),
|
||||||
"temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device)
|
"temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device),
|
||||||
}
|
}
|
||||||
|
|
||||||
prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"])
|
prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"])
|
||||||
@ -146,5 +141,4 @@ def test_gptq_linear():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
test_gptq_linear()
|
test_gptq_linear()
|
||||||
|
@ -4,6 +4,7 @@ from packaging import version
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
|
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
|
||||||
|
|
||||||
HAS_TRITON = True
|
HAS_TRITON = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
HAS_TRITON = False
|
HAS_TRITON = False
|
||||||
|
106
tests/test_lora/test_lora.py
Normal file
106
tests/test_lora/test_lora.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
import copy
|
||||||
|
import os
|
||||||
|
from itertools import product
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from peft import LoraConfig
|
||||||
|
from torch import distributed as dist
|
||||||
|
from torch.optim import AdamW
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin
|
||||||
|
from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||||
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
from tests.test_checkpoint_io.utils import shared_tempdir
|
||||||
|
|
||||||
|
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
|
||||||
|
model = model_fn()
|
||||||
|
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
|
||||||
|
|
||||||
|
test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()]
|
||||||
|
test_configs = [
|
||||||
|
{
|
||||||
|
"lora_config": lora_config,
|
||||||
|
"quantize": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"lora_config": lora_config,
|
||||||
|
"quantize": True,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
for plugin, test_config in product(test_plugins, test_configs):
|
||||||
|
# checkpoint loaded model
|
||||||
|
model_save = model_fn()
|
||||||
|
model_load = copy.deepcopy(model_save)
|
||||||
|
|
||||||
|
optimizer = AdamW(model.parameters(), lr=0.001)
|
||||||
|
criterion = loss_fn
|
||||||
|
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
model_save = booster.enable_lora(model_save, **test_config)
|
||||||
|
model_save, optimizer, criterion, _, _ = booster.boost(model_save, optimizer, criterion)
|
||||||
|
|
||||||
|
with shared_tempdir() as tempdir:
|
||||||
|
lora_ckpt_path = os.path.join(tempdir, "ckpt")
|
||||||
|
booster.save_lora_as_pretrained(model_save, lora_ckpt_path)
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
# The Lora checkpoint should be small in size
|
||||||
|
checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024)
|
||||||
|
assert checkpoint_size_mb < 1
|
||||||
|
|
||||||
|
model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path, **test_config)
|
||||||
|
model_load, _, _, _, _ = booster.boost(model_load)
|
||||||
|
|
||||||
|
check_state_dict_equal(model_save.state_dict(), model_load.state_dict())
|
||||||
|
|
||||||
|
# test fwd bwd correctness
|
||||||
|
test_model = model_load
|
||||||
|
model_copy = copy.deepcopy(model_load)
|
||||||
|
|
||||||
|
data = data_gen_fn()
|
||||||
|
data = {
|
||||||
|
k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
output = test_model(**data)
|
||||||
|
output = output_transform_fn(output)
|
||||||
|
loss = criterion(output)
|
||||||
|
|
||||||
|
booster.backward(loss, optimizer)
|
||||||
|
optimizer.clip_grad_by_norm(1.0)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
for (n1, p1), (n2, p2) in zip(test_model.named_parameters(), model_copy.named_parameters()):
|
||||||
|
if "lora_" in n1:
|
||||||
|
# lora modules require gradients, thus updated
|
||||||
|
assert p1.requires_grad
|
||||||
|
assert not torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3)
|
||||||
|
else:
|
||||||
|
if not p1.requires_grad:
|
||||||
|
torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3)
|
||||||
|
|
||||||
|
|
||||||
|
def run_lora_test():
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
task_type = None
|
||||||
|
if name == "transformers_llama_for_casual_lm":
|
||||||
|
task_type = "CAUSAL_LM"
|
||||||
|
if name == "transformers_llama_for_sequence_classification":
|
||||||
|
task_type = "SEQ_CLS"
|
||||||
|
check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
config = {}
|
||||||
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
run_lora_test()
|
||||||
|
|
||||||
|
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_torch_ddp_lora():
|
||||||
|
spawn(run_dist, 2)
|
@ -1,108 +0,0 @@
|
|||||||
import copy
|
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from peft import LoraConfig
|
|
||||||
from torch import distributed as dist
|
|
||||||
from torch.optim import AdamW
|
|
||||||
|
|
||||||
import colossalai
|
|
||||||
from colossalai.booster import Booster
|
|
||||||
from colossalai.booster.plugin import TorchDDPPlugin
|
|
||||||
from colossalai.testing import (
|
|
||||||
assert_equal,
|
|
||||||
assert_not_equal,
|
|
||||||
check_state_dict_equal,
|
|
||||||
clear_cache_before_run,
|
|
||||||
rerun_if_address_is_in_use,
|
|
||||||
spawn,
|
|
||||||
)
|
|
||||||
from tests.kit.model_zoo import model_zoo
|
|
||||||
from tests.test_checkpoint_io.utils import shared_tempdir
|
|
||||||
|
|
||||||
|
|
||||||
@clear_cache_before_run()
|
|
||||||
def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
|
|
||||||
model = model_fn()
|
|
||||||
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
|
|
||||||
|
|
||||||
plugin = TorchDDPPlugin()
|
|
||||||
booster = Booster(plugin=plugin)
|
|
||||||
|
|
||||||
model = booster.enable_lora(model, lora_config=lora_config)
|
|
||||||
model_copy = copy.deepcopy(model)
|
|
||||||
|
|
||||||
optimizer = AdamW(model.parameters(), lr=0.001)
|
|
||||||
criterion = loss_fn
|
|
||||||
|
|
||||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
|
||||||
|
|
||||||
data = data_gen_fn()
|
|
||||||
data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()}
|
|
||||||
|
|
||||||
output = model(**data)
|
|
||||||
output = output_transform_fn(output)
|
|
||||||
loss = criterion(output)
|
|
||||||
|
|
||||||
booster.backward(loss, optimizer)
|
|
||||||
optimizer.clip_grad_by_norm(1.0)
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
for (n1, p1), (n2, p2) in zip(model.named_parameters(), model_copy.named_parameters()):
|
|
||||||
if "lora_" in n1:
|
|
||||||
# lora modules require gradients, thus updated
|
|
||||||
assert p1.requires_grad
|
|
||||||
assert_not_equal(p1.to(p2.device), p2)
|
|
||||||
else:
|
|
||||||
if not p1.requires_grad:
|
|
||||||
assert_equal(p1.to(p2.device), p2)
|
|
||||||
|
|
||||||
|
|
||||||
@clear_cache_before_run()
|
|
||||||
def check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
|
|
||||||
plugin = TorchDDPPlugin()
|
|
||||||
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
|
|
||||||
|
|
||||||
model_save = model_fn()
|
|
||||||
model_load = copy.deepcopy(model_save)
|
|
||||||
|
|
||||||
booster = Booster(plugin=plugin)
|
|
||||||
model_save = booster.enable_lora(model_save, lora_config=lora_config)
|
|
||||||
model_save, _, _, _, _ = booster.boost(model_save)
|
|
||||||
|
|
||||||
with shared_tempdir() as tempdir:
|
|
||||||
lora_ckpt_path = os.path.join(tempdir, "ckpt")
|
|
||||||
booster.save_lora_as_pretrained(model_save, lora_ckpt_path)
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
# The Lora checkpoint should be small in size
|
|
||||||
checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024)
|
|
||||||
assert checkpoint_size_mb < 1
|
|
||||||
|
|
||||||
model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path)
|
|
||||||
model_load, _, _, _, _ = booster.boost(model_load)
|
|
||||||
|
|
||||||
check_state_dict_equal(model_save.state_dict(), model_load.state_dict())
|
|
||||||
|
|
||||||
|
|
||||||
def run_lora_test():
|
|
||||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
|
||||||
task_type = None
|
|
||||||
if name == "transformers_llama_for_casual_lm":
|
|
||||||
task_type = "CAUSAL_LM"
|
|
||||||
if name == "transformers_llama_for_sequence_classification":
|
|
||||||
task_type = "SEQ_CLS"
|
|
||||||
check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
|
|
||||||
check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
|
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
|
||||||
config = {}
|
|
||||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
|
||||||
run_lora_test()
|
|
||||||
|
|
||||||
|
|
||||||
@rerun_if_address_is_in_use()
|
|
||||||
def test_torch_ddp_lora():
|
|
||||||
spawn(run_dist, 2)
|
|
Loading…
Reference in New Issue
Block a user