[Device]Support npu (#6159)

* support npu

* support pretrain

support pretrain

fix

* support lora

fix

fix

* support chatglm

fix

fxi

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

fix

fix

* Update train.py

* Update train.py

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111 2024-12-17 15:42:39 +08:00 committed by GitHub
parent e994c64568
commit aaafb38851
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 295 additions and 152 deletions

View File

@ -100,7 +100,7 @@ LLaMA3_Conv = Conversation(
messages=[], messages=[],
offset=0, offset=0,
sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN, sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN,
seps=["<|begin_of_text|>", "<|end_of_text|>"], seps=["<|begin_of_text|>", "<|eot_id|>"],
) )
default_conversation = LLaMA3_Conv default_conversation = LLaMA3_Conv

View File

@ -88,7 +88,7 @@ def supervised_tokenize_sft(
assert ( assert (
tokenizer.bos_token == conversation_template.seps[0] and tokenizer.eos_token == conversation_template.seps[1] tokenizer.bos_token == conversation_template.seps[0] and tokenizer.eos_token == conversation_template.seps[1]
), "`bos_token` and `eos_token` should be the same with `conversation_template.seps`." ), f"`bos_token`{tokenizer.bos_token} and `eos_token`{tokenizer.eos_token} should be the same with `conversation_template.seps`{conversation_template.seps}."
if ignore_index is None: if ignore_index is None:
ignore_index = IGNORE_INDEX ignore_index = IGNORE_INDEX

View File

@ -43,6 +43,7 @@ def save_checkpoint(
step: int, step: int,
batch_size: int, batch_size: int,
coordinator: DistCoordinator, coordinator: DistCoordinator,
use_lora: bool = False,
) -> None: ) -> None:
""" """
Save model checkpoint, optimizer, LR scheduler and intermedidate running states. Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
@ -51,7 +52,10 @@ def save_checkpoint(
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}") save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True) os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True) if use_lora:
booster.save_lora_as_pretrained(model, os.path.join(save_dir, "modeling"))
else:
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True) booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))

View File

