[npu] add npu support for gemini and zero (#5067)

* [npu] setup device utils (#5047)

* [npu] add npu device support

* [npu] support low level zero

* [test] update npu zero plugin test

* [hotfix] fix import

* [test] recover tests

* [npu] gemini support npu (#5052)

* [npu] refactor device utils

* [gemini] support npu

* [example] llama2+gemini support npu

* [kernel] add arm cpu adam kernel (#5065)

* [kernel] add arm cpu adam

* [optim] update adam optimizer

* [kernel] arm cpu adam remove bf16 support
This commit is contained in:
Hongxin Liu
2023-11-20 16:12:41 +08:00
committed by GitHub
parent 8d56c9c389
commit e5ce4c8ea6
46 changed files with 994 additions and 233 deletions

View File

@@ -7,31 +7,29 @@ from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
import torch
import torch.distributed as dist
from packaging.version import Version
from torch.distributed import ProcessGroup
from torch.nn import Parameter
from torch.optim import Optimizer
from torch.distributed import ProcessGroup
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
from colossalai.checkpoint_io.utils import StateDictSharder
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
from colossalai.tensor.d_tensor import (
distribute_tensor,
distribute_tensor_with_customization,
get_device_mesh,
get_sharding_spec,
init_as_dtensor,
init_tensor_as_customization_distributed,
is_customized_distributed_tensor,
is_distributed_tensor,
)
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager
from .gemini_ddp import GeminiDDP
from colossalai.checkpoint_io.utils import gather_distributed_param
from colossalai.tensor.d_tensor import (
distribute_tensor,
distribute_tensor_with_customization,
init_tensor_as_customization_distributed,
get_device_mesh,
get_sharding_spec,
is_customized_distributed_tensor,
is_distributed_tensor,
get_global_shape,
init_as_dtensor
)
__all__ = ["GeminiOptimizer", "GeminiAdamOptimizer"]
@@ -312,7 +310,7 @@ class GeminiOptimizer(OptimizerWrapper):
chunk16 = self.param_to_chunk16[fake_param]
chunk32 = chunk16.paired_chunk
if chunk32.device_type == "cuda":
if chunk32.device_type == "cuda" or chunk32.device_type == "npu":
continue
if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem:
@@ -326,7 +324,7 @@ class GeminiOptimizer(OptimizerWrapper):
for fake_param in group["params"]:
chunk16 = self.param_to_chunk16[fake_param]
chunk32 = chunk16.paired_chunk
if chunk32.device_type == "cuda":
if chunk32.device_type == "cuda" or chunk32.device_type == "npu":
state = self.optim.state[fake_param]
for k, v in state.items():
if isinstance(v, torch.Tensor):
@@ -479,15 +477,19 @@ class GeminiOptimizer(OptimizerWrapper):
state_tensor = states[state_name].detach().clone().to(torch.float32).cpu()
if is_dtensor:
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
state_tensor = init_as_dtensor(state_tensor,
device_mesh=device_mesh,
sharding_spec=shard_spec,
global_shape = global_shape)
state_tensor = init_as_dtensor(
state_tensor,
device_mesh=device_mesh,
sharding_spec=shard_spec,
global_shape=global_shape,
)
elif is_customized_distributed:
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn)
init_tensor_as_customization_distributed(
state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn
)
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
collected_states[state_name] = state_tensor.reshape(global_shape)
return collected_states
@@ -533,13 +535,14 @@ class GeminiOptimizer(OptimizerWrapper):
collected_states[state_name] = torch.reshape(state_tensor, param.shape)
if is_dtensor:
state_tensor = state_tensor.to(param.device)
state_tensor = init_as_dtensor(state_tensor,
sharding_spec=shard_spec,
device_mesh=device_mesh,
global_shape=global_shape)
state_tensor = init_as_dtensor(
state_tensor, sharding_spec=shard_spec, device_mesh=device_mesh, global_shape=global_shape
)
elif is_customized_distributed:
state_tensor = state_tensor.to(param.device)
init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn)
init_tensor_as_customization_distributed(
state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn
)
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
return collected_states
@@ -548,7 +551,7 @@ class GeminiOptimizer(OptimizerWrapper):
self,
param_id: int,
state_names: list,
device: torch.device = torch.device("cuda"),
device: torch.device = get_current_device(),
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
@@ -705,7 +708,7 @@ class GeminiOptimizer(OptimizerWrapper):
ret_val = torch.zeros(
state_end - state_start, dtype=torch.float32, device=param.device, requires_grad=False
)
if is_dtensor:
value = torch.reshape(value, global_shape)
value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh)