[Device]Support npu (#6159)

* support npu

* support pretrain

support pretrain

fix

* support lora

fix

fix

* support chatglm

fix

fxi

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

fix

fix

* Update train.py

* Update train.py

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111
2024-12-17 15:42:39 +08:00
committed by GitHub
parent e994c64568
commit aaafb38851
18 changed files with 295 additions and 152 deletions

View File

@@ -100,7 +100,7 @@ LLaMA3_Conv = Conversation(
messages=[],
offset=0,
sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN,
seps=["<|begin_of_text|>", "<|end_of_text|>"],
seps=["<|begin_of_text|>", "<|eot_id|>"],
)
default_conversation = LLaMA3_Conv

View File

@@ -88,7 +88,7 @@ def supervised_tokenize_sft(
assert (
tokenizer.bos_token == conversation_template.seps[0] and tokenizer.eos_token == conversation_template.seps[1]
), "`bos_token` and `eos_token` should be the same with `conversation_template.seps`."
), f"`bos_token`{tokenizer.bos_token} and `eos_token`{tokenizer.eos_token} should be the same with `conversation_template.seps`{conversation_template.seps}."
if ignore_index is None:
ignore_index = IGNORE_INDEX

View File

@@ -43,6 +43,7 @@ def save_checkpoint(
step: int,
batch_size: int,
coordinator: DistCoordinator,
use_lora: bool = False,
) -> None:
"""
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
@@ -51,7 +52,10 @@ def save_checkpoint(
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
if use_lora:
booster.save_lora_as_pretrained(model, os.path.join(save_dir, "modeling"))
else:
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))

View File

@@ -21,6 +21,7 @@ from colossal_llama.utils.ckpt_io import load_checkpoint, save_checkpoint
from colossal_llama.utils.froze import freeze_non_embeds_parameters
from colossal_llama.utils.neftune_patch import activate_neftune, deactivate_neftune
from colossal_llama.utils.utils import all_reduce_mean, format_numel_str, get_model_numel
from peft import LoraConfig
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -65,7 +66,7 @@ def train(args) -> None:
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_fused_normalization=torch.cuda.is_available(),
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "gemini_auto":
@@ -75,7 +76,7 @@ def train(args) -> None:
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_fused_normalization=torch.cuda.is_available(),
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2":
@@ -101,10 +102,9 @@ def train(args) -> None:
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_fused_normalization=torch.cuda.is_available(),
enable_fused_normalization=get_accelerator().is_available(),
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
microbatch_size=args.microbatch_size,
@@ -117,11 +117,17 @@ def train(args) -> None:
# ======================================================
# Initialize Tokenizer, Dataset, Collator and Dataloader
# ======================================================
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
if args.pad_token == "eos":
tokenizer.pad_token = tokenizer.eos_token
try:
tokenizer.pad_token = tokenizer.eos_token
except AttributeError:
coordinator.print_on_master(f"pad_token can't be set")
elif args.pad_token == "unk":
tokenizer.pad_token = tokenizer.unk_token
try:
tokenizer.pad_token = tokenizer.unk_token
except AttributeError:
coordinator.print_on_master(f"pad_token can't be set")
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
@@ -164,33 +170,31 @@ def train(args) -> None:
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
# When training the ChatGLM model, LoRA and gradient checkpointing are incompatible.
init_ctx = (
LazyInitContext(default_device=get_current_device())
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) and args.lora_rank == 0
else nullcontext()
)
with init_ctx:
if args.use_flash_attn:
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
# Freeze part of parameters.
if args.freeze_non_embeds_params:
freeze_non_embeds_parameters(model=model)
if args.lora_rank > 0:
lora_config = LoraConfig(task_type="CAUSAL_LM", r=args.lora_rank, lora_alpha=32, lora_dropout=0.1)
model = booster.enable_lora(model, lora_config=lora_config)
# this is essential, otherwise the grad checkpoint will not work.
model.train()
if args.use_grad_checkpoint:
model.gradient_checkpointing_enable()
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
model_numel = get_model_numel(model)
@@ -327,6 +331,7 @@ def train(args) -> None:
step=step + 1,
batch_size=args.batch_size,
coordinator=coordinator,
use_lora=(args.lora_rank > 0),
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
@@ -371,44 +376,45 @@ def train(args) -> None:
total_loss.fill_(0.0)
pbar.update()
# Save modeling.
save_model_condition = (
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
)
if not args.skip_save_each_epoch:
save_model_condition = save_model_condition or (step + 1) == len(dataloader)
if save_model_condition and not args.benchmark:
coordinator.print_on_master("\nStart saving model checkpoint with running states")
if args.use_neft:
coordinator.print_on_master("Deactivate NEFTune before saving model.")
deactivate_neftune(model, handle)
accelerator.empty_cache()
save_checkpoint(
save_dir=args.save_dir,
booster=booster,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
epoch=epoch,
step=step + 1,
batch_size=args.batch_size,
coordinator=coordinator,
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
# Save modeling.
save_model_condition = (
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
)
if args.use_neft:
coordinator.print_on_master("Activate NEFTune.")
model, handle = activate_neftune(model)
if not args.skip_save_each_epoch:
save_model_condition = save_model_condition or (step + 1) == len(dataloader)
# Delete cache.
# del batch, batch_labels, batch_output, loss
accelerator.empty_cache()
if save_model_condition and not args.benchmark:
coordinator.print_on_master("\nStart saving model checkpoint with running states")
if args.use_neft:
coordinator.print_on_master("Deactivate NEFTune before saving model.")
deactivate_neftune(model, handle)
accelerator.empty_cache()
save_checkpoint(
save_dir=args.save_dir,
booster=booster,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
epoch=epoch,
step=step + 1,
batch_size=args.batch_size,
coordinator=coordinator,
use_lora=(args.lora_rank > 0),
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
)
if args.use_neft:
coordinator.print_on_master("Activate NEFTune.")
model, handle = activate_neftune(model)
# Delete cache.
# del batch, batch_labels, batch_output, loss
accelerator.empty_cache()
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader.sampler.set_start_index(start_index=0)
@@ -522,6 +528,7 @@ if __name__ == "__main__":
parser.add_argument(
"--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin."
)
parser.add_argument("--lora_rank", type=int, default=0, help="lora rank when using lora to train.")
# Additional arguments for benchmark.
parser.add_argument("--num_samples", type=int, default=500, help="Number of samples for benchmarking.")