@ -21,6 +21,7 @@ from colossal_llama.utils.ckpt_io import load_checkpoint, save_checkpoint
from colossal_llama.utils.froze import freeze_non_embeds_parameters from colossal_llama.utils.froze import freeze_non_embeds_parameters
from colossal_llama.utils.neftune_patch import activate_neftune, deactivate_neftune from colossal_llama.utils.neftune_patch import activate_neftune, deactivate_neftune
from colossal_llama.utils.utils import all_reduce_mean, format_numel_str, get_model_numel from colossal_llama.utils.utils import all_reduce_mean, format_numel_str, get_model_numel
from peft import LoraConfig
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
@ -65,7 +66,7 @@ def train(args) -> None:
initial_scale=2**16, initial_scale=2**16,
max_norm=args.grad_clip, max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1), enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_fused_normalization=torch.cuda.is_available(), enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.use_flash_attn, enable_flash_attention=args.use_flash_attn,
) )
elif args.plugin == "gemini_auto": elif args.plugin == "gemini_auto":
@ -75,7 +76,7 @@ def train(args) -> None:
initial_scale=2**16, initial_scale=2**16,
max_norm=args.grad_clip, max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1), enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_fused_normalization=torch.cuda.is_available(), enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.use_flash_attn, enable_flash_attention=args.use_flash_attn,
) )
elif args.plugin == "zero2": elif args.plugin == "zero2":
@ -101,10 +102,9 @@ def train(args) -> None:
sequence_parallelism_mode=args.sp_mode, sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage, zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn, enable_flash_attention=args.use_flash_attn,
enable_fused_normalization=torch.cuda.is_available(), enable_fused_normalization=get_accelerator().is_available(),
enable_sequence_parallelism=args.enable_sequence_parallelism, enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False, cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip, max_norm=args.grad_clip,
precision=args.mixed_precision, precision=args.mixed_precision,
microbatch_size=args.microbatch_size, microbatch_size=args.microbatch_size,
@ -117,11 +117,17 @@ def train(args) -> None:
# ====================================================== # ======================================================
# Initialize Tokenizer, Dataset, Collator and Dataloader # Initialize Tokenizer, Dataset, Collator and Dataloader
# ====================================================== # ======================================================
tokenizer = AutoTokenizer.from_pretrained(args.pretrained) tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
if args.pad_token == "eos": if args.pad_token == "eos":
tokenizer.pad_token = tokenizer.eos_token try:
tokenizer.pad_token = tokenizer.eos_token
except AttributeError:
coordinator.print_on_master(f"pad_token can't be set")
elif args.pad_token == "unk": elif args.pad_token == "unk":
tokenizer.pad_token = tokenizer.unk_token try:
tokenizer.pad_token = tokenizer.unk_token
except AttributeError:
coordinator.print_on_master(f"pad_token can't be set")
tokenizer.add_bos_token = False tokenizer.add_bos_token = False
tokenizer.add_eos_token = False tokenizer.add_eos_token = False
@ -164,33 +170,31 @@ def train(args) -> None:
# ====================================================== # ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler # Initialize Model, Objective, Optimizer and LR Scheduler
# ====================================================== # ======================================================
# When training the ChatGLM model, LoRA and gradient checkpointing are incompatible.
init_ctx = ( init_ctx = (
LazyInitContext(default_device=get_current_device()) LazyInitContext(default_device=get_current_device())
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) and args.lora_rank == 0
else nullcontext() else nullcontext()
) )
with init_ctx: with init_ctx:
if args.use_flash_attn: model = AutoModelForCausalLM.from_pretrained(
model = AutoModelForCausalLM.from_pretrained( args.pretrained,
args.pretrained, torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
attn_implementation="flash_attention_2", trust_remote_code=True,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, )
trust_remote_code=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
# Freeze part of parameters. # Freeze part of parameters.
if args.freeze_non_embeds_params: if args.freeze_non_embeds_params:
freeze_non_embeds_parameters(model=model) freeze_non_embeds_parameters(model=model)
if args.lora_rank > 0:
lora_config = LoraConfig(task_type="CAUSAL_LM", r=args.lora_rank, lora_alpha=32, lora_dropout=0.1)
model = booster.enable_lora(model, lora_config=lora_config)
# this is essential, otherwise the grad checkpoint will not work. # this is essential, otherwise the grad checkpoint will not work.
model.train() model.train()
if args.use_grad_checkpoint: if args.use_grad_checkpoint:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
model_numel = get_model_numel(model) model_numel = get_model_numel(model)
@ -327,6 +331,7 @@ def train(args) -> None:
step=step + 1, step=step + 1,
batch_size=args.batch_size, batch_size=args.batch_size,
coordinator=coordinator, coordinator=coordinator,
use_lora=(args.lora_rank > 0),
) )
coordinator.print_on_master( coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}" f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
@ -371,44 +376,45 @@ def train(args) -> None:
total_loss.fill_(0.0) total_loss.fill_(0.0)
pbar.update() pbar.update()
# Save modeling. # Save modeling.
save_model_condition = ( save_model_condition = (
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0 args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
)
if not args.skip_save_each_epoch:
save_model_condition = save_model_condition or (step + 1) == len(dataloader)
if save_model_condition and not args.benchmark:
coordinator.print_on_master("\nStart saving model checkpoint with running states")
if args.use_neft:
coordinator.print_on_master("Deactivate NEFTune before saving model.")
deactivate_neftune(model, handle)
accelerator.empty_cache()
save_checkpoint(
save_dir=args.save_dir,
booster=booster,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
epoch=epoch,
step=step + 1,
batch_size=args.batch_size,
coordinator=coordinator,
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
) )
if args.use_neft: if not args.skip_save_each_epoch:
coordinator.print_on_master("Activate NEFTune.") save_model_condition = save_model_condition or (step + 1) == len(dataloader)
model, handle = activate_neftune(model)
# Delete cache. if save_model_condition and not args.benchmark:
# del batch, batch_labels, batch_output, loss coordinator.print_on_master("\nStart saving model checkpoint with running states")
accelerator.empty_cache()
if args.use_neft:
coordinator.print_on_master("Deactivate NEFTune before saving model.")
deactivate_neftune(model, handle)
accelerator.empty_cache()
save_checkpoint(
save_dir=args.save_dir,
booster=booster,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
epoch=epoch,
step=step + 1,
batch_size=args.batch_size,
coordinator=coordinator,
use_lora=(args.lora_rank > 0),
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
)
if args.use_neft:
coordinator.print_on_master("Activate NEFTune.")
model, handle = activate_neftune(model)
# Delete cache.
# del batch, batch_labels, batch_output, loss
accelerator.empty_cache()
# the continue epochs are not resumed, so we need to reset the sampler start index and start step # the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader.sampler.set_start_index(start_index=0) dataloader.sampler.set_start_index(start_index=0)
@ -522,6 +528,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin." "--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin."
) )
parser.add_argument("--lora_rank", type=int, default=0, help="lora rank when using lora to train.")
# Additional arguments for benchmark. # Additional arguments for benchmark.
parser.add_argument("--num_samples", type=int, default=500, help="Number of samples for benchmarking.") parser.add_argument("--num_samples", type=int, default=500, help="Number of samples for benchmarking.")

