mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[npu] change device to accelerator api (#5239)
* update accelerator * fix timer * fix amp * update * fix * update bug * add error raise * fix autocast * fix set device * remove doc accelerator * update doc * update doc * update doc * use nullcontext * update cpu * update null context * change time limit for example * udpate * update * update * update * [npu] polish accelerator code --------- Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com> Co-authored-by: zxl <43881818+oahzxl@users.noreply.github.com>
This commit is contained in:
@@ -10,7 +10,7 @@ from torch.utils.data import DataLoader, DistributedSampler
|
||||
from tqdm import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
from .base import OnPolicyTrainer
|
||||
from .callbacks import Callback
|
||||
@@ -105,7 +105,7 @@ class PPOTrainer(OnPolicyTrainer):
|
||||
self.critic_optim = critic_optim
|
||||
|
||||
self.offload_inference_models = offload_inference_models
|
||||
self.device = get_current_device()
|
||||
self.device = get_accelerator().get_current_device()
|
||||
|
||||
def _before_fit(
|
||||
self,
|
||||
|
@@ -6,7 +6,6 @@ import torch.nn as nn
|
||||
import colossalai
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
|
||||
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
|
||||
|
||||
from .ddp import DDPStrategy
|
||||
@@ -158,9 +157,19 @@ class GeminiStrategy(DDPStrategy):
|
||||
|
||||
warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.")
|
||||
|
||||
# colossalai has changed api for get_current_device in 0.3.4 version or newer
|
||||
try:
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
chunk_init_device = get_accelerator().get_current_device()
|
||||
except:
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
chunk_init_device = get_current_device()
|
||||
|
||||
# NOTE: dist should be initialized before calling get_current_device()
|
||||
plugin_initializer = lambda: GeminiPlugin(
|
||||
chunk_init_device=get_current_device(),
|
||||
chunk_init_device=chunk_init_device,
|
||||
placement_policy=placement_policy,
|
||||
shard_param_frac=shard_param_frac,
|
||||
offload_optim_frac=offload_optim_frac,
|
||||
|
@@ -1,44 +1,37 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
|
||||
Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
|
||||
"""
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_llama2.dataset.loader import (
|
||||
DataCollatorForSupervisedDataset,
|
||||
StatefulDistributedSampler,
|
||||
load_tokenized_dataset,
|
||||
setup_distributed_dataloader,
|
||||
)
|
||||
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
|
||||
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
|
||||
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import (
|
||||
GeminiPlugin,
|
||||
LowLevelZeroPlugin,
|
||||
HybridParallelPlugin,
|
||||
)
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from colossal_llama2.dataset.loader import (
|
||||
load_tokenized_dataset,
|
||||
setup_distributed_dataloader,
|
||||
DataCollatorForSupervisedDataset,
|
||||
StatefulDistributedSampler,
|
||||
)
|
||||
|
||||
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
|
||||
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
|
||||
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
|
||||
|
||||
|
||||
def get_model_numel(model: torch.nn.Module) -> int:
|
||||
@@ -215,9 +208,18 @@ def main() -> None:
|
||||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
init_ctx = (
|
||||
LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
)
|
||||
|
||||
# colossalai has changed api for get_current_device in 0.3.4 version or newer
|
||||
try:
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
current_device = get_accelerator().get_current_device()
|
||||
except:
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
current_device = get_current_device()
|
||||
|
||||
init_ctx = LazyInitContext(default_device=current_device) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
with init_ctx:
|
||||
model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
|
||||
# Freeze part of parameters.
|
||||
@@ -320,7 +322,7 @@ def main() -> None:
|
||||
initial=start_step,
|
||||
) as pbar:
|
||||
for step, batch in pbar:
|
||||
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
||||
batch = {k: v.to(current_device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
||||
|
||||
batch_output = model(**batch)
|
||||
|
||||
@@ -372,9 +374,7 @@ def main() -> None:
|
||||
# Final save.
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(
|
||||
f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}"
|
||||
)
|
||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
Reference in New Issue
Block a user