[npu] add npu support for gemini and zero (#5067)

* [npu] setup device utils (#5047)

* [npu] add npu device support

* [npu] support low level zero

* [test] update npu zero plugin test

* [hotfix] fix import

* [test] recover tests

* [npu] gemini support npu (#5052)

* [npu] refactor device utils

* [gemini] support npu

* [example] llama2+gemini support npu

* [kernel] add arm cpu adam kernel (#5065)

* [kernel] add arm cpu adam

* [optim] update adam optimizer

* [kernel] arm cpu adam remove bf16 support
This commit is contained in:
Hongxin Liu
2023-11-20 16:12:41 +08:00
committed by GitHub
parent 8d56c9c389
commit e5ce4c8ea6
46 changed files with 994 additions and 233 deletions

View File

@@ -10,14 +10,24 @@ import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group
from colossalai.checkpoint_io.utils import StateDictSharder
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyTensor
from colossalai.logging import get_dist_logger
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.d_tensor import (
distribute_tensor,
distribute_tensor_with_customization,
get_device_mesh,
get_global_shape,
get_sharding_spec,
init_as_dtensor,
init_tensor_as_customization_distributed,
is_customized_distributed_tensor,
is_distributed_tensor,
)
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored
from colossalai.checkpoint_io.utils import gather_distributed_param
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
from .gemini_hook import GeminiZeROHook
@@ -25,18 +35,6 @@ from .gemini_mgr import GeminiManager
from .memory_tracer import MemStats, OrderedParamGenerator
from .utils import get_temp_total_chunk_on_cuda
from colossalai.tensor.d_tensor import (
distribute_tensor,
distribute_tensor_with_customization,
init_tensor_as_customization_distributed,
get_device_mesh,
get_sharding_spec,
is_customized_distributed_tensor,
is_distributed_tensor,
get_global_shape,
init_as_dtensor
)
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
except ImportError:
@@ -162,7 +160,7 @@ class GeminiDDP(ModelWrapper):
self._init_chunks(
param_order=param_order,
strict_ddp_mode=strict_ddp_mode,
cpu_offload=self.gemini_manager.policy_name != "cuda",
cpu_offload=not (self.gemini_manager.policy_name == "static" and offload_param_frac == 0),
pin_memory=pin_memory,
)
super().__init__(module)
@@ -453,12 +451,13 @@ class GeminiDDP(ModelWrapper):
global_shape = get_global_shape(tensor)
device_mesh = get_device_mesh(tensor)
shard_spec = get_sharding_spec(tensor)
record_tensor = init_as_dtensor(record_tensor,
device_mesh=device_mesh,
sharding_spec=shard_spec,
global_shape = global_shape)
record_tensor = init_as_dtensor(
record_tensor, device_mesh=device_mesh, sharding_spec=shard_spec, global_shape=global_shape
)
elif is_customized_distributed_tensor(tensor):
init_tensor_as_customization_distributed(record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn)
init_tensor_as_customization_distributed(
record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn
)
record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu()
assert tensor not in chunk_to_save_data
@@ -634,7 +633,15 @@ class GeminiDDP(ModelWrapper):
local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
def load(param_name, dest_tensor, copy_func, source_device_mesh=None, source_sharding_spec=None, shard_fn=None, gather_fn=None):
def load(
param_name,
dest_tensor,
copy_func,
source_device_mesh=None,
source_sharding_spec=None,
shard_fn=None,
gather_fn=None,
):
state_key = prefix + param_name
if state_key in state_dict:
input_param = state_dict[state_key]
@@ -642,7 +649,9 @@ class GeminiDDP(ModelWrapper):
if source_device_mesh is not None and source_sharding_spec is not None:
input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec)
elif shard_fn is not None and gather_fn is not None:
input_param = distribute_tensor_with_customization(input_param, shard_fn=shard_fn, gather_fn=gather_fn)
input_param = distribute_tensor_with_customization(
input_param, shard_fn=shard_fn, gather_fn=gather_fn
)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1:
@@ -687,7 +696,6 @@ class GeminiDDP(ModelWrapper):
temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision)
for tensor, tensor_info in chunk.tensors_info.items():
source_device_mesh, source_sharding_spec, shard_fn, gather_fn = None, None, None, None
if is_distributed_tensor(tensor):
# shard the input param
@@ -699,7 +707,15 @@ class GeminiDDP(ModelWrapper):
parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor]
parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end]
load(parameter_name, tensor, partial(load_parameter, parameter_slice), source_device_mesh, source_sharding_spec, shard_fn, gather_fn)
load(
parameter_name,
tensor,
partial(load_parameter, parameter_slice),
source_device_mesh,
source_sharding_spec,
shard_fn,
gather_fn,
)
if chunk.is_gathered:
chunk.cuda_global_chunk.copy_(temp_chunk)
@@ -799,7 +815,7 @@ class GeminiDDP(ModelWrapper):
for buffer in self.module.buffers():
if isinstance(buffer, LazyTensor):
buffer.materialize()
buffer.data = buffer.cuda()
buffer.data = buffer.to(get_current_device())
if torch.is_floating_point(buffer):
buffer.data = buffer.to(self.mixed_precision)