View File

@ -509,9 +509,9 @@ class LazyInitContext:
# factory_like functions (eg. torch.empty_like()) # factory_like functions (eg. torch.empty_like())
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
orig_t = args[0] orig_t = args[0]
return self.tensor_cls( device = kwargs.pop("device", orig_t.device)
orig_target, *orig_t.shape, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs dtype = kwargs.pop("dtype", orig_t.dtype)
) return self.tensor_cls(orig_target, *orig_t.shape, *args[1:], device=device, dtype=dtype, **kwargs)
return wrapper, target return wrapper, target

View File

@ -171,7 +171,7 @@ def _communicate(
for req in reqs: for req in reqs:
req.wait() req.wait()
# To protect against race condition when using batch_isend_irecv(). # To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize() get_accelerator().synchronize()
if recv_prev and recv_prev_split: if recv_prev and recv_prev_split:
if isinstance(tensor_recv_prev, torch.Tensor): if isinstance(tensor_recv_prev, torch.Tensor):

View File

@ -14,6 +14,8 @@ from torch.distributed import ProcessGroup
from torch.distributed import distributed_c10d as c10d from torch.distributed import distributed_c10d as c10d
from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils._pytree import tree_flatten, tree_unflatten
from colossalai.accelerator import get_accelerator
from .stage_manager import PipelineStageManager from .stage_manager import PipelineStageManager
@ -31,7 +33,7 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
buf = tensor.numpy().tobytes()[:tensor_size] buf = tensor.numpy().tobytes()[:tensor_size]
if b"cuda" in buf: if b"cuda" in buf:
buf_array = bytearray(buf) buf_array = bytearray(buf)
device_index = torch.cuda.current_device() device_index = get_accelerator().current_device()
# There might be more than one output tensors during forward # There might be more than one output tensors during forward
for cuda_str in re.finditer(b"cuda", buf_array): for cuda_str in re.finditer(b"cuda", buf_array):
pos = cuda_str.start() pos = cuda_str.start()
@ -86,7 +88,7 @@ def _broadcast_object_list(
else: else:
current_device = torch.device("cpu") current_device = torch.device("cpu")
if is_nccl_backend: if is_nccl_backend:
current_device = torch.device("cuda", torch.cuda.current_device()) current_device = torch.device("cuda", get_accelerator().current_device())
my_rank = dist.get_rank() my_rank = dist.get_rank()
# Serialize object_list elements to tensors on src rank. # Serialize object_list elements to tensors on src rank.
@ -139,14 +141,14 @@ def _broadcast_object_list(
# unconsistence in device # unconsistence in device
if ( if (
isinstance(unpickle_object, torch.Tensor) isinstance(unpickle_object, torch.Tensor)
and unpickle_object.device.index != torch.cuda.current_device() and unpickle_object.device.index != get_accelerator().current_device()
): ):
unpickle_object = unpickle_object.cuda() unpickle_object = unpickle_object.to(get_accelerator().current_device())
object_list[i] = unpickle_object object_list[i] = unpickle_object
def _check_for_nccl_backend(group): def _check_for_nccl_hccl_backend(group):
pg = group or c10d._get_default_group() pg = group or c10d._get_default_group()
# Gate PG wrapper check on Gloo availability. # Gate PG wrapper check on Gloo availability.
if c10d._GLOO_AVAILABLE: if c10d._GLOO_AVAILABLE:
@ -154,14 +156,14 @@ def _check_for_nccl_backend(group):
while isinstance(pg, c10d._ProcessGroupWrapper): while isinstance(pg, c10d._ProcessGroupWrapper):
pg = pg.wrapped_pg pg = pg.wrapped_pg
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL return (c10d.is_nccl_available() or torch.distributed.is_hccl_available()) and pg.name() == c10d.Backend.NCCL
def _check_device(group): def _check_device(group):
is_nccl_backend = _check_for_nccl_backend(group) is_nccl_backend = _check_for_nccl_hccl_backend(group)
current_device = torch.device("cpu") current_device = torch.device("cpu")
if is_nccl_backend: if is_nccl_backend:
current_device = torch.device("cuda", torch.cuda.current_device()) current_device = torch.device(get_accelerator().current_device())
return current_device, is_nccl_backend return current_device, is_nccl_backend
@ -348,8 +350,11 @@ def _send_recv_serialization_object(
unpickle_object = _cuda_safe_tensor_to_object(recv_object_tensor, recv_object_size_tensor.item()) unpickle_object = _cuda_safe_tensor_to_object(recv_object_tensor, recv_object_size_tensor.item())
if isinstance(unpickle_object, torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device(): if (
unpickle_object = unpickle_object.cuda() isinstance(unpickle_object, torch.Tensor)
and unpickle_object.device.index != get_accelerator().current_device()
):
unpickle_object = unpickle_object.to(get_accelerator().current_device())
return unpickle_object return unpickle_object
@ -474,9 +479,11 @@ def _p2p_comm(
recv_prev_shape = None recv_prev_shape = None
if tensor_send_next is not None: if tensor_send_next is not None:
send_next_shape = torch.tensor(tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64) send_next_shape = torch.tensor(
tensor_send_next.size(), device=get_accelerator().current_device(), dtype=torch.int64
)
if recv_prev: if recv_prev:
recv_prev_shape = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64) recv_prev_shape = torch.empty((3), device=get_accelerator().current_device(), dtype=torch.int64)
ops = [] ops = []
if send_next_shape is not None: if send_next_shape is not None:
@ -501,7 +508,7 @@ def _p2p_comm(
# send and recv data # send and recv data
tensor_recv_prev = None tensor_recv_prev = None
if recv_prev: if recv_prev:
tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_dtype) tensor_recv_prev = torch.empty(recv_prev_shape, device=get_accelerator().current_device(), dtype=comm_dtype)
ops = [] ops = []
if tensor_send_next is not None: if tensor_send_next is not None:

View File

@ -2,7 +2,6 @@ from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
import torch.cuda
import torch.distributed import torch.distributed
from torch.nn import Module, ModuleList from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
@ -18,7 +17,7 @@ from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_
from .base import PipelineSchedule from .base import PipelineSchedule
def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None: def _wait_p2p(wait_handles) -> None:
if wait_handles is not None: if wait_handles is not None:
for req in wait_handles: for req in wait_handles:
req.wait() req.wait()

View File

@ -2,7 +2,6 @@ from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import torch import torch
import torch.cuda
from torch.nn import Module from torch.nn import Module
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map

View File

@ -1,15 +1,28 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import numbers
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import init
from torch.nn.parameter import Parameter
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from ._operation import hook_parameter_in_backward from ._operation import hook_parameter_in_backward
from .utils import SeqParallelUtils from .utils import SeqParallelUtils
SUPPORT_NPU = False
try:
import torch_npu
SUPPORT_NPU = True
except Exception:
pass
__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"] __all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]
try: try:
@ -21,7 +34,6 @@ except ImportError:
try: try:
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
class FusedLayerNormWithHook(ApexFusedLayerNorm): class FusedLayerNormWithHook(ApexFusedLayerNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True): def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
@ -32,7 +44,41 @@ try:
output = hook_parameter_in_backward(output, self.weight, self.bias) output = hook_parameter_in_backward(output, self.weight, self.bias)
return output return output
class FusedRMSNormWithHook(ApexFusedRMSNorm): except ImportError:
warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel")
FusedRMSNormWithHook = None
if SUPPORT_NPU:
class NPUFusedRMSNormWithHook(nn.Module):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = Parameter(torch.empty(*normalized_shape))
else:
self.register_parameter("weight", None)
self.reset_parameters()
def reset_parameters(self):
if self.elementwise_affine:
init.ones_(self.weight)
def forward(self, input):
output, _ = torch_npu.npu_rms_norm(input, self.weight, self.eps)
output = hook_parameter_in_backward(output, self.weight)
return output
FusedRMSNormWithHook = NPUFusedRMSNormWithHook
else:
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
class CUDAFusedRMSNormWithHook(ApexFusedRMSNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True): def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine) super().__init__(normalized_shape, eps, elementwise_affine)
@ -41,8 +87,7 @@ try:
output = hook_parameter_in_backward(output, self.weight) output = hook_parameter_in_backward(output, self.weight)
return output return output
except ImportError: FusedRMSNormWithHook = CUDAFusedRMSNormWithHook
warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel")
FAST_LAYERNORM_SUPPORTED_SIZE = [ FAST_LAYERNORM_SUPPORTED_SIZE = [
1024, 1024,

View File

@ -9,7 +9,7 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig from colossalai.shardformer import ShardConfig
from colossalai.shardformer.layer import AttnMaskType, ColoAttention from colossalai.shardformer.layer import ColoAttention
from colossalai.shardformer.layer._operation import ( from colossalai.shardformer.layer._operation import (
all_to_all_comm, all_to_all_comm,
gather_sp_output, gather_sp_output,
@ -25,42 +25,7 @@ def get_flash_core_attention_forward():
def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_mask): def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_mask):
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: context_layer = ColoAttention.attention(query_layer, key_layer, value_layer, **attention_mask)
attention_mask_type = AttnMaskType.CAUSAL
attn_bias = torch.zeros(
query_layer.shape[0],
1,
query_layer.shape[2],
key_layer.shape[2],
dtype=query_layer.dtype,
device=query_layer.device,
)
temp_mask = (
torch.ones(
query_layer.shape[2],
key_layer.shape[2],
dtype=torch.bool,
device=query_layer.device,
)
.tril(diagonal=0)
.expand(query_layer.shape[0], 1, -1, -1)
)
attn_bias.masked_fill_(temp_mask.logical_not(), torch.finfo(query_layer.dtype).min)
else:
attention_mask_type = AttnMaskType.CUSTOM
if attention_mask is not None:
attn_bias = torch.zeros_like(attention_mask, dtype=query_layer.dtype)
attn_bias.masked_fill_(attention_mask, torch.finfo(query_layer.dtype).min)
dropout_p = self.attention_dropout.p if self.training else 0.0
context_layer = ColoAttention.attention(
query_layer,
key_layer,
value_layer,
attention_mask=attn_bias,
attention_mask_type=attention_mask_type,
dropout_p=dropout_p,
scale=1.0 / self.norm_factor,
)
context_layer = context_layer.permute(2, 0, 1, 3) context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(*new_context_layer_shape) context_layer = context_layer.reshape(*new_context_layer_shape)
@ -180,9 +145,20 @@ class ChatGLMPipelineForwards:
], ],
dim=-1, dim=-1,
) )
if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): if shard_config.enable_flash_attention:
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) mask_shape = (batch_size, 1, seq_length, seq_length)
full_attention_mask: dict = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
hidden_states.device,
q_padding_mask=attention_mask,
is_causal=True,
)
else:
if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
# Support SP + PP # Support SP + PP
sp_size = shard_config.sequence_parallel_size sp_size = shard_config.sequence_parallel_size
@ -237,7 +213,7 @@ class ChatGLMPipelineForwards:
layer_ret = torch.utils.checkpoint.checkpoint( layer_ret = torch.utils.checkpoint.checkpoint(
layer, layer,
hidden_states, hidden_states,
attention_mask, full_attention_mask,
rotary_pos_emb, rotary_pos_emb,
past_key_values[idx], past_key_values[idx],
use_cache, use_cache,
@ -402,10 +378,19 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
], ],
dim=-1, dim=-1,
) )
if shard_config.enable_flash_attention:
if full_attention_mask is None: mask_shape = (batch_size, 1, seq_length, seq_length)
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): full_attention_mask: dict = ColoAttention.prepare_attn_kwargs(
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) mask_shape,
hidden_states.dtype,
hidden_states.device,
q_padding_mask=attention_mask,
is_causal=True,
)
else:
if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
# Rotary positional embeddings # Rotary positional embeddings
rotary_pos_emb = self.rotary_pos_emb(self.seq_length) rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
@ -652,3 +637,79 @@ def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, s
return output, kv_cache return output, kv_cache
return forward return forward
def get_flash_attention_forward_for_chat_glm_model():
from .chatglm2_6b.modeling_chatglm import ChatGLMModel
def forward(
self: ChatGLMModel,
input_ids,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
full_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size, seq_length = input_ids.shape
if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)
if self.pre_seq_len is not None:
if past_key_values is None:
past_key_values = self.get_prompt(
batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype
)
if attention_mask is not None:
attention_mask = torch.cat(
[attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1
)
mask_shape = (batch_size, 1, seq_length, seq_length)
full_attention_mask: dict = ColoAttention.prepare_attn_kwargs(
mask_shape,
inputs_embeds.dtype,
inputs_embeds.device,
q_padding_mask=attention_mask,
is_causal=True,
)
# Rotary positional embeddings
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
if position_ids is not None:
rotary_pos_emb = rotary_pos_emb[position_ids]
else:
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
# Run encoder.
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
inputs_embeds,
full_attention_mask,
rotary_pos_emb=rotary_pos_emb,
kv_caches=past_key_values,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
)
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
return forward

