mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[npu] add npu support for gemini and zero (#5067)
* [npu] setup device utils (#5047) * [npu] add npu device support * [npu] support low level zero * [test] update npu zero plugin test * [hotfix] fix import * [test] recover tests * [npu] gemini support npu (#5052) * [npu] refactor device utils * [gemini] support npu * [example] llama2+gemini support npu * [kernel] add arm cpu adam kernel (#5065) * [kernel] add arm cpu adam * [optim] update adam optimizer * [kernel] arm cpu adam remove bf16 support
This commit is contained in:
@@ -7,31 +7,29 @@ from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from packaging.version import Version
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn import Parameter
|
||||
from torch.optim import Optimizer
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
|
||||
from colossalai.checkpoint_io.utils import StateDictSharder
|
||||
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
|
||||
from colossalai.tensor.d_tensor import (
|
||||
distribute_tensor,
|
||||
distribute_tensor_with_customization,
|
||||
get_device_mesh,
|
||||
get_sharding_spec,
|
||||
init_as_dtensor,
|
||||
init_tensor_as_customization_distributed,
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
)
|
||||
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
|
||||
|
||||
from .chunk import Chunk, ChunkManager
|
||||
from .gemini_ddp import GeminiDDP
|
||||
from colossalai.checkpoint_io.utils import gather_distributed_param
|
||||
from colossalai.tensor.d_tensor import (
|
||||
distribute_tensor,
|
||||
distribute_tensor_with_customization,
|
||||
init_tensor_as_customization_distributed,
|
||||
get_device_mesh,
|
||||
get_sharding_spec,
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
get_global_shape,
|
||||
init_as_dtensor
|
||||
)
|
||||
|
||||
__all__ = ["GeminiOptimizer", "GeminiAdamOptimizer"]
|
||||
|
||||
@@ -312,7 +310,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
chunk16 = self.param_to_chunk16[fake_param]
|
||||
chunk32 = chunk16.paired_chunk
|
||||
|
||||
if chunk32.device_type == "cuda":
|
||||
if chunk32.device_type == "cuda" or chunk32.device_type == "npu":
|
||||
continue
|
||||
|
||||
if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem:
|
||||
@@ -326,7 +324,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
for fake_param in group["params"]:
|
||||
chunk16 = self.param_to_chunk16[fake_param]
|
||||
chunk32 = chunk16.paired_chunk
|
||||
if chunk32.device_type == "cuda":
|
||||
if chunk32.device_type == "cuda" or chunk32.device_type == "npu":
|
||||
state = self.optim.state[fake_param]
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
@@ -479,15 +477,19 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
state_tensor = states[state_name].detach().clone().to(torch.float32).cpu()
|
||||
if is_dtensor:
|
||||
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
|
||||
state_tensor = init_as_dtensor(state_tensor,
|
||||
device_mesh=device_mesh,
|
||||
sharding_spec=shard_spec,
|
||||
global_shape = global_shape)
|
||||
state_tensor = init_as_dtensor(
|
||||
state_tensor,
|
||||
device_mesh=device_mesh,
|
||||
sharding_spec=shard_spec,
|
||||
global_shape=global_shape,
|
||||
)
|
||||
elif is_customized_distributed:
|
||||
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
|
||||
init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn)
|
||||
init_tensor_as_customization_distributed(
|
||||
state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn
|
||||
)
|
||||
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
|
||||
|
||||
|
||||
collected_states[state_name] = state_tensor.reshape(global_shape)
|
||||
return collected_states
|
||||
|
||||
@@ -533,13 +535,14 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
collected_states[state_name] = torch.reshape(state_tensor, param.shape)
|
||||
if is_dtensor:
|
||||
state_tensor = state_tensor.to(param.device)
|
||||
state_tensor = init_as_dtensor(state_tensor,
|
||||
sharding_spec=shard_spec,
|
||||
device_mesh=device_mesh,
|
||||
global_shape=global_shape)
|
||||
state_tensor = init_as_dtensor(
|
||||
state_tensor, sharding_spec=shard_spec, device_mesh=device_mesh, global_shape=global_shape
|
||||
)
|
||||
elif is_customized_distributed:
|
||||
state_tensor = state_tensor.to(param.device)
|
||||
init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn)
|
||||
init_tensor_as_customization_distributed(
|
||||
state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn
|
||||
)
|
||||
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
|
||||
|
||||
return collected_states
|
||||
@@ -548,7 +551,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
self,
|
||||
param_id: int,
|
||||
state_names: list,
|
||||
device: torch.device = torch.device("cuda"),
|
||||
device: torch.device = get_current_device(),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -705,7 +708,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
ret_val = torch.zeros(
|
||||
state_end - state_start, dtype=torch.float32, device=param.device, requires_grad=False
|
||||
)
|
||||
|
||||
|
||||
if is_dtensor:
|
||||
value = torch.reshape(value, global_shape)
|
||||
value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh)
|
||||
|
Reference in New Issue
Block a user