ColossalAI/colossalai/shardformer/layer/linear.py
Wang Binluo eea37da6fa
[fp8] Merge feature/fp8_comm to main branch of Colossalai (#6016)
* add SimPO

* fix dataloader

* remove debug code

* add orpo

* fix style

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix torch colossalai version

* update transformers version

* [shardformer] DeepseekMoE support (#5871)

* [Feature] deepseek moe expert parallel implement

* [misc] fix typo, remove redundant file (#5867)

* [misc] fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] deepseek support & unit test

* [misc] remove debug code & useless print

* [misc] fix typos (#5872)

* [Feature] remove modeling file, use auto config. (#5884)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [Deepseek] remove redundant code (#5888)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [Feature/deepseek] resolve comment. (#5889)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [misc] mv module replacement into if branch

* [misc] add some warning message and modify some code in unit test

* [misc] fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support

* [HotFix] CI,import,requirements-test for #5838 (#5892)

* [Hot Fix] CI,import,requirements-test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Enable PP + SP for llama (#5868)

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use a one cross entropy func for all shardformer models

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)

* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

* fix style

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix eval

* hotfix citation

* [zero] support all-gather overlap (#5898)

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api

* fix orpo cross entropy loss

* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)

* Remove unnecessary calls to deepcopy

* Build DimSpec's difference dict only once

This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.

* Fix documentation of DimSpec's difference method

* [ShardFormer] fix qwen2 sp (#5903)

* [compatibility] support torch 2.2 (#5875)

* Support Pytorch 2.2.2

* keep build_on_pr file and update .compatibility

* fix object_to_tensor usage when torch>=2.3.0 (#5820)

* [misc] support torch2.3 (#5893)

* [misc] support torch2.3

* [devops] update compatibility ci

* [devops] update compatibility ci

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] remove debug

* [devops] remove debug

* [release] update version (#5912)

* [plugin] support all-gather overlap for hybrid parallel (#5919)

* [plugin] fixed all-gather overlap support for hybrid parallel

* add kto

* fix style, add kto data sample

* [Examples] Add lazy init to OPT and GPT examples (#5924)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [ColossalChat] Hotfix for ColossalChat (#5910)

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* fix ddp issue

* add Qwen 1.5 32B

* refactor tokenization

* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)

* cannot access local variable 'default_conversation' where it is not associated with a value

set default value for 'default_conversation'

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix test data

* refactor evaluation

* remove real data path

* remove real data path

* Add n_fused as an input from native_module (#5894)

* [FIX BUG] convert env param to int in (#5934)

* [Hotfix] Fix ZeRO typo #5936

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)

* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix style

* fix style

* fix style

* [shardformer] hotfix attn mask (#5945)

* [shardformer] hotfix attn mask (#5947)

* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement

* [zero] hotfix update master params (#5951)

* [release] update version (#5952)

* [Chat] Fix lora (#5946)

* fix merging

* remove filepath

* fix style

* Update README.md (#5958)

* [hotfix] Remove unused plan section (#5957)

* remove readme

* fix readme

* update

* [test] add mixtral for sequence classification

* [test] add mixtral transformer test

* [moe] fix plugin

* [test] mixtra pp shard test

* [chore] handle non member group

* [zero] solve hang

* [test] pass mixtral shardformer test

* [moe] implement transit between non moe tp and ep

* [zero] solve hang

* [misc] solve booster hang by rename the variable

* solve hang when parallel mode = pp + dp

* [moe] implement submesh initialization

* [moe] add mixtral dp grad scaling when not all experts are activated

* [chore] manually revert unintended commit

* [chore] trivial fix

* [chore] arg pass & remove drop token

* [test] add mixtral modelling test

* [moe] implement tp

* [moe] test deepseek

* [moe] clean legacy code

* [Feature] MoE Ulysses Support (#5918)

* moe sp support

* moe sp bug solve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [chore] minor fix

* [moe] init moe plugin comm setting with sp

* moe sp + ep bug fix

* [moe] finalize test (no pp)

* [moe] full test for deepseek and mixtral (pp + sp to fix)

* [chore] minor fix after rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [chore] solve moe ckpt test failure and some other arg pass failure

* [moe] remove ops

* [test] fix test: test_zero1_2

* [bug] fix: somehow logger hangs the program

* [moe] deepseek moe sp support

* [test] add check

* [deepseek] replace attn (a workaround for bug in transformers)

* [misc] skip redunant test

* [misc] remove debug/print code

* [moe] refactor mesh assignment

* Revert "[moe] implement submesh initialization"

This reverts commit 2f9bce6686.

* [chore] change moe_pg_mesh to private

* [misc] remove incompatible test config

* [misc] fix ci failure: change default value to false in moe plugin

* [misc] remove useless condition

* [chore] docstring

* [moe] remove force_overlap_comm flag and add warning instead

* [doc] add MoeHybridParallelPlugin docstring

* [moe] solve dp axis issue

* [chore] remove redundant test case, print string & reduce test tokens

* [feat] Dist Loader for Eval (#5950)

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tp error

* remove unused parameters

* remove unused

* update inference

* update docs

* update inference

---------

Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [lora] lora support hybrid parallel plugin (#5956)

* lora support hybrid plugin

* fix

* fix

* fix

* fix

* Support overall loss, update KTO logging

* [Docs] clarify launch port

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Hotfix] README link (#5966)

* update ignore

* update readme

* run style

* update readme

* [Hotfix] Avoid fused RMSnorm import error without apex (#5985)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Chat] fix readme (#5989)

* fix readme

* fix readme, tokenization fully tested

* fix readme, tokenization fully tested

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix sync condition (#6000)

* [plugin] add cast inputs option for zero (#6003)

* [pre-commit.ci] pre-commit autoupdate (#5995)

updates:
- [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](https://github.com/psf/black-pre-commit-mirror/compare/24.4.2...24.8.0)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] Bypass the huggingface bug to solve the mask mismatch problem (#5991)

* [Feature] Zigzag Ring attention (#5905)

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] update compatibility (#6008)

* [misc] update compatibility

* [misc] update requirements

* [devops] disable requirements cache

* [test] fix torch ddp test

* [test] fix rerun on address in use

* [test] fix lazy init

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix the merge

* overlap kv comm with output rescale (#6017)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* fix the merge

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix

* fix

* fix the merge

* fix

* [misc] Use dist logger in plugins (#6011)

* use dist logger in plugins

* remove trash

* print on rank 0

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* fix

* fix

* fix

* fix

* fix the merge

* fix

* fix

* fix

* fix

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
2024-08-22 09:21:34 +08:00

667 lines
27 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from colossalai.lazy import LazyInitContext
from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor.api import (
is_distributed_tensor,
shard_colwise,
shard_rowwise,
sharded_tensor_to_existing_param,
)
from ._operation import (
gather_forward_reducescatter_backward,
gather_forward_split_backward,
linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward,
linear_with_async_comm,
reduce_forward,
reducescatter_forward_gather_backward,
split_forward_gather_backward,
)
from .parallel_module import PaddingParallelModule, ParallelModule
from .utils import create_randomizer_with_offset
__all__ = ["Linear1D_Col", "Linear1D_Row"]
class Linear1D_Col(ParallelModule):
r"""Linear layer with column parallelism.
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
its second dimension as :math:`A = [A_1, ..., A_p]`.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None.
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (`typing.Callable`):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
gather_output: bool = False,
seq_parallel_mode: str = None,
seq_parallel_dim: int = 1,
overlap: torch.cuda.Stream = None,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
fp8_communication: bool = False,
**kwargs,
):
super().__init__(weight=weight, bias_=bias_, **kwargs)
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
self.process_group = process_group
self.fp8_communication = fp8_communication
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
# sanity check
if weight is not None:
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
else:
assert bias_ is None, "bias_ must be None if weight is None"
# Parameters.
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if not is_distributed_tensor(self.weight):
sharded_weight = shard_rowwise(self.weight.data, self.process_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight)
if bias:
if bias_ is None:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
self.bias = bias_
if not is_distributed_tensor(self.bias):
sharded_bias = shard_colwise(self.bias.data, self.process_group)
sharded_tensor_to_existing_param(sharded_bias, self.bias)
else:
self.bias = None
if weight is None:
# init weights
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
) -> ParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
if out_features < tp_size:
return module
if out_features % tp_size != 0:
raise ValueError(
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
)
linear_1d = Linear1D_Col(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
weight=module.weight,
bias_=module.bias,
**kwargs,
)
return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with self.randomizer.fork_rng(enable_cpu=True):
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
assert (
input_.shape[-1] == self.weight.shape[-1]
), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format(
input_.shape, self.weight.shape, self.weight.shape[-1]
)
# Set up backprop all-reduce.
input_parallel = input_
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel_mode == "split_gather":
input_parallel = gather_forward_reducescatter_backward(
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
)
output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication
)
elif self.seq_parallel_mode == "ring":
output_parallel = linear_gather_forward_reducescatter_backward(
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
)
else:
output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
)
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
)
else:
output = output_parallel
if self.skip_bias_add:
return output, self.bias
else:
return output
class Linear1D_Row(ParallelModule):
r"""Linear layer with row parallelism
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
seq_parallel_mode (`str`): The type of sp mode, it will use sequence parallel when `seq_parallel_mode` is not None. Defaults to None.
seq_parallel_dim (`int`): Which dim will sequence parallelism split and gather the sequence.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
seq_parallel_mode: str = None,
seq_parallel_dim: int = 1,
parallel_input: bool = True,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
stream_chunk_num: int = 1,
fp8_communication: bool = False,
):
super().__init__()
self.stream_chunk_num = stream_chunk_num
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
self.process_group = process_group
self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
self.num_partitions = dist.get_world_size(self.process_group)
self.fp8_communication = fp8_communication
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
# sanity check
if weight is not None:
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
else:
assert bias_ is None, "bias_ must be None if weight is None"
# Parameters.
if weight is None:
# Initialize weight.
factory_kwargs = {"device": device, "dtype": dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if not is_distributed_tensor(self.weight):
sharded_weight = shard_colwise(self.weight.data, self.process_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight)
if self.stream_chunk_num > 1:
# TODO() work for inference only
self.chunk_weight()
if bias:
if bias_ is None:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
self.bias = bias_
else:
self.bias = None
if weight is None:
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
) -> ParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
if in_features < tp_size:
return module
if in_features % tp_size != 0:
raise ValueError(
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
)
linear_1d = Linear1D_Row(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
weight=module.weight,
bias_=module.bias,
**kwargs,
)
return linear_1d
def chunk_weight(self):
self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0)
@torch.no_grad()
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
if self.process_group is None:
src_rank = 0
else:
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
origin_device = self.bias.device
bias = self.bias.cuda()
dist.broadcast(bias, src=src_rank, group=self.process_group)
bias = bias.to(origin_device)
self.bias.copy_(bias)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
if self.parallel_input:
assert (
input_.shape[-1] == self.weight.shape[-1]
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
input_.shape, self.weight.shape, self.weight.shape[-1]
)
input_ = input_
else:
assert (
divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1]
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions
)
input_ = split_forward_gather_backward(
input_, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
)
if self.stream_chunk_num > 1:
if self.training:
raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!")
with torch.no_grad():
output_parallel_list = [None for i in range(self.stream_chunk_num)]
handle_list = []
for i in range(self.stream_chunk_num):
output_parallel_list[i] = F.linear(input_, self.weight_list[i])
handle = torch.distributed.all_reduce(
output_parallel_list[i], group=self.process_group, async_op=True
)
handle_list.append(handle)
for handle in handle_list:
handle.wait()
output = torch.cat(output_parallel_list, dim=-1)
else:
if self.seq_parallel_mode is None:
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
elif self.seq_parallel_mode == "split_gather":
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
output = reducescatter_forward_gather_backward(
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
)
elif self.seq_parallel_mode == "ring":
output = linear_reducescatter_forward_gather_backward(
input_,
self.weight,
process_group=self.process_group,
dim=self.seq_parallel_dim,
ring=True,
)
else:
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
output = reduce_forward(output_parallel, self.process_group)
if not self.skip_bias_add:
if self.bias is not None:
output = output + self.bias
return output
else:
return output, self.bias
class PaddingLMHead(PaddingParallelModule):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
make_vocab_size_divisible_by: int = 64,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
):
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
if out_features % make_vocab_size_divisible_by != 0:
self.out_features = (
out_features + make_vocab_size_divisible_by - (out_features % make_vocab_size_divisible_by)
)
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
weight = Parameter(torch.empty(out_features, self.in_features, **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
if bias:
if bias_ is None:
self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
else:
bias_ = None
# resize embeddings
super().__init__(self.out_features, out_features, weight, bias_)
if weight is None:
self.reset_parameters(weight_initializer, bias_initializer)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
@staticmethod
def from_native_module(
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
) -> PaddingParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
# ensure only one process group is passed
lm_head_linear = PaddingLMHead(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
weight=module.weight,
bias_=module.bias,
**kwargs,
)
return lm_head_linear
def forward(self, input: Tensor) -> Tensor:
output = F.linear(input, self.weight, self.bias)
output = output[..., : self.old_num_embeddings]
return output
class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
r"""Linear layer with column parallelism.
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
its second dimension as :math:`A = [A_1, ..., A_p]`.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None.
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (`typing.Callable`):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
make_vocab_size_divisible_by: int = 64,
fp8_communication: bool = False,
**kwargs,
):
# create weight and bias
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
weight = Parameter(torch.empty(out_features, self.in_features, **factory_kwargs))
if bias:
if bias_ is None:
bias_ = Parameter(torch.empty(out_features, **factory_kwargs))
else:
bias_ = None
# calculate new vocab size
self.tensor_parallel_size = dist.get_world_size(group=process_group)
new_out_features = out_features
multiple = make_vocab_size_divisible_by * self.tensor_parallel_size
if out_features % multiple != 0:
new_out_features = out_features + multiple - (out_features % multiple)
super().__init__(
in_features=in_features,
out_features=new_out_features,
bias=bias,
device=device,
process_group=process_group,
weight=weight,
bias_=bias_,
**kwargs,
new_num_embeddings=new_out_features,
old_num_embeddings=out_features,
fp8_communication=fp8_communication,
)
# get the length of valid embeddings
tp_rank = dist.get_rank(process_group)
partition_size = self.new_num_embeddings // dist.get_world_size(process_group)
if self.old_num_embeddings >= (tp_rank + 1) * partition_size:
self.num_valid_embeddings_local = partition_size
elif self.old_num_embeddings >= tp_rank * partition_size:
self.num_valid_embeddings_local = self.old_num_embeddings - tp_rank * partition_size
else:
self.num_valid_embeddings_local = 0
@staticmethod
def from_native_module(
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
) -> PaddingParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
lm_head_linear = VocabParallelLMHead1D(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
weight=module.weight,
bias_=module.bias,
**kwargs,
)
return lm_head_linear
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
# get forward output
if self.skip_bias_add:
output, bias = super().forward(input_)
else:
output = super().forward(input_)
# delete the padding of output
if self.gather_output:
output = output[..., : self.old_num_embeddings]
else:
output = output[..., : self.num_valid_embeddings_local]
# return
if self.skip_bias_add:
return output, bias
return output