View File

@ -11,6 +11,7 @@ from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards
from ..modeling.chatglm2 import ( from ..modeling.chatglm2 import (
get_chatglm_sequence_parallel_attention_forward, get_chatglm_sequence_parallel_attention_forward,
get_chatglm_sequence_parallel_forward_fn, get_chatglm_sequence_parallel_forward_fn,
get_flash_attention_forward_for_chat_glm_model,
get_flash_core_attention_forward, get_flash_core_attention_forward,
get_jit_fused_glm_block_forward, get_jit_fused_glm_block_forward,
) )
@ -203,6 +204,13 @@ class ChatGLMPolicy(Policy):
policy=policy, policy=policy,
target_key="CoreAttention", target_key="CoreAttention",
) )
self.append_or_create_method_replacement(
description={
"forward": get_flash_attention_forward_for_chat_glm_model(),
},
policy=policy,
target_key="ChatGLMModel",
)
# use sequence parallel # use sequence parallel
if self.shard_config.enable_sequence_parallelism: if self.shard_config.enable_sequence_parallelism:

View File

@ -157,7 +157,7 @@ class GeminiDDP(ModelWrapper):
self.enable_async_reduce = enable_async_reduce self.enable_async_reduce = enable_async_reduce
if enable_async_reduce: if enable_async_reduce:
self.async_reduce_stream = torch.cuda.Stream() self.async_reduce_stream = get_accelerator().Stream()
else: else:
self.async_reduce_stream = None self.async_reduce_stream = None
@ -363,7 +363,7 @@ class GeminiDDP(ModelWrapper):
master_weights: bool, master_weights: bool,
enable_gradient_accumulation: bool, enable_gradient_accumulation: bool,
p: nn.Parameter, p: nn.Parameter,
async_reduce_stream: Optional[torch.cuda.Stream] = None, async_reduce_stream=None,
): ):
async_reduce_scatter = async_reduce_stream is not None async_reduce_scatter = async_reduce_stream is not None
setattr(p, "_gemini_reduced", True) setattr(p, "_gemini_reduced", True)
@ -402,9 +402,9 @@ class GeminiDDP(ModelWrapper):
grad_chunk.add_tensor_to_chunk_slice(p, grad) grad_chunk.add_tensor_to_chunk_slice(p, grad)
if async_reduce_stream is not None: if async_reduce_stream is not None:
async_reduce_stream.wait_stream(torch.cuda.current_stream()) async_reduce_stream.wait_stream(get_accelerator().current_stream())
with torch.cuda.stream(async_reduce_stream): with get_accelerator().stream(async_reduce_stream):
reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce_scatter) reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce_scatter)
if reduced: if reduced:
grad_chunk.wait_async_reduce() grad_chunk.wait_async_reduce()

