support fp8 communication in pipeline parallelism

This commit is contained in:
BurkeHulk
2024-07-12 15:25:25 +08:00
parent 1e1959467e
commit e88190184a
4 changed files with 126 additions and 1 deletions

View File

@@ -12,6 +12,7 @@ from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils import get_current_device
from colossalai.quantization.fp8 import cast_to_fp8_pipeline, cast_from_fp8_pipeline
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from .base import PipelineSchedule
@@ -32,6 +33,7 @@ class InterleavedSchedule(PipelineSchedule):
microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
overlap_p2p: bool = True,
fp8_communication: bool = False,
) -> None:
super().__init__(stage_manager)
assert (
@@ -56,6 +58,7 @@ class InterleavedSchedule(PipelineSchedule):
self.tensor_metadata_recv = None
self.grad_metadata_recv = None
self.fp8_communication = fp8_communication
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
@@ -191,8 +194,12 @@ class InterleavedSchedule(PipelineSchedule):
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage():
if self.fp8_communication:
cast_to_fp8_pipeline(output_tensor)
send_handles = self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
self.send_tensor_metadata = not self.enable_metadata_cache
if self.fp8_communication:
cast_from_fp8_pipeline(output_tensor)
return send_handles
return []
@@ -210,10 +217,14 @@ class InterleavedSchedule(PipelineSchedule):
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage():
if self.fp8_communication:
cast_to_fp8_pipeline(input_tensor_grad)
send_handles = self.comm.send_backward(
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata
)
self.send_grad_metadata = not self.enable_metadata_cache
if self.fp8_communication:
cast_from_fp8_pipeline(input_tensor_grad)
return send_handles
return []
@@ -224,6 +235,8 @@ class InterleavedSchedule(PipelineSchedule):
is_send = not self.stage_manager.is_last_stage()
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
is_recv = not self.stage_manager.is_first_stage()
if self.fp8_communication:
cast_to_fp8_pipeline(output_tensor)
input_tensor, wait_handles = self.comm.send_forward_recv_forward(
output_tensor,
is_send,
@@ -237,6 +250,8 @@ class InterleavedSchedule(PipelineSchedule):
if is_recv and self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
if self.fp8_communication:
cast_from_fp8_pipeline(output_tensor)
return input_tensor, wait_handles
def send_backward_recv_backward(
@@ -246,6 +261,8 @@ class InterleavedSchedule(PipelineSchedule):
is_send = not self.stage_manager.is_first_stage()
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
is_recv = not self.stage_manager.is_last_stage()
if self.fp8_communication:
cast_to_fp8_pipeline(input_tensor_grad)
output_tensor_grad, wait_handles = self.comm.send_backward_recv_backward(
input_tensor_grad,
is_send,
@@ -258,6 +275,8 @@ class InterleavedSchedule(PipelineSchedule):
self.send_grad_metadata = not self.enable_metadata_cache and is_send
if is_recv and self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
if self.fp8_communication:
cast_from_fp8_pipeline(input_tensor_grad)
return output_tensor_grad, wait_handles
def forward_step(
@@ -379,6 +398,8 @@ class InterleavedSchedule(PipelineSchedule):
# Wait until current input is received
_wait_p2p(fwd_wait_handles)
if self.fp8_communication and input_obj is not None:
cast_from_fp8_pipeline(input_obj)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if not last_batch:
@@ -441,6 +462,8 @@ class InterleavedSchedule(PipelineSchedule):
# Wait for input
_wait_p2p(fwd_wait_handles)
if self.fp8_communication and input_obj is not None:
cast_from_fp8_pipeline(input_obj)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
@@ -467,6 +490,8 @@ class InterleavedSchedule(PipelineSchedule):
# Wait for input.
_wait_p2p(fwd_wait_handles)
if self.fp8_communication and input_obj is not None:
cast_from_fp8_pipeline(input_obj)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
# Add input_obj and output_obj to end of list.
input_objs[model_chunk_id].append(input_obj)
@@ -511,6 +536,8 @@ class InterleavedSchedule(PipelineSchedule):
input_obj, fwd_wait_handles = send_forward_recv_forward()
# Wait for upstream grad
_wait_p2p(bwd_wait_handles)
if self.fp8_communication and output_obj_grad is not None:
cast_from_fp8_pipeline(output_obj_grad)
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
# NOTE: It's documented by NCCL that running two concurrent communicators (batch_isend_irecv)
# risks deadlock (https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2134/user-guide/docs/usage/communicators.html)
@@ -532,6 +559,8 @@ class InterleavedSchedule(PipelineSchedule):
# Wait for upstream grad
_wait_p2p(bwd_wait_handles)
if self.fp8_communication and output_obj_grad is not None:
cast_from_fp8_pipeline(output_obj_grad)
# backward local grads
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
if not last_batch: