mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 21:09:18 +00:00
Merge branch 'main' into sync/npu
This commit is contained in:
@@ -1,14 +1,14 @@
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Iterable, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
from torch.nn import Module
|
||||
from torch.nn import Module, ModuleList
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
|
||||
@@ -16,18 +16,35 @@ from .base import PipelineSchedule
|
||||
|
||||
|
||||
class InterleavedSchedule(PipelineSchedule):
|
||||
def __init__(self, num_microbatches: int, num_model_chunks: int, stage_manager: PipelineStageManager) -> None:
|
||||
self.num_model_chunks = num_model_chunks
|
||||
assert (
|
||||
num_microbatches % self.num_model_chunks == 0
|
||||
), "Number of microbatches should be an integer multiple of number of model chunks"
|
||||
def __init__(
|
||||
self,
|
||||
stage_manager: PipelineStageManager,
|
||||
num_model_chunks: int,
|
||||
num_microbatch: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
) -> None:
|
||||
super().__init__(stage_manager)
|
||||
assert (
|
||||
num_microbatch is not None or microbatch_size is not None
|
||||
), "Either num_microbatch or microbatch_size should be provided"
|
||||
|
||||
self.comm = PipelineP2PCommunication(stage_manager)
|
||||
self.num_microbatches = num_microbatches
|
||||
self.batch: Optional[Any] = None
|
||||
self.batch_size: Optional[int] = None
|
||||
self.microbatch_offset: Optional[int] = None
|
||||
self.microbatch_size: Optional[int] = None
|
||||
self.num_microbatch = num_microbatch
|
||||
self.microbatch_size = microbatch_size
|
||||
self.num_model_chunks = num_model_chunks
|
||||
|
||||
self.batch: Any
|
||||
self.batch_size: int
|
||||
self.last_batch_size: Optional[int] = None
|
||||
self.microbatch_offset: List[int]
|
||||
|
||||
# P2PMeta cache
|
||||
self.enable_metadata_cache = enable_metadata_cache
|
||||
self.send_tensor_metadata = True
|
||||
self.send_grad_metadata = True
|
||||
self.tensor_metadata_recv = None
|
||||
self.grad_metadata_recv = None
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
@@ -39,11 +56,37 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
batch = next(data_iter)
|
||||
if device is not None:
|
||||
batch = tree_map(partial(to_device, device=device), batch)
|
||||
|
||||
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
|
||||
self.batch = batch
|
||||
self.batch_size = get_batch_size(batch)
|
||||
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
|
||||
assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches"
|
||||
self.microbatch_size = self.batch_size // self.num_microbatches
|
||||
|
||||
if self.microbatch_size is None:
|
||||
assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch"
|
||||
self.microbatch_size = self.batch_size // self.num_microbatch
|
||||
if self.num_microbatch is None:
|
||||
assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size"
|
||||
self.num_microbatch = self.batch_size // self.microbatch_size
|
||||
|
||||
if not self.forward_only:
|
||||
assert self.last_batch_size is None or self.last_batch_size == self.batch_size
|
||||
assert self.batch_size == self.microbatch_size * self.num_microbatch
|
||||
|
||||
assert (
|
||||
self.num_microbatch % self.stage_manager.num_stages == 0
|
||||
), "Number of microbatch should be an integer multiple of number of pipeline parallel devices"
|
||||
|
||||
if self.forward_only:
|
||||
self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1
|
||||
# NOTE: disable metadata cache when batch size changes (not valid anymore)
|
||||
if self.batch_size != self.last_batch_size:
|
||||
self.enable_metadata_cache = False
|
||||
self.send_tensor_metadata = True
|
||||
self.send_grad_metadata = True
|
||||
self.tensor_metadata_recv = None
|
||||
self.grad_metadata_recv = None
|
||||
|
||||
self.last_batch_size = self.batch_size
|
||||
|
||||
def load_micro_batch(self, model_chunk_id: int) -> Any:
|
||||
"""Load a micro batch from the current batch.
|
||||
@@ -54,11 +97,12 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
Returns:
|
||||
Any: Micro batch.
|
||||
"""
|
||||
assert self.microbatch_offset[model_chunk_id] <= self.batch_size, "Microbatches exhausted"
|
||||
micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size)
|
||||
self.microbatch_offset[model_chunk_id] += self.microbatch_size
|
||||
return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
|
||||
|
||||
def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int:
|
||||
def get_model_chunk_id(self, microbatch_id: int, is_forward: bool) -> int:
|
||||
"""Helper method to get the model chunk ID given the iteration number.
|
||||
|
||||
Args:
|
||||
@@ -68,38 +112,13 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
Returns:
|
||||
int: The model chunk idx of the input microbatch_id
|
||||
"""
|
||||
microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks)
|
||||
assert microbatch_id < self.num_microbatch * self.num_model_chunks
|
||||
microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks)
|
||||
model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages
|
||||
if not forward:
|
||||
if not is_forward:
|
||||
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
|
||||
return model_chunk_id
|
||||
|
||||
def is_first_stage(self, model_chunk_id: int) -> bool:
|
||||
"""Is the current virtual stage the first stage
|
||||
|
||||
Args:
|
||||
model_chunk_id (int): The current model chunk idx.
|
||||
|
||||
Returns:
|
||||
bool: Whether the current virtual stage is the first stage.
|
||||
"""
|
||||
if self.stage_manager.is_first_stage() and model_chunk_id == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_last_stage(self, model_chunk_id: int) -> bool:
|
||||
"""Is the current virtual stage the last stage
|
||||
|
||||
Args:
|
||||
model_chunk_id (int): The current model chunk idx.
|
||||
|
||||
Returns:
|
||||
bool: Whether the current virtual stage is the last stage.
|
||||
"""
|
||||
if self.stage_manager.is_last_stage() and model_chunk_id == self.num_model_chunks - 1:
|
||||
return True
|
||||
return False
|
||||
|
||||
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any:
|
||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||
For interleaved 1F1B.
|
||||
@@ -111,12 +130,13 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
"""
|
||||
if self.is_first_stage(model_chunk_id):
|
||||
input_tensor = None
|
||||
else:
|
||||
input_tensor = self.comm.recv_forward(prev_rank)
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_first_stage():
|
||||
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
return input_tensor
|
||||
return input_tensor
|
||||
|
||||
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any:
|
||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||
@@ -129,14 +149,15 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
Returns:
|
||||
Any: The input gradient tensor or gradient tensor list.
|
||||
"""
|
||||
if self.is_last_stage(model_chunk_id):
|
||||
output_tensor_grad = None
|
||||
else:
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank)
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
return output_tensor_grad
|
||||
|
||||
def send_forward(self, model_chunk_id, output_object: Any, next_rank: int = None) -> None:
|
||||
def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> None:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
For interleaved 1F1B.
|
||||
|
||||
@@ -145,10 +166,12 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
output_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.is_last_stage(model_chunk_id):
|
||||
self.comm.send_forward(output_object, next_rank)
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_last_stage():
|
||||
self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
|
||||
def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None:
|
||||
def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> None:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
For interleaved 1F1B.
|
||||
|
||||
@@ -157,12 +180,102 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
input_object (Any): Object to be sent.
|
||||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
"""
|
||||
if not self.is_first_stage(model_chunk_id):
|
||||
self.comm.send_backward(input_object, prev_rank)
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_first_stage():
|
||||
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
|
||||
def send_forward_recv_backward(
|
||||
self,
|
||||
model_chunk_id_send: int,
|
||||
model_chunk_id_recv: int,
|
||||
output_tensor: Any,
|
||||
next_rank: Optional[int] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
) -> Any:
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
|
||||
send_data = not self.stage_manager.is_last_stage()
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
|
||||
recv_data = not self.stage_manager.is_last_stage()
|
||||
|
||||
if send_data and recv_data:
|
||||
if not self.send_forward_recv_backward and self.grad_metadata_recv is not None:
|
||||
send_prior_fallback = None # must not fallback
|
||||
output_tensor_grad = self.comm.send_forward_recv_backward(
|
||||
output_tensor,
|
||||
next_rank,
|
||||
send_metadata=self.send_tensor_metadata,
|
||||
metadata_recv=self.grad_metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
return output_tensor_grad
|
||||
|
||||
# send only or recv only
|
||||
self.send_forward(model_chunk_id_send, output_tensor)
|
||||
return self.recv_backward(model_chunk_id_recv)
|
||||
|
||||
def send_backward_recv_forward(
|
||||
self,
|
||||
model_chunk_id_send: int,
|
||||
model_chunk_id_recv: int,
|
||||
input_tensor_grad: Any,
|
||||
prev_rank: Optional[int] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
) -> Any:
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
|
||||
send_data = not self.stage_manager.is_first_stage()
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
|
||||
recv_data = not self.stage_manager.is_first_stage()
|
||||
|
||||
if send_data and recv_data:
|
||||
if not self.send_backward_recv_backward and self.tensor_metadata_recv is not None:
|
||||
send_prior_fallback = None # must not fallback
|
||||
input_tensor = self.comm.send_backward_recv_forward(
|
||||
input_tensor_grad,
|
||||
prev_rank,
|
||||
send_metadata=self.send_grad_metadata,
|
||||
metadata_recv=self.tensor_metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
return input_tensor
|
||||
|
||||
# send only or recv only
|
||||
self.send_backward(model_chunk_id_send, input_tensor_grad)
|
||||
return self.recv_forward(model_chunk_id_recv)
|
||||
|
||||
def send_forward_recv_forward(
|
||||
self, model_chunk_id_send: int, model_chunk_id_recv: int, output_tensor: Any, send_prior: bool
|
||||
):
|
||||
if send_prior:
|
||||
self.send_forward(model_chunk_id_send, output_tensor)
|
||||
input_tensor = self.recv_forward(model_chunk_id_recv)
|
||||
else:
|
||||
input_tensor = self.recv_forward(model_chunk_id_recv)
|
||||
self.send_forward(model_chunk_id_send, output_tensor)
|
||||
|
||||
return input_tensor
|
||||
|
||||
def send_backward_recv_backward(
|
||||
self, model_chunk_id_send: int, model_chunk_id_recv: int, input_tensor_grad: Any, send_prior: bool
|
||||
):
|
||||
if send_prior:
|
||||
self.send_backward(model_chunk_id_send, input_tensor_grad)
|
||||
output_tensor_grad = self.recv_backward(model_chunk_id_recv)
|
||||
else:
|
||||
output_tensor_grad = self.recv_backward(model_chunk_id_recv)
|
||||
self.send_backward(model_chunk_id_send, input_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
def forward_step(
|
||||
self,
|
||||
model_chunk: Module,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
input_obj: Optional[dict],
|
||||
criterion: Callable,
|
||||
@@ -171,7 +284,7 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
) -> Union[torch.Tensor, dict]:
|
||||
"""Forward one step of the pipeline
|
||||
Args:
|
||||
model (Module): Model Chunk to be run
|
||||
model (ModuleList or Module): Model Chunk to be run
|
||||
input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None.
|
||||
criterion (Callable): Criterion to calculate loss.
|
||||
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
|
||||
@@ -184,17 +297,25 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
|
||||
# for the first stage, input_obj is None
|
||||
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
|
||||
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
|
||||
|
||||
if self.is_last_stage(model_chunk_id):
|
||||
loss = criterion(output_obj, micro_batch) / self.num_microbatches
|
||||
if accum_loss is not None:
|
||||
accum_loss.add_(loss.detach())
|
||||
if outputs is not None:
|
||||
outputs.append(tree_map(detach, output_obj))
|
||||
return loss
|
||||
else:
|
||||
return output_obj
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if isinstance(model_chunk, ModuleList):
|
||||
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
|
||||
else:
|
||||
# NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers
|
||||
internal_inputs = {} if input_obj is None else input_obj
|
||||
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
||||
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
|
||||
|
||||
if self.stage_manager.is_last_stage():
|
||||
loss = criterion(output_obj, micro_batch) / self.num_microbatch
|
||||
if accum_loss is not None:
|
||||
accum_loss.add_(loss.detach())
|
||||
if outputs is not None:
|
||||
outputs.append(tree_map(detach, output_obj))
|
||||
return loss
|
||||
else:
|
||||
return output_obj
|
||||
|
||||
def backward_step(
|
||||
self,
|
||||
@@ -241,19 +362,193 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
input_obj_grad[k] = v.grad
|
||||
return input_obj_grad
|
||||
|
||||
def run_forward_only(
|
||||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False,
|
||||
) -> Dict:
|
||||
assert self.forward_only
|
||||
|
||||
self.load_batch(data_iter)
|
||||
|
||||
outputs = [] if return_outputs and self.stage_manager.is_last_stage(ignore_chunk=True) else None
|
||||
|
||||
accum_loss = None
|
||||
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
accum_loss = torch.scalar_tensor(0, device=get_current_device())
|
||||
|
||||
model_chunk_id = self.get_model_chunk_id(0, is_forward=True)
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
|
||||
for i in range(self.num_microbatch * self.num_model_chunks):
|
||||
last_iteration = i == self.num_microbatch * self.num_model_chunks - 1
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
|
||||
if not last_iteration:
|
||||
input_obj = self.send_forward_recv_forward(
|
||||
model_chunk_id_send=model_chunk_id,
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True),
|
||||
output_tensor=output_obj,
|
||||
send_prior=self.stage_manager.stage % 2 == 0,
|
||||
)
|
||||
else:
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
return {"loss": accum_loss, "outputs": outputs}
|
||||
|
||||
def run_forward_backward(
|
||||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False,
|
||||
) -> Dict:
|
||||
"""
|
||||
Runs interleaved schedule, with communication between pipeline stages.
|
||||
"""
|
||||
assert not self.forward_only
|
||||
|
||||
self.load_batch(data_iter)
|
||||
|
||||
num_microbatch = self.num_microbatch * self.num_model_chunks
|
||||
num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
|
||||
num_warmup_microbatch += (self.num_model_chunks - 1) * self.stage_manager.num_stages
|
||||
num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch)
|
||||
num_microbatch_remaining = num_microbatch - num_warmup_microbatch
|
||||
|
||||
# Input, output tensors only need to be saved when doing backward passes
|
||||
input_objs = [[] for _ in range(self.num_model_chunks)]
|
||||
output_objs = [[] for _ in range(self.num_model_chunks)]
|
||||
|
||||
outputs = [] if return_outputs and self.stage_manager.is_last_stage(ignore_chunk=True) else None
|
||||
|
||||
accum_loss = None
|
||||
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
accum_loss = torch.scalar_tensor(0, device=get_current_device())
|
||||
|
||||
model_chunk_id = self.get_model_chunk_id(0, is_forward=True)
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
# Run warmup forward passes.
|
||||
for i in range(num_warmup_microbatch):
|
||||
last_iteration = i == num_warmup_microbatch - 1
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
|
||||
if last_iteration and num_microbatch_remaining == 0:
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
else:
|
||||
input_obj = self.send_forward_recv_forward(
|
||||
model_chunk_id_send=model_chunk_id,
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True),
|
||||
output_tensor=output_obj,
|
||||
send_prior=self.stage_manager.stage % 2 == 0,
|
||||
)
|
||||
|
||||
if num_microbatch_remaining > 0:
|
||||
model_chunk_id = self.get_model_chunk_id(0, is_forward=False)
|
||||
output_obj_grad = self.recv_backward(model_chunk_id)
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for i in range(num_microbatch_remaining):
|
||||
last_iteration = i == num_microbatch_remaining - 1
|
||||
|
||||
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
# Add input_obj and output_obj to end of list.
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
|
||||
# Pop output_obj and output_obj from the start of the list for the backward pass.
|
||||
_input_obj = input_objs[model_chunk_id].pop(0)
|
||||
_output_obj = output_objs[model_chunk_id].pop(0)
|
||||
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
|
||||
|
||||
# NOTE: perform 2x communication for forward and backward
|
||||
def send_forward_recv_backward():
|
||||
if last_iteration and num_microbatch == num_microbatch_remaining:
|
||||
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
else:
|
||||
output_obj_grad = self.send_forward_recv_backward(
|
||||
model_chunk_id_send=self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True),
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
|
||||
output_tensor=output_obj,
|
||||
send_prior_fallback=self.stage_manager.stage % 2 == 0,
|
||||
)
|
||||
return output_obj_grad
|
||||
|
||||
def send_backward_recv_forward():
|
||||
if last_iteration:
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
|
||||
self.send_backward(model_chunk_id, input_obj_grad)
|
||||
else:
|
||||
input_obj = self.send_backward_recv_forward(
|
||||
model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False),
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True),
|
||||
input_tensor_grad=input_obj_grad,
|
||||
send_prior_fallback=self.stage_manager.stage % 2 == 0 and i > 0,
|
||||
)
|
||||
return input_obj
|
||||
|
||||
if self.stage_manager.stage % 2 == 0:
|
||||
output_obj_grad = send_forward_recv_backward()
|
||||
input_obj = send_backward_recv_forward()
|
||||
else:
|
||||
input_obj = send_backward_recv_forward()
|
||||
output_obj_grad = send_forward_recv_backward()
|
||||
|
||||
if num_microbatch_remaining == 0:
|
||||
model_chunk_id = self.get_model_chunk_id(0, is_forward=False)
|
||||
output_obj_grad = self.recv_backward(model_chunk_id)
|
||||
# Run cooldown backward passes.
|
||||
for i in range(num_microbatch_remaining, num_microbatch):
|
||||
last_iteration = i == num_microbatch - 1
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
|
||||
_input_obj = input_objs[model_chunk_id].pop(0)
|
||||
_output_obj = output_objs[model_chunk_id].pop(0)
|
||||
# output_obj_grad = self.recv_backward(model_chunk_id)
|
||||
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
|
||||
|
||||
if not last_iteration:
|
||||
output_obj_grad = self.send_backward_recv_backward(
|
||||
model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False),
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
|
||||
input_tensor_grad=input_obj_grad,
|
||||
send_prior=self.stage_manager.stage % 2 == 0 and i > num_microbatch_remaining,
|
||||
)
|
||||
else:
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
|
||||
self.send_backward(model_chunk_id, input_obj_grad)
|
||||
|
||||
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
|
||||
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
return {"loss": accum_loss, "outputs": outputs}
|
||||
|
||||
def forward_backward_step(
|
||||
self,
|
||||
model_chunk: Module,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False,
|
||||
) -> dict:
|
||||
"""Runs interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
|
||||
"""
|
||||
Args:
|
||||
model_chunk (List[Module]): Model Chunk to be trained.
|
||||
model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification
|
||||
data_iter (Iterable): Data iterator.
|
||||
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
|
||||
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
|
||||
@@ -263,118 +558,15 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
Returns:
|
||||
dict: A dict with keys: 'loss' and 'outputs'.
|
||||
"""
|
||||
forward_only = not torch.is_grad_enabled()
|
||||
self.forward_only = not torch.is_grad_enabled()
|
||||
if optimizer is None:
|
||||
assert forward_only, "Optimizer should be passed when doing backward."
|
||||
assert self.forward_only, "Optimizer should be passed when doing backward."
|
||||
|
||||
self.load_batch(data_iter)
|
||||
num_model_chunks = len(model_chunk)
|
||||
|
||||
# num_warmup_microbatches is the step when not all the processes are working
|
||||
num_microbatches = self.num_microbatches * num_model_chunks
|
||||
if forward_only:
|
||||
num_warmup_microbatches = num_microbatches
|
||||
if self.forward_only:
|
||||
result = self.run_forward_only(model_chunk, data_iter, criterion, return_loss, return_outputs)
|
||||
else:
|
||||
num_warmup_microbatches = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
|
||||
num_warmup_microbatches += (num_model_chunks - 1) * self.stage_manager.num_stages
|
||||
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
|
||||
result = self.run_forward_backward(
|
||||
model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs
|
||||
)
|
||||
|
||||
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
|
||||
|
||||
# Input, output tensors only need to be saved when doing backward passes
|
||||
input_objs = None
|
||||
output_objs = None
|
||||
|
||||
if not forward_only:
|
||||
input_objs = [[] for _ in range(num_model_chunks)]
|
||||
output_objs = [[] for _ in range(num_model_chunks)]
|
||||
|
||||
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
|
||||
|
||||
if return_loss and self.stage_manager.is_last_stage():
|
||||
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
else:
|
||||
accum_loss = None
|
||||
|
||||
# for ranks except the first one, get into recv state
|
||||
# print(self.stage_manager.stage,num_microbatches, num_warmup_microbatches, num_microbatches_remaining)
|
||||
input_obj = self.recv_forward(0)
|
||||
input_objs[0].append(input_obj)
|
||||
# Run warmup forward passes.
|
||||
for i in range(num_warmup_microbatches):
|
||||
model_chunk_id = self.get_model_chunk_id(i, forward=True)
|
||||
|
||||
# recv first on first rank to avoid sending or recving at the same time
|
||||
if self.stage_manager.is_first_stage():
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
if not forward_only:
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
else:
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
if not forward_only:
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
if num_microbatches_remaining == 0 and i + 1 == num_warmup_microbatches:
|
||||
break
|
||||
else:
|
||||
model_chunk_id = self.get_model_chunk_id(i + 1, forward=True)
|
||||
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
if not forward_only:
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for i in range(num_microbatches_remaining):
|
||||
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches, forward=True)
|
||||
last_iteration = i == (num_microbatches_remaining - 1)
|
||||
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
if forward_only:
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
|
||||
if not last_iteration:
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
|
||||
else:
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
# Add input_obj and output_obj to end of list.
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
|
||||
model_chunk_id = self.get_model_chunk_id(i, forward=False)
|
||||
output_obj_grad = self.recv_backward(model_chunk_id)
|
||||
|
||||
# Pop output_obj and output_obj from the start of the list for
|
||||
# the backward pass.
|
||||
input_obj = input_objs[model_chunk_id].pop(0)
|
||||
output_obj = output_objs[model_chunk_id].pop(0)
|
||||
|
||||
# backward
|
||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||
|
||||
if last_iteration:
|
||||
input_obj = None
|
||||
else:
|
||||
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches + 1, forward=True)
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
model_chunk_id = self.get_model_chunk_id(i, forward=False)
|
||||
self.send_backward(model_chunk_id, input_obj_grad)
|
||||
|
||||
# Run cooldown backward passes.
|
||||
if not forward_only:
|
||||
for i in range(num_microbatches_remaining, num_microbatches):
|
||||
model_chunk_id = self.get_model_chunk_id(i, forward=False)
|
||||
# print(f"{self.stage_manager.stage}/{model_chunk_id}: {len(input_objs[model_chunk_id])} {len(output_objs[model_chunk_id])} {i}")
|
||||
input_obj = input_objs[model_chunk_id].pop(0)
|
||||
output_obj = output_objs[model_chunk_id].pop(0)
|
||||
|
||||
output_obj_grad = self.recv_backward(model_chunk_id)
|
||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||
self.send_backward(model_chunk_id, input_obj_grad)
|
||||
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
return {"loss": accum_loss, "outputs": outputs}
|
||||
return result
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Iterable, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
@@ -8,7 +8,7 @@ from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
from ._utils import (
|
||||
@@ -30,6 +30,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
stage_manager: PipelineStageManager,
|
||||
num_microbatches: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
) -> None:
|
||||
"""1F1B pipeline schedule.
|
||||
|
||||
@@ -42,13 +43,21 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
assert (
|
||||
num_microbatches is not None or microbatch_size is not None
|
||||
), "Either num_microbatches or microbatch_size should be provided"
|
||||
|
||||
self.comm = PipelineP2PCommunication(stage_manager)
|
||||
self.num_microbatches = num_microbatches
|
||||
self.microbatch_size = microbatch_size
|
||||
self.batch: Optional[Any] = None
|
||||
self.batch_size: Optional[int] = None
|
||||
self.last_batch_size: Optional[int] = None
|
||||
self.microbatch_offset: Optional[int] = None
|
||||
self._use_microbatch_size = num_microbatches is None
|
||||
|
||||
# P2PMeta cache
|
||||
self.enable_metadata_cache = enable_metadata_cache
|
||||
self.send_tensor_metadata = True
|
||||
self.send_grad_metadata = True
|
||||
self.tensor_metadata_recv = None
|
||||
self.grad_metadata_recv = None
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
@@ -60,24 +69,45 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
batch = next(data_iter)
|
||||
if device is not None:
|
||||
batch = tree_map(partial(to_device, device=device), batch)
|
||||
|
||||
self.microbatch_offset = 0
|
||||
self.batch = batch
|
||||
self.batch_size = get_batch_size(batch)
|
||||
self.microbatch_offset = 0
|
||||
if not self._use_microbatch_size:
|
||||
assert (
|
||||
self.batch_size % self.num_microbatches == 0
|
||||
), "Batch size should divided by the number of microbatches"
|
||||
|
||||
if self.microbatch_size is None:
|
||||
assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by # microbatches"
|
||||
self.microbatch_size = self.batch_size // self.num_microbatches
|
||||
else:
|
||||
if self.num_microbatches is None:
|
||||
assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size"
|
||||
self.num_microbatches = self.batch_size // self.microbatch_size
|
||||
|
||||
if not self.forward_only:
|
||||
assert self.last_batch_size is None or self.last_batch_size == self.batch_size
|
||||
assert self.batch_size == self.microbatch_size * self.num_microbatches
|
||||
|
||||
assert (
|
||||
self.num_microbatches >= self.stage_manager.num_stages
|
||||
), "Number of microbatch should be larger than number of stages"
|
||||
|
||||
if self.forward_only:
|
||||
self.num_microbatches = (self.batch_size - 1) // self.microbatch_size + 1
|
||||
# NOTE: disable metadata cache when batch size changes (not valid anymore)
|
||||
if self.batch_size != self.last_batch_size:
|
||||
self.enable_metadata_cache = False
|
||||
self.send_tensor_metadata = True
|
||||
self.send_grad_metadata = True
|
||||
self.tensor_metadata_recv = None
|
||||
self.grad_metadata_recv = None
|
||||
|
||||
self.last_batch_size = self.batch_size
|
||||
|
||||
def load_micro_batch(self) -> Any:
|
||||
"""Load a micro batch from the current batch.
|
||||
|
||||
Returns:
|
||||
Any: Micro batch.
|
||||
"""
|
||||
assert self.microbatch_offset <= self.batch_size, "Microbatches exhausted"
|
||||
micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size)
|
||||
self.microbatch_offset += self.microbatch_size
|
||||
return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
|
||||
@@ -92,12 +122,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
"""
|
||||
if self.stage_manager.is_first_stage():
|
||||
input_tensor = None
|
||||
else:
|
||||
input_tensor = self.comm.recv_forward(prev_rank)
|
||||
if not self.stage_manager.is_first_stage():
|
||||
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
return input_tensor
|
||||
return input_tensor
|
||||
|
||||
def recv_backward(self, next_rank: int = None) -> Any:
|
||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||
@@ -109,14 +139,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
Returns:
|
||||
Any: The input gradient tensor or gradient tensor list.
|
||||
"""
|
||||
if self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = None
|
||||
else:
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank)
|
||||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
return output_tensor_grad
|
||||
|
||||
def send_forward(self, output_object: Any, next_rank: int = None) -> None:
|
||||
def send_forward(self, output_tensor: Any, next_rank: int = None) -> None:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
@@ -125,20 +155,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
self.comm.send_forward(output_object, next_rank)
|
||||
self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
|
||||
def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any:
|
||||
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
Args:
|
||||
output_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
return self.comm.send_forward_recv_backward(output_object, next_rank)
|
||||
|
||||
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
|
||||
def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
@@ -147,9 +167,38 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
"""
|
||||
if not self.stage_manager.is_first_stage():
|
||||
self.comm.send_backward(input_object, prev_rank)
|
||||
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
|
||||
def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any:
|
||||
def send_forward_recv_backward(
|
||||
self, output_tensor: Any, next_rank: int = None, send_prior_fallback: Optional[bool] = None
|
||||
) -> Any:
|
||||
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
Args:
|
||||
output_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
if not self.send_tensor_metadata and self.grad_metadata_recv is not None:
|
||||
send_prior_fallback = None # must not fallback
|
||||
output_tensor_grad = self.comm.send_forward_recv_backward(
|
||||
output_tensor,
|
||||
next_rank,
|
||||
send_metadata=self.send_tensor_metadata,
|
||||
metadata_recv=self.grad_metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
def send_backward_recv_forward(
|
||||
self, input_tensor_grad: Any, prev_rank: int = None, send_prior_fallback: Optional[bool] = None
|
||||
) -> Any:
|
||||
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
@@ -158,23 +207,20 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
prev_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_first_stage():
|
||||
return self.comm.send_backward_recv_forward(output_object, prev_rank)
|
||||
if not self.send_grad_metadata and self.tensor_metadata_recv is not None:
|
||||
send_prior_fallback = None # must not fallback
|
||||
input_tensor = self.comm.send_backward_recv_forward(
|
||||
input_tensor_grad,
|
||||
prev_rank,
|
||||
send_metadata=self.send_grad_metadata,
|
||||
metadata_recv=self.tensor_metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any:
|
||||
"""Sends the input tensor to the next stage and copy the input tensor from the previous stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
Args:
|
||||
input_object (Any): Object to be sent.
|
||||
prev_rank (int, optional): The previous rank of the recipient of the tensor.
|
||||
next_rank (int, optional): The next rank of the recipient of the tensor.
|
||||
"""
|
||||
if self.stage_manager.is_first_stage():
|
||||
return self.comm.send_forward(input_object, next_rank)
|
||||
elif self.stage_manager.is_last_stage():
|
||||
return self.comm.recv_forward(prev_rank)
|
||||
else:
|
||||
return self.comm.send_forward_recv_forward(input_object, prev_rank, next_rank)
|
||||
return input_tensor
|
||||
|
||||
def forward_step(
|
||||
self,
|
||||
@@ -254,7 +300,38 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
input_obj_grad[k] = v.grad
|
||||
return input_obj_grad
|
||||
|
||||
def forward_backward_step(
|
||||
def run_forward_only(
|
||||
self,
|
||||
model: Module,
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False,
|
||||
) -> Dict:
|
||||
"""
|
||||
Runs forward only schedule, with communication between pipeline stages.
|
||||
"""
|
||||
assert self.forward_only
|
||||
|
||||
self.load_batch(data_iter)
|
||||
|
||||
accum_loss = None
|
||||
if return_loss and self.stage_manager.is_last_stage():
|
||||
accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device())
|
||||
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
|
||||
|
||||
for _ in range(self.num_microbatches):
|
||||
input_obj = self.recv_forward()
|
||||
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
|
||||
self.send_forward(output_obj)
|
||||
|
||||
if outputs is not None:
|
||||
if isinstance(model, ModelWrapper):
|
||||
model = model.unwrap()
|
||||
outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0))
|
||||
return {"loss": accum_loss, "outputs": outputs}
|
||||
|
||||
def run_forward_backward(
|
||||
self,
|
||||
model: Module,
|
||||
data_iter: Iterable,
|
||||
@@ -262,23 +339,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False,
|
||||
) -> dict:
|
||||
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
|
||||
Args:
|
||||
model (Module): Model to be trained.
|
||||
data_iter (Iterable): Data iterator.
|
||||
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
|
||||
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
|
||||
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
|
||||
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
|
||||
|
||||
Returns:
|
||||
dict: A dict with keys: 'loss' and 'outputs'.
|
||||
) -> Dict:
|
||||
"""
|
||||
forward_only = not torch.is_grad_enabled()
|
||||
if optimizer is None:
|
||||
assert forward_only, "Optimizer should be passed when doing backward."
|
||||
Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
"""
|
||||
assert not self.forward_only
|
||||
|
||||
self.load_batch(data_iter)
|
||||
|
||||
@@ -288,30 +353,20 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
|
||||
|
||||
# Input, output tensors only need to be saved when doing backward passes
|
||||
input_objs = None
|
||||
output_objs = None
|
||||
input_objs, output_objs = [], []
|
||||
|
||||
if not forward_only:
|
||||
input_objs = []
|
||||
output_objs = []
|
||||
|
||||
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
|
||||
accum_loss = None
|
||||
if return_loss and self.stage_manager.is_last_stage():
|
||||
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
else:
|
||||
accum_loss = None
|
||||
accum_loss = torch.scalar_tensor(0, device=get_current_device())
|
||||
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
|
||||
|
||||
# Run warmup forward passes.
|
||||
for i in range(num_warmup_microbatches):
|
||||
input_obj = self.recv_forward()
|
||||
|
||||
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
|
||||
|
||||
self.send_forward(output_obj)
|
||||
|
||||
if not forward_only:
|
||||
input_objs.append(input_obj)
|
||||
output_objs.append(output_obj)
|
||||
input_objs.append(input_obj)
|
||||
output_objs.append(output_obj)
|
||||
|
||||
# Before running 1F1B, need to receive first forward tensor.
|
||||
# If all microbatches are run in warmup / cooldown phase, then no need to
|
||||
@@ -324,44 +379,72 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
last_iteration = i == (num_microbatches_remaining - 1)
|
||||
|
||||
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
|
||||
if forward_only:
|
||||
self.send_forward(output_obj)
|
||||
output_obj_grad = self.send_forward_recv_backward(
|
||||
output_obj, send_prior_fallback=self.stage_manager.stage % 2 == 0
|
||||
)
|
||||
# Add input_obj and output_obj to end of list.
|
||||
input_objs.append(input_obj)
|
||||
output_objs.append(output_obj)
|
||||
|
||||
if not last_iteration:
|
||||
input_obj = self.recv_forward()
|
||||
else:
|
||||
# TODO adjust here
|
||||
self.send_forward(output_obj)
|
||||
output_obj_grad = self.recv_backward()
|
||||
# Pop output_obj and output_obj from the start of the list for
|
||||
# the backward pass.
|
||||
input_obj = input_objs.pop(0)
|
||||
output_obj = output_objs.pop(0)
|
||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||
|
||||
# Add input_obj and output_obj to end of list.
|
||||
input_objs.append(input_obj)
|
||||
output_objs.append(output_obj)
|
||||
|
||||
# Pop output_obj and output_obj from the start of the list for
|
||||
# the backward pass.
|
||||
input_obj = input_objs.pop(0)
|
||||
output_obj = output_objs.pop(0)
|
||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||
|
||||
if last_iteration:
|
||||
input_obj = None
|
||||
else:
|
||||
input_obj = self.recv_forward()
|
||||
if last_iteration:
|
||||
self.send_backward(input_obj_grad)
|
||||
else:
|
||||
input_obj = self.send_backward_recv_forward(
|
||||
input_obj_grad, send_prior_fallback=self.stage_manager.stage % 2 == 0
|
||||
)
|
||||
|
||||
# Run cooldown backward passes.
|
||||
if not forward_only:
|
||||
for i in range(num_warmup_microbatches):
|
||||
input_obj = input_objs.pop(0)
|
||||
output_obj = output_objs.pop(0)
|
||||
for i in range(num_warmup_microbatches):
|
||||
input_obj = input_objs.pop(0)
|
||||
output_obj = output_objs.pop(0)
|
||||
|
||||
output_obj_grad = self.recv_backward()
|
||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||
self.send_backward(input_obj_grad)
|
||||
output_obj_grad = self.recv_backward()
|
||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||
self.send_backward(input_obj_grad)
|
||||
|
||||
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
|
||||
|
||||
if outputs is not None:
|
||||
if isinstance(model, ModelWrapper):
|
||||
model = model.unwrap()
|
||||
outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0))
|
||||
return {"loss": accum_loss, "outputs": outputs}
|
||||
|
||||
def forward_backward_step(
|
||||
self,
|
||||
model: Module,
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Args:
|
||||
model (Module): Model to be trained.
|
||||
data_iter (Iterable): Data iterator.
|
||||
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
|
||||
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
|
||||
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
|
||||
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing loss and outputs.
|
||||
"""
|
||||
|
||||
self.forward_only = not torch.is_grad_enabled()
|
||||
if optimizer is None:
|
||||
assert self.forward_only, "Optimizer should be passed when doing backward."
|
||||
|
||||
if self.forward_only:
|
||||
result = self.run_forward_only(model, data_iter, criterion, return_loss, return_outputs)
|
||||
else:
|
||||
result = self.run_forward_backward(model, data_iter, criterion, optimizer, return_loss, return_outputs)
|
||||
|
||||
return result
|
||||
|
Reference in New Issue
Block a user