ColossalAI/colossalai/pipeline/schedule/one_f_one_b.py
Wang Binluo eea37da6fa
[fp8] Merge feature/fp8_comm to main branch of Colossalai (#6016)
* add SimPO

* fix dataloader

* remove debug code

* add orpo

* fix style

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix torch colossalai version

* update transformers version

* [shardformer] DeepseekMoE support (#5871)

* [Feature] deepseek moe expert parallel implement

* [misc] fix typo, remove redundant file (#5867)

* [misc] fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] deepseek support & unit test

* [misc] remove debug code & useless print

* [misc] fix typos (#5872)

* [Feature] remove modeling file, use auto config. (#5884)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [Deepseek] remove redundant code (#5888)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [Feature/deepseek] resolve comment. (#5889)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [misc] mv module replacement into if branch

* [misc] add some warning message and modify some code in unit test

* [misc] fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support

* [HotFix] CI,import,requirements-test for #5838 (#5892)

* [Hot Fix] CI,import,requirements-test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Enable PP + SP for llama (#5868)

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use a one cross entropy func for all shardformer models

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)

* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

* fix style

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix eval

* hotfix citation

* [zero] support all-gather overlap (#5898)

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api

* fix orpo cross entropy loss

* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)

* Remove unnecessary calls to deepcopy

* Build DimSpec's difference dict only once

This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.

* Fix documentation of DimSpec's difference method

* [ShardFormer] fix qwen2 sp (#5903)

* [compatibility] support torch 2.2 (#5875)

* Support Pytorch 2.2.2

* keep build_on_pr file and update .compatibility

* fix object_to_tensor usage when torch>=2.3.0 (#5820)

* [misc] support torch2.3 (#5893)

* [misc] support torch2.3

* [devops] update compatibility ci

* [devops] update compatibility ci

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] remove debug

* [devops] remove debug

* [release] update version (#5912)

* [plugin] support all-gather overlap for hybrid parallel (#5919)

* [plugin] fixed all-gather overlap support for hybrid parallel

* add kto

* fix style, add kto data sample

* [Examples] Add lazy init to OPT and GPT examples (#5924)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [ColossalChat] Hotfix for ColossalChat (#5910)

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* fix ddp issue

* add Qwen 1.5 32B

* refactor tokenization

* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)

* cannot access local variable 'default_conversation' where it is not associated with a value

set default value for 'default_conversation'

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix test data

* refactor evaluation

* remove real data path

* remove real data path

* Add n_fused as an input from native_module (#5894)

* [FIX BUG] convert env param to int in (#5934)

* [Hotfix] Fix ZeRO typo #5936

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)

* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix style

* fix style

* fix style

* [shardformer] hotfix attn mask (#5945)

* [shardformer] hotfix attn mask (#5947)

* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement

* [zero] hotfix update master params (#5951)

* [release] update version (#5952)

* [Chat] Fix lora (#5946)

* fix merging

* remove filepath

* fix style

* Update README.md (#5958)

* [hotfix] Remove unused plan section (#5957)

* remove readme

* fix readme

* update

* [test] add mixtral for sequence classification

* [test] add mixtral transformer test

* [moe] fix plugin

* [test] mixtra pp shard test

* [chore] handle non member group

* [zero] solve hang

* [test] pass mixtral shardformer test

* [moe] implement transit between non moe tp and ep

* [zero] solve hang

* [misc] solve booster hang by rename the variable

* solve hang when parallel mode = pp + dp

* [moe] implement submesh initialization

* [moe] add mixtral dp grad scaling when not all experts are activated

* [chore] manually revert unintended commit

* [chore] trivial fix

* [chore] arg pass & remove drop token

* [test] add mixtral modelling test

* [moe] implement tp

* [moe] test deepseek

* [moe] clean legacy code

* [Feature] MoE Ulysses Support (#5918)

* moe sp support

* moe sp bug solve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [chore] minor fix

* [moe] init moe plugin comm setting with sp

* moe sp + ep bug fix

* [moe] finalize test (no pp)

* [moe] full test for deepseek and mixtral (pp + sp to fix)

* [chore] minor fix after rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [chore] solve moe ckpt test failure and some other arg pass failure

* [moe] remove ops

* [test] fix test: test_zero1_2

* [bug] fix: somehow logger hangs the program

* [moe] deepseek moe sp support

* [test] add check

* [deepseek] replace attn (a workaround for bug in transformers)

* [misc] skip redunant test

* [misc] remove debug/print code

* [moe] refactor mesh assignment

* Revert "[moe] implement submesh initialization"

This reverts commit 2f9bce6686.

* [chore] change moe_pg_mesh to private

* [misc] remove incompatible test config

* [misc] fix ci failure: change default value to false in moe plugin

* [misc] remove useless condition

* [chore] docstring

* [moe] remove force_overlap_comm flag and add warning instead

* [doc] add MoeHybridParallelPlugin docstring

* [moe] solve dp axis issue

* [chore] remove redundant test case, print string & reduce test tokens

* [feat] Dist Loader for Eval (#5950)

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tp error

* remove unused parameters

* remove unused

* update inference

* update docs

* update inference

---------

Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [lora] lora support hybrid parallel plugin (#5956)

* lora support hybrid plugin

* fix

* fix

* fix

* fix

* Support overall loss, update KTO logging

* [Docs] clarify launch port

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Hotfix] README link (#5966)

* update ignore

* update readme

* run style

* update readme

* [Hotfix] Avoid fused RMSnorm import error without apex (#5985)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Chat] fix readme (#5989)

* fix readme

* fix readme, tokenization fully tested

* fix readme, tokenization fully tested

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix sync condition (#6000)

* [plugin] add cast inputs option for zero (#6003)

* [pre-commit.ci] pre-commit autoupdate (#5995)

updates:
- [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](https://github.com/psf/black-pre-commit-mirror/compare/24.4.2...24.8.0)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] Bypass the huggingface bug to solve the mask mismatch problem (#5991)

* [Feature] Zigzag Ring attention (#5905)

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] update compatibility (#6008)

* [misc] update compatibility

* [misc] update requirements

* [devops] disable requirements cache

* [test] fix torch ddp test

* [test] fix rerun on address in use

* [test] fix lazy init

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix the merge

* overlap kv comm with output rescale (#6017)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* fix the merge

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix

* fix

* fix the merge

* fix

* [misc] Use dist logger in plugins (#6011)

* use dist logger in plugins

* remove trash

* print on rank 0

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* fix

* fix

* fix

* fix

* fix the merge

* fix

* fix

* fix

* fix

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
2024-08-22 09:21:34 +08:00

475 lines
20 KiB
Python

from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import torch
import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_map
from colossalai.accelerator import get_accelerator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline
from colossalai.utils import get_current_device
from ._utils import (
detach,
get_batch_size,
get_micro_batch,
merge_batch,
model_forward,
retain_grad,
to_device,
tree_map_hf,
)
from .base import PipelineSchedule
class OneForwardOneBackwardSchedule(PipelineSchedule):
def __init__(
self,
stage_manager: PipelineStageManager,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
fp8_communication: bool = False,
) -> None:
"""1F1B pipeline schedule.
Args:
stage_manager (PipelineStageManager): Pipeline stage manager
num_microbatches (Optional[int], optional): The number of microbatches. If not provided, it will be derived from microbatch size. Defaults to None.
microbatch_size (Optional[int], optional): Microbatch size. If num_microbatches is provided, this will be ignored. Defaults to None.
"""
super().__init__(stage_manager)
assert (
num_microbatches is not None or microbatch_size is not None
), "Either num_microbatches or microbatch_size should be provided"
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=False)
self.num_microbatches = num_microbatches
self.microbatch_size = microbatch_size
self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None
self.last_batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None
# P2PMeta cache
self.enable_metadata_cache = enable_metadata_cache
self.send_tensor_metadata = True
self.send_grad_metadata = True
self.tensor_metadata_recv = None
self.grad_metadata_recv = None
self.fp8_communication = fp8_communication
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
Args:
data_iter (Iterable): Data iterator.
device (Optional[torch.device], optional): Target device. Defaults to None.
"""
batch = next(data_iter)
if device is not None:
batch = tree_map(partial(to_device, device=device), batch)
self.microbatch_offset = 0
self.batch = batch
self.batch_size = get_batch_size(batch)
if self.microbatch_size is None:
assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by # microbatches"
self.microbatch_size = self.batch_size // self.num_microbatches
if self.num_microbatches is None:
assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size"
self.num_microbatches = self.batch_size // self.microbatch_size
if not self.forward_only:
assert self.last_batch_size is None or self.last_batch_size == self.batch_size
assert self.batch_size == self.microbatch_size * self.num_microbatches
assert (
self.num_microbatches >= self.stage_manager.num_stages
), "Number of microbatch should be larger than number of stages"
if self.forward_only:
self.num_microbatches = (self.batch_size - 1) // self.microbatch_size + 1
# NOTE: disable metadata cache when batch size changes (not valid anymore)
if self.batch_size != self.last_batch_size:
self.enable_metadata_cache = False
self.send_tensor_metadata = True
self.send_grad_metadata = True
self.tensor_metadata_recv = None
self.grad_metadata_recv = None
self.last_batch_size = self.batch_size
def load_micro_batch(self) -> Any:
"""Load a micro batch from the current batch.
Returns:
Any: Micro batch.
"""
assert self.microbatch_offset <= self.batch_size, "Microbatches exhausted"
micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size)
self.microbatch_offset += self.microbatch_size
return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
def recv_forward(self, prev_rank: int = None) -> Any:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
For 1F1B.
Args:
prev_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input tensor or input tensor list.
"""
if not self.stage_manager.is_first_stage():
input_tensor, _ = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
if self.fp8_communication:
cast_from_fp8_pipeline(input_tensor)
return input_tensor
def recv_backward(self, next_rank: int = None) -> Any:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
For 1F1B.
Args:
next_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input gradient tensor or gradient tensor list.
"""
if not self.stage_manager.is_last_stage():
output_tensor_grad, _ = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
if self.fp8_communication:
cast_from_fp8_pipeline(output_tensor_grad)
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad
def send_forward(self, output_tensor: Any, next_rank: int = None) -> None:
"""Sends the input tensor to the next stage in pipeline.
For 1F1B.
Args:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
if self.fp8_communication:
cast_to_fp8_pipeline(output_tensor)
self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
self.send_tensor_metadata = not self.enable_metadata_cache
if self.fp8_communication:
cast_from_fp8_pipeline(output_tensor, del_metadata=False)
def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
For 1F1B.
Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
"""
if not self.stage_manager.is_first_stage():
if self.fp8_communication:
cast_to_fp8_pipeline(input_tensor_grad)
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
self.send_grad_metadata = not self.enable_metadata_cache
if self.fp8_communication:
cast_from_fp8_pipeline(input_tensor_grad, del_metadata=False)
def send_forward_recv_backward(self, output_tensor: Any, send_first: Optional[bool] = None) -> Any:
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
For 1F1B.
Args:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
if not self.send_tensor_metadata and self.grad_metadata_recv is not None:
send_first = None
if self.fp8_communication:
cast_to_fp8_pipeline(output_tensor)
output_tensor_grad, _ = self.comm.send_forward_recv_backward(
output_tensor,
send_metadata=self.send_tensor_metadata,
metadata_recv=self.grad_metadata_recv,
send_first=send_first,
)
self.send_tensor_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
if self.fp8_communication:
cast_from_fp8_pipeline(output_tensor, del_metadata=False)
cast_from_fp8_pipeline(output_tensor_grad)
return output_tensor_grad
def send_backward_recv_forward(self, input_tensor_grad: Any, send_first: Optional[bool] = None) -> Any:
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
For 1F1B.
Args:
output_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_first_stage():
if not self.send_grad_metadata and self.tensor_metadata_recv is not None:
send_first = None # must not fallback
if self.fp8_communication:
cast_to_fp8_pipeline(input_tensor_grad)
input_tensor, _ = self.comm.send_backward_recv_forward(
input_tensor_grad,
send_metadata=self.send_grad_metadata,
metadata_recv=self.tensor_metadata_recv,
send_first=send_first,
)
self.send_grad_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
if self.fp8_communication:
cast_from_fp8_pipeline(input_tensor)
cast_from_fp8_pipeline(input_tensor_grad, del_metadata=False)
return input_tensor
def forward_step(
self,
model: Module,
input_obj: Optional[dict],
criterion: Callable,
accum_loss: Optional[torch.Tensor] = None,
outputs: Optional[List[Any]] = None,
) -> Union[torch.Tensor, dict]:
"""Forward one step of the pipeline
Args:
model (Module): Model to be run
input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None.
criterion (Callable): Criterion to calculate loss.
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
Returns:
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
micro_batch = self.load_micro_batch()
# for the first stage, input_obj is None
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
output_obj = model_forward(model, micro_batch, input_obj)
if self.stage_manager.is_last_stage():
loss = criterion(output_obj, micro_batch) / self.num_microbatches
if accum_loss is not None:
accum_loss.add_(loss.detach())
if outputs is not None:
outputs.append(tree_map_hf(detach, output_obj))
return loss
else:
return output_obj
def backward_step(
self,
optimizer: OptimizerWrapper,
input_obj: Optional[dict],
output_obj: Union[dict, torch.Tensor],
output_obj_grad: Optional[dict],
) -> Optional[dict]:
"""Backward one step of the pipeline
Args:
optimizer (OptimizerWrapper): Optimizer to update the model
input_obj (Optional[dict]): Output of the previous stage. If it is the first stage, the `input_obj` is None.
output_obj (Union[dict, torch.Tensor]): Output of the current stage. If it is the last stage, the output is the loss (Tensor).
output_obj_grad (dict): Gradient of the `output_obj`. If it is the last stage, the `output_obj_grad` is None.
Returns:
Optional[dict]: Gradient of the `input_obj`. If it is the first stage, the `input_obj_grad` is None.
"""
# Retain the grad on the input_obj.
tree_map(retain_grad, input_obj)
# Backward pass.
if output_obj_grad is None:
optimizer.backward(output_obj)
else:
if "backward_tensor_keys" not in output_obj:
for k, grad in output_obj_grad.items():
optimizer.backward_by_grad(output_obj[k], grad)
else:
for k, grad in output_obj_grad.items():
output_obj[k].grad = grad
for k in output_obj["backward_tensor_keys"]:
tensor_to_backward = output_obj[k]
optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
# Collect the grad of the input_obj.
input_obj_grad = None
if input_obj is not None:
input_obj_grad = {}
for k, v in input_obj.items():
if isinstance(v, torch.Tensor) and v.grad is not None:
input_obj_grad[k] = v.grad
return input_obj_grad
def run_forward_only(
self,
model: Module,
data_iter: Iterable,
criterion: Callable[..., Any],
return_loss: bool = False,
return_outputs: bool = False,
) -> Dict:
"""
Runs forward only schedule, with communication between pipeline stages.
"""
assert self.forward_only
self.load_batch(data_iter)
accum_loss = None
if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device())
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
for _ in range(self.num_microbatches):
input_obj = self.recv_forward()
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
self.send_forward(output_obj)
if outputs is not None:
if isinstance(model, ModelWrapper):
model = model.unwrap()
batch_size_dim = getattr(model, "batch_size_dim", 0)
outputs = merge_batch(outputs, batch_size_dim)
return {"loss": accum_loss, "outputs": outputs}
def run_forward_backward(
self,
model: Module,
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> Dict:
"""
Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
"""
assert not self.forward_only
self.load_batch(data_iter)
# num_warmup_microbatches is the step when not all the processes are working
num_warmup_microbatches = self.stage_manager.num_stages - self.stage_manager.stage - 1
num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches)
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
# Input, output tensors only need to be saved when doing backward passes
input_objs, output_objs = [], []
accum_loss = None
if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.scalar_tensor(0, device=get_current_device())
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_obj = self.recv_forward()
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
self.send_forward(output_obj)
input_objs.append(input_obj)
output_objs.append(output_obj)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
input_obj = self.recv_forward()
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
last_iteration = i == (num_microbatches_remaining - 1)
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
output_obj_grad = self.send_forward_recv_backward(output_obj, send_first=self.stage_manager.stage % 2 == 0)
# Add input_obj and output_obj to end of list.
input_objs.append(input_obj)
output_objs.append(output_obj)
# Pop output_obj and output_obj from the start of the list for
# the backward pass.
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
if last_iteration:
self.send_backward(input_obj_grad)
else:
input_obj = self.send_backward_recv_forward(
input_obj_grad, send_first=self.stage_manager.stage % 2 == 0
)
# Run cooldown backward passes.
for i in range(num_warmup_microbatches):
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
output_obj_grad = self.recv_backward()
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(input_obj_grad)
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
if outputs is not None:
if isinstance(model, ModelWrapper):
model = model.unwrap()
batch_size_dim = getattr(model, "batch_size_dim", 0)
outputs = merge_batch(outputs, batch_size_dim)
return {"loss": accum_loss, "outputs": outputs}
def forward_backward_step(
self,
model: Module,
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> dict:
"""
Args:
model (Module): Model to be trained.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns:
dict: Dictionary containing loss and outputs.
"""
self.forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert self.forward_only, "Optimizer should be passed when doing backward."
if self.forward_only:
result = self.run_forward_only(model, data_iter, criterion, return_loss, return_outputs)
else:
result = self.run_forward_backward(model, data_iter, criterion, optimizer, return_loss, return_outputs)
return result