mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-02 21:48:15 +00:00
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686
.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* Support overall loss, update KTO logging
* [Docs] clarify launch port
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Hotfix] README link (#5966)
* update ignore
* update readme
* run style
* update readme
* [Hotfix] Avoid fused RMSnorm import error without apex (#5985)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Chat] fix readme (#5989)
* fix readme
* fix readme, tokenization fully tested
* fix readme, tokenization fully tested
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix sync condition (#6000)
* [plugin] add cast inputs option for zero (#6003)
* [pre-commit.ci] pre-commit autoupdate (#5995)
updates:
- [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](https://github.com/psf/black-pre-commit-mirror/compare/24.4.2...24.8.0)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [misc] Bypass the huggingface bug to solve the mask mismatch problem (#5991)
* [Feature] Zigzag Ring attention (#5905)
* halfway
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* unified cross entropy func for all shardformer models
* remove redundant lines
* add basic ring attn; debug cross entropy
* fwd bwd logic complete
* fwd bwd logic complete; add experimental triton rescale
* precision tests passed
* precision tests passed
* fix typos and remove misc files
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* add sp_mode to benchmark; fix varlen interface
* update softmax_lse shape by new interface
* change tester name
* remove buffer clone; support packed seq layout
* add varlen tests
* fix typo
* all tests passed
* add dkv_group; fix mask
* remove debug statements
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [misc] update compatibility (#6008)
* [misc] update compatibility
* [misc] update requirements
* [devops] disable requirements cache
* [test] fix torch ddp test
* [test] fix rerun on address in use
* [test] fix lazy init
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix the merge
* fix the merge
* overlap kv comm with output rescale (#6017)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* fix the merge
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix the merge
* fix
* fix
* fix the merge
* fix
* [misc] Use dist logger in plugins (#6011)
* use dist logger in plugins
* remove trash
* print on rank 0
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* fix
* fix
* fix
* fix
* fix the merge
* fix
* fix
* fix
* fix
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
488 lines
18 KiB
Python
488 lines
18 KiB
Python
from contextlib import contextmanager
|
|
from typing import List, Optional, Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch import nn
|
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
|
from torch.distributed import ProcessGroup, get_world_size
|
|
|
|
from colossalai.accelerator import get_accelerator
|
|
|
|
|
|
class SeqParallelUtils:
|
|
@staticmethod
|
|
def marked_as_sp_partial_derived_param(param):
|
|
"""
|
|
Mark a parameter as partially derived in sequence parallelism.
|
|
|
|
Args:
|
|
param: The parameter to mark as partially derived.
|
|
"""
|
|
setattr(param, "partial_derived", True)
|
|
|
|
@staticmethod
|
|
def is_sp_partial_derived_param(param):
|
|
"""
|
|
Check if a parameter is marked as partially derived in sequence parallelism.
|
|
|
|
Args:
|
|
param: The parameter to check.
|
|
|
|
Returns:
|
|
bool: True if the parameter is marked as partially derived, False otherwise.
|
|
"""
|
|
return getattr(param, "partial_derived", False)
|
|
|
|
@staticmethod
|
|
def allreduce_partial_data_grad(
|
|
process_group: ProcessGroup,
|
|
model: nn.Module = None,
|
|
grads: List[torch.Tensor] = None,
|
|
):
|
|
"""
|
|
Allreduce partial derived gradients across the specified process group.
|
|
|
|
This function performs gradient synchronization for parameters that are marked as partially derived in sequence parallelism.
|
|
|
|
Args:
|
|
process_group (ProcessGroup): The process group for gradient synchronization.
|
|
model (nn.Module): The model from which gradients will be synchronized.
|
|
grads (List[torch.Tensor]): The list of gradients to be synchronized.
|
|
only_sp_partial (bool): Whether handle all the parameters or only parameters marked as partial derived.
|
|
Raises:
|
|
AssertionError: If both `model` and `grads` are provided or neither is provided.
|
|
"""
|
|
# Ensure that exactly one of `model` and `grads` is provided for gradient synchronization.
|
|
assert (model is not None) ^ (grads is not None), "Exactly one of model and grads must be not None."
|
|
|
|
# Get the size of the process group, which determines whether synchronization is needed.
|
|
group_size = get_world_size(process_group) if process_group is not None else 1
|
|
|
|
if group_size == 1:
|
|
# If the process group size is 1, no synchronization is required.
|
|
return
|
|
|
|
if model is not None:
|
|
# If `model` is provided, extract partial derived gradients from the model's parameters.
|
|
grads = []
|
|
|
|
for p in model.parameters():
|
|
if p.grad is not None:
|
|
if SeqParallelUtils.is_sp_partial_derived_param(p):
|
|
grads.append(p.grad.data)
|
|
|
|
# Flatten and reduce the gradients using the specified process group.
|
|
if len(grads) == 0:
|
|
return
|
|
coalesced = _flatten_dense_tensors(grads)
|
|
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=process_group)
|
|
|
|
# Unflatten the synchronized gradients and update the model's gradients.
|
|
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
|
buf.copy_(synced)
|
|
else:
|
|
# If `grads` are provided explicitly, synchronize those gradients directly.
|
|
coalesced = _flatten_dense_tensors(grads)
|
|
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=process_group)
|
|
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
|
buf.copy_(synced)
|
|
|
|
|
|
class Randomizer:
|
|
"""
|
|
Randomizer enables the program to be executed under a different seed within the context.
|
|
|
|
Example:
|
|
|
|
```python
|
|
randomizer = Randomizer(seed=1024)
|
|
|
|
with randomizer.fork():
|
|
# do something here with seed 1024
|
|
do_something()
|
|
```
|
|
|
|
Args:
|
|
seed (int): The random seed to set.
|
|
enable_cpu (bool): fork the CPU RNG state as well.
|
|
with_index (bool): whether to use the index of the randomizer.
|
|
"""
|
|
|
|
_INDEX = 0
|
|
|
|
def __init__(self, seed: int):
|
|
self.seed = seed
|
|
|
|
# Handle device rng state
|
|
# 1. get the current rng state
|
|
# 2. set the seed and store the rng state
|
|
# 3. recover the original rng state
|
|
device_original_rng_state = get_accelerator().get_rng_state()
|
|
get_accelerator().manual_seed(seed)
|
|
self.device_rng_state = get_accelerator().get_rng_state()
|
|
get_accelerator().set_rng_state(device_original_rng_state)
|
|
|
|
# to the same for cpu rng state
|
|
cpu_original_rng_state = torch.get_rng_state()
|
|
torch.manual_seed(seed)
|
|
self.cpu_rng_state = torch.get_rng_state()
|
|
torch.set_rng_state(cpu_original_rng_state)
|
|
|
|
def _set_device_rng_state(self, rng_state):
|
|
get_accelerator().set_rng_state(rng_state)
|
|
|
|
def _get_device_rng_state(self):
|
|
current_state = get_accelerator().get_rng_state()
|
|
return current_state
|
|
|
|
def _set_cpu_rng_state(self, rng_state):
|
|
torch.set_rng_state(rng_state)
|
|
|
|
def _get_cpu_rng_state(self):
|
|
current_state = torch.get_rng_state()
|
|
return current_state
|
|
|
|
@contextmanager
|
|
def fork_rng(self, enable_cpu: bool = False):
|
|
"""
|
|
This is a context manager to change the dropout state and recover the original state.
|
|
|
|
Usage:
|
|
::
|
|
>>> with _seed_manager.dropout_mode():
|
|
>>> input = super().forward(input)
|
|
"""
|
|
try:
|
|
current_device_rng_state = self._get_device_rng_state()
|
|
self._set_device_rng_state(self.device_rng_state)
|
|
|
|
if enable_cpu:
|
|
current_cpu_rng_state = self._get_cpu_rng_state()
|
|
self._set_cpu_rng_state(self.cpu_rng_state)
|
|
yield
|
|
finally:
|
|
self.device_rng_state = self._get_device_rng_state()
|
|
self._set_device_rng_state(current_device_rng_state)
|
|
|
|
if enable_cpu:
|
|
self.cpu_rng_state = self._get_cpu_rng_state()
|
|
self._set_cpu_rng_state(current_cpu_rng_state)
|
|
|
|
@staticmethod
|
|
def index():
|
|
"""
|
|
Return the index of the randomizer. The index is useful when the user wants
|
|
to introduce some randomness in the program.
|
|
|
|
Note:
|
|
The index will increment by one each time this method is called.
|
|
|
|
Example:
|
|
|
|
```python
|
|
# assume we need a randomizer to init the weight of different layers
|
|
# we can use the index of the randomizer to do so that
|
|
# each layer has its own randomizer with a different seed
|
|
base_seed = torch.random.initial_seed()
|
|
seed = base_seed + Randomizer.index()
|
|
randomizer = Randomizer(seed)
|
|
|
|
with randomizer.fork():
|
|
init_weights()
|
|
```
|
|
|
|
"""
|
|
idx = Randomizer._INDEX
|
|
return idx
|
|
|
|
@staticmethod
|
|
def increment_index():
|
|
"""
|
|
Increment the index of the randomizer by one.
|
|
"""
|
|
Randomizer._INDEX += 1
|
|
|
|
@staticmethod
|
|
def reset_index():
|
|
"""
|
|
Reset the index to zero.
|
|
"""
|
|
Randomizer._INDEX = 0
|
|
|
|
@staticmethod
|
|
def is_randomizer_index_synchronized(process_group: ProcessGroup = None):
|
|
"""
|
|
Return whether the randomizer index is synchronized across processes.
|
|
"""
|
|
index = Randomizer.index()
|
|
if dist.is_initialized():
|
|
# convert the index to tensor
|
|
index_tensor = torch.tensor(index, dtype=torch.int32, device=get_accelerator().get_current_device())
|
|
|
|
# all gather the index
|
|
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
|
|
dist.all_gather(gathered_index, index_tensor, process_group)
|
|
|
|
# make sure all the gathered index are the same
|
|
for i in range(1, dist.get_world_size(process_group)):
|
|
if gathered_index[i] != gathered_index[0]:
|
|
return False
|
|
|
|
return True
|
|
|
|
@staticmethod
|
|
def synchronize_index(process_group: ProcessGroup = None):
|
|
"""
|
|
All gather the index and pick the largest value.
|
|
"""
|
|
index = Randomizer.index()
|
|
|
|
if dist.is_initialized():
|
|
# convert the index to tensor
|
|
index_tensor = torch.tensor(index, dtype=torch.int32, device=get_accelerator().get_current_device())
|
|
|
|
# all gather the index
|
|
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
|
|
dist.all_gather(gathered_index, index_tensor, process_group)
|
|
|
|
# pick the largest index
|
|
for i in range(1, dist.get_world_size(process_group)):
|
|
if gathered_index[i] > index_tensor:
|
|
index_tensor = gathered_index[i]
|
|
|
|
# set the index
|
|
Randomizer._INDEX = index_tensor.item()
|
|
|
|
|
|
def create_randomizer_with_offset(
|
|
seed: int, process_group: ProcessGroup = None, offset_by_rank: bool = True, offset_by_index: bool = True
|
|
):
|
|
"""
|
|
Create a randomizer with an offset. The offset is equal to the rank of the process and the index of the randomizer.
|
|
|
|
Args:
|
|
seed (int): The base random seed to set.
|
|
process_group (ProcessGroup): the process group to get the rank from.
|
|
offset_by_rank (bool): whether to offset by the rank of the process, i.e., the rank of the process will be added to the seed. Default: True.
|
|
offset_by_index (bool): whether to offset by the index of the randomizer, i.e., the index of the randomizer will be added to the seed. Default: True.
|
|
|
|
Returns:
|
|
Randomizer: the randomizer with offset.
|
|
"""
|
|
base_seed = seed
|
|
|
|
if offset_by_rank and dist.is_initialized():
|
|
rank = dist.get_rank(process_group)
|
|
base_seed += rank
|
|
|
|
if offset_by_index:
|
|
# check if the randomizer index is synchronized
|
|
is_synchronized = Randomizer.is_randomizer_index_synchronized(process_group)
|
|
assert is_synchronized, (
|
|
"We detect that the randomizer index is not synchronized across processes."
|
|
"This is not allowed when we want to create a randomizer with offset by index."
|
|
"Please call Randomizer.synchronize_index() first."
|
|
)
|
|
|
|
base_seed += Randomizer.index()
|
|
Randomizer.increment_index()
|
|
|
|
return Randomizer(seed=base_seed)
|
|
|
|
|
|
def split_batch_zigzag(
|
|
batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False
|
|
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
|
"""
|
|
Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask
|
|
in the causal setting will result in the preceding ranks having much less workload.
|
|
We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2).
|
|
For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |.
|
|
|
|
Args:
|
|
batch (List[torch.Tensor] or Tensor): The input tensor(s) to split.
|
|
sp_group (ProcessGroup): The process group for sequence parallelism.
|
|
seq_dim (int): The sequence dimension to split.
|
|
is_label (bool): If True, mask and shift the tensor for next token prediction.
|
|
|
|
"""
|
|
sp_size = dist.get_world_size(sp_group)
|
|
sp_rank = dist.get_rank(sp_group)
|
|
if isinstance(batch, torch.Tensor):
|
|
batch = [batch]
|
|
seq_dim = seq_dim if seq_dim != -1 else batch[0].dim() - 1
|
|
|
|
if sp_size > 1:
|
|
for idx, tensor in enumerate(batch):
|
|
assert (
|
|
tensor.shape[seq_dim] // (sp_size * 2) > 1 and tensor.shape[seq_dim] % (sp_size * 2) == 0
|
|
), f"Bro, the seq length {tensor.shape[seq_dim]} for tensor {idx} can't be split by {sp_size * 2}!"
|
|
if is_label:
|
|
assert tensor.dim() == 2, "Label shape should be (B, Seqlen)"
|
|
tensor = torch.cat([tensor[:, 1:], torch.full_like(tensor[:, :1], -100)], dim=1)
|
|
|
|
tensor = tensor.view(
|
|
*tensor.shape[:seq_dim],
|
|
2 * sp_size,
|
|
tensor.shape[seq_dim] // (2 * sp_size),
|
|
*tensor.shape[seq_dim + 1 :],
|
|
)
|
|
indices = torch.tensor([sp_rank, 2 * sp_size - 1 - sp_rank], device=tensor.device)
|
|
tensor = tensor.index_select(seq_dim, indices).contiguous()
|
|
# (B, 2, Sq // (2 * sp_size), ...) -> (B, Sq // sp_size, ...)
|
|
batch[idx] = tensor.view(*tensor.shape[:seq_dim], -1, *tensor.shape[seq_dim + 2 :])
|
|
|
|
if len(batch) == 1:
|
|
return batch[0]
|
|
return batch
|
|
|
|
|
|
def split_varlen_zigzag(
|
|
batch: Union[List[torch.Tensor], torch.Tensor],
|
|
cu_seqlens: torch.Tensor,
|
|
sp_group: ProcessGroup,
|
|
max_seqlen: int = 0,
|
|
is_2d: bool = False,
|
|
is_label: bool = False,
|
|
) -> Union[List[torch.Tensor], torch.Tensor]:
|
|
"""Split each sequence in a batch of packed sequences in a zigzag fashion.
|
|
For each tensor in batch, return packed sequences if is_2d is False;
|
|
else return a padded batch of sequences.
|
|
|
|
Args:
|
|
batch (List[torch.Tensor]): Packed sequences of shape (B * Sq, ...), or (B, Sq, ...) if is_2d.
|
|
cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) before splitting.
|
|
sp_group (ProcessGroup): The process group for sequence parallelism.
|
|
max_seqlen (int): The maximum sequence length in the batch before splitting.
|
|
is_2d (bool): If True, then input has batch size and sequence length split into two dimensions.
|
|
is_label (bool): If True, mask out the first token in each sequence (<Start of Sentence>).
|
|
|
|
Returns:
|
|
batch (List[torch.Tensor]): Packed sequences of shape (B * max_seqlen // sp_size)
|
|
or (B, max_seqlen // sp_size, ...) if is_2d
|
|
"""
|
|
sp_size = dist.get_world_size(sp_group)
|
|
sp_rank = dist.get_rank(sp_group)
|
|
if is_2d:
|
|
assert max_seqlen > 0, "max_seqlen must be provided for 2D input"
|
|
|
|
if isinstance(batch, torch.Tensor):
|
|
batch = [batch]
|
|
for i, packed_seq in enumerate(batch):
|
|
device = packed_seq.device
|
|
dtype = packed_seq.dtype
|
|
|
|
if is_2d:
|
|
assert max_seqlen % (sp_size * 2) == 0
|
|
# Recreate a padded tensor with the new max seqlen
|
|
shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:])
|
|
local_seq = torch.zeros(shape, dtype=dtype, device=device)
|
|
else:
|
|
total_seqlen = cu_seqlens[-1]
|
|
assert (
|
|
total_seqlen % (2 * sp_size) == 0
|
|
), f"total_seqlen {total_seqlen} must be divisible by 2 * sp_size = {2 * sp_size}"
|
|
local_seq = []
|
|
|
|
for j in range(len(cu_seqlens) - 1):
|
|
start, end = cu_seqlens[j], cu_seqlens[j + 1]
|
|
seqlen = end - start
|
|
assert (
|
|
seqlen % (2 * sp_size) == 0
|
|
), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting"
|
|
|
|
if is_2d:
|
|
seq = packed_seq[j][:seqlen]
|
|
if is_label:
|
|
# Shift one position to the right for next token prediction
|
|
seq = torch.cat([seq[1:], torch.tensor([-100], dtype=dtype, device=device)])
|
|
|
|
seq = seq.chunk(2 * sp_size, dim=0)
|
|
half = seqlen // sp_size // 2
|
|
local_seq[j][:half] = seq[sp_rank]
|
|
local_seq[j][half : seqlen // sp_size] = seq[2 * sp_size - 1 - sp_rank]
|
|
else:
|
|
seq = packed_seq[start:end]
|
|
if is_label:
|
|
seq = torch.cat(seq[1:], torch.tensor([-100], dtype=dtype, device=device))
|
|
seq = seq.chunk(sp_size * 2)
|
|
local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]])
|
|
|
|
if is_2d:
|
|
batch[i] = local_seq.contiguous()
|
|
else:
|
|
batch[i] = torch.cat(local_seq, dim=0)
|
|
|
|
if len(batch) == 1:
|
|
batch = batch[0]
|
|
return batch
|
|
|
|
|
|
def is_share_sp_tp(sp_mode: str):
|
|
"""sp_mode "ring" and "split_gather" use the TP group as SP group
|
|
to split both the vocab and sequence, so we must gather the sequence
|
|
to correctly get logits at each positions.
|
|
"""
|
|
return sp_mode in ["ring", "split_gather"]
|
|
|
|
|
|
class RingComm:
|
|
def __init__(self, process_group: dist.ProcessGroup):
|
|
self._process_group = process_group
|
|
self._ops = []
|
|
self.rank = dist.get_rank(self._process_group)
|
|
self.world_size = dist.get_world_size(self._process_group)
|
|
self._reqs = []
|
|
|
|
self.send_rank = (self.rank + 1) % self.world_size
|
|
self.recv_rank = (self.rank - 1) % self.world_size
|
|
|
|
self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
|
|
self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)
|
|
|
|
def send_recv(
|
|
self,
|
|
send_tensor: torch.Tensor,
|
|
recv_tensor: Optional[torch.Tensor] = None,
|
|
commit: bool = True,
|
|
) -> torch.Tensor:
|
|
if recv_tensor is None:
|
|
res = torch.empty_like(send_tensor)
|
|
else:
|
|
res = recv_tensor
|
|
|
|
# looks like batch_isend_irecv doesn't deadlock even
|
|
# when we don't swap send recv ops based on rank
|
|
send_op = dist.P2POp(dist.isend, send_tensor, self.send_rank, group=self._process_group)
|
|
recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
|
|
self._ops.extend([send_op, recv_op])
|
|
|
|
if commit:
|
|
self._reqs = dist.batch_isend_irecv(self._ops)
|
|
return res
|
|
|
|
def commit(self):
|
|
assert len(self._ops) > 0, "No ops to commit"
|
|
self._reqs = dist.batch_isend_irecv(self._ops)
|
|
|
|
def wait(self):
|
|
assert len(self._reqs) > 0, "No requests to wait for"
|
|
for req in self._reqs:
|
|
req.wait()
|
|
self._reqs = []
|
|
self._ops = []
|
|
|
|
|
|
@torch.jit.script
|
|
def get_half_index(cu_seqlens, *, front: bool):
|
|
index = torch.zeros(cu_seqlens[-1], dtype=torch.bool, device=cu_seqlens.device)
|
|
for i in range(len(cu_seqlens) - 1):
|
|
start, end = cu_seqlens[i], cu_seqlens[i + 1]
|
|
if front:
|
|
end = (start + end) // 2
|
|
else:
|
|
start = (start + end) // 2
|
|
index[start:end] = True
|
|
return index
|