mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[npu] add npu support for gemini and zero (#5067)
* [npu] setup device utils (#5047) * [npu] add npu device support * [npu] support low level zero * [test] update npu zero plugin test * [hotfix] fix import * [test] recover tests * [npu] gemini support npu (#5052) * [npu] refactor device utils * [gemini] support npu * [example] llama2+gemini support npu * [kernel] add arm cpu adam kernel (#5065) * [kernel] add arm cpu adam * [optim] update adam optimizer * [kernel] arm cpu adam remove bf16 support
This commit is contained in:
@@ -10,14 +10,24 @@ import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
|
||||
from colossalai.checkpoint_io.utils import StateDictSharder
|
||||
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
|
||||
from colossalai.interface import ModelWrapper
|
||||
from colossalai.lazy import LazyTensor
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor.d_tensor import (
|
||||
distribute_tensor,
|
||||
distribute_tensor_with_customization,
|
||||
get_device_mesh,
|
||||
get_global_shape,
|
||||
get_sharding_spec,
|
||||
init_as_dtensor,
|
||||
init_tensor_as_customization_distributed,
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
)
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored
|
||||
from colossalai.checkpoint_io.utils import gather_distributed_param
|
||||
|
||||
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
|
||||
from .gemini_hook import GeminiZeROHook
|
||||
@@ -25,18 +35,6 @@ from .gemini_mgr import GeminiManager
|
||||
from .memory_tracer import MemStats, OrderedParamGenerator
|
||||
from .utils import get_temp_total_chunk_on_cuda
|
||||
|
||||
from colossalai.tensor.d_tensor import (
|
||||
distribute_tensor,
|
||||
distribute_tensor_with_customization,
|
||||
init_tensor_as_customization_distributed,
|
||||
get_device_mesh,
|
||||
get_sharding_spec,
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
get_global_shape,
|
||||
init_as_dtensor
|
||||
)
|
||||
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
||||
except ImportError:
|
||||
@@ -162,7 +160,7 @@ class GeminiDDP(ModelWrapper):
|
||||
self._init_chunks(
|
||||
param_order=param_order,
|
||||
strict_ddp_mode=strict_ddp_mode,
|
||||
cpu_offload=self.gemini_manager.policy_name != "cuda",
|
||||
cpu_offload=not (self.gemini_manager.policy_name == "static" and offload_param_frac == 0),
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
super().__init__(module)
|
||||
@@ -453,12 +451,13 @@ class GeminiDDP(ModelWrapper):
|
||||
global_shape = get_global_shape(tensor)
|
||||
device_mesh = get_device_mesh(tensor)
|
||||
shard_spec = get_sharding_spec(tensor)
|
||||
record_tensor = init_as_dtensor(record_tensor,
|
||||
device_mesh=device_mesh,
|
||||
sharding_spec=shard_spec,
|
||||
global_shape = global_shape)
|
||||
record_tensor = init_as_dtensor(
|
||||
record_tensor, device_mesh=device_mesh, sharding_spec=shard_spec, global_shape=global_shape
|
||||
)
|
||||
elif is_customized_distributed_tensor(tensor):
|
||||
init_tensor_as_customization_distributed(record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn)
|
||||
init_tensor_as_customization_distributed(
|
||||
record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn
|
||||
)
|
||||
record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu()
|
||||
|
||||
assert tensor not in chunk_to_save_data
|
||||
@@ -634,7 +633,15 @@ class GeminiDDP(ModelWrapper):
|
||||
local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
def load(param_name, dest_tensor, copy_func, source_device_mesh=None, source_sharding_spec=None, shard_fn=None, gather_fn=None):
|
||||
def load(
|
||||
param_name,
|
||||
dest_tensor,
|
||||
copy_func,
|
||||
source_device_mesh=None,
|
||||
source_sharding_spec=None,
|
||||
shard_fn=None,
|
||||
gather_fn=None,
|
||||
):
|
||||
state_key = prefix + param_name
|
||||
if state_key in state_dict:
|
||||
input_param = state_dict[state_key]
|
||||
@@ -642,7 +649,9 @@ class GeminiDDP(ModelWrapper):
|
||||
if source_device_mesh is not None and source_sharding_spec is not None:
|
||||
input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec)
|
||||
elif shard_fn is not None and gather_fn is not None:
|
||||
input_param = distribute_tensor_with_customization(input_param, shard_fn=shard_fn, gather_fn=gather_fn)
|
||||
input_param = distribute_tensor_with_customization(
|
||||
input_param, shard_fn=shard_fn, gather_fn=gather_fn
|
||||
)
|
||||
|
||||
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
||||
if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1:
|
||||
@@ -687,7 +696,6 @@ class GeminiDDP(ModelWrapper):
|
||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision)
|
||||
|
||||
for tensor, tensor_info in chunk.tensors_info.items():
|
||||
|
||||
source_device_mesh, source_sharding_spec, shard_fn, gather_fn = None, None, None, None
|
||||
if is_distributed_tensor(tensor):
|
||||
# shard the input param
|
||||
@@ -699,7 +707,15 @@ class GeminiDDP(ModelWrapper):
|
||||
|
||||
parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor]
|
||||
parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end]
|
||||
load(parameter_name, tensor, partial(load_parameter, parameter_slice), source_device_mesh, source_sharding_spec, shard_fn, gather_fn)
|
||||
load(
|
||||
parameter_name,
|
||||
tensor,
|
||||
partial(load_parameter, parameter_slice),
|
||||
source_device_mesh,
|
||||
source_sharding_spec,
|
||||
shard_fn,
|
||||
gather_fn,
|
||||
)
|
||||
|
||||
if chunk.is_gathered:
|
||||
chunk.cuda_global_chunk.copy_(temp_chunk)
|
||||
@@ -799,7 +815,7 @@ class GeminiDDP(ModelWrapper):
|
||||
for buffer in self.module.buffers():
|
||||
if isinstance(buffer, LazyTensor):
|
||||
buffer.materialize()
|
||||
buffer.data = buffer.cuda()
|
||||
buffer.data = buffer.to(get_current_device())
|
||||
if torch.is_floating_point(buffer):
|
||||
buffer.data = buffer.to(self.mixed_precision)
|
||||
|
||||
|
Reference in New Issue
Block a user