View File

@ -62,7 +62,7 @@ class GeminiZeROHook(ColoParamOpHook):
# #
# Other than that, self._gemini_manager.wait_chunks will have synced with default stream # Other than that, self._gemini_manager.wait_chunks will have synced with default stream
# by calling dist.Work.wait() and this line makes no diff. # by calling dist.Work.wait() and this line makes no diff.
self._gemini_manager.chunk_manager._prefetch_stream.wait_stream(torch.cuda.current_stream()) self._gemini_manager.chunk_manager._prefetch_stream.wait_stream(get_accelerator().current_stream())
with get_accelerator().stream(self._gemini_manager.chunk_manager._prefetch_stream): with get_accelerator().stream(self._gemini_manager.chunk_manager._prefetch_stream):
for chunk in chunks_fetch_async: for chunk in chunks_fetch_async:

View File

@ -9,6 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
@ -59,7 +60,9 @@ def warm_up(
for i, data in enumerate(dataloader): for i, data in enumerate(dataloader):
if i > num_runs: if i > num_runs:
break break
inputs, labels = data[0].cuda(), data[1].cuda() inputs, labels = data[0].to(get_accelerator().get_current_device()), data[1].to(
get_accelerator().get_current_device()
)
outputs = model(inputs, labels=labels) outputs = model(inputs, labels=labels)
loss = criterion(outputs) loss = criterion(outputs)
booster.backward(loss, optimizer) booster.backward(loss, optimizer)
@ -85,7 +88,7 @@ def benchmark(
warm_up_steps: int = 3, warm_up_steps: int = 3,
): ):
results = {} results = {}
model_device = torch.cuda.current_device() model_device = get_accelerator().get_current_device()
# Warm up # Warm up
warm_up_fn( warm_up_fn(
@ -106,8 +109,8 @@ def benchmark(
# Measure Allocated Memory and Throughput # Measure Allocated Memory and Throughput
memory = {} memory = {}
throughput = {} throughput = {}
torch.cuda.reset_peak_memory_stats(device=model_device) get_accelerator().reset_peak_memory_stats(device=model_device)
pre_mem = torch.cuda.memory_allocated(device=model_device) pre_mem = get_accelerator().memory_allocated(device=model_device)
start_time = time() start_time = time()
@ -116,7 +119,9 @@ def benchmark(
dataloader, desc=f"Epoch [{epoch + 1}/{epoch_num}]", disable=not DistCoordinator().is_master() dataloader, desc=f"Epoch [{epoch + 1}/{epoch_num}]", disable=not DistCoordinator().is_master()
) as pbar: ) as pbar:
for data in pbar: for data in pbar:
inputs, labels = data[0].cuda(), data[1].cuda() inputs, labels = data[0].to(get_accelerator().get_current_device()), data[1].to(
get_accelerator().get_current_device()
)
outputs = model(inputs, labels=labels) outputs = model(inputs, labels=labels)
loss = criterion(outputs) loss = criterion(outputs)
booster.backward(loss, optimizer) booster.backward(loss, optimizer)
@ -128,8 +133,8 @@ def benchmark(
all_sample = epoch_num * len(dataloader) all_sample = epoch_num * len(dataloader)
post_mem = torch.cuda.memory_allocated(device=model_device) post_mem = get_accelerator().memory_allocated(device=model_device)
max_mem = torch.cuda.max_memory_allocated(device=model_device) max_mem = get_accelerator().max_memory_allocated(device=model_device)
memory[f"batch_size_{batch_size}"] = { memory[f"batch_size_{batch_size}"] = {
"cuda_pre_training_bytes": format_num(pre_mem, bytes=True), "cuda_pre_training_bytes": format_num(pre_mem, bytes=True),

View File

@ -38,7 +38,7 @@ criterion = lambda x: x.loss
def move_to_cuda(batch): def move_to_cuda(batch):
return {k: v.cuda() for k, v in batch.items()} return {k: v.to(get_accelerator().get_current_device()) for k, v in batch.items()}
@torch.no_grad() @torch.no_grad()
@ -266,7 +266,8 @@ def main():
cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels) cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels)
if model_name == "bert-base-uncased": if model_name == "bert-base-uncased":
model = BertForSequenceClassification.from_pretrained(model_name, config=cfg).cuda() model = BertForSequenceClassification.from_pretrained(model_name, config=cfg)
model = model.to(get_accelerator().get_current_device())
elif model_name == "albert-xxlarge-v2": elif model_name == "albert-xxlarge-v2":
model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg) model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg)
else: else:

View File

@ -154,7 +154,7 @@ def main():
offload_param_frac=args.offload_param_frac, offload_param_frac=args.offload_param_frac,
tp_size=args.tp, tp_size=args.tp,
extra_dp_size=args.extra_dp, extra_dp_size=args.extra_dp,
enable_fused_normalization=torch.cuda.is_available(), enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.xformers, enable_flash_attention=args.xformers,
max_prefetch=args.prefetch_num, max_prefetch=args.prefetch_num,
enable_async_reduce=not args.disable_async_reduce, enable_async_reduce=not args.disable_async_reduce,
@ -168,7 +168,7 @@ def main():
warmup_non_model_data_ratio=args.warmup_ratio, warmup_non_model_data_ratio=args.warmup_ratio,
tp_size=args.tp, tp_size=args.tp,
extra_dp_size=args.extra_dp, extra_dp_size=args.extra_dp,
enable_fused_normalization=torch.cuda.is_available(), enable_fused_normalization=get_accelerator().is_available(),
max_prefetch=args.prefetch_num, max_prefetch=args.prefetch_num,
enable_async_reduce=not args.disable_async_reduce, enable_async_reduce=not args.disable_async_reduce,
enable_flash_attention=args.xformers, enable_flash_attention=args.xformers,
@ -245,7 +245,7 @@ def main():
sp_size=args.sp, sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode, sequence_parallelism_mode=args.sp_mode,
enable_sequence_parallelism=args.sp > 1, enable_sequence_parallelism=args.sp > 1,
enable_fused_normalization=torch.cuda.is_available(), enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.xformers, enable_flash_attention=args.xformers,
microbatch_size=args.mbs, microbatch_size=args.mbs,
precision="bf16", precision="bf16",
@ -264,7 +264,7 @@ def main():
num_model_chunks=args.n_chunks, num_model_chunks=args.n_chunks,
zero_stage=args.zero, zero_stage=args.zero,
cpu_offload=True, cpu_offload=True,
enable_fused_normalization=torch.cuda.is_available(), enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.xformers, enable_flash_attention=args.xformers,
microbatch_size=args.mbs, microbatch_size=args.mbs,
initial_scale=2**8, initial_scale=2**8,
@ -287,8 +287,8 @@ def main():
config = MODEL_CONFIGS[args.config] config = MODEL_CONFIGS[args.config]
else: else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True) config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
get_accelerator().manual_seed(42)
torch.cuda.manual_seed(42)
dataset = RandomDataset( dataset = RandomDataset(
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
) )
@ -311,7 +311,6 @@ def main():
config, config,
trust_remote_code=True, trust_remote_code=True,
**init_kwargs, **init_kwargs,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
) )
if args.grad_checkpoint: if args.grad_checkpoint:
@ -321,9 +320,13 @@ def main():
model_numel = get_model_numel(model) model_numel = get_model_numel(model)
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
if config.model_type == "chatglm":
num_layers = model.config.num_layers
else:
num_layers = model.config.num_hidden_layers
performance_evaluator = PerformanceEvaluator( performance_evaluator = PerformanceEvaluator(
model_numel, model_numel,
model.config.num_hidden_layers, num_layers,
model.config.hidden_size, model.config.hidden_size,
model.config.vocab_size, model.config.vocab_size,
args.grad_checkpoint, args.grad_checkpoint,
@ -337,7 +340,7 @@ def main():
torch.set_default_dtype(torch.float) torch.set_default_dtype(torch.float)
coordinator.print_on_master( coordinator.print_on_master(
f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB" f"Booster init max device memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
) )
coordinator.print_on_master( coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
@ -389,7 +392,7 @@ def main():
performance_evaluator.on_step_end(**batch) performance_evaluator.on_step_end(**batch)
prof.step() prof.step()
performance_evaluator.on_fit_end() performance_evaluator.on_fit_end()
coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB") coordinator.print_on_master(f"Max device memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,3 +1,5 @@
import math
from ...base_extension import _Extension from ...base_extension import _Extension
@ -47,6 +49,8 @@ class FlashAttentionNpuExtension(_Extension):
q_indices: Optional[torch.Tensor] = None, q_indices: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None, kv_indices: Optional[torch.Tensor] = None,
): ):
if scale is None:
scale = 1.0 / math.sqrt(q.size(-1))
num_heads = q.size(1) num_heads = q.size(1)
return torch_npu.npu_fusion_attention( return torch_npu.npu_fusion_attention(
q, q,