Merge branch 'main' into sync/npu

This commit is contained in:
ver217
2024-01-18 12:05:21 +08:00
152 changed files with 8641 additions and 2138 deletions

View File

@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Callable, Iterable, List, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import torch
import torch.cuda
@@ -8,7 +8,7 @@ from torch.utils._pytree import tree_map
from colossalai.accelerator import get_accelerator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from ._utils import (
@@ -30,6 +30,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
stage_manager: PipelineStageManager,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
) -> None:
"""1F1B pipeline schedule.
@@ -42,13 +43,21 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
assert (
num_microbatches is not None or microbatch_size is not None
), "Either num_microbatches or microbatch_size should be provided"
self.comm = PipelineP2PCommunication(stage_manager)
self.num_microbatches = num_microbatches
self.microbatch_size = microbatch_size
self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None
self.last_batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None
self._use_microbatch_size = num_microbatches is None
# P2PMeta cache
self.enable_metadata_cache = enable_metadata_cache
self.send_tensor_metadata = True
self.send_grad_metadata = True
self.tensor_metadata_recv = None
self.grad_metadata_recv = None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
@@ -60,24 +69,45 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
batch = next(data_iter)
if device is not None:
batch = tree_map(partial(to_device, device=device), batch)
self.microbatch_offset = 0
self.batch = batch
self.batch_size = get_batch_size(batch)
self.microbatch_offset = 0
if not self._use_microbatch_size:
assert (
self.batch_size % self.num_microbatches == 0
), "Batch size should divided by the number of microbatches"
if self.microbatch_size is None:
assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by # microbatches"
self.microbatch_size = self.batch_size // self.num_microbatches
else:
if self.num_microbatches is None:
assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size"
self.num_microbatches = self.batch_size // self.microbatch_size
if not self.forward_only:
assert self.last_batch_size is None or self.last_batch_size == self.batch_size
assert self.batch_size == self.microbatch_size * self.num_microbatches
assert (
self.num_microbatches >= self.stage_manager.num_stages
), "Number of microbatch should be larger than number of stages"
if self.forward_only:
self.num_microbatches = (self.batch_size - 1) // self.microbatch_size + 1
# NOTE: disable metadata cache when batch size changes (not valid anymore)
if self.batch_size != self.last_batch_size:
self.enable_metadata_cache = False
self.send_tensor_metadata = True
self.send_grad_metadata = True
self.tensor_metadata_recv = None
self.grad_metadata_recv = None
self.last_batch_size = self.batch_size
def load_micro_batch(self) -> Any:
"""Load a micro batch from the current batch.
Returns:
Any: Micro batch.
"""
assert self.microbatch_offset <= self.batch_size, "Microbatches exhausted"
micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size)
self.microbatch_offset += self.microbatch_size
return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
@@ -92,12 +122,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Returns:
Any: The input tensor or input tensor list.
"""
if self.stage_manager.is_first_stage():
input_tensor = None
else:
input_tensor = self.comm.recv_forward(prev_rank)
if not self.stage_manager.is_first_stage():
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor
return input_tensor
def recv_backward(self, next_rank: int = None) -> Any:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
@@ -109,14 +139,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Returns:
Any: The input gradient tensor or gradient tensor list.
"""
if self.stage_manager.is_last_stage():
output_tensor_grad = None
else:
output_tensor_grad = self.comm.recv_backward(next_rank)
if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad
return output_tensor_grad
def send_forward(self, output_object: Any, next_rank: int = None) -> None:
def send_forward(self, output_tensor: Any, next_rank: int = None) -> None:
"""Sends the input tensor to the next stage in pipeline.
For 1F1B.
@@ -125,20 +155,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_object, next_rank)
self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
self.send_tensor_metadata = not self.enable_metadata_cache
def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any:
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
For 1F1B.
Args:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
return self.comm.send_forward_recv_backward(output_object, next_rank)
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
For 1F1B.
@@ -147,9 +167,38 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
prev_rank (int, optional): The rank of the recipient of the tensor
"""
if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_object, prev_rank)
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
self.send_grad_metadata = not self.enable_metadata_cache
def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any:
def send_forward_recv_backward(
self, output_tensor: Any, next_rank: int = None, send_prior_fallback: Optional[bool] = None
) -> Any:
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
For 1F1B.
Args:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
if not self.send_tensor_metadata and self.grad_metadata_recv is not None:
send_prior_fallback = None # must not fallback
output_tensor_grad = self.comm.send_forward_recv_backward(
output_tensor,
next_rank,
send_metadata=self.send_tensor_metadata,
metadata_recv=self.grad_metadata_recv,
send_prior_fallback=send_prior_fallback,
)
self.send_tensor_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad
def send_backward_recv_forward(
self, input_tensor_grad: Any, prev_rank: int = None, send_prior_fallback: Optional[bool] = None
) -> Any:
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
For 1F1B.
@@ -158,23 +207,20 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
prev_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_first_stage():
return self.comm.send_backward_recv_forward(output_object, prev_rank)
if not self.send_grad_metadata and self.tensor_metadata_recv is not None:
send_prior_fallback = None # must not fallback
input_tensor = self.comm.send_backward_recv_forward(
input_tensor_grad,
prev_rank,
send_metadata=self.send_grad_metadata,
metadata_recv=self.tensor_metadata_recv,
send_prior_fallback=send_prior_fallback,
)
self.send_grad_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any:
"""Sends the input tensor to the next stage and copy the input tensor from the previous stage in pipeline.
For 1F1B.
Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The previous rank of the recipient of the tensor.
next_rank (int, optional): The next rank of the recipient of the tensor.
"""
if self.stage_manager.is_first_stage():
return self.comm.send_forward(input_object, next_rank)
elif self.stage_manager.is_last_stage():
return self.comm.recv_forward(prev_rank)
else:
return self.comm.send_forward_recv_forward(input_object, prev_rank, next_rank)
return input_tensor
def forward_step(
self,
@@ -254,7 +300,38 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
input_obj_grad[k] = v.grad
return input_obj_grad
def forward_backward_step(
def run_forward_only(
self,
model: Module,
data_iter: Iterable,
criterion: Callable[..., Any],
return_loss: bool = False,
return_outputs: bool = False,
) -> Dict:
"""
Runs forward only schedule, with communication between pipeline stages.
"""
assert self.forward_only
self.load_batch(data_iter)
accum_loss = None
if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device())
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
for _ in range(self.num_microbatches):
input_obj = self.recv_forward()
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
self.send_forward(output_obj)
if outputs is not None:
if isinstance(model, ModelWrapper):
model = model.unwrap()
outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0))
return {"loss": accum_loss, "outputs": outputs}
def run_forward_backward(
self,
model: Module,
data_iter: Iterable,
@@ -262,23 +339,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> dict:
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Args:
model (Module): Model to be trained.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns:
dict: A dict with keys: 'loss' and 'outputs'.
) -> Dict:
"""
forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert forward_only, "Optimizer should be passed when doing backward."
Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
"""
assert not self.forward_only
self.load_batch(data_iter)
@@ -288,30 +353,20 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
# Input, output tensors only need to be saved when doing backward passes
input_objs = None
output_objs = None
input_objs, output_objs = [], []
if not forward_only:
input_objs = []
output_objs = []
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
accum_loss = None
if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
else:
accum_loss = None
accum_loss = torch.scalar_tensor(0, device=get_current_device())
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_obj = self.recv_forward()
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
self.send_forward(output_obj)
if not forward_only:
input_objs.append(input_obj)
output_objs.append(output_obj)
input_objs.append(input_obj)
output_objs.append(output_obj)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
@@ -324,44 +379,72 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
last_iteration = i == (num_microbatches_remaining - 1)
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
if forward_only:
self.send_forward(output_obj)
output_obj_grad = self.send_forward_recv_backward(
output_obj, send_prior_fallback=self.stage_manager.stage % 2 == 0
)
# Add input_obj and output_obj to end of list.
input_objs.append(input_obj)
output_objs.append(output_obj)
if not last_iteration:
input_obj = self.recv_forward()
else:
# TODO adjust here
self.send_forward(output_obj)
output_obj_grad = self.recv_backward()
# Pop output_obj and output_obj from the start of the list for
# the backward pass.
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
# Add input_obj and output_obj to end of list.
input_objs.append(input_obj)
output_objs.append(output_obj)
# Pop output_obj and output_obj from the start of the list for
# the backward pass.
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
if last_iteration:
input_obj = None
else:
input_obj = self.recv_forward()
if last_iteration:
self.send_backward(input_obj_grad)
else:
input_obj = self.send_backward_recv_forward(
input_obj_grad, send_prior_fallback=self.stage_manager.stage % 2 == 0
)
# Run cooldown backward passes.
if not forward_only:
for i in range(num_warmup_microbatches):
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
for i in range(num_warmup_microbatches):
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
output_obj_grad = self.recv_backward()
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(input_obj_grad)
output_obj_grad = self.recv_backward()
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(input_obj_grad)
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
if outputs is not None:
if isinstance(model, ModelWrapper):
model = model.unwrap()
outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0))
return {"loss": accum_loss, "outputs": outputs}
def forward_backward_step(
self,
model: Module,
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> dict:
"""
Args:
model (Module): Model to be trained.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns:
dict: Dictionary containing loss and outputs.
"""
self.forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert self.forward_only, "Optimizer should be passed when doing backward."
if self.forward_only:
result = self.run_forward_only(model, data_iter, criterion, return_loss, return_outputs)
else:
result = self.run_forward_backward(model, data_iter, criterion, optimizer, return_loss, return_outputs)
return result