mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[npu] add npu support for gemini and zero (#5067)
* [npu] setup device utils (#5047) * [npu] add npu device support * [npu] support low level zero * [test] update npu zero plugin test * [hotfix] fix import * [test] recover tests * [npu] gemini support npu (#5052) * [npu] refactor device utils * [gemini] support npu * [example] llama2+gemini support npu * [kernel] add arm cpu adam kernel (#5065) * [kernel] add arm cpu adam * [optim] update adam optimizer * [kernel] arm cpu adam remove bf16 support
This commit is contained in:
@@ -13,6 +13,7 @@ from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
|
||||
import colossalai
|
||||
import colossalai.utils.device as device_utils
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
@@ -194,7 +195,7 @@ def main():
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
|
||||
torch.set_default_dtype(torch.float)
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {device_utils.max_memory_allocated()/1024**2:.2f} MB")
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
|
||||
)
|
||||
@@ -220,7 +221,7 @@ def main():
|
||||
performance_evaluator.on_step_end(**batch)
|
||||
|
||||
performance_evaluator.on_fit_end()
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {device_utils.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -5,7 +5,9 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
import colossalai.utils.device as device_utils
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
|
||||
def divide(x: float, y: float) -> float:
|
||||
@@ -20,7 +22,7 @@ def divide(x: float, y: float) -> float:
|
||||
def all_reduce_mean(x: float, world_size: int) -> float:
|
||||
if world_size == 1:
|
||||
return x
|
||||
tensor = torch.tensor([x], device=torch.cuda.current_device())
|
||||
tensor = torch.tensor([x], device=get_current_device())
|
||||
dist.all_reduce(tensor)
|
||||
tensor = tensor / world_size
|
||||
return tensor.item()
|
||||
@@ -84,13 +86,13 @@ class PerformanceEvaluator:
|
||||
self.disable = self.ignore_steps > 0 and step < self.ignore_steps
|
||||
if self.disable:
|
||||
return
|
||||
torch.cuda.synchronize()
|
||||
device_utils.synchronize()
|
||||
self.timer.start()
|
||||
|
||||
def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
|
||||
if self.disable:
|
||||
return
|
||||
torch.cuda.synchronize()
|
||||
device_utils.synchronize()
|
||||
self.timer.end()
|
||||
|
||||
batch_size, seq_len = input_ids.shape
|
||||
|
Reference in New Issue
Block a user