mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
Update metainfo patch branch (#2517)
* init
* rename and remove useless func
* basic chunk
* add evoformer
* align evoformer
* add meta
* basic chunk
* basic memory
* finish basic inference memory estimation
* finish memory estimation
* fix bug
* finish memory estimation
* add part of index tracer
* finish basic index tracer
* add doc string
* add doc str
* polish code
* polish code
* update active log
* polish code
* add possible region search
* finish region search loop
* finish chunk define
* support new op
* rename index tracer
* finishi codegen on msa
* redesign index tracer, add source and change compute
* pass outproduct mean
* code format
* code format
* work with outerproductmean and msa
* code style
* code style
* code style
* code style
* change threshold
* support check_index_duplicate
* support index dupilictae and update loop
* support output
* update memory estimate
* optimise search
* fix layernorm
* move flow tracer
* refactor flow tracer
* format code
* refactor flow search
* code style
* adapt codegen to prepose node
* code style
* remove abandoned function
* remove flow tracer
* code style
* code style
* reorder nodes
* finish node reorder
* update run
* code style
* add chunk select class
* add chunk select
* code style
* add chunksize in emit, fix bug in reassgin shape
* code style
* turn off print mem
* add evoformer openfold init
* init openfold
* add benchmark
* add print
* code style
* code style
* init openfold
* update openfold
* align openfold
* use max_mem to control stratge
* update source add
* add reorder in mem estimator
* improve reorder efficeincy
* support ones_like, add prompt if fit mode search fail
* fix a bug in ones like, dont gen chunk if dim size is 1
* fix bug again
* update min memory stratege, reduce mem usage by 30%
* last version of benchmark
* refactor structure
* restruct dir
* update test
* rename
* take apart chunk code gen
* close mem and code print
* code format
* rename ambiguous variable
* seperate flow tracer
* seperate input node dim search
* seperate prepose_nodes
* seperate non chunk input
* seperate reorder
* rename
* ad reorder graph
* seperate trace flow
* code style
* code style
* fix typo
* set benchmark
* rename test
* update codegen test
* Fix state_dict key missing issue of the ZeroDDP (#2363)
* Fix state_dict output for ZeroDDP duplicated parameters
* Rewrite state_dict based on get_static_torch_model
* Modify get_static_torch_model to be compatible with the lower version (ZeroDDP)
* update codegen test
* update codegen test
* add chunk search test
* code style
* add available
* [hotfix] fix gpt gemini example (#2404)
* [hotfix] fix gpt gemini example
* [example] add new assertions
* remove autochunk_available
* [workflow] added nightly release to pypi (#2403)
* add comments
* code style
* add doc for search chunk
* [doc] updated readme regarding pypi installation (#2406)
* add doc for search
* [doc] updated kernel-related optimisers' docstring (#2385)
* [doc] updated kernel-related optimisers' docstring
* polish doc
* rename trace_index to trace_indice
* rename function from index to indice
* rename
* rename in doc
* [polish] polish code for get_static_torch_model (#2405)
* [gemini] polish code
* [testing] remove code
* [gemini] make more robust
* rename
* rename
* remove useless function
* [worfklow] added coverage test (#2399)
* [worfklow] added coverage test
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* add doc for trace indice
* [docker] updated Dockerfile and release workflow (#2410)
* add doc
* update doc
* add available
* change imports
* add test in import
* [workflow] refactored the example check workflow (#2411)
* [workflow] refactored the example check workflow
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* Update parallel_context.py (#2408)
* [hotfix] add DISTPAN argument for benchmark (#2412)
* change the benchmark config file
* change config
* revert config file
* rename distpan to distplan
* [workflow] added precommit check for code consistency (#2401)
* [workflow] added precommit check for code consistency
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* adapt new fx
* [workflow] added translation for non-english comments (#2414)
* [setup] refactored setup.py for dependency graph (#2413)
* change import
* update doc
* [workflow] auto comment if precommit check fails (#2417)
* [hotfix] add norm clearing for the overflow step (#2416)
* [examples] adding tflops to PaLM (#2365)
* [workflow]auto comment with test coverage report (#2419)
* [workflow]auto comment with test coverage report
* polish code
* polish yaml
* [doc] added documentation for CI/CD (#2420)
* [doc] added documentation for CI/CD
* polish markdown
* polish markdown
* polish markdown
* [example] removed duplicated stable diffusion example (#2424)
* [zero] add inference mode and its unit test (#2418)
* [workflow] report test coverage even if below threshold (#2431)
* [example] improved the clarity yof the example readme (#2427)
* [example] improved the clarity yof the example readme
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* [ddp] add is_ddp_ignored (#2434)
[ddp] rename to is_ddp_ignored
* [workflow] make test coverage report collapsable (#2436)
* [autoparallel] add shard option (#2423)
* [fx] allow native ckpt trace and codegen. (#2438)
* [cli] provided more details if colossalai run fail (#2442)
* [autoparallel] integrate device mesh initialization into autoparallelize (#2393)
* [autoparallel] integrate device mesh initialization into autoparallelize
* add megatron solution
* update gpt autoparallel examples with latest api
* adapt beta value to fit the current computation cost
* [zero] fix state_dict and load_state_dict for ddp ignored parameters (#2443)
* [ddp] add is_ddp_ignored
[ddp] rename to is_ddp_ignored
* [zero] fix state_dict and load_state_dict
* fix bugs
* [zero] update unit test for ZeroDDP
* [example] updated the hybrid parallel tutorial (#2444)
* [example] updated the hybrid parallel tutorial
* polish code
* [zero] add warning for ignored parameters (#2446)
* [example] updated large-batch optimizer tutorial (#2448)
* [example] updated large-batch optimizer tutorial
* polish code
* polish code
* [example] fixed seed error in train_dreambooth_colossalai.py (#2445)
* [workflow] fixed the on-merge condition check (#2452)
* [workflow] automated the compatiblity test (#2453)
* [workflow] automated the compatiblity test
* polish code
* [autoparallel] update binary elementwise handler (#2451)
* [autoparallel] update binary elementwise handler
* polish
* [workflow] automated bdist wheel build (#2459)
* [workflow] automated bdist wheel build
* polish workflow
* polish readme
* polish readme
* Fix False warning in initialize.py (#2456)
* Update initialize.py
* pre-commit run check
* [examples] update autoparallel tutorial demo (#2449)
* [examples] update autoparallel tutorial demo
* add test_ci.sh
* polish
* add conda yaml
* [cli] fixed hostname mismatch error (#2465)
* [example] integrate autoparallel demo with CI (#2466)
* [example] integrate autoparallel demo with CI
* polish code
* polish code
* polish code
* polish code
* [zero] low level optim supports ProcessGroup (#2464)
* [example] update vit ci script (#2469)
* [example] update vit ci script
* [example] update requirements
* [example] update requirements
* [example] integrate seq-parallel tutorial with CI (#2463)
* [zero] polish low level optimizer (#2473)
* polish pp middleware (#2476)
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
* [example] update gpt gemini example ci test (#2477)
* [zero] add unit test for low-level zero init (#2474)
* [workflow] fixed the skip condition of example weekly check workflow (#2481)
* [example] stable diffusion add roadmap
* add dummy test_ci.sh
* [example] stable diffusion add roadmap (#2482)
* [CI] add test_ci.sh for palm, opt and gpt (#2475)
* polish code
* [example] titans for gpt
* polish readme
* remove license
* polish code
* update readme
* [example] titans for gpt (#2484)
* [autoparallel] support origin activation ckpt on autoprallel system (#2468)
* [autochunk] support evoformer tracer (#2485)
support full evoformer tracer, which is a main module of alphafold. previously we just support a simplifed version of it.
1. support some evoformer's op in fx
2. support evoformer test
3. add repos for test code
* [example] fix requirements (#2488)
* [zero] add unit testings for hybrid parallelism (#2486)
* [hotfix] gpt example titans bug #2493
* polish code and fix dataloader bugs
* [hotfix] gpt example titans bug #2493 (#2494)
* [fx] allow control of ckpt_codegen init (#2498)
* [fx] allow control of ckpt_codegen init
Currently in ColoGraphModule, ActivationCheckpointCodeGen will be set automatically in __init__. But other codegen can't be set if so.
So I add an arg to control whether to set ActivationCheckpointCodeGen in __init__.
* code style
* [example] dreambooth example
* add test_ci.sh to dreambooth
* [autochunk] support autochunk on evoformer (#2497)
* Revert "Update parallel_context.py (#2408)"
This reverts commit 7d5640b9db
.
* add avg partition (#2483)
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
* [auto-chunk] support extramsa (#3) (#2504)
* [utils] lazy init. (#2148)
* [utils] lazy init.
* [utils] remove description.
* [utils] complete.
* [utils] finalize.
* [utils] fix names.
* [autochunk] support parsing blocks (#2506)
* [zero] add strict ddp mode (#2508)
* [zero] add strict ddp mode
* [polish] add comments for strict ddp mode
* [zero] fix test error
* [doc] update opt and tutorial links (#2509)
* [workflow] fixed changed file detection (#2515)
Co-authored-by: oahzxl <xuanlei.zhao@gmail.com>
Co-authored-by: eric8607242 <e0928021388@gmail.com>
Co-authored-by: HELSON <c2h214748@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Haofan Wang <haofanwang.ai@gmail.com>
Co-authored-by: Jiarui Fang <fangjiarui123@gmail.com>
Co-authored-by: ZijianYY <119492445+ZijianYY@users.noreply.github.com>
Co-authored-by: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com>
Co-authored-by: Super Daniel <78588128+super-dainiu@users.noreply.github.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang97@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
Co-authored-by: oahzxl <43881818+oahzxl@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Fazzie-Maqianli <55798671+Fazziekey@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
This commit is contained in:
3
examples/language/gpt/titans/model/__init__.py
Normal file
3
examples/language/gpt/titans/model/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .embed import vocab_parallel_cross_entropy
|
||||
from .gpt1d import *
|
||||
from .pipeline_gpt1d import *
|
599
examples/language/gpt/titans/model/embed.py
Normal file
599
examples/language/gpt/titans/model/embed.py
Normal file
@@ -0,0 +1,599 @@
|
||||
import torch
|
||||
import torch.nn.init as init
|
||||
from torch import Tensor
|
||||
from torch import distributed as dist
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.layer.base_layer import ParallelLayer
|
||||
from colossalai.nn.layer.parallel_1d._utils import gather_forward_split_backward, reduce_grad, reduce_input
|
||||
from colossalai.nn.layer.parallel_1d.layers import Linear1D_Row
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.registry import LAYERS, LOSSES, MODELS
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class VocabParallelEmbedding(torch.nn.Module):
|
||||
"""Language model embeddings.
|
||||
|
||||
Arguments:
|
||||
hidden_size: hidden size
|
||||
vocab_size: vocabulary size
|
||||
max_sequence_length: maximum size of sequence. This
|
||||
is used for positional embedding
|
||||
embedding_dropout_prob: dropout probability for embeddings
|
||||
init_method: weight initialization method
|
||||
num_tokentypes: size of the token-type embeddings. 0 value
|
||||
will ignore this embedding
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
max_sequence_length,
|
||||
embedding_dropout_prob,
|
||||
num_tokentypes=0,
|
||||
dtype=torch.float):
|
||||
super(VocabParallelEmbedding, self).__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_tokentypes = num_tokentypes
|
||||
|
||||
# Word embeddings (parallel).
|
||||
self.word_embeddings = VocabParallelEmbedding1D(vocab_size, self.hidden_size, dtype=dtype)
|
||||
self._word_embeddings_key = 'word_embeddings'
|
||||
|
||||
# Position embedding (serial).
|
||||
self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size, dtype=dtype)
|
||||
self._position_embeddings_key = 'position_embeddings'
|
||||
# Initialize the position embeddings.
|
||||
# self.init_method(self.position_embeddings.weight)
|
||||
|
||||
# Token type embedding.
|
||||
# Add this as an optional field that can be added through
|
||||
# method call so we can load a pretrain model without
|
||||
# token types and add them as needed.
|
||||
self._tokentype_embeddings_key = 'tokentype_embeddings'
|
||||
if self.num_tokentypes > 0:
|
||||
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size, dtype=dtype)
|
||||
# Initialize the token-type embeddings.
|
||||
# self.init_method(self.tokentype_embeddings.weight)
|
||||
else:
|
||||
self.tokentype_embeddings = None
|
||||
|
||||
# Embeddings dropout
|
||||
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
|
||||
|
||||
def zero_parameters(self):
|
||||
"""Zero out all parameters in embedding."""
|
||||
self.word_embeddings.weight.data.fill_(0)
|
||||
self.word_embeddings.weight.shared = True
|
||||
self.position_embeddings.weight.data.fill_(0)
|
||||
self.position_embeddings.weight.shared = True
|
||||
if self.num_tokentypes > 0:
|
||||
self.tokentype_embeddings.weight.data.fill_(0)
|
||||
self.tokentype_embeddings.weight.shared = True
|
||||
|
||||
def add_tokentype_embeddings(self, num_tokentypes):
|
||||
"""Add token-type embedding. This function is provided so we can add
|
||||
token-type embeddings in case the pretrained model does not have it.
|
||||
This allows us to load the model normally and then add this embedding.
|
||||
"""
|
||||
if self.tokentype_embeddings is not None:
|
||||
raise Exception('tokentype embeddings is already initialized')
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True)
|
||||
self.num_tokentypes = num_tokentypes
|
||||
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
|
||||
# Initialize the token-type embeddings.
|
||||
# self.init_method(self.tokentype_embeddings.weight)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, tokentype_ids=None):
|
||||
# Embeddings.
|
||||
if input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device())
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
embeddings = words_embeddings + position_embeddings
|
||||
|
||||
# Dropout.
|
||||
with seed(ParallelMode.TENSOR):
|
||||
embeddings = self.embedding_dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
|
||||
"""For easy load."""
|
||||
|
||||
state_dict_ = {}
|
||||
state_dict_[self._word_embeddings_key] \
|
||||
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
|
||||
state_dict_[self._position_embeddings_key] \
|
||||
= self.position_embeddings.state_dict(
|
||||
destination, prefix, keep_vars)
|
||||
if self.num_tokentypes > 0:
|
||||
state_dict_[self._tokentype_embeddings_key] \
|
||||
= self.tokentype_embeddings.state_dict(
|
||||
destination, prefix, keep_vars)
|
||||
|
||||
return state_dict_
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
"""Customized load."""
|
||||
|
||||
# Word embedding.
|
||||
if self._word_embeddings_key in state_dict:
|
||||
state_dict_ = state_dict[self._word_embeddings_key]
|
||||
else:
|
||||
# for backward compatibility.
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if 'word_embeddings' in key:
|
||||
state_dict_[key.split('word_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
self.word_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
|
||||
# Position embedding.
|
||||
if self._position_embeddings_key in state_dict:
|
||||
state_dict_ = state_dict[self._position_embeddings_key]
|
||||
else:
|
||||
# for backward compatibility.
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if 'position_embeddings' in key:
|
||||
state_dict_[key.split('position_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
|
||||
# Tokentype embedding.
|
||||
if self.num_tokentypes > 0:
|
||||
state_dict_ = {}
|
||||
if self._tokentype_embeddings_key in state_dict:
|
||||
state_dict_ = state_dict[self._tokentype_embeddings_key]
|
||||
else:
|
||||
# for backward compatibility.
|
||||
for key in state_dict.keys():
|
||||
if 'tokentype_embeddings' in key:
|
||||
state_dict_[key.split('tokentype_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
if len(state_dict_.keys()) > 0:
|
||||
self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
else:
|
||||
print('***WARNING*** expected tokentype embeddings in the '
|
||||
'checkpoint but could not find it',
|
||||
flush=True)
|
||||
|
||||
|
||||
class VocabParallelEmbedding1D(torch.nn.Module):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
This is mainly adapted from torch.nn.Embedding and all the default
|
||||
values are kept.
|
||||
Arguments:
|
||||
num_embeddings: vocabulary size.
|
||||
embedding_dim: size of hidden state.
|
||||
init_method: method to initialize weights.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings, embedding_dim, dtype=None, init_method=None):
|
||||
super(VocabParallelEmbedding1D, self).__init__()
|
||||
# Keep the input dimensions.
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
# Set the details for compatibility.
|
||||
self.padding_idx = None
|
||||
self.max_norm = None
|
||||
self.norm_type = 2.
|
||||
self.scale_grad_by_freq = False
|
||||
self.sparse = False
|
||||
self._weight = None
|
||||
self.tensor_model_parallel_size = gpc.tensor_parallel_size
|
||||
# Divide the weight matrix along the vocabulary dimension.
|
||||
self.vocab_start_index, self.vocab_end_index = \
|
||||
VocabUtility.vocab_range_from_global_vocab_size(
|
||||
self.num_embeddings, gpc.get_local_rank(ParallelMode.PARALLEL_1D),
|
||||
self.tensor_model_parallel_size)
|
||||
self.num_embeddings_per_partition = self.vocab_end_index - \
|
||||
self.vocab_start_index
|
||||
|
||||
# Allocate weights and initialize.
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
self.weight = Parameter(torch.empty(self.num_embeddings_per_partition, self.embedding_dim, **factory_kwargs))
|
||||
init.uniform_(self.weight, -1, 1)
|
||||
|
||||
def forward(self, input_):
|
||||
if self.tensor_model_parallel_size > 1:
|
||||
# Build the mask.
|
||||
input_mask = (input_ < self.vocab_start_index) | \
|
||||
(input_ >= self.vocab_end_index)
|
||||
# Mask the input.
|
||||
masked_input = input_.clone() - self.vocab_start_index
|
||||
masked_input[input_mask] = 0
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, self.max_norm, self.norm_type,
|
||||
self.scale_grad_by_freq, self.sparse)
|
||||
# Mask the output embedding.
|
||||
if self.tensor_model_parallel_size > 1:
|
||||
output_parallel[input_mask, :] = 0.0
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
|
||||
return output
|
||||
|
||||
|
||||
@LOSSES.register_module
|
||||
class vocab_parallel_cross_entropy(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, vocab_parallel_logits, target):
|
||||
"""Helper function for the cross entropy."""
|
||||
vocab_parallel_logits = vocab_parallel_logits[..., :-1, :].contiguous()
|
||||
target = target[..., 1:].contiguous()
|
||||
return _VocabParallelCrossEntropy.apply(vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1)),
|
||||
target.view(-1))
|
||||
|
||||
|
||||
class _VocabParallelCrossEntropy(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, vocab_parallel_logits, target):
|
||||
|
||||
# Maximum value along vocab dimension across all GPUs.
|
||||
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
|
||||
torch.distributed.all_reduce(logits_max,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
# Subtract the maximum value.
|
||||
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
|
||||
|
||||
# Get the partition's vocab indices
|
||||
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
|
||||
partition_vocab_size = vocab_parallel_logits.size()[-1]
|
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
world_size = gpc.tensor_parallel_size
|
||||
vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size)
|
||||
|
||||
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
||||
target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
|
||||
masked_target = target.clone() - vocab_start_index
|
||||
masked_target[target_mask] = 0
|
||||
|
||||
# Get predicted-logits = logits[target].
|
||||
# For Simplicity, we convert logits to a 2-D tensor with size
|
||||
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
|
||||
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
|
||||
masked_target_1d = masked_target.view(-1)
|
||||
arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
|
||||
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
|
||||
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
|
||||
predicted_logits = predicted_logits_1d.view_as(target)
|
||||
predicted_logits[target_mask] = 0.0
|
||||
# All reduce is needed to get the chunks from other GPUs.
|
||||
torch.distributed.all_reduce(predicted_logits,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
|
||||
# Sum of exponential of logits along vocab dimension across all GPUs.
|
||||
exp_logits = vocab_parallel_logits
|
||||
torch.exp(vocab_parallel_logits, out=exp_logits)
|
||||
sum_exp_logits = exp_logits.sum(dim=-1)
|
||||
torch.distributed.all_reduce(sum_exp_logits,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
|
||||
# Loss = log(sum(exp(logits))) - predicted-logit.
|
||||
loss = torch.log(sum_exp_logits) - predicted_logits
|
||||
loss = loss.mean()
|
||||
# Store softmax, target-mask and masked-target for backward pass.
|
||||
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
|
||||
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
|
||||
# Retreive tensors from the forward path.
|
||||
softmax, target_mask, masked_target_1d = ctx.saved_tensors
|
||||
|
||||
# All the inputs have softmax as their gradient.
|
||||
grad_input = softmax
|
||||
# For simplicity, work with the 2D gradient.
|
||||
partition_vocab_size = softmax.size()[-1]
|
||||
grad_2d = grad_input.view(-1, partition_vocab_size)
|
||||
|
||||
# Add the gradient from matching classes.
|
||||
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
|
||||
grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float())
|
||||
|
||||
# Finally elementwise multiplication with the output gradients.
|
||||
grad_input.mul_(grad_output.unsqueeze(dim=-1))
|
||||
|
||||
return grad_input, None
|
||||
|
||||
|
||||
class VocabUtility:
|
||||
"""Split the vocabulary into `world_size` chunks amd return the
|
||||
first and last index of the vocabulary belonging to the `rank`
|
||||
partition: Note that indices in [fist, last)"""
|
||||
|
||||
@staticmethod
|
||||
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
|
||||
index_f = rank * per_partition_vocab_size
|
||||
index_l = index_f + per_partition_vocab_size
|
||||
return index_f, index_l
|
||||
|
||||
@staticmethod
|
||||
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
|
||||
per_partition_vocab_size = divide(global_vocab_size, world_size)
|
||||
return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size)
|
||||
|
||||
|
||||
class VocabParallelGPTLMHead1D(ParallelLayer):
|
||||
"""
|
||||
Language model head that shares the same parameters with the embedding matrix.
|
||||
"""
|
||||
|
||||
def __init__(self, embed=None, vocab_size=None, dtype=None, embed_dim=None):
|
||||
super().__init__()
|
||||
if embed is not None:
|
||||
self.head = embed
|
||||
else:
|
||||
self.head = VocabParallelEmbedding1D(vocab_size, embed_dim, dtype=dtype)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = reduce_grad(x, ParallelMode.PARALLEL_1D)
|
||||
x = F.linear(x, self.head.weight)
|
||||
return x
|
||||
|
||||
|
||||
###################################
|
||||
|
||||
|
||||
class HiddenParallelEmbedding(torch.nn.Module):
|
||||
"""Language model embeddings.
|
||||
|
||||
Arguments:
|
||||
hidden_size: hidden size
|
||||
vocab_size: vocabulary size
|
||||
max_sequence_length: maximum size of sequence. This
|
||||
is used for positional embedding
|
||||
embedding_dropout_prob: dropout probability for embeddings
|
||||
init_method: weight initialization method
|
||||
num_tokentypes: size of the token-type embeddings. 0 value
|
||||
will ignore this embedding
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
max_sequence_length,
|
||||
embedding_dropout_prob,
|
||||
dtype=torch.float,
|
||||
padding_idx: int = 0,
|
||||
num_tokentypes=0,
|
||||
):
|
||||
super(HiddenParallelEmbedding, self).__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_tokentypes = num_tokentypes
|
||||
|
||||
# Word embeddings (parallel).
|
||||
self.word_embeddings = HiddenParallelEmbedding1D(vocab_size, hidden_size, dtype, padding_idx)
|
||||
self._word_embeddings_key = 'word_embeddings'
|
||||
|
||||
# Position embedding (serial).
|
||||
self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size)
|
||||
self._position_embeddings_key = 'position_embeddings'
|
||||
# Initialize the position embeddings.
|
||||
# self.init_method(self.position_embeddings.weight)
|
||||
|
||||
# Token type embedding.
|
||||
# Add this as an optional field that can be added through
|
||||
# method call so we can load a pretrain model without
|
||||
# token types and add them as needed.
|
||||
self._tokentype_embeddings_key = 'tokentype_embeddings'
|
||||
if self.num_tokentypes > 0:
|
||||
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size)
|
||||
# Initialize the token-type embeddings.
|
||||
# self.init_method(self.tokentype_embeddings.weight)
|
||||
else:
|
||||
self.tokentype_embeddings = None
|
||||
|
||||
# Embeddings dropout
|
||||
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
|
||||
|
||||
def zero_parameters(self):
|
||||
"""Zero out all parameters in embedding."""
|
||||
self.word_embeddings.weight.data.fill_(0)
|
||||
self.word_embeddings.weight.shared = True
|
||||
self.position_embeddings.weight.data.fill_(0)
|
||||
self.position_embeddings.weight.shared = True
|
||||
if self.num_tokentypes > 0:
|
||||
self.tokentype_embeddings.weight.data.fill_(0)
|
||||
self.tokentype_embeddings.weight.shared = True
|
||||
|
||||
def add_tokentype_embeddings(self, num_tokentypes):
|
||||
"""Add token-type embedding. This function is provided so we can add
|
||||
token-type embeddings in case the pretrained model does not have it.
|
||||
This allows us to load the model normally and then add this embedding.
|
||||
"""
|
||||
if self.tokentype_embeddings is not None:
|
||||
raise Exception('tokentype embeddings is already initialized')
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True)
|
||||
self.num_tokentypes = num_tokentypes
|
||||
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
|
||||
# Initialize the token-type embeddings.
|
||||
# self.init_method(self.tokentype_embeddings.weight)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, tokentype_ids=None):
|
||||
if input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device())
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
embeddings = words_embeddings + position_embeddings
|
||||
|
||||
# Dropout.
|
||||
with seed(ParallelMode.TENSOR):
|
||||
embeddings = self.embedding_dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
|
||||
"""For easy load."""
|
||||
|
||||
state_dict_ = {}
|
||||
state_dict_[self._word_embeddings_key] \
|
||||
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
|
||||
state_dict_[self._position_embeddings_key] \
|
||||
= self.position_embeddings.state_dict(
|
||||
destination, prefix, keep_vars)
|
||||
if self.num_tokentypes > 0:
|
||||
state_dict_[self._tokentype_embeddings_key] \
|
||||
= self.tokentype_embeddings.state_dict(
|
||||
destination, prefix, keep_vars)
|
||||
|
||||
return state_dict_
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
"""Customized load."""
|
||||
|
||||
# Word embedding.
|
||||
if self._word_embeddings_key in state_dict:
|
||||
state_dict_ = state_dict[self._word_embeddings_key]
|
||||
else:
|
||||
# for backward compatibility.
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if 'word_embeddings' in key:
|
||||
state_dict_[key.split('word_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
self.word_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
|
||||
# Position embedding.
|
||||
if self._position_embeddings_key in state_dict:
|
||||
state_dict_ = state_dict[self._position_embeddings_key]
|
||||
else:
|
||||
# for backward compatibility.
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if 'position_embeddings' in key:
|
||||
state_dict_[key.split('position_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
|
||||
# Tokentype embedding.
|
||||
if self.num_tokentypes > 0:
|
||||
state_dict_ = {}
|
||||
if self._tokentype_embeddings_key in state_dict:
|
||||
state_dict_ = state_dict[self._tokentype_embeddings_key]
|
||||
else:
|
||||
# for backward compatibility.
|
||||
for key in state_dict.keys():
|
||||
if 'tokentype_embeddings' in key:
|
||||
state_dict_[key.split('tokentype_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
if len(state_dict_.keys()) > 0:
|
||||
self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
else:
|
||||
print('***WARNING*** expected tokentype embeddings in the '
|
||||
'checkpoint but could not find it',
|
||||
flush=True)
|
||||
|
||||
|
||||
class HiddenParallelEmbedding1D(torch.nn.Module):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
This is mainly adapted from torch.nn.Embedding and all the default
|
||||
values are kept.
|
||||
Arguments:
|
||||
num_embeddings: vocabulary size.
|
||||
embedding_dim: size of hidden state.
|
||||
init_method: method to initialize weights.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings, embedding_dim, dtype=torch.float, padding_idx: int = None, init_method=None):
|
||||
super(HiddenParallelEmbedding1D, self).__init__()
|
||||
# Keep the input dimensions.
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
embed_dim_per_partition = divide(embedding_dim, gpc.tensor_parallel_size)
|
||||
# Set the details for compatibility.
|
||||
self.padding_idx = padding_idx
|
||||
self.max_norm = None
|
||||
self.norm_type = 2.
|
||||
self.scale_grad_by_freq = False
|
||||
self.sparse = False
|
||||
self._weight = None
|
||||
|
||||
# Allocate weights and initialize.
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
self.weight = Parameter(torch.empty(num_embeddings, embed_dim_per_partition, **factory_kwargs))
|
||||
init.uniform_(self.weight, -1, 1)
|
||||
|
||||
def forward(self, input_):
|
||||
|
||||
# Get the embeddings.
|
||||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, self.max_norm, self.norm_type,
|
||||
self.scale_grad_by_freq, self.sparse)
|
||||
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
return output
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class HiddenParallelGPTLMHead1D(ParallelLayer):
|
||||
"""
|
||||
Language model head that shares the same parameters with the embedding matrix.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed=None,
|
||||
embed_dim=None,
|
||||
vocab_size=None,
|
||||
dtype=None,
|
||||
):
|
||||
super().__init__()
|
||||
if embed is not None:
|
||||
self.head = embed
|
||||
self.synced_embed = True
|
||||
else:
|
||||
# self.embedding = HiddenParallelEmbedding1D(vocab_size, hidden_size, dtype, padding_idx)
|
||||
# (hidden_size/q, vocab_size)
|
||||
self.synced_embed = False
|
||||
self.head = Linear1D_Row(in_features=embed_dim,
|
||||
out_features=vocab_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
parallel_input=False)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if self.synced_embed:
|
||||
x = F.linear(x, self.head.weight)
|
||||
else:
|
||||
x = self.head(x)
|
||||
|
||||
return x
|
349
examples/language/gpt/titans/model/gpt1d.py
Normal file
349
examples/language/gpt/titans/model/gpt1d.py
Normal file
@@ -0,0 +1,349 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch import nn as nn
|
||||
|
||||
from colossalai import kernel
|
||||
from colossalai import nn as col_nn
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
|
||||
from colossalai.nn.layer import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.nn.layer.base_layer import ParallelLayer
|
||||
from colossalai.nn.layer.utils import ACT2FN, divide
|
||||
from colossalai.utils import checkpoint
|
||||
from colossalai.utils.activation_checkpoint import checkpoint
|
||||
|
||||
__all__ = [
|
||||
'GPTMLP1D', 'GPTSelfAttention1D', 'GPTTransformerLayer1D', 'FusedGPTSelfAttention1D', 'FusedGPTTransformerLayer1D'
|
||||
]
|
||||
|
||||
|
||||
class GPTMLP1D(ParallelLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
mlp_ratio: int,
|
||||
act_func: str = 'gelu',
|
||||
dropout_prob: float = 0.,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_features = in_features
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.checkpoint = checkpoint
|
||||
self.skip_bias_add = skip_bias_add
|
||||
|
||||
self.act = ACT2FN[act_func]
|
||||
skip_dense_1_add_bias = False
|
||||
|
||||
# Project to mlp_ratio * h.
|
||||
self.dense_1 = Linear1D_Col(
|
||||
self.in_features,
|
||||
int(self.mlp_ratio * self.in_features),
|
||||
dtype=dtype,
|
||||
gather_output=False,
|
||||
skip_bias_add=skip_dense_1_add_bias,
|
||||
)
|
||||
|
||||
# Project back to h.
|
||||
self.dense_2 = Linear1D_Row(
|
||||
int(self.mlp_ratio * self.in_features),
|
||||
self.in_features,
|
||||
dtype=dtype,
|
||||
parallel_input=True,
|
||||
)
|
||||
|
||||
self.dropout = col_nn.Dropout(dropout_prob)
|
||||
|
||||
def _forward(self, hidden_states: Tensor) -> Tensor:
|
||||
intermediate_output = self.dense_1(hidden_states)
|
||||
intermediate_output = self.act(intermediate_output)
|
||||
|
||||
output = self.dense_2(intermediate_output)
|
||||
output = self.dropout(output)
|
||||
return output
|
||||
|
||||
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
|
||||
return checkpoint(self._forward, False, hidden_states)
|
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor:
|
||||
if self.checkpoint:
|
||||
return self._checkpoint_forward(hidden_states)
|
||||
else:
|
||||
return self._forward(hidden_states)
|
||||
|
||||
|
||||
class GenericGPTSelfAttention1D(ParallelLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
attention_dropout_prob: float,
|
||||
hidden_dropout_prob: float,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings=1024,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.attention_head_size = divide(hidden_size, num_attention_heads)
|
||||
self.num_attention_heads_per_partition = divide(num_attention_heads, gpc.tensor_parallel_size)
|
||||
self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size)
|
||||
self.checkpoint = checkpoint
|
||||
self.query_key_value = Linear1D_Col(
|
||||
hidden_size,
|
||||
3 * hidden_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.attention_dropout = col_nn.Dropout(attention_dropout_prob)
|
||||
self.dense = Linear1D_Row(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
dtype=dtype,
|
||||
parallel_input=True,
|
||||
)
|
||||
self.dropout = col_nn.Dropout(hidden_dropout_prob)
|
||||
|
||||
def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer):
|
||||
raise NotImplementedError
|
||||
|
||||
def _forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
|
||||
query_key_value = self.query_key_value(hidden_states)
|
||||
new_qkv_shape = query_key_value.shape[:-1] + \
|
||||
(self.num_attention_heads_per_partition, 3 * self.attention_head_size)
|
||||
query_key_value = query_key_value.view(new_qkv_shape)
|
||||
query_key_value = query_key_value.permute((0, 2, 1, 3))
|
||||
query_layer, key_layer, value_layer = torch.chunk(query_key_value, 3, dim=-1)
|
||||
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
attention_scores = self.softmax_forward(attention_scores, attention_mask, query_layer, key_layer)
|
||||
|
||||
attention_scores = attention_scores.type(value_layer.dtype)
|
||||
|
||||
attention_probs = self.attention_dropout(attention_scores)
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.transpose(1, 2)
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
output = self.dense(context_layer)
|
||||
output = self.dropout(output)
|
||||
|
||||
return output
|
||||
|
||||
def _checkpoint_forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
|
||||
return checkpoint(self._forward, False, hidden_states, attention_mask)
|
||||
|
||||
def forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
|
||||
if self.checkpoint:
|
||||
return self._checkpoint_forward(hidden_states, attention_mask)
|
||||
else:
|
||||
return self._forward(hidden_states, attention_mask)
|
||||
|
||||
|
||||
class GPTSelfAttention1D(GenericGPTSelfAttention1D):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
attention_dropout_prob: float,
|
||||
hidden_dropout_prob: float,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings=1024):
|
||||
super().__init__(hidden_size,
|
||||
num_attention_heads,
|
||||
attention_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
max_positions = max_position_embeddings
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.tril(torch.ones((max_positions, max_positions),
|
||||
dtype=torch.uint8)).view(1, 1, max_positions, max_positions),
|
||||
)
|
||||
self.register_buffer("masked_bias", torch.tensor(-1e4))
|
||||
|
||||
def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer):
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
# causal mask
|
||||
query_length, key_length = query_layer.size(-2), key_layer.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].bool()
|
||||
attention_scores = torch.where(causal_mask, attention_scores, self.masked_bias.to(attention_scores))
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
attention_scores = attention_scores + attention_mask
|
||||
attention_scores = self.softmax(attention_scores)
|
||||
return attention_scores
|
||||
|
||||
|
||||
class FusedGPTSelfAttention1D(GenericGPTSelfAttention1D):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
attention_dropout_prob: float,
|
||||
hidden_dropout_prob: float,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings=1024):
|
||||
super().__init__(hidden_size,
|
||||
num_attention_heads,
|
||||
attention_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings)
|
||||
self.softmax = kernel.FusedScaleMaskSoftmax(input_in_fp16=True,
|
||||
input_in_bf16=False,
|
||||
attn_mask_type=AttnMaskType.causal,
|
||||
scaled_masked_softmax_fusion=True,
|
||||
mask_func=None,
|
||||
softmax_in_fp32=True,
|
||||
scale=math.sqrt(self.attention_head_size))
|
||||
|
||||
def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer):
|
||||
return self.softmax(attention_scores, attention_mask)
|
||||
|
||||
|
||||
class GenericGPTTransformerLayer1D(ParallelLayer):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: float = 4.0,
|
||||
attention_dropout_prob: float = 0.,
|
||||
hidden_dropout_prob: float = 0.,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
apply_post_layer_norm: bool = False,
|
||||
attention=None,
|
||||
layer_norm=None):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self.dtype = dtype
|
||||
self.norm1 = layer_norm(hidden_size, eps=layer_norm_epsilon)
|
||||
self.apply_post_layer_norm = apply_post_layer_norm
|
||||
self.attention = attention(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout_prob=attention_dropout_prob,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
checkpoint=False,
|
||||
)
|
||||
|
||||
self.norm2 = layer_norm(hidden_size, eps=layer_norm_epsilon)
|
||||
self.mlp = GPTMLP1D(
|
||||
in_features=hidden_size,
|
||||
dropout_prob=hidden_dropout_prob,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
dtype=dtype,
|
||||
checkpoint=False,
|
||||
)
|
||||
|
||||
def _forward(self, hidden_states, attention_mask) -> Tensor:
|
||||
if not self.apply_post_layer_norm:
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
if self.apply_post_layer_norm:
|
||||
residual = hidden_states
|
||||
attention_output = self.attention(hidden_states, attention_mask)
|
||||
hidden_states = residual + attention_output
|
||||
|
||||
if not self.apply_post_layer_norm:
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
if self.apply_post_layer_norm:
|
||||
residual = hidden_states
|
||||
feed_forward_hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + feed_forward_hidden_states
|
||||
|
||||
output = (hidden_states, attention_mask)
|
||||
return output
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
if self.checkpoint:
|
||||
return checkpoint(self._forward, False, hidden_states, attention_mask)
|
||||
else:
|
||||
return self._forward(hidden_states, attention_mask)
|
||||
|
||||
|
||||
class GPTTransformerLayer1D(GenericGPTTransformerLayer1D):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: float = 4,
|
||||
attention_dropout_prob: float = 0,
|
||||
hidden_dropout_prob: float = 0,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 0.00001,
|
||||
apply_post_layer_norm: bool = False):
|
||||
attention = GPTSelfAttention1D
|
||||
layer_norm = nn.LayerNorm
|
||||
super().__init__(hidden_size,
|
||||
num_attention_heads,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout_prob=attention_dropout_prob,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
apply_post_layer_norm=apply_post_layer_norm,
|
||||
attention=attention,
|
||||
layer_norm=layer_norm)
|
||||
|
||||
|
||||
class FusedGPTTransformerLayer1D(GenericGPTTransformerLayer1D):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: float = 4,
|
||||
attention_dropout_prob: float = 0,
|
||||
hidden_dropout_prob: float = 0,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 0.00001,
|
||||
apply_post_layer_norm: bool = False):
|
||||
attention = FusedGPTSelfAttention1D
|
||||
layer_norm = kernel.LayerNorm
|
||||
super().__init__(hidden_size,
|
||||
num_attention_heads,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout_prob=attention_dropout_prob,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
apply_post_layer_norm=apply_post_layer_norm,
|
||||
attention=attention,
|
||||
layer_norm=layer_norm)
|
322
examples/language/gpt/titans/model/pipeline_gpt1d.py
Normal file
322
examples/language/gpt/titans/model/pipeline_gpt1d.py
Normal file
@@ -0,0 +1,322 @@
|
||||
import inspect
|
||||
|
||||
# import model_zoo.gpt.gpt as col_gpt
|
||||
import titans.model.gpt.gpt as col_gpt
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai import kernel
|
||||
from colossalai import nn as col_nn
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
|
||||
from colossalai.pipeline.utils import partition_uniform
|
||||
|
||||
from .embed import HiddenParallelEmbedding, HiddenParallelGPTLMHead1D, VocabParallelEmbedding, VocabParallelGPTLMHead1D
|
||||
from .gpt1d import FusedGPTTransformerLayer1D, GPTTransformerLayer1D
|
||||
|
||||
__all__ = [
|
||||
'GPT2_small_pipeline_1D',
|
||||
'GPT2_exlarge_pipeline_1D',
|
||||
'GPT3_pipeline_1D',
|
||||
'GPT2_exlarge_pipeline_hybrid',
|
||||
'GPT2_small_pipeline_hybrid',
|
||||
'GPT3_pipeline_hybrid',
|
||||
]
|
||||
|
||||
|
||||
class GenericPipelineGPT(nn.Module):
|
||||
|
||||
def __init__(self, embedding=None, blocks=None, norm=None, head=None) -> None:
|
||||
super().__init__()
|
||||
self.embedding = embedding
|
||||
self.blocks = blocks
|
||||
self.norm = norm
|
||||
self.head = head
|
||||
assert blocks is not None
|
||||
if norm is not None or head is not None:
|
||||
assert norm is not None and head is not None
|
||||
|
||||
def forward(self, hidden_states=None, input_ids=None, attention_mask=None):
|
||||
if self.embedding is not None:
|
||||
hidden_states = self.embedding(input_ids=input_ids)
|
||||
batch_size = hidden_states.shape[0]
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||
for block in self.blocks:
|
||||
hidden_states, attention_mask = block(hidden_states, attention_mask)
|
||||
if self.norm is not None:
|
||||
hidden_states = self.head(self.norm(hidden_states))
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PipelineGPT1D(GenericPipelineGPT):
|
||||
|
||||
def __init__(self,
|
||||
num_layers: int = 12,
|
||||
hidden_size: int = 768,
|
||||
num_attention_heads: int = 12,
|
||||
vocab_size: int = 50304,
|
||||
embed_drop_rate: float = 0.,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: int = 4.0,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_rate: float = 0.,
|
||||
dtype: torch.dtype = torch.float,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
apply_post_layer_norm: bool = False,
|
||||
first: bool = False,
|
||||
last: bool = False,
|
||||
embed_split_hidden=False):
|
||||
embedding = None
|
||||
norm = None
|
||||
head = None
|
||||
embed_cls = VocabParallelEmbedding
|
||||
head_cls = VocabParallelGPTLMHead1D
|
||||
if embed_split_hidden:
|
||||
embed_cls = HiddenParallelEmbedding
|
||||
head_cls = HiddenParallelGPTLMHead1D
|
||||
if first:
|
||||
embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype)
|
||||
blocks = nn.ModuleList([
|
||||
GPTTransformerLayer1D(hidden_size,
|
||||
num_attention_heads,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout_prob=attn_drop_rate,
|
||||
hidden_dropout_prob=drop_rate,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
apply_post_layer_norm=apply_post_layer_norm) for _ in range(num_layers)
|
||||
])
|
||||
if last:
|
||||
norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype)
|
||||
super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head)
|
||||
|
||||
|
||||
class FusedPipelineGPT1D(GenericPipelineGPT):
|
||||
|
||||
def __init__(self,
|
||||
num_layers: int = 12,
|
||||
hidden_size: int = 768,
|
||||
num_attention_heads: int = 12,
|
||||
vocab_size: int = 50304,
|
||||
embed_drop_rate: float = 0.,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: int = 4.0,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_rate: float = 0.,
|
||||
dtype: torch.dtype = torch.float,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
apply_post_layer_norm: bool = False,
|
||||
first: bool = False,
|
||||
last: bool = False,
|
||||
embed_split_hidden=False):
|
||||
embedding = None
|
||||
norm = None
|
||||
head = None
|
||||
embed_cls = VocabParallelEmbedding
|
||||
head_cls = VocabParallelGPTLMHead1D
|
||||
if embed_split_hidden:
|
||||
embed_cls = HiddenParallelEmbedding
|
||||
head_cls = HiddenParallelGPTLMHead1D
|
||||
if first:
|
||||
embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype)
|
||||
blocks = nn.ModuleList([
|
||||
FusedGPTTransformerLayer1D(hidden_size,
|
||||
num_attention_heads,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout_prob=attn_drop_rate,
|
||||
hidden_dropout_prob=drop_rate,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
apply_post_layer_norm=apply_post_layer_norm) for _ in range(num_layers)
|
||||
])
|
||||
if last:
|
||||
norm = kernel.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype)
|
||||
super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head)
|
||||
|
||||
def forward(self, hidden_states=None, input_ids=None, attention_mask=None):
|
||||
if self.embedding is not None:
|
||||
hidden_states = self.embedding(input_ids=input_ids)
|
||||
attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
|
||||
for block in self.blocks:
|
||||
hidden_states, attention_mask = block(hidden_states, attention_mask)
|
||||
if self.norm is not None:
|
||||
hidden_states = self.head(self.norm(hidden_states))
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PipelineGPTHybrid(GenericPipelineGPT):
|
||||
|
||||
def __init__(self,
|
||||
num_layers: int = 12,
|
||||
hidden_size: int = 768,
|
||||
num_attention_heads: int = 12,
|
||||
vocab_size: int = 50304,
|
||||
embed_drop_rate: float = 0.,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: int = 4,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_rate: float = 0.,
|
||||
dtype: torch.dtype = torch.float,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
apply_post_layer_norm: bool = False,
|
||||
first: bool = False,
|
||||
last: bool = False,
|
||||
embed_split_hidden=False):
|
||||
embedding = None
|
||||
norm = None
|
||||
head = None
|
||||
if first:
|
||||
embedding = col_gpt.GPTEmbedding(hidden_size,
|
||||
vocab_size,
|
||||
max_position_embeddings,
|
||||
dropout=embed_drop_rate,
|
||||
dtype=dtype)
|
||||
blocks = nn.ModuleList([
|
||||
col_gpt.GPTBlock(hidden_size,
|
||||
num_attention_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout=attn_drop_rate,
|
||||
dropout=drop_rate,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
activation=nn.functional.gelu) for _ in range(num_layers)
|
||||
])
|
||||
if last:
|
||||
norm = col_nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
# head = col_gpt.GPTLMHead(vocab_size=vocab_size,
|
||||
# hidden_size=hidden_size,
|
||||
# dtype=dtype,
|
||||
# bias=False)
|
||||
head = col_nn.Classifier(hidden_size, vocab_size, dtype=dtype, bias=False)
|
||||
super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head)
|
||||
|
||||
|
||||
def _filter_kwargs(func, kwargs):
|
||||
sig = inspect.signature(func)
|
||||
return {k: v for k, v in kwargs.items() if k in sig.parameters}
|
||||
|
||||
|
||||
def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
|
||||
logger = get_dist_logger()
|
||||
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
else:
|
||||
pipeline_size = 1
|
||||
pipeline_rank = 0
|
||||
rank = gpc.get_global_rank()
|
||||
|
||||
if pipeline_size > 1:
|
||||
wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1])
|
||||
else:
|
||||
wrapper = None
|
||||
parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]
|
||||
models = []
|
||||
for start, end in parts:
|
||||
kwargs['num_layers'] = end - start
|
||||
kwargs['first'] = start == 0
|
||||
kwargs['last'] = end == num_layers
|
||||
logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers')
|
||||
chunk = module_cls(**_filter_kwargs(module_cls.__init__, kwargs)).to(device)
|
||||
|
||||
if wrapper is not None:
|
||||
if start == 0:
|
||||
wrapper.register_module(chunk.embedding.word_embeddings)
|
||||
elif end == num_layers:
|
||||
wrapper.register_module(chunk.head)
|
||||
models.append(chunk)
|
||||
if len(models) == 1:
|
||||
model = models[0]
|
||||
else:
|
||||
model = nn.ModuleList(models)
|
||||
|
||||
numel = 0
|
||||
for _, param in model.named_parameters(recurse=True):
|
||||
numel += param.numel()
|
||||
logger.info(f'Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB')
|
||||
return model
|
||||
|
||||
|
||||
def _build_gpt_pipeline_1d(num_layers, num_chunks, device=torch.device('cuda'), fused=False, **kwargs):
|
||||
model = FusedPipelineGPT1D if fused else PipelineGPT1D
|
||||
return _build_generic_gpt_pipeline_1d(model, num_layers, num_chunks, device, **kwargs)
|
||||
|
||||
|
||||
def _build_gpt_pipeline_hybrid(num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
|
||||
return _build_generic_gpt_pipeline_1d(PipelineGPTHybrid, num_layers, num_chunks, device, **kwargs)
|
||||
|
||||
|
||||
def GPT2_small_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
|
||||
cfg = dict(hidden_size=768,
|
||||
num_attention_heads=12,
|
||||
checkpoint=checkpoint,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden)
|
||||
return _build_gpt_pipeline_1d(12, num_chunks, fused=fused, **cfg)
|
||||
|
||||
|
||||
def GPT2_exlarge_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
|
||||
cfg = dict(hidden_size=1600,
|
||||
num_attention_heads=32,
|
||||
checkpoint=checkpoint,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden)
|
||||
return _build_gpt_pipeline_1d(48, num_chunks, fused=fused, **cfg)
|
||||
|
||||
|
||||
def GPT3_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
|
||||
cfg = dict(hidden_size=12288,
|
||||
num_attention_heads=96,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=2048,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden)
|
||||
return _build_gpt_pipeline_1d(96, num_chunks, fused=fused, **cfg)
|
||||
|
||||
|
||||
def GPT2_exlarge_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
|
||||
cfg = dict(hidden_size=1600,
|
||||
num_attention_heads=32,
|
||||
checkpoint=checkpoint,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden)
|
||||
return _build_gpt_pipeline_hybrid(48, num_chunks, **cfg)
|
||||
|
||||
|
||||
def GPT2_small_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
|
||||
cfg = dict(hidden_size=768,
|
||||
num_attention_heads=12,
|
||||
checkpoint=checkpoint,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden)
|
||||
return _build_gpt_pipeline_hybrid(12, num_chunks, **cfg)
|
||||
|
||||
|
||||
def GPT3_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
|
||||
cfg = dict(hidden_size=12288,
|
||||
num_attention_heads=96,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=2048,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden)
|
||||
return _build_gpt_pipeline_hybrid(96, num_chunks, **cfg)
|
Reference in New Issue
Block a user