mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[legacy] move communication and nn to legacy and refactor logger (#4671)
* [legacy] move communication to legacy (#4640) * [legacy] refactor logger and clean up legacy codes (#4654) * [legacy] make logger independent to gpc * [legacy] make optim independent to registry * [legacy] move test engine to legacy * [legacy] move nn to legacy (#4656) * [legacy] move nn to legacy * [checkpointio] fix save hf config * [test] remove useledd rpc pp test * [legacy] fix nn init * [example] skip tutorial hybriad parallel example * [devops] test doc check * [devops] test doc check
This commit is contained in:
56
colossalai/legacy/communication/ring.py
Normal file
56
colossalai/legacy/communication/ring.py
Normal file
@@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device, synchronize
|
||||
|
||||
|
||||
def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> torch.Tensor:
|
||||
"""Sends a tensor to the next member and receives a tensor from the previous member.
|
||||
This function returns the received tensor from the previous member.
|
||||
|
||||
Args:
|
||||
tensor_send_next (:class:`torch.Tensor`): Tensor sent to next member
|
||||
parallel_mode (ParallelMode): Parallel group mode used in this communication
|
||||
|
||||
Returns:
|
||||
:class:`torch.Tensor`: The tensor received from the previous.
|
||||
|
||||
Note:
|
||||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
|
||||
"""
|
||||
buffer_shape = tensor_send_next.size()
|
||||
|
||||
ops = []
|
||||
current_rank = gpc.get_global_rank()
|
||||
|
||||
tensor_recv_prev = torch.empty(buffer_shape,
|
||||
requires_grad=True,
|
||||
device=get_current_device(),
|
||||
dtype=tensor_send_next.dtype)
|
||||
|
||||
# send to next rank
|
||||
send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next,
|
||||
gpc.get_next_global_rank(parallel_mode))
|
||||
ops.append(send_next_op)
|
||||
|
||||
# receive from prev rank
|
||||
recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev,
|
||||
gpc.get_prev_global_rank(parallel_mode))
|
||||
ops.append(recv_prev_op)
|
||||
|
||||
if current_rank % 2 == 0:
|
||||
ops = ops[::-1]
|
||||
|
||||
reqs = torch.distributed.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
# To protect against race condition when using batch_isend_irecv().
|
||||
synchronize()
|
||||
|
||||
return tensor_recv_prev
|
Reference in New Issue
Block a user