[npu] change device to accelerator api (#5239)

* update accelerator

* fix timer

* fix amp

* update

* fix

* update bug

* add error raise

* fix autocast

* fix set device

* remove doc accelerator

* update doc

* update doc

* update doc

* use nullcontext

* update cpu

* update null context

* change time limit for example

* udpate

* update

* update

* update

* [npu] polish accelerator code

---------

Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com>
Co-authored-by: zxl <43881818+oahzxl@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2024-01-09 10:20:05 +08:00
committed by GitHub
parent dd2c28a323
commit d202cc28c0
128 changed files with 1773 additions and 868 deletions

View File

@@ -1,44 +1,37 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
"""
import json
import argparse
import json
import os
import resource
from contextlib import nullcontext
from tqdm import tqdm
import torch
import torch.distributed as dist
from colossal_llama2.dataset.loader import (
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
load_tokenized_dataset,
setup_distributed_dataloader,
)
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
from torch.utils.tensorboard import SummaryWriter
from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig
from tqdm import tqdm
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import (
GeminiPlugin,
LowLevelZeroPlugin,
HybridParallelPlugin,
)
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from colossal_llama2.dataset.loader import (
load_tokenized_dataset,
setup_distributed_dataloader,
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
)
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
def get_model_numel(model: torch.nn.Module) -> int:
@@ -215,9 +208,18 @@ def main() -> None:
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
init_ctx = (
LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
)
# colossalai has changed api for get_current_device in 0.3.4 version or newer
try:
from colossalai.accelerator import get_accelerator
current_device = get_accelerator().get_current_device()
except:
from colossalai.utils import get_current_device
current_device = get_current_device()
init_ctx = LazyInitContext(default_device=current_device) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
with init_ctx:
model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
# Freeze part of parameters.
@@ -320,7 +322,7 @@ def main() -> None:
initial=start_step,
) as pbar:
for step, batch in pbar:
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
batch = {k: v.to(current_device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
batch_output = model(**batch)
@@ -372,9 +374,7 @@ def main() -> None:
# Final save.
coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
coordinator.print_on_master(
f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}"
)
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")