ColossalAI/examples/language/gpt/titans/model/gpt1d.py
Boyuan Yao 7a58dc5ad2
Update metainfo patch branch (#2517)
* init

* rename and remove useless func

* basic chunk

* add evoformer

* align evoformer

* add meta

* basic chunk

* basic memory

* finish basic inference memory estimation

* finish memory estimation

* fix bug

* finish memory estimation

* add part of index tracer

* finish basic index tracer

* add doc string

* add doc str

* polish code

* polish code

* update active log

* polish code

* add possible region search

* finish region search loop

* finish chunk define

* support new op

* rename index tracer

* finishi codegen on msa

* redesign index tracer, add source and change compute

* pass outproduct mean

* code format

* code format

* work with outerproductmean and msa

* code style

* code style

* code style

* code style

* change threshold

* support check_index_duplicate

* support index dupilictae and update loop

* support output

* update memory estimate

* optimise search

* fix layernorm

* move flow tracer

* refactor flow tracer

* format code

* refactor flow search

* code style

* adapt codegen to prepose node

* code style

* remove abandoned function

* remove flow tracer

* code style

* code style

* reorder nodes

* finish node reorder

* update run

* code style

* add chunk select class

* add chunk select

* code style

* add chunksize in emit, fix bug in reassgin shape

* code style

* turn off print mem

* add evoformer openfold init

* init openfold

* add benchmark

* add print

* code style

* code style

* init openfold

* update openfold

* align openfold

* use max_mem to control stratge

* update source add

* add reorder in mem estimator

* improve reorder efficeincy

* support ones_like, add prompt if fit mode search fail

* fix a bug in ones like, dont gen chunk if dim size is 1

* fix bug again

* update min memory stratege, reduce mem usage by 30%

* last version of benchmark

* refactor structure

* restruct dir

* update test

* rename

* take apart chunk code gen

* close mem and code print

* code format

* rename ambiguous variable

* seperate flow tracer

* seperate input node dim search

* seperate prepose_nodes

* seperate non chunk input

* seperate reorder

* rename

* ad reorder graph

* seperate trace flow

* code style

* code style

* fix typo

* set benchmark

* rename test

* update codegen test

* Fix state_dict key missing issue of the ZeroDDP (#2363)

* Fix state_dict output for ZeroDDP duplicated parameters

* Rewrite state_dict based on get_static_torch_model

* Modify get_static_torch_model to be compatible with the lower version (ZeroDDP)

* update codegen test

* update codegen test

* add chunk search test

* code style

* add available

* [hotfix] fix gpt gemini example (#2404)

* [hotfix] fix gpt gemini example

* [example] add new assertions

* remove autochunk_available

* [workflow] added nightly release to pypi (#2403)

* add comments

* code style

* add doc for search chunk

* [doc] updated readme regarding pypi installation (#2406)

* add doc for search

* [doc] updated kernel-related optimisers' docstring (#2385)

* [doc] updated kernel-related optimisers' docstring

* polish doc

* rename trace_index to trace_indice

* rename function from index to indice

* rename

* rename in doc

* [polish] polish code for get_static_torch_model (#2405)

* [gemini] polish code

* [testing] remove code

* [gemini] make more robust

* rename

* rename

* remove useless function

* [worfklow] added coverage test (#2399)

* [worfklow] added coverage test

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* add doc for trace indice

* [docker] updated Dockerfile and release workflow (#2410)

* add doc

* update doc

* add available

* change imports

* add test in import

* [workflow] refactored the example check workflow (#2411)

* [workflow] refactored the example check workflow

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* Update parallel_context.py (#2408)

* [hotfix] add DISTPAN argument for benchmark (#2412)

* change the benchmark config file

* change config

* revert config file

* rename distpan to distplan

* [workflow] added precommit check for code consistency (#2401)

* [workflow] added precommit check for code consistency

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* adapt new fx

* [workflow] added translation for non-english comments (#2414)

* [setup] refactored setup.py for dependency graph (#2413)

* change import

* update doc

* [workflow] auto comment if precommit check fails (#2417)

* [hotfix] add norm clearing for the overflow step (#2416)

* [examples] adding tflops to PaLM (#2365)

* [workflow]auto comment with test coverage report (#2419)

* [workflow]auto comment with test coverage report

* polish code

* polish yaml

* [doc] added documentation for CI/CD (#2420)

* [doc] added documentation for CI/CD

* polish markdown

* polish markdown

* polish markdown

* [example] removed duplicated stable diffusion example (#2424)

* [zero] add inference mode and its unit test (#2418)

* [workflow] report test coverage even if below threshold (#2431)

* [example] improved the clarity yof the example readme (#2427)

* [example] improved the clarity yof the example readme

* polish workflow

* polish workflow

* polish workflow

* polish workflow

* polish workflow

* polish workflow

* [ddp] add is_ddp_ignored (#2434)

[ddp] rename to is_ddp_ignored

* [workflow] make test coverage report collapsable (#2436)

* [autoparallel] add shard option (#2423)

* [fx] allow native ckpt trace and codegen. (#2438)

* [cli] provided more details if colossalai run fail (#2442)

* [autoparallel] integrate device mesh initialization into autoparallelize (#2393)

* [autoparallel] integrate device mesh initialization into autoparallelize

* add megatron solution

* update gpt autoparallel examples with latest api

* adapt beta value to fit the current computation cost

* [zero] fix state_dict and load_state_dict for ddp ignored parameters (#2443)

* [ddp] add is_ddp_ignored

[ddp] rename to is_ddp_ignored

* [zero] fix state_dict and load_state_dict

* fix bugs

* [zero] update unit test for ZeroDDP

* [example] updated the hybrid parallel tutorial (#2444)

* [example] updated the hybrid parallel tutorial

* polish code

* [zero] add warning for ignored parameters (#2446)

* [example] updated large-batch optimizer tutorial (#2448)

* [example] updated large-batch optimizer tutorial

* polish code

* polish code

* [example] fixed seed error in train_dreambooth_colossalai.py (#2445)

* [workflow] fixed the on-merge condition check (#2452)

* [workflow] automated the compatiblity test (#2453)

* [workflow] automated the compatiblity test

* polish code

* [autoparallel] update binary elementwise handler (#2451)

* [autoparallel] update binary elementwise handler

* polish

* [workflow] automated bdist wheel build (#2459)

* [workflow] automated bdist wheel build

* polish workflow

* polish readme

* polish readme

* Fix False warning in initialize.py (#2456)

* Update initialize.py

* pre-commit run check

* [examples] update autoparallel tutorial demo (#2449)

* [examples] update autoparallel tutorial demo

* add test_ci.sh

* polish

* add conda yaml

* [cli] fixed hostname mismatch error (#2465)

* [example] integrate autoparallel demo with CI (#2466)

* [example] integrate autoparallel demo with CI

* polish code

* polish code

* polish code

* polish code

* [zero] low level optim supports ProcessGroup (#2464)

* [example] update vit ci script (#2469)

* [example] update vit ci script

* [example] update requirements

* [example] update requirements

* [example] integrate seq-parallel tutorial with CI (#2463)

* [zero] polish low level optimizer (#2473)

* polish pp middleware (#2476)

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>

* [example] update gpt gemini example ci test (#2477)

* [zero] add unit test for low-level zero init (#2474)

* [workflow] fixed the skip condition of  example weekly check workflow (#2481)

* [example] stable diffusion add roadmap

* add dummy test_ci.sh

* [example] stable diffusion add roadmap (#2482)

* [CI] add test_ci.sh for palm, opt and gpt (#2475)

* polish code

* [example] titans for gpt

* polish readme

* remove license

* polish code

* update readme

* [example] titans for gpt (#2484)

* [autoparallel] support origin activation ckpt on autoprallel system (#2468)

* [autochunk] support evoformer tracer (#2485)

support full evoformer tracer, which is a main module of alphafold. previously we just support a simplifed version of it.
1. support some evoformer's op in fx
2. support evoformer test
3. add repos for test code

* [example] fix requirements (#2488)

* [zero] add unit testings for hybrid parallelism  (#2486)

* [hotfix] gpt example titans bug #2493

* polish code and fix dataloader bugs

* [hotfix] gpt example titans bug #2493 (#2494)

* [fx] allow control of ckpt_codegen init (#2498)

* [fx] allow control of ckpt_codegen init

Currently in ColoGraphModule, ActivationCheckpointCodeGen will be set automatically in __init__. But other codegen can't be set if so. 
So I add an arg to control whether to set ActivationCheckpointCodeGen in __init__.

* code style

* [example] dreambooth example

* add test_ci.sh to dreambooth

* [autochunk] support autochunk on evoformer (#2497)

* Revert "Update parallel_context.py (#2408)"

This reverts commit 7d5640b9db.

* add avg partition (#2483)

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>

* [auto-chunk] support extramsa (#3) (#2504)

* [utils] lazy init. (#2148)

* [utils] lazy init.

* [utils] remove description.

* [utils] complete.

* [utils] finalize.

* [utils] fix names.

* [autochunk] support parsing blocks (#2506)

* [zero] add strict ddp mode (#2508)

* [zero] add strict ddp mode

* [polish] add comments for strict ddp mode

* [zero] fix test error

* [doc] update opt and tutorial links (#2509)

* [workflow] fixed changed file detection (#2515)

Co-authored-by: oahzxl <xuanlei.zhao@gmail.com>
Co-authored-by: eric8607242 <e0928021388@gmail.com>
Co-authored-by: HELSON <c2h214748@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Haofan Wang <haofanwang.ai@gmail.com>
Co-authored-by: Jiarui Fang <fangjiarui123@gmail.com>
Co-authored-by: ZijianYY <119492445+ZijianYY@users.noreply.github.com>
Co-authored-by: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com>
Co-authored-by: Super Daniel <78588128+super-dainiu@users.noreply.github.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang97@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
Co-authored-by: oahzxl <43881818+oahzxl@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Fazzie-Maqianli <55798671+Fazziekey@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
2023-01-27 09:52:21 +08:00

350 lines
14 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import torch
from torch import Tensor
from torch import nn as nn
from colossalai import kernel
from colossalai import nn as col_nn
from colossalai.core import global_context as gpc
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
from colossalai.nn.layer import Linear1D_Col, Linear1D_Row
from colossalai.nn.layer.base_layer import ParallelLayer
from colossalai.nn.layer.utils import ACT2FN, divide
from colossalai.utils import checkpoint
from colossalai.utils.activation_checkpoint import checkpoint
__all__ = [
'GPTMLP1D', 'GPTSelfAttention1D', 'GPTTransformerLayer1D', 'FusedGPTSelfAttention1D', 'FusedGPTTransformerLayer1D'
]
class GPTMLP1D(ParallelLayer):
def __init__(
self,
in_features: int,
mlp_ratio: int,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
checkpoint: bool = False,
skip_bias_add: bool = False,
):
super().__init__()
self.in_features = in_features
self.mlp_ratio = mlp_ratio
self.checkpoint = checkpoint
self.skip_bias_add = skip_bias_add
self.act = ACT2FN[act_func]
skip_dense_1_add_bias = False
# Project to mlp_ratio * h.
self.dense_1 = Linear1D_Col(
self.in_features,
int(self.mlp_ratio * self.in_features),
dtype=dtype,
gather_output=False,
skip_bias_add=skip_dense_1_add_bias,
)
# Project back to h.
self.dense_2 = Linear1D_Row(
int(self.mlp_ratio * self.in_features),
self.in_features,
dtype=dtype,
parallel_input=True,
)
self.dropout = col_nn.Dropout(dropout_prob)
def _forward(self, hidden_states: Tensor) -> Tensor:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output)
output = self.dense_2(intermediate_output)
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, False, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
class GenericGPTSelfAttention1D(ParallelLayer):
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
checkpoint: bool = False,
max_position_embeddings=1024,
):
super().__init__()
self.hidden_size = hidden_size
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.num_attention_heads_per_partition = divide(num_attention_heads, gpc.tensor_parallel_size)
self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size)
self.checkpoint = checkpoint
self.query_key_value = Linear1D_Col(
hidden_size,
3 * hidden_size,
dtype=dtype,
)
self.attention_dropout = col_nn.Dropout(attention_dropout_prob)
self.dense = Linear1D_Row(
hidden_size,
hidden_size,
dtype=dtype,
parallel_input=True,
)
self.dropout = col_nn.Dropout(hidden_dropout_prob)
def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer):
raise NotImplementedError
def _forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads_per_partition, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(query_key_value, 3, dim=-1)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = self.softmax_forward(attention_scores, attention_mask, query_layer, key_layer)
attention_scores = attention_scores.type(value_layer.dtype)
attention_probs = self.attention_dropout(attention_scores)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(1, 2)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer)
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
return checkpoint(self._forward, False, hidden_states, attention_mask)
def forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states, attention_mask)
else:
return self._forward(hidden_states, attention_mask)
class GPTSelfAttention1D(GenericGPTSelfAttention1D):
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
checkpoint: bool = False,
max_position_embeddings=1024):
super().__init__(hidden_size,
num_attention_heads,
attention_dropout_prob,
hidden_dropout_prob,
dtype=dtype,
checkpoint=checkpoint,
max_position_embeddings=max_position_embeddings)
self.softmax = nn.Softmax(dim=-1)
max_positions = max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions),
dtype=torch.uint8)).view(1, 1, max_positions, max_positions),
)
self.register_buffer("masked_bias", torch.tensor(-1e4))
def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer):
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# causal mask
query_length, key_length = query_layer.size(-2), key_layer.size(-2)
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].bool()
attention_scores = torch.where(causal_mask, attention_scores, self.masked_bias.to(attention_scores))
if attention_mask is not None:
# Apply the attention mask
attention_scores = attention_scores + attention_mask
attention_scores = self.softmax(attention_scores)
return attention_scores
class FusedGPTSelfAttention1D(GenericGPTSelfAttention1D):
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
checkpoint: bool = False,
max_position_embeddings=1024):
super().__init__(hidden_size,
num_attention_heads,
attention_dropout_prob,
hidden_dropout_prob,
dtype=dtype,
checkpoint=checkpoint,
max_position_embeddings=max_position_embeddings)
self.softmax = kernel.FusedScaleMaskSoftmax(input_in_fp16=True,
input_in_bf16=False,
attn_mask_type=AttnMaskType.causal,
scaled_masked_softmax_fusion=True,
mask_func=None,
softmax_in_fp32=True,
scale=math.sqrt(self.attention_head_size))
def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer):
return self.softmax(attention_scores, attention_mask)
class GenericGPTTransformerLayer1D(ParallelLayer):
def __init__(self,
hidden_size: int,
num_attention_heads: int,
act_func: str = 'gelu',
mlp_ratio: float = 4.0,
attention_dropout_prob: float = 0.,
hidden_dropout_prob: float = 0.,
dtype=None,
checkpoint: bool = False,
max_position_embeddings: int = 1024,
layer_norm_epsilon: float = 1e-5,
apply_post_layer_norm: bool = False,
attention=None,
layer_norm=None):
super().__init__()
self.checkpoint = checkpoint
self.dtype = dtype
self.norm1 = layer_norm(hidden_size, eps=layer_norm_epsilon)
self.apply_post_layer_norm = apply_post_layer_norm
self.attention = attention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout_prob=attention_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
dtype=dtype,
max_position_embeddings=max_position_embeddings,
checkpoint=False,
)
self.norm2 = layer_norm(hidden_size, eps=layer_norm_epsilon)
self.mlp = GPTMLP1D(
in_features=hidden_size,
dropout_prob=hidden_dropout_prob,
act_func=act_func,
mlp_ratio=mlp_ratio,
dtype=dtype,
checkpoint=False,
)
def _forward(self, hidden_states, attention_mask) -> Tensor:
if not self.apply_post_layer_norm:
residual = hidden_states
hidden_states = self.norm1(hidden_states)
if self.apply_post_layer_norm:
residual = hidden_states
attention_output = self.attention(hidden_states, attention_mask)
hidden_states = residual + attention_output
if not self.apply_post_layer_norm:
residual = hidden_states
hidden_states = self.norm2(hidden_states)
if self.apply_post_layer_norm:
residual = hidden_states
feed_forward_hidden_states = self.mlp(hidden_states)
hidden_states = residual + feed_forward_hidden_states
output = (hidden_states, attention_mask)
return output
def forward(self, hidden_states, attention_mask):
if self.checkpoint:
return checkpoint(self._forward, False, hidden_states, attention_mask)
else:
return self._forward(hidden_states, attention_mask)
class GPTTransformerLayer1D(GenericGPTTransformerLayer1D):
def __init__(self,
hidden_size: int,
num_attention_heads: int,
act_func: str = 'gelu',
mlp_ratio: float = 4,
attention_dropout_prob: float = 0,
hidden_dropout_prob: float = 0,
dtype=None,
checkpoint: bool = False,
max_position_embeddings: int = 1024,
layer_norm_epsilon: float = 0.00001,
apply_post_layer_norm: bool = False):
attention = GPTSelfAttention1D
layer_norm = nn.LayerNorm
super().__init__(hidden_size,
num_attention_heads,
act_func=act_func,
mlp_ratio=mlp_ratio,
attention_dropout_prob=attention_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
dtype=dtype,
checkpoint=checkpoint,
max_position_embeddings=max_position_embeddings,
layer_norm_epsilon=layer_norm_epsilon,
apply_post_layer_norm=apply_post_layer_norm,
attention=attention,
layer_norm=layer_norm)
class FusedGPTTransformerLayer1D(GenericGPTTransformerLayer1D):
def __init__(self,
hidden_size: int,
num_attention_heads: int,
act_func: str = 'gelu',
mlp_ratio: float = 4,
attention_dropout_prob: float = 0,
hidden_dropout_prob: float = 0,
dtype=None,
checkpoint: bool = False,
max_position_embeddings: int = 1024,
layer_norm_epsilon: float = 0.00001,
apply_post_layer_norm: bool = False):
attention = FusedGPTSelfAttention1D
layer_norm = kernel.LayerNorm
super().__init__(hidden_size,
num_attention_heads,
act_func=act_func,
mlp_ratio=mlp_ratio,
attention_dropout_prob=attention_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
dtype=dtype,
checkpoint=checkpoint,
max_position_embeddings=max_position_embeddings,
layer_norm_epsilon=layer_norm_epsilon,
apply_post_layer_norm=apply_post_layer_norm,
attention=attention,
layer_norm=layer_norm)