mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[npu] add npu support for gemini and zero (#5067)
* [npu] setup device utils (#5047) * [npu] add npu device support * [npu] support low level zero * [test] update npu zero plugin test * [hotfix] fix import * [test] recover tests * [npu] gemini support npu (#5052) * [npu] refactor device utils * [gemini] support npu * [example] llama2+gemini support npu * [kernel] add arm cpu adam kernel (#5065) * [kernel] add arm cpu adam * [optim] update adam optimizer * [kernel] arm cpu adam remove bf16 support
This commit is contained in:
@@ -25,6 +25,7 @@ from colossalai.cluster import DistCoordinator, ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.device import IS_NPU_AVAILABLE
|
||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||
from colossalai.zero.gemini.memory_tracer import MemStats
|
||||
|
||||
@@ -37,6 +38,7 @@ PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16}
|
||||
|
||||
ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
|
||||
|
||||
|
||||
def get_param_info(optim: Optimizer):
|
||||
# Get a backup of necessary information of parameters for future use, which includes:
|
||||
# 1. A mapping from integer param_id to param32 shape.
|
||||
@@ -53,6 +55,8 @@ def get_param_info(optim: Optimizer):
|
||||
start_index += len(group["params"])
|
||||
|
||||
return param_info
|
||||
|
||||
|
||||
class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@@ -359,6 +363,8 @@ class GeminiPlugin(DPPluginBase):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
|
||||
if IS_NPU_AVAILABLE:
|
||||
assert placement_policy == "static", "NPU only supports static placement policy"
|
||||
self.gemini_config = dict(
|
||||
chunk_config_dict=chunk_config_dict,
|
||||
chunk_init_device=(chunk_init_device or get_current_device()),
|
||||
@@ -437,7 +443,7 @@ class GeminiPlugin(DPPluginBase):
|
||||
return True
|
||||
|
||||
def supported_devices(self) -> List[str]:
|
||||
return ["cuda"]
|
||||
return ["cuda", "npu"]
|
||||
|
||||
def configure(
|
||||
self,
|
||||
@@ -485,4 +491,4 @@ class GeminiPlugin(DPPluginBase):
|
||||
return GeminiCheckpointIO()
|
||||
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError
|
||||
|
Reference in New Issue
Block a user