diff --git a/applications/Colossal-LLaMA/colossal_llama/dataset/conversation.py b/applications/Colossal-LLaMA/colossal_llama/dataset/conversation.py index 8ec9c848b..3ffbb87db 100644 --- a/applications/Colossal-LLaMA/colossal_llama/dataset/conversation.py +++ b/applications/Colossal-LLaMA/colossal_llama/dataset/conversation.py @@ -100,7 +100,7 @@ LLaMA3_Conv = Conversation( messages=[], offset=0, sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN, - seps=["<|begin_of_text|>", "<|end_of_text|>"], + seps=["<|begin_of_text|>", "<|eot_id|>"], ) default_conversation = LLaMA3_Conv diff --git a/applications/Colossal-LLaMA/colossal_llama/dataset/spliced_and_tokenized_dataset.py b/applications/Colossal-LLaMA/colossal_llama/dataset/spliced_and_tokenized_dataset.py index 30122d283..15cb29874 100644 --- a/applications/Colossal-LLaMA/colossal_llama/dataset/spliced_and_tokenized_dataset.py +++ b/applications/Colossal-LLaMA/colossal_llama/dataset/spliced_and_tokenized_dataset.py @@ -88,7 +88,7 @@ def supervised_tokenize_sft( assert ( tokenizer.bos_token == conversation_template.seps[0] and tokenizer.eos_token == conversation_template.seps[1] - ), "`bos_token` and `eos_token` should be the same with `conversation_template.seps`." + ), f"`bos_token`{tokenizer.bos_token} and `eos_token`{tokenizer.eos_token} should be the same with `conversation_template.seps`{conversation_template.seps}." if ignore_index is None: ignore_index = IGNORE_INDEX diff --git a/applications/Colossal-LLaMA/colossal_llama/utils/ckpt_io.py b/applications/Colossal-LLaMA/colossal_llama/utils/ckpt_io.py index 05342ce41..2d712f416 100644 --- a/applications/Colossal-LLaMA/colossal_llama/utils/ckpt_io.py +++ b/applications/Colossal-LLaMA/colossal_llama/utils/ckpt_io.py @@ -43,6 +43,7 @@ def save_checkpoint( step: int, batch_size: int, coordinator: DistCoordinator, + use_lora: bool = False, ) -> None: """ Save model checkpoint, optimizer, LR scheduler and intermedidate running states. @@ -51,7 +52,10 @@ def save_checkpoint( save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}") os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True) - booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True) + if use_lora: + booster.save_lora_as_pretrained(model, os.path.join(save_dir, "modeling")) + else: + booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True) booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True) booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) diff --git a/applications/Colossal-LLaMA/train.py b/applications/Colossal-LLaMA/train.py index db23275e4..6650469f3 100644 --- a/applications/Colossal-LLaMA/train.py +++ b/applications/Colossal-LLaMA/train.py @@ -21,6 +21,7 @@ from colossal_llama.utils.ckpt_io import load_checkpoint, save_checkpoint from colossal_llama.utils.froze import freeze_non_embeds_parameters from colossal_llama.utils.neftune_patch import activate_neftune, deactivate_neftune from colossal_llama.utils.utils import all_reduce_mean, format_numel_str, get_model_numel +from peft import LoraConfig from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer @@ -65,7 +66,7 @@ def train(args) -> None: initial_scale=2**16, max_norm=args.grad_clip, enable_gradient_accumulation=(args.accumulation_steps > 1), - enable_fused_normalization=torch.cuda.is_available(), + enable_fused_normalization=get_accelerator().is_available(), enable_flash_attention=args.use_flash_attn, ) elif args.plugin == "gemini_auto": @@ -75,7 +76,7 @@ def train(args) -> None: initial_scale=2**16, max_norm=args.grad_clip, enable_gradient_accumulation=(args.accumulation_steps > 1), - enable_fused_normalization=torch.cuda.is_available(), + enable_fused_normalization=get_accelerator().is_available(), enable_flash_attention=args.use_flash_attn, ) elif args.plugin == "zero2": @@ -101,10 +102,9 @@ def train(args) -> None: sequence_parallelism_mode=args.sp_mode, zero_stage=args.zero_stage, enable_flash_attention=args.use_flash_attn, - enable_fused_normalization=torch.cuda.is_available(), + enable_fused_normalization=get_accelerator().is_available(), enable_sequence_parallelism=args.enable_sequence_parallelism, cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False, - parallel_output=False, max_norm=args.grad_clip, precision=args.mixed_precision, microbatch_size=args.microbatch_size, @@ -117,11 +117,17 @@ def train(args) -> None: # ====================================================== # Initialize Tokenizer, Dataset, Collator and Dataloader # ====================================================== - tokenizer = AutoTokenizer.from_pretrained(args.pretrained) + tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True) if args.pad_token == "eos": - tokenizer.pad_token = tokenizer.eos_token + try: + tokenizer.pad_token = tokenizer.eos_token + except AttributeError: + coordinator.print_on_master(f"pad_token can't be set") elif args.pad_token == "unk": - tokenizer.pad_token = tokenizer.unk_token + try: + tokenizer.pad_token = tokenizer.unk_token + except AttributeError: + coordinator.print_on_master(f"pad_token can't be set") tokenizer.add_bos_token = False tokenizer.add_eos_token = False @@ -164,33 +170,31 @@ def train(args) -> None: # ====================================================== # Initialize Model, Objective, Optimizer and LR Scheduler # ====================================================== + # When training the ChatGLM model, LoRA and gradient checkpointing are incompatible. init_ctx = ( LazyInitContext(default_device=get_current_device()) - if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) + if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) and args.lora_rank == 0 else nullcontext() ) with init_ctx: - if args.use_flash_attn: - model = AutoModelForCausalLM.from_pretrained( - args.pretrained, - attn_implementation="flash_attention_2", - torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, - trust_remote_code=True, - ) - else: - model = AutoModelForCausalLM.from_pretrained( - args.pretrained, - torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, - trust_remote_code=True, - ) + model = AutoModelForCausalLM.from_pretrained( + args.pretrained, + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + trust_remote_code=True, + ) # Freeze part of parameters. if args.freeze_non_embeds_params: freeze_non_embeds_parameters(model=model) + + if args.lora_rank > 0: + lora_config = LoraConfig(task_type="CAUSAL_LM", r=args.lora_rank, lora_alpha=32, lora_dropout=0.1) + model = booster.enable_lora(model, lora_config=lora_config) + # this is essential, otherwise the grad checkpoint will not work. model.train() if args.use_grad_checkpoint: - model.gradient_checkpointing_enable() + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") model_numel = get_model_numel(model) @@ -327,6 +331,7 @@ def train(args) -> None: step=step + 1, batch_size=args.batch_size, coordinator=coordinator, + use_lora=(args.lora_rank > 0), ) coordinator.print_on_master( f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}" @@ -371,44 +376,45 @@ def train(args) -> None: total_loss.fill_(0.0) pbar.update() - # Save modeling. - save_model_condition = ( - args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0 - ) - - if not args.skip_save_each_epoch: - save_model_condition = save_model_condition or (step + 1) == len(dataloader) - - if save_model_condition and not args.benchmark: - coordinator.print_on_master("\nStart saving model checkpoint with running states") - - if args.use_neft: - coordinator.print_on_master("Deactivate NEFTune before saving model.") - deactivate_neftune(model, handle) - - accelerator.empty_cache() - save_checkpoint( - save_dir=args.save_dir, - booster=booster, - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - epoch=epoch, - step=step + 1, - batch_size=args.batch_size, - coordinator=coordinator, - ) - coordinator.print_on_master( - f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}" + # Save modeling. + save_model_condition = ( + args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0 ) - if args.use_neft: - coordinator.print_on_master("Activate NEFTune.") - model, handle = activate_neftune(model) + if not args.skip_save_each_epoch: + save_model_condition = save_model_condition or (step + 1) == len(dataloader) - # Delete cache. - # del batch, batch_labels, batch_output, loss - accelerator.empty_cache() + if save_model_condition and not args.benchmark: + coordinator.print_on_master("\nStart saving model checkpoint with running states") + + if args.use_neft: + coordinator.print_on_master("Deactivate NEFTune before saving model.") + deactivate_neftune(model, handle) + + accelerator.empty_cache() + save_checkpoint( + save_dir=args.save_dir, + booster=booster, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + epoch=epoch, + step=step + 1, + batch_size=args.batch_size, + coordinator=coordinator, + use_lora=(args.lora_rank > 0), + ) + coordinator.print_on_master( + f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}" + ) + + if args.use_neft: + coordinator.print_on_master("Activate NEFTune.") + model, handle = activate_neftune(model) + + # Delete cache. + # del batch, batch_labels, batch_output, loss + accelerator.empty_cache() # the continue epochs are not resumed, so we need to reset the sampler start index and start step dataloader.sampler.set_start_index(start_index=0) @@ -522,6 +528,7 @@ if __name__ == "__main__": parser.add_argument( "--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin." ) + parser.add_argument("--lora_rank", type=int, default=0, help="lora rank when using lora to train.") # Additional arguments for benchmark. parser.add_argument("--num_samples", type=int, default=500, help="Number of samples for benchmarking.") diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index b130111ba..4072bb197 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -509,9 +509,9 @@ class LazyInitContext: # factory_like functions (eg. torch.empty_like()) def wrapper(*args, **kwargs): orig_t = args[0] - return self.tensor_cls( - orig_target, *orig_t.shape, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs - ) + device = kwargs.pop("device", orig_t.device) + dtype = kwargs.pop("dtype", orig_t.dtype) + return self.tensor_cls(orig_target, *orig_t.shape, *args[1:], device=device, dtype=dtype, **kwargs) return wrapper, target diff --git a/colossalai/legacy/communication/p2p.py b/colossalai/legacy/communication/p2p.py index cf0bd4ba2..089fcf23b 100644 --- a/colossalai/legacy/communication/p2p.py +++ b/colossalai/legacy/communication/p2p.py @@ -171,7 +171,7 @@ def _communicate( for req in reqs: req.wait() # To protect against race condition when using batch_isend_irecv(). - torch.cuda.synchronize() + get_accelerator().synchronize() if recv_prev and recv_prev_split: if isinstance(tensor_recv_prev, torch.Tensor): diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 8c319aceb..8ae8a516f 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -14,6 +14,8 @@ from torch.distributed import ProcessGroup from torch.distributed import distributed_c10d as c10d from torch.utils._pytree import tree_flatten, tree_unflatten +from colossalai.accelerator import get_accelerator + from .stage_manager import PipelineStageManager @@ -31,7 +33,7 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) - buf = tensor.numpy().tobytes()[:tensor_size] if b"cuda" in buf: buf_array = bytearray(buf) - device_index = torch.cuda.current_device() + device_index = get_accelerator().current_device() # There might be more than one output tensors during forward for cuda_str in re.finditer(b"cuda", buf_array): pos = cuda_str.start() @@ -86,7 +88,7 @@ def _broadcast_object_list( else: current_device = torch.device("cpu") if is_nccl_backend: - current_device = torch.device("cuda", torch.cuda.current_device()) + current_device = torch.device("cuda", get_accelerator().current_device()) my_rank = dist.get_rank() # Serialize object_list elements to tensors on src rank. @@ -139,14 +141,14 @@ def _broadcast_object_list( # unconsistence in device if ( isinstance(unpickle_object, torch.Tensor) - and unpickle_object.device.index != torch.cuda.current_device() + and unpickle_object.device.index != get_accelerator().current_device() ): - unpickle_object = unpickle_object.cuda() + unpickle_object = unpickle_object.to(get_accelerator().current_device()) object_list[i] = unpickle_object -def _check_for_nccl_backend(group): +def _check_for_nccl_hccl_backend(group): pg = group or c10d._get_default_group() # Gate PG wrapper check on Gloo availability. if c10d._GLOO_AVAILABLE: @@ -154,14 +156,14 @@ def _check_for_nccl_backend(group): while isinstance(pg, c10d._ProcessGroupWrapper): pg = pg.wrapped_pg - return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL + return (c10d.is_nccl_available() or torch.distributed.is_hccl_available()) and pg.name() == c10d.Backend.NCCL def _check_device(group): - is_nccl_backend = _check_for_nccl_backend(group) + is_nccl_backend = _check_for_nccl_hccl_backend(group) current_device = torch.device("cpu") if is_nccl_backend: - current_device = torch.device("cuda", torch.cuda.current_device()) + current_device = torch.device(get_accelerator().current_device()) return current_device, is_nccl_backend @@ -348,8 +350,11 @@ def _send_recv_serialization_object( unpickle_object = _cuda_safe_tensor_to_object(recv_object_tensor, recv_object_size_tensor.item()) - if isinstance(unpickle_object, torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device(): - unpickle_object = unpickle_object.cuda() + if ( + isinstance(unpickle_object, torch.Tensor) + and unpickle_object.device.index != get_accelerator().current_device() + ): + unpickle_object = unpickle_object.to(get_accelerator().current_device()) return unpickle_object @@ -474,9 +479,11 @@ def _p2p_comm( recv_prev_shape = None if tensor_send_next is not None: - send_next_shape = torch.tensor(tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64) + send_next_shape = torch.tensor( + tensor_send_next.size(), device=get_accelerator().current_device(), dtype=torch.int64 + ) if recv_prev: - recv_prev_shape = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64) + recv_prev_shape = torch.empty((3), device=get_accelerator().current_device(), dtype=torch.int64) ops = [] if send_next_shape is not None: @@ -501,7 +508,7 @@ def _p2p_comm( # send and recv data tensor_recv_prev = None if recv_prev: - tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_dtype) + tensor_recv_prev = torch.empty(recv_prev_shape, device=get_accelerator().current_device(), dtype=comm_dtype) ops = [] if tensor_send_next is not None: diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 5da98364d..eeab080af 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -2,7 +2,6 @@ from functools import partial from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import torch -import torch.cuda import torch.distributed from torch.nn import Module, ModuleList from torch.utils._pytree import tree_map @@ -18,7 +17,7 @@ from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_ from .base import PipelineSchedule -def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None: +def _wait_p2p(wait_handles) -> None: if wait_handles is not None: for req in wait_handles: req.wait() diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 224d63688..dcffa858c 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -2,7 +2,6 @@ from functools import partial from typing import Any, Callable, Dict, Iterable, List, Optional, Union import torch -import torch.cuda from torch.nn import Module from torch.utils._pytree import tree_map diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 043bf6aeb..1c2b44fc8 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -1,15 +1,28 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import numbers import warnings from abc import ABC, abstractmethod +import torch import torch.nn as nn +from torch.nn import init +from torch.nn.parameter import Parameter from colossalai.lazy import LazyInitContext from ._operation import hook_parameter_in_backward from .utils import SeqParallelUtils +SUPPORT_NPU = False +try: + import torch_npu + + SUPPORT_NPU = True +except Exception: + pass + + __all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"] try: @@ -21,7 +34,6 @@ except ImportError: try: from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm - from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm class FusedLayerNormWithHook(ApexFusedLayerNorm): def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True): @@ -32,7 +44,41 @@ try: output = hook_parameter_in_backward(output, self.weight, self.bias) return output - class FusedRMSNormWithHook(ApexFusedRMSNorm): +except ImportError: + warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel") + +FusedRMSNormWithHook = None +if SUPPORT_NPU: + + class NPUFusedRMSNormWithHook(nn.Module): + def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True): + super().__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = torch.Size(normalized_shape) + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = Parameter(torch.empty(*normalized_shape)) + else: + self.register_parameter("weight", None) + self.reset_parameters() + + def reset_parameters(self): + if self.elementwise_affine: + init.ones_(self.weight) + + def forward(self, input): + + output, _ = torch_npu.npu_rms_norm(input, self.weight, self.eps) + output = hook_parameter_in_backward(output, self.weight) + return output + + FusedRMSNormWithHook = NPUFusedRMSNormWithHook +else: + from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm + + class CUDAFusedRMSNormWithHook(ApexFusedRMSNorm): def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True): super().__init__(normalized_shape, eps, elementwise_affine) @@ -41,8 +87,7 @@ try: output = hook_parameter_in_backward(output, self.weight) return output -except ImportError: - warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel") + FusedRMSNormWithHook = CUDAFusedRMSNormWithHook FAST_LAYERNORM_SUPPORTED_SIZE = [ 1024, diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index a9be5c74d..be13200b5 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -9,7 +9,7 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig -from colossalai.shardformer.layer import AttnMaskType, ColoAttention +from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.layer._operation import ( all_to_all_comm, gather_sp_output, @@ -25,42 +25,7 @@ def get_flash_core_attention_forward(): def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_mask): query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] - if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: - attention_mask_type = AttnMaskType.CAUSAL - attn_bias = torch.zeros( - query_layer.shape[0], - 1, - query_layer.shape[2], - key_layer.shape[2], - dtype=query_layer.dtype, - device=query_layer.device, - ) - temp_mask = ( - torch.ones( - query_layer.shape[2], - key_layer.shape[2], - dtype=torch.bool, - device=query_layer.device, - ) - .tril(diagonal=0) - .expand(query_layer.shape[0], 1, -1, -1) - ) - attn_bias.masked_fill_(temp_mask.logical_not(), torch.finfo(query_layer.dtype).min) - else: - attention_mask_type = AttnMaskType.CUSTOM - if attention_mask is not None: - attn_bias = torch.zeros_like(attention_mask, dtype=query_layer.dtype) - attn_bias.masked_fill_(attention_mask, torch.finfo(query_layer.dtype).min) - dropout_p = self.attention_dropout.p if self.training else 0.0 - context_layer = ColoAttention.attention( - query_layer, - key_layer, - value_layer, - attention_mask=attn_bias, - attention_mask_type=attention_mask_type, - dropout_p=dropout_p, - scale=1.0 / self.norm_factor, - ) + context_layer = ColoAttention.attention(query_layer, key_layer, value_layer, **attention_mask) context_layer = context_layer.permute(2, 0, 1, 3) new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.reshape(*new_context_layer_shape) @@ -180,9 +145,20 @@ class ChatGLMPipelineForwards: ], dim=-1, ) - if full_attention_mask is None: - if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): - full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + if shard_config.enable_flash_attention: + mask_shape = (batch_size, 1, seq_length, seq_length) + full_attention_mask: dict = ColoAttention.prepare_attn_kwargs( + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, + ) + else: + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) # Support SP + PP sp_size = shard_config.sequence_parallel_size @@ -237,7 +213,7 @@ class ChatGLMPipelineForwards: layer_ret = torch.utils.checkpoint.checkpoint( layer, hidden_states, - attention_mask, + full_attention_mask, rotary_pos_emb, past_key_values[idx], use_cache, @@ -402,10 +378,19 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode, ], dim=-1, ) - - if full_attention_mask is None: - if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): - full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + if shard_config.enable_flash_attention: + mask_shape = (batch_size, 1, seq_length, seq_length) + full_attention_mask: dict = ColoAttention.prepare_attn_kwargs( + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, + ) + else: + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) # Rotary positional embeddings rotary_pos_emb = self.rotary_pos_emb(self.seq_length) @@ -652,3 +637,79 @@ def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, s return output, kv_cache return forward + + +def get_flash_attention_forward_for_chat_glm_model(): + from .chatglm2_6b.modeling_chatglm import ChatGLMModel + + def forward( + self: ChatGLMModel, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt( + batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype + ) + if attention_mask is not None: + attention_mask = torch.cat( + [attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1 + ) + + mask_shape = (batch_size, 1, seq_length, seq_length) + full_attention_mask: dict = ColoAttention.prepare_attn_kwargs( + mask_shape, + inputs_embeds.dtype, + inputs_embeds.device, + q_padding_mask=attention_mask, + is_causal=True, + ) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, + full_attention_mask, + rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + ) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + return forward diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index c003570a0..4ddcf8bfc 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -11,6 +11,7 @@ from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards from ..modeling.chatglm2 import ( get_chatglm_sequence_parallel_attention_forward, get_chatglm_sequence_parallel_forward_fn, + get_flash_attention_forward_for_chat_glm_model, get_flash_core_attention_forward, get_jit_fused_glm_block_forward, ) @@ -203,6 +204,13 @@ class ChatGLMPolicy(Policy): policy=policy, target_key="CoreAttention", ) + self.append_or_create_method_replacement( + description={ + "forward": get_flash_attention_forward_for_chat_glm_model(), + }, + policy=policy, + target_key="ChatGLMModel", + ) # use sequence parallel if self.shard_config.enable_sequence_parallelism: diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index fe5fb82ca..a033e917b 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -157,7 +157,7 @@ class GeminiDDP(ModelWrapper): self.enable_async_reduce = enable_async_reduce if enable_async_reduce: - self.async_reduce_stream = torch.cuda.Stream() + self.async_reduce_stream = get_accelerator().Stream() else: self.async_reduce_stream = None @@ -363,7 +363,7 @@ class GeminiDDP(ModelWrapper): master_weights: bool, enable_gradient_accumulation: bool, p: nn.Parameter, - async_reduce_stream: Optional[torch.cuda.Stream] = None, + async_reduce_stream=None, ): async_reduce_scatter = async_reduce_stream is not None setattr(p, "_gemini_reduced", True) @@ -402,9 +402,9 @@ class GeminiDDP(ModelWrapper): grad_chunk.add_tensor_to_chunk_slice(p, grad) if async_reduce_stream is not None: - async_reduce_stream.wait_stream(torch.cuda.current_stream()) + async_reduce_stream.wait_stream(get_accelerator().current_stream()) - with torch.cuda.stream(async_reduce_stream): + with get_accelerator().stream(async_reduce_stream): reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce_scatter) if reduced: grad_chunk.wait_async_reduce() diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index bf5faa0fe..786b30c24 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -62,7 +62,7 @@ class GeminiZeROHook(ColoParamOpHook): # # Other than that, self._gemini_manager.wait_chunks will have synced with default stream # by calling dist.Work.wait() and this line makes no diff. - self._gemini_manager.chunk_manager._prefetch_stream.wait_stream(torch.cuda.current_stream()) + self._gemini_manager.chunk_manager._prefetch_stream.wait_stream(get_accelerator().current_stream()) with get_accelerator().stream(self._gemini_manager.chunk_manager._prefetch_stream): for chunk in chunks_fetch_async: diff --git a/examples/language/bert/benchmark_utils.py b/examples/language/bert/benchmark_utils.py index 04d55cb2e..b70dc7496 100644 --- a/examples/language/bert/benchmark_utils.py +++ b/examples/language/bert/benchmark_utils.py @@ -9,6 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader from tqdm import tqdm +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.cluster import DistCoordinator @@ -59,7 +60,9 @@ def warm_up( for i, data in enumerate(dataloader): if i > num_runs: break - inputs, labels = data[0].cuda(), data[1].cuda() + inputs, labels = data[0].to(get_accelerator().get_current_device()), data[1].to( + get_accelerator().get_current_device() + ) outputs = model(inputs, labels=labels) loss = criterion(outputs) booster.backward(loss, optimizer) @@ -85,7 +88,7 @@ def benchmark( warm_up_steps: int = 3, ): results = {} - model_device = torch.cuda.current_device() + model_device = get_accelerator().get_current_device() # Warm up warm_up_fn( @@ -106,8 +109,8 @@ def benchmark( # Measure Allocated Memory and Throughput memory = {} throughput = {} - torch.cuda.reset_peak_memory_stats(device=model_device) - pre_mem = torch.cuda.memory_allocated(device=model_device) + get_accelerator().reset_peak_memory_stats(device=model_device) + pre_mem = get_accelerator().memory_allocated(device=model_device) start_time = time() @@ -116,7 +119,9 @@ def benchmark( dataloader, desc=f"Epoch [{epoch + 1}/{epoch_num}]", disable=not DistCoordinator().is_master() ) as pbar: for data in pbar: - inputs, labels = data[0].cuda(), data[1].cuda() + inputs, labels = data[0].to(get_accelerator().get_current_device()), data[1].to( + get_accelerator().get_current_device() + ) outputs = model(inputs, labels=labels) loss = criterion(outputs) booster.backward(loss, optimizer) @@ -128,8 +133,8 @@ def benchmark( all_sample = epoch_num * len(dataloader) - post_mem = torch.cuda.memory_allocated(device=model_device) - max_mem = torch.cuda.max_memory_allocated(device=model_device) + post_mem = get_accelerator().memory_allocated(device=model_device) + max_mem = get_accelerator().max_memory_allocated(device=model_device) memory[f"batch_size_{batch_size}"] = { "cuda_pre_training_bytes": format_num(pre_mem, bytes=True), diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index f048abdd2..96f1bece0 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -38,7 +38,7 @@ criterion = lambda x: x.loss def move_to_cuda(batch): - return {k: v.cuda() for k, v in batch.items()} + return {k: v.to(get_accelerator().get_current_device()) for k, v in batch.items()} @torch.no_grad() @@ -266,7 +266,8 @@ def main(): cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels) if model_name == "bert-base-uncased": - model = BertForSequenceClassification.from_pretrained(model_name, config=cfg).cuda() + model = BertForSequenceClassification.from_pretrained(model_name, config=cfg) + model = model.to(get_accelerator().get_current_device()) elif model_name == "albert-xxlarge-v2": model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg) else: diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 1e49f0aa8..2964f83f4 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -154,7 +154,7 @@ def main(): offload_param_frac=args.offload_param_frac, tp_size=args.tp, extra_dp_size=args.extra_dp, - enable_fused_normalization=torch.cuda.is_available(), + enable_fused_normalization=get_accelerator().is_available(), enable_flash_attention=args.xformers, max_prefetch=args.prefetch_num, enable_async_reduce=not args.disable_async_reduce, @@ -168,7 +168,7 @@ def main(): warmup_non_model_data_ratio=args.warmup_ratio, tp_size=args.tp, extra_dp_size=args.extra_dp, - enable_fused_normalization=torch.cuda.is_available(), + enable_fused_normalization=get_accelerator().is_available(), max_prefetch=args.prefetch_num, enable_async_reduce=not args.disable_async_reduce, enable_flash_attention=args.xformers, @@ -245,7 +245,7 @@ def main(): sp_size=args.sp, sequence_parallelism_mode=args.sp_mode, enable_sequence_parallelism=args.sp > 1, - enable_fused_normalization=torch.cuda.is_available(), + enable_fused_normalization=get_accelerator().is_available(), enable_flash_attention=args.xformers, microbatch_size=args.mbs, precision="bf16", @@ -264,7 +264,7 @@ def main(): num_model_chunks=args.n_chunks, zero_stage=args.zero, cpu_offload=True, - enable_fused_normalization=torch.cuda.is_available(), + enable_fused_normalization=get_accelerator().is_available(), enable_flash_attention=args.xformers, microbatch_size=args.mbs, initial_scale=2**8, @@ -287,8 +287,8 @@ def main(): config = MODEL_CONFIGS[args.config] else: config = AutoConfig.from_pretrained(args.config, trust_remote_code=True) + get_accelerator().manual_seed(42) - torch.cuda.manual_seed(42) dataset = RandomDataset( num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size ) @@ -311,7 +311,6 @@ def main(): config, trust_remote_code=True, **init_kwargs, - attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, ) if args.grad_checkpoint: @@ -321,9 +320,13 @@ def main(): model_numel = get_model_numel(model) coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") + if config.model_type == "chatglm": + num_layers = model.config.num_layers + else: + num_layers = model.config.num_hidden_layers performance_evaluator = PerformanceEvaluator( model_numel, - model.config.num_hidden_layers, + num_layers, model.config.hidden_size, model.config.vocab_size, args.grad_checkpoint, @@ -337,7 +340,7 @@ def main(): torch.set_default_dtype(torch.float) coordinator.print_on_master( - f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB" + f"Booster init max device memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB" ) coordinator.print_on_master( f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" @@ -389,7 +392,7 @@ def main(): performance_evaluator.on_step_end(**batch) prof.step() performance_evaluator.on_fit_end() - coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB") + coordinator.print_on_master(f"Max device memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB") if __name__ == "__main__": diff --git a/extensions/pybind/flash_attention/flash_attention_npu.py b/extensions/pybind/flash_attention/flash_attention_npu.py index 8a30972b6..d53a68676 100644 --- a/extensions/pybind/flash_attention/flash_attention_npu.py +++ b/extensions/pybind/flash_attention/flash_attention_npu.py @@ -1,3 +1,5 @@ +import math + from ...base_extension import _Extension @@ -47,6 +49,8 @@ class FlashAttentionNpuExtension(_Extension): q_indices: Optional[torch.Tensor] = None, kv_indices: Optional[torch.Tensor] = None, ): + if scale is None: + scale = 1.0 / math.sqrt(q.size(-1)) num_heads = q.size(1) return torch_npu.npu_fusion_attention( q,