mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[Device]Support npu (#6159)
* support npu * support pretrain support pretrain fix * support lora fix fix * support chatglm fix fxi fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix fix fix * Update train.py * Update train.py * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -509,9 +509,9 @@ class LazyInitContext:
|
||||
# factory_like functions (eg. torch.empty_like())
|
||||
def wrapper(*args, **kwargs):
|
||||
orig_t = args[0]
|
||||
return self.tensor_cls(
|
||||
orig_target, *orig_t.shape, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs
|
||||
)
|
||||
device = kwargs.pop("device", orig_t.device)
|
||||
dtype = kwargs.pop("dtype", orig_t.dtype)
|
||||
return self.tensor_cls(orig_target, *orig_t.shape, *args[1:], device=device, dtype=dtype, **kwargs)
|
||||
|
||||
return wrapper, target
|
||||
|
||||
|
@@ -171,7 +171,7 @@ def _communicate(
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
# To protect against race condition when using batch_isend_irecv().
|
||||
torch.cuda.synchronize()
|
||||
get_accelerator().synchronize()
|
||||
|
||||
if recv_prev and recv_prev_split:
|
||||
if isinstance(tensor_recv_prev, torch.Tensor):
|
||||
|
@@ -14,6 +14,8 @@ from torch.distributed import ProcessGroup
|
||||
from torch.distributed import distributed_c10d as c10d
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
from .stage_manager import PipelineStageManager
|
||||
|
||||
|
||||
@@ -31,7 +33,7 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
|
||||
buf = tensor.numpy().tobytes()[:tensor_size]
|
||||
if b"cuda" in buf:
|
||||
buf_array = bytearray(buf)
|
||||
device_index = torch.cuda.current_device()
|
||||
device_index = get_accelerator().current_device()
|
||||
# There might be more than one output tensors during forward
|
||||
for cuda_str in re.finditer(b"cuda", buf_array):
|
||||
pos = cuda_str.start()
|
||||
@@ -86,7 +88,7 @@ def _broadcast_object_list(
|
||||
else:
|
||||
current_device = torch.device("cpu")
|
||||
if is_nccl_backend:
|
||||
current_device = torch.device("cuda", torch.cuda.current_device())
|
||||
current_device = torch.device("cuda", get_accelerator().current_device())
|
||||
|
||||
my_rank = dist.get_rank()
|
||||
# Serialize object_list elements to tensors on src rank.
|
||||
@@ -139,14 +141,14 @@ def _broadcast_object_list(
|
||||
# unconsistence in device
|
||||
if (
|
||||
isinstance(unpickle_object, torch.Tensor)
|
||||
and unpickle_object.device.index != torch.cuda.current_device()
|
||||
and unpickle_object.device.index != get_accelerator().current_device()
|
||||
):
|
||||
unpickle_object = unpickle_object.cuda()
|
||||
unpickle_object = unpickle_object.to(get_accelerator().current_device())
|
||||
|
||||
object_list[i] = unpickle_object
|
||||
|
||||
|
||||
def _check_for_nccl_backend(group):
|
||||
def _check_for_nccl_hccl_backend(group):
|
||||
pg = group or c10d._get_default_group()
|
||||
# Gate PG wrapper check on Gloo availability.
|
||||
if c10d._GLOO_AVAILABLE:
|
||||
@@ -154,14 +156,14 @@ def _check_for_nccl_backend(group):
|
||||
while isinstance(pg, c10d._ProcessGroupWrapper):
|
||||
pg = pg.wrapped_pg
|
||||
|
||||
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
|
||||
return (c10d.is_nccl_available() or torch.distributed.is_hccl_available()) and pg.name() == c10d.Backend.NCCL
|
||||
|
||||
|
||||
def _check_device(group):
|
||||
is_nccl_backend = _check_for_nccl_backend(group)
|
||||
is_nccl_backend = _check_for_nccl_hccl_backend(group)
|
||||
current_device = torch.device("cpu")
|
||||
if is_nccl_backend:
|
||||
current_device = torch.device("cuda", torch.cuda.current_device())
|
||||
current_device = torch.device(get_accelerator().current_device())
|
||||
return current_device, is_nccl_backend
|
||||
|
||||
|
||||
@@ -348,8 +350,11 @@ def _send_recv_serialization_object(
|
||||
|
||||
unpickle_object = _cuda_safe_tensor_to_object(recv_object_tensor, recv_object_size_tensor.item())
|
||||
|
||||
if isinstance(unpickle_object, torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device():
|
||||
unpickle_object = unpickle_object.cuda()
|
||||
if (
|
||||
isinstance(unpickle_object, torch.Tensor)
|
||||
and unpickle_object.device.index != get_accelerator().current_device()
|
||||
):
|
||||
unpickle_object = unpickle_object.to(get_accelerator().current_device())
|
||||
|
||||
return unpickle_object
|
||||
|
||||
@@ -474,9 +479,11 @@ def _p2p_comm(
|
||||
recv_prev_shape = None
|
||||
|
||||
if tensor_send_next is not None:
|
||||
send_next_shape = torch.tensor(tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64)
|
||||
send_next_shape = torch.tensor(
|
||||
tensor_send_next.size(), device=get_accelerator().current_device(), dtype=torch.int64
|
||||
)
|
||||
if recv_prev:
|
||||
recv_prev_shape = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64)
|
||||
recv_prev_shape = torch.empty((3), device=get_accelerator().current_device(), dtype=torch.int64)
|
||||
|
||||
ops = []
|
||||
if send_next_shape is not None:
|
||||
@@ -501,7 +508,7 @@ def _p2p_comm(
|
||||
# send and recv data
|
||||
tensor_recv_prev = None
|
||||
if recv_prev:
|
||||
tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_dtype)
|
||||
tensor_recv_prev = torch.empty(recv_prev_shape, device=get_accelerator().current_device(), dtype=comm_dtype)
|
||||
|
||||
ops = []
|
||||
if tensor_send_next is not None:
|
||||
|
@@ -2,7 +2,6 @@ from functools import partial
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
import torch.distributed
|
||||
from torch.nn import Module, ModuleList
|
||||
from torch.utils._pytree import tree_map
|
||||
@@ -18,7 +17,7 @@ from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_
|
||||
from .base import PipelineSchedule
|
||||
|
||||
|
||||
def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
|
||||
def _wait_p2p(wait_handles) -> None:
|
||||
if wait_handles is not None:
|
||||
for req in wait_handles:
|
||||
req.wait()
|
||||
|
@@ -2,7 +2,6 @@ from functools import partial
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
from torch.nn import Module
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
@@ -1,15 +1,28 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
import numbers
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import init
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
|
||||
from ._operation import hook_parameter_in_backward
|
||||
from .utils import SeqParallelUtils
|
||||
|
||||
SUPPORT_NPU = False
|
||||
try:
|
||||
import torch_npu
|
||||
|
||||
SUPPORT_NPU = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]
|
||||
|
||||
try:
|
||||
@@ -21,7 +34,6 @@ except ImportError:
|
||||
|
||||
try:
|
||||
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm
|
||||
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
|
||||
|
||||
class FusedLayerNormWithHook(ApexFusedLayerNorm):
|
||||
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
|
||||
@@ -32,7 +44,41 @@ try:
|
||||
output = hook_parameter_in_backward(output, self.weight, self.bias)
|
||||
return output
|
||||
|
||||
class FusedRMSNormWithHook(ApexFusedRMSNorm):
|
||||
except ImportError:
|
||||
warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel")
|
||||
|
||||
FusedRMSNormWithHook = None
|
||||
if SUPPORT_NPU:
|
||||
|
||||
class NPUFusedRMSNormWithHook(nn.Module):
|
||||
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
|
||||
super().__init__()
|
||||
if isinstance(normalized_shape, numbers.Integral):
|
||||
normalized_shape = (normalized_shape,)
|
||||
self.normalized_shape = torch.Size(normalized_shape)
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if self.elementwise_affine:
|
||||
self.weight = Parameter(torch.empty(*normalized_shape))
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
if self.elementwise_affine:
|
||||
init.ones_(self.weight)
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
output, _ = torch_npu.npu_rms_norm(input, self.weight, self.eps)
|
||||
output = hook_parameter_in_backward(output, self.weight)
|
||||
return output
|
||||
|
||||
FusedRMSNormWithHook = NPUFusedRMSNormWithHook
|
||||
else:
|
||||
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
|
||||
|
||||
class CUDAFusedRMSNormWithHook(ApexFusedRMSNorm):
|
||||
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
|
||||
super().__init__(normalized_shape, eps, elementwise_affine)
|
||||
|
||||
@@ -41,8 +87,7 @@ try:
|
||||
output = hook_parameter_in_backward(output, self.weight)
|
||||
return output
|
||||
|
||||
except ImportError:
|
||||
warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel")
|
||||
FusedRMSNormWithHook = CUDAFusedRMSNormWithHook
|
||||
|
||||
FAST_LAYERNORM_SUPPORTED_SIZE = [
|
||||
1024,
|
||||
|
@@ -9,7 +9,7 @@ from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
|
||||
from colossalai.shardformer.layer import ColoAttention
|
||||
from colossalai.shardformer.layer._operation import (
|
||||
all_to_all_comm,
|
||||
gather_sp_output,
|
||||
@@ -25,42 +25,7 @@ def get_flash_core_attention_forward():
|
||||
|
||||
def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_mask):
|
||||
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
|
||||
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
||||
attention_mask_type = AttnMaskType.CAUSAL
|
||||
attn_bias = torch.zeros(
|
||||
query_layer.shape[0],
|
||||
1,
|
||||
query_layer.shape[2],
|
||||
key_layer.shape[2],
|
||||
dtype=query_layer.dtype,
|
||||
device=query_layer.device,
|
||||
)
|
||||
temp_mask = (
|
||||
torch.ones(
|
||||
query_layer.shape[2],
|
||||
key_layer.shape[2],
|
||||
dtype=torch.bool,
|
||||
device=query_layer.device,
|
||||
)
|
||||
.tril(diagonal=0)
|
||||
.expand(query_layer.shape[0], 1, -1, -1)
|
||||
)
|
||||
attn_bias.masked_fill_(temp_mask.logical_not(), torch.finfo(query_layer.dtype).min)
|
||||
else:
|
||||
attention_mask_type = AttnMaskType.CUSTOM
|
||||
if attention_mask is not None:
|
||||
attn_bias = torch.zeros_like(attention_mask, dtype=query_layer.dtype)
|
||||
attn_bias.masked_fill_(attention_mask, torch.finfo(query_layer.dtype).min)
|
||||
dropout_p = self.attention_dropout.p if self.training else 0.0
|
||||
context_layer = ColoAttention.attention(
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attention_mask=attn_bias,
|
||||
attention_mask_type=attention_mask_type,
|
||||
dropout_p=dropout_p,
|
||||
scale=1.0 / self.norm_factor,
|
||||
)
|
||||
context_layer = ColoAttention.attention(query_layer, key_layer, value_layer, **attention_mask)
|
||||
context_layer = context_layer.permute(2, 0, 1, 3)
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
||||
context_layer = context_layer.reshape(*new_context_layer_shape)
|
||||
@@ -180,9 +145,20 @@ class ChatGLMPipelineForwards:
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
if full_attention_mask is None:
|
||||
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
||||
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
|
||||
|
||||
if shard_config.enable_flash_attention:
|
||||
mask_shape = (batch_size, 1, seq_length, seq_length)
|
||||
full_attention_mask: dict = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
if full_attention_mask is None:
|
||||
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
||||
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
|
||||
|
||||
# Support SP + PP
|
||||
sp_size = shard_config.sequence_parallel_size
|
||||
@@ -237,7 +213,7 @@ class ChatGLMPipelineForwards:
|
||||
layer_ret = torch.utils.checkpoint.checkpoint(
|
||||
layer,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
full_attention_mask,
|
||||
rotary_pos_emb,
|
||||
past_key_values[idx],
|
||||
use_cache,
|
||||
@@ -402,10 +378,19 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if full_attention_mask is None:
|
||||
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
||||
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
|
||||
if shard_config.enable_flash_attention:
|
||||
mask_shape = (batch_size, 1, seq_length, seq_length)
|
||||
full_attention_mask: dict = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
if full_attention_mask is None:
|
||||
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
||||
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
|
||||
|
||||
# Rotary positional embeddings
|
||||
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
||||
@@ -652,3 +637,79 @@ def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, s
|
||||
return output, kv_cache
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_flash_attention_forward_for_chat_glm_model():
|
||||
from .chatglm2_6b.modeling_chatglm import ChatGLMModel
|
||||
|
||||
def forward(
|
||||
self: ChatGLMModel,
|
||||
input_ids,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
full_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embedding(input_ids)
|
||||
|
||||
if self.pre_seq_len is not None:
|
||||
if past_key_values is None:
|
||||
past_key_values = self.get_prompt(
|
||||
batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype
|
||||
)
|
||||
if attention_mask is not None:
|
||||
attention_mask = torch.cat(
|
||||
[attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1
|
||||
)
|
||||
|
||||
mask_shape = (batch_size, 1, seq_length, seq_length)
|
||||
full_attention_mask: dict = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
inputs_embeds.dtype,
|
||||
inputs_embeds.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
|
||||
# Rotary positional embeddings
|
||||
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
||||
if position_ids is not None:
|
||||
rotary_pos_emb = rotary_pos_emb[position_ids]
|
||||
else:
|
||||
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
|
||||
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
||||
|
||||
# Run encoder.
|
||||
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
||||
inputs_embeds,
|
||||
full_attention_mask,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
kv_caches=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
@@ -11,6 +11,7 @@ from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards
|
||||
from ..modeling.chatglm2 import (
|
||||
get_chatglm_sequence_parallel_attention_forward,
|
||||
get_chatglm_sequence_parallel_forward_fn,
|
||||
get_flash_attention_forward_for_chat_glm_model,
|
||||
get_flash_core_attention_forward,
|
||||
get_jit_fused_glm_block_forward,
|
||||
)
|
||||
@@ -203,6 +204,13 @@ class ChatGLMPolicy(Policy):
|
||||
policy=policy,
|
||||
target_key="CoreAttention",
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_flash_attention_forward_for_chat_glm_model(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key="ChatGLMModel",
|
||||
)
|
||||
|
||||
# use sequence parallel
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
|
@@ -157,7 +157,7 @@ class GeminiDDP(ModelWrapper):
|
||||
self.enable_async_reduce = enable_async_reduce
|
||||
|
||||
if enable_async_reduce:
|
||||
self.async_reduce_stream = torch.cuda.Stream()
|
||||
self.async_reduce_stream = get_accelerator().Stream()
|
||||
else:
|
||||
self.async_reduce_stream = None
|
||||
|
||||
@@ -363,7 +363,7 @@ class GeminiDDP(ModelWrapper):
|
||||
master_weights: bool,
|
||||
enable_gradient_accumulation: bool,
|
||||
p: nn.Parameter,
|
||||
async_reduce_stream: Optional[torch.cuda.Stream] = None,
|
||||
async_reduce_stream=None,
|
||||
):
|
||||
async_reduce_scatter = async_reduce_stream is not None
|
||||
setattr(p, "_gemini_reduced", True)
|
||||
@@ -402,9 +402,9 @@ class GeminiDDP(ModelWrapper):
|
||||
grad_chunk.add_tensor_to_chunk_slice(p, grad)
|
||||
|
||||
if async_reduce_stream is not None:
|
||||
async_reduce_stream.wait_stream(torch.cuda.current_stream())
|
||||
async_reduce_stream.wait_stream(get_accelerator().current_stream())
|
||||
|
||||
with torch.cuda.stream(async_reduce_stream):
|
||||
with get_accelerator().stream(async_reduce_stream):
|
||||
reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce_scatter)
|
||||
if reduced:
|
||||
grad_chunk.wait_async_reduce()
|
||||
|
@@ -62,7 +62,7 @@ class GeminiZeROHook(ColoParamOpHook):
|
||||
#
|
||||
# Other than that, self._gemini_manager.wait_chunks will have synced with default stream
|
||||
# by calling dist.Work.wait() and this line makes no diff.
|
||||
self._gemini_manager.chunk_manager._prefetch_stream.wait_stream(torch.cuda.current_stream())
|
||||
self._gemini_manager.chunk_manager._prefetch_stream.wait_stream(get_accelerator().current_stream())
|
||||
|
||||
with get_accelerator().stream(self._gemini_manager.chunk_manager._prefetch_stream):
|
||||
for chunk in chunks_fetch_async:
|
||||
|
Reference in New Issue
Block a user