mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 14:33:20 +00:00
[example]add gpt2 benchmark example script. (#5295)
* benchmark gpt2 * fix fix fix fix * [doc] fix typo in Colossal-LLaMA-2/README.md (#5247) * [workflow] fixed build CI (#5240) * [workflow] fixed build CI * polish * polish * polish * polish * polish * [ci] fixed booster test (#5251) * [ci] fixed booster test * [ci] fixed booster test * [ci] fixed booster test * [ci] fixed ddp test (#5254) * [ci] fixed ddp test * polish * fix typo in applications/ColossalEval/README.md (#5250) * [ci] fix shardformer tests. (#5255) * fix ci fix * revert: revert p2p * feat: add enable_metadata_cache option * revert: enable t5 tests --------- Co-authored-by: Wenhao Chen <cwher@outlook.com> * [doc] fix doc typo (#5256) * [doc] fix annotation display * [doc] fix llama2 doc * [hotfix]: add pp sanity check and fix mbs arg (#5268) * fix: fix misleading mbs arg * feat: add pp sanity check * fix: fix 1f1b sanity check * [workflow] fixed incomplete bash command (#5272) * [workflow] fixed oom tests (#5275) * [workflow] fixed oom tests * polish * polish * polish * [ci] fix test_hybrid_parallel_plugin_checkpoint_io.py (#5276) * fix ci fix * fix test * revert: revert p2p * feat: add enable_metadata_cache option * revert: enable t5 tests * fix --------- Co-authored-by: Wenhao Chen <cwher@outlook.com> * [shardformer] hybridparallelplugin support gradients accumulation. (#5246) * support gradients acc fix fix fix fix fix fix fix fix fix fix fix fix fix * fix fix * fix fix fix * [hotfix] Fix ShardFormer test execution path when using sequence parallelism (#5230) * fix auto loading gpt2 tokenizer (#5279) * [doc] add llama2-13B disyplay (#5285) * Update README.md * fix 13b typo --------- Co-authored-by: binmakeswell <binmakeswell@gmail.com> * fix llama pretrain (#5287) * fix * fix * fix fix * fix fix fix * fix fix * benchmark gpt2 * fix fix fix fix * [workflow] fixed build CI (#5240) * [workflow] fixed build CI * polish * polish * polish * polish * polish * [ci] fixed booster test (#5251) * [ci] fixed booster test * [ci] fixed booster test * [ci] fixed booster test * fix fix * fix fix fix * fix * fix fix fix fix fix * fix * Update shardformer.py --------- Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: Frank Lee <somerlee.9@gmail.com> Co-authored-by: Wenhao Chen <cwher@outlook.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: Michelle <97082656+MichelleMa8@users.noreply.github.com> Co-authored-by: Desperado-Jia <502205863@qq.com>
This commit is contained in:
parent
4b8312c08e
commit
29695cf70c
2
.github/workflows/build_on_pr.yml
vendored
2
.github/workflows/build_on_pr.yml
vendored
@ -201,4 +201,4 @@ jobs:
|
|||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v3
|
||||||
with:
|
with:
|
||||||
name: report
|
name: report
|
||||||
path: report/
|
path: report/
|
2
.github/workflows/build_on_schedule.yml
vendored
2
.github/workflows/build_on_schedule.yml
vendored
@ -83,4 +83,4 @@ jobs:
|
|||||||
SERVER_URL: ${{github.server_url }}
|
SERVER_URL: ${{github.server_url }}
|
||||||
REPO: ${{ github.repository }}
|
REPO: ${{ github.repository }}
|
||||||
RUN_ID: ${{ github.run_id }}
|
RUN_ID: ${{ github.run_id }}
|
||||||
WEBHOOK_URL: ${{ secrets.LARK_NOTIFICATION_WEBHOOK_URL }}
|
WEBHOOK_URL: ${{ secrets.LARK_NOTIFICATION_WEBHOOK_URL }}
|
@ -36,6 +36,8 @@ from .pp_plugin_base import PipelinePluginBase
|
|||||||
|
|
||||||
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
|
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
|
||||||
|
|
||||||
|
PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}
|
||||||
|
|
||||||
|
|
||||||
def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
|
def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
|
||||||
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
|
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
|
||||||
@ -1059,6 +1061,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
overlap_communication=overlap_communication,
|
overlap_communication=overlap_communication,
|
||||||
cpu_offload=cpu_offload,
|
cpu_offload=cpu_offload,
|
||||||
partition_grad=(self.zero_stage == 2),
|
partition_grad=(self.zero_stage == 2),
|
||||||
|
forced_dtype=PRECISION_TORCH_TYPE[precision],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.max_norm = max_norm
|
self.max_norm = max_norm
|
||||||
|
@ -9,6 +9,7 @@ except:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import fused_weight_gradient_mlp_cuda
|
import fused_weight_gradient_mlp_cuda
|
||||||
|
|
||||||
_grad_accum_fusion_available = True
|
_grad_accum_fusion_available = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
_grad_accum_fusion_available = False
|
_grad_accum_fusion_available = False
|
||||||
@ -78,7 +79,8 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
|
|||||||
|
|
||||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
|
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
|
||||||
weight = weight.view(weight.shape)
|
weight = weight.view(weight.shape)
|
||||||
bias = bias.view(bias.shape)
|
if bias is not None:
|
||||||
|
bias = bias.view(bias.shape)
|
||||||
|
|
||||||
total_input = input
|
total_input = input
|
||||||
grad_input = grad_output.matmul(weight.T)
|
grad_input = grad_output.matmul(weight.T)
|
||||||
@ -91,9 +93,8 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
|
|||||||
if ctx.async_grad_allreduce:
|
if ctx.async_grad_allreduce:
|
||||||
# Asynchronous all-reduce
|
# Asynchronous all-reduce
|
||||||
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
||||||
# Delay the start of weight gradient computation shortly (3us) to have
|
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||||
# all-reduce scheduled first and have GPU resources allocated
|
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
||||||
_ = torch.empty(1, device=grad_output.device) + 1
|
|
||||||
|
|
||||||
grad_weight = total_input.t().matmul(grad_output)
|
grad_weight = total_input.t().matmul(grad_output)
|
||||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||||
@ -115,7 +116,6 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||||||
ctx.use_bias = bias is not None
|
ctx.use_bias = bias is not None
|
||||||
ctx.process_group = process_group
|
ctx.process_group = process_group
|
||||||
ctx.async_grad_allreduce = async_grad_allreduce
|
ctx.async_grad_allreduce = async_grad_allreduce
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
output = F.linear(input_, weight, bias)
|
output = F.linear(input_, weight, bias)
|
||||||
else:
|
else:
|
||||||
@ -143,9 +143,8 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||||||
if ctx.async_grad_allreduce:
|
if ctx.async_grad_allreduce:
|
||||||
# Asynchronous all-reduce
|
# Asynchronous all-reduce
|
||||||
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
||||||
# Delay the start of weight gradient computation shortly (3us) to have
|
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||||
# all-reduce scheduled first and have GPU resources allocated
|
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
||||||
_ = torch.empty(1, device=grad_output.device) + 1
|
|
||||||
|
|
||||||
if _grad_accum_fusion_available and weight.grad is not None:
|
if _grad_accum_fusion_available and weight.grad is not None:
|
||||||
grad = weight.grad
|
grad = weight.grad
|
||||||
@ -228,9 +227,8 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||||||
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
|
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
|
||||||
).contiguous()
|
).contiguous()
|
||||||
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||||
# Delay the start of weight gradient computation shortly (3us) to have
|
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||||
# reduce-scatter scheduled first and have GPU resources allocated
|
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
||||||
_ = torch.empty(1, device=grad_output.device) + 1
|
|
||||||
|
|
||||||
if _grad_accum_fusion_available and weight.grad is not None:
|
if _grad_accum_fusion_available and weight.grad is not None:
|
||||||
grad = weight.grad
|
grad = weight.grad
|
||||||
@ -394,9 +392,8 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||||||
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
|
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
|
||||||
).contiguous()
|
).contiguous()
|
||||||
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||||
# Delay the start of weight gradient computation shortly (3us) to have
|
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||||
# reduce-scatter scheduled first and have GPU resources allocated
|
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
||||||
_ = torch.empty(1, device=grad_output.device) + 1
|
|
||||||
|
|
||||||
grad_weight = total_input.t().matmul(grad_output)
|
grad_weight = total_input.t().matmul(grad_output)
|
||||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||||
@ -431,7 +428,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||||||
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
|
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
|
||||||
# calculate gradient
|
# calculate gradient
|
||||||
if len(input_parallel.shape) > 2:
|
if len(input_parallel.shape) > 2:
|
||||||
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
|
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
|
||||||
grad_weight = input_parallel.t().matmul(grad_output)
|
grad_weight = input_parallel.t().matmul(grad_output)
|
||||||
# wait until reduce-scatter finished
|
# wait until reduce-scatter finished
|
||||||
reducescatter_handle.wait()
|
reducescatter_handle.wait()
|
||||||
|
@ -24,6 +24,8 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
|||||||
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
||||||
from colossalai.shardformer.shard import ShardConfig
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
|
|
||||||
|
from ..layer import cross_entropy_1d
|
||||||
|
|
||||||
|
|
||||||
class GPT2PipelineForwards:
|
class GPT2PipelineForwards:
|
||||||
"""
|
"""
|
||||||
@ -326,7 +328,15 @@ class GPT2PipelineForwards:
|
|||||||
shift_labels = labels[..., 1:].contiguous()
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
# Flatten the tokens
|
# Flatten the tokens
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
if shard_config.enable_tensor_parallelism:
|
||||||
|
loss = cross_entropy_1d(
|
||||||
|
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (lm_logits,) + outputs[1:]
|
output = (lm_logits,) + outputs[1:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
@ -1006,3 +1016,84 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
|
from transformers import GPT2LMHeadModel
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self: GPT2LMHeadModel,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||||
|
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||||
|
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
transformer_outputs = self.transformer(
|
||||||
|
input_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
hidden_states = transformer_outputs[0]
|
||||||
|
|
||||||
|
lm_logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# move labels to correct device to enable model parallelism
|
||||||
|
labels = labels.to(lm_logits.device)
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
if shard_config.enable_tensor_parallelism:
|
||||||
|
loss = cross_entropy_1d(
|
||||||
|
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (lm_logits,) + transformer_outputs[1:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithCrossAttentions(
|
||||||
|
loss=loss,
|
||||||
|
logits=lm_logits,
|
||||||
|
past_key_values=transformer_outputs.past_key_values,
|
||||||
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
|
attentions=transformer_outputs.attentions,
|
||||||
|
cross_attentions=transformer_outputs.cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
return forward
|
||||||
|
@ -5,7 +5,12 @@ from torch import Tensor, nn
|
|||||||
|
|
||||||
import colossalai.shardformer.layer as col_nn
|
import colossalai.shardformer.layer as col_nn
|
||||||
|
|
||||||
from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, gpt2_sequence_parallel_forward_fn
|
from ..modeling.gpt2 import (
|
||||||
|
GPT2PipelineForwards,
|
||||||
|
get_gpt2_flash_attention_forward,
|
||||||
|
get_lm_forward_with_dist_cross_entropy,
|
||||||
|
gpt2_sequence_parallel_forward_fn,
|
||||||
|
)
|
||||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -87,9 +92,7 @@ class GPT2Policy(Policy):
|
|||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.c_proj",
|
suffix="mlp.c_proj",
|
||||||
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
||||||
kwargs={
|
kwargs={"seq_parallel": use_sequence_parallel},
|
||||||
"seq_parallel": use_sequence_parallel,
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attn.attn_dropout",
|
suffix="attn.attn_dropout",
|
||||||
@ -167,15 +170,35 @@ class GPT2Policy(Policy):
|
|||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
held_layers = []
|
held_layers = []
|
||||||
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
|
if stage_manager.is_interleave:
|
||||||
if stage_manager.is_first_stage():
|
assert stage_manager.num_model_chunks is not None
|
||||||
held_layers.append(module.wte)
|
layers_per_stage = self.distribute_layers(
|
||||||
held_layers.append(module.wpe)
|
len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks
|
||||||
held_layers.append(module.drop)
|
)
|
||||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
stage_indices = Policy.get_stage_index(
|
||||||
held_layers.extend(module.h[start_idx:end_idx])
|
layers_per_stage,
|
||||||
if stage_manager.is_last_stage():
|
stage_manager.stage,
|
||||||
held_layers.append(module.ln_f)
|
num_model_chunks=stage_manager.num_model_chunks,
|
||||||
|
num_stages=stage_manager.num_stages,
|
||||||
|
)
|
||||||
|
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
held_layers.append(module.wte)
|
||||||
|
held_layers.append(module.wpe)
|
||||||
|
held_layers.append(module.drop)
|
||||||
|
for start_idx, end_idx in stage_indices:
|
||||||
|
held_layers.extend(module.h[start_idx:end_idx])
|
||||||
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
held_layers.append(module.ln_f)
|
||||||
|
else:
|
||||||
|
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
|
||||||
|
if stage_manager.is_first_stage():
|
||||||
|
held_layers.append(module.wte)
|
||||||
|
held_layers.append(module.wpe)
|
||||||
|
held_layers.append(module.drop)
|
||||||
|
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||||
|
held_layers.extend(module.h[start_idx:end_idx])
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
held_layers.append(module.ln_f)
|
||||||
return held_layers
|
return held_layers
|
||||||
|
|
||||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||||
@ -189,13 +212,27 @@ class GPT2Policy(Policy):
|
|||||||
else:
|
else:
|
||||||
module = self.model.transformer
|
module = self.model.transformer
|
||||||
|
|
||||||
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
|
if stage_manager.is_interleave:
|
||||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
layers_per_stage = self.distribute_layers(
|
||||||
method_replacement = {
|
len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks
|
||||||
"forward": partial(
|
|
||||||
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
|
||||||
)
|
)
|
||||||
}
|
stage_manager.stage_indices = Policy.get_stage_index(
|
||||||
|
layers_per_stage,
|
||||||
|
stage_manager.stage,
|
||||||
|
num_model_chunks=stage_manager.num_model_chunks,
|
||||||
|
num_stages=stage_manager.num_stages,
|
||||||
|
)
|
||||||
|
method_replacement = {
|
||||||
|
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
|
||||||
|
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||||
|
method_replacement = {
|
||||||
|
"forward": partial(
|
||||||
|
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
||||||
|
)
|
||||||
|
}
|
||||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
||||||
|
|
||||||
|
|
||||||
@ -232,9 +269,10 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
|||||||
GPT2LMHeadModel: ModulePolicyDescription(
|
GPT2LMHeadModel: ModulePolicyDescription(
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}
|
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": False}
|
||||||
)
|
)
|
||||||
]
|
],
|
||||||
|
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
module_policy.update(addon_module)
|
module_policy.update(addon_module)
|
||||||
@ -249,7 +287,7 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
|||||||
|
|
||||||
def get_held_layers(self) -> List[nn.Module]:
|
def get_held_layers(self) -> List[nn.Module]:
|
||||||
held_layers = super().get_held_layers()
|
held_layers = super().get_held_layers()
|
||||||
if self.pipeline_stage_manager.is_last_stage():
|
if self.pipeline_stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
held_layers.append(self.model.lm_head)
|
held_layers.append(self.model.lm_head)
|
||||||
return held_layers
|
return held_layers
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -9,6 +10,9 @@ from ..policies.base_policy import Policy
|
|||||||
from .shard_config import ShardConfig
|
from .shard_config import ShardConfig
|
||||||
from .sharder import ModelSharder
|
from .sharder import ModelSharder
|
||||||
|
|
||||||
|
# set CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that when communication and computation overlap, the order of core scheduling is correct
|
||||||
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
||||||
|
|
||||||
|
|
||||||
class ShardFormer:
|
class ShardFormer:
|
||||||
"""
|
"""
|
||||||
|
0
examples/__init__.py
Normal file
0
examples/__init__.py
Normal file
0
examples/language/__init__.py
Normal file
0
examples/language/__init__.py
Normal file
@ -121,4 +121,4 @@ class RandomDataset(Dataset):
|
|||||||
"input_ids": self.input_ids[idx],
|
"input_ids": self.input_ids[idx],
|
||||||
"attention_mask": self.attention_mask[idx],
|
"attention_mask": self.attention_mask[idx],
|
||||||
"labels": self.input_ids[idx],
|
"labels": self.input_ids[idx],
|
||||||
}
|
}
|
228
examples/language/gpt/hybridparallelism/benchmark.py
Normal file
228
examples/language/gpt/hybridparallelism/benchmark.py
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
import argparse
|
||||||
|
import resource
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
|
||||||
|
from torch.optim import Adam
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||||
|
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
|
||||||
|
# import colossalai.utils.device as device_utils
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin
|
||||||
|
from colossalai.cluster import DistCoordinator
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
from examples.language.data_utils import RandomDataset
|
||||||
|
from examples.language.model_utils import format_numel_str, get_model_numel
|
||||||
|
from examples.language.performance_evaluator import PerformanceEvaluator
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Constants
|
||||||
|
# ==============================
|
||||||
|
MODEL_CONFIGS = {
|
||||||
|
"118M": GPT2Config(activation_function="gelu"),
|
||||||
|
"338M": GPT2Config(n_embd=1024, n_head=16, n_layer=24, activation_function="gelu"),
|
||||||
|
"738M": GPT2Config(n_embd=1280, n_head=20, n_layer=36, activation_function="gelu"),
|
||||||
|
"6.21B": GPT2Config(n_embd=4096, n_head=32, n_layer=32, n_positions=4096, activation_function="gelu"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# ==============================
|
||||||
|
# Parse Arguments
|
||||||
|
# ==============================
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("-c", "--config", type=str, default="6.21B", help="Model configuration")
|
||||||
|
parser.add_argument(
|
||||||
|
"-p",
|
||||||
|
"--plugin",
|
||||||
|
choices=["gemini", "gemini_auto", "fsdp", "fsdp_cpu", "3d", "3d_cpu"],
|
||||||
|
default="gemini",
|
||||||
|
help="Choose which plugin to use",
|
||||||
|
)
|
||||||
|
parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size")
|
||||||
|
parser.add_argument("-s", "--num_steps", type=int, default=200, help="Number of steps to run")
|
||||||
|
parser.add_argument("-i", "--ignore_steps", type=int, default=3, help="Number of steps to ignore")
|
||||||
|
parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing")
|
||||||
|
parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length")
|
||||||
|
parser.add_argument(
|
||||||
|
"-w", "--warmup_ratio", type=float, default=0.8, help="warm up ratio of non-model data. Only for gemini-auto"
|
||||||
|
)
|
||||||
|
parser.add_argument("-m", "--memory_limit", type=int, help="Gemini memory limit in mb")
|
||||||
|
parser.add_argument("--shard_param_frac", type=float, default=1.0, help="Shard param fraction. Only for gemini")
|
||||||
|
parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini")
|
||||||
|
parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini")
|
||||||
|
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
|
||||||
|
parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
|
||||||
|
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
|
||||||
|
parser.add_argument("--mbs", type=int, default=1)
|
||||||
|
parser.add_argument("--zero", type=int, default=0)
|
||||||
|
parser.add_argument("--pp_style", type=str, default="1f1b")
|
||||||
|
parser.add_argument("--num_model_chunks", type=int, default=2)
|
||||||
|
parser.add_argument("--cpu_offload", action="store_true", help="Use gradient checkpointing")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
colossalai.launch_from_torch({})
|
||||||
|
coordinator = DistCoordinator()
|
||||||
|
|
||||||
|
def empty_init():
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Initialize Booster
|
||||||
|
# ==============================
|
||||||
|
use_empty_init = True
|
||||||
|
if args.plugin == "gemini":
|
||||||
|
plugin = GeminiPlugin(
|
||||||
|
precision="bf16",
|
||||||
|
shard_param_frac=args.shard_param_frac,
|
||||||
|
offload_optim_frac=args.offload_optim_frac,
|
||||||
|
offload_param_frac=args.offload_param_frac,
|
||||||
|
tp_size=args.tp,
|
||||||
|
extra_dp_size=args.extra_dp,
|
||||||
|
)
|
||||||
|
elif args.plugin == "gemini_auto":
|
||||||
|
plugin = GeminiPlugin(
|
||||||
|
placement_policy="auto",
|
||||||
|
precision="bf16",
|
||||||
|
warmup_non_model_data_ratio=args.warmup_ratio,
|
||||||
|
tp_size=args.tp,
|
||||||
|
extra_dp_size=args.extra_dp,
|
||||||
|
)
|
||||||
|
elif args.plugin == "fsdp":
|
||||||
|
if use_empty_init:
|
||||||
|
plugin = TorchFSDPPlugin(
|
||||||
|
mixed_precision=MixedPrecision(
|
||||||
|
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
|
||||||
|
),
|
||||||
|
param_init_fn=empty_init(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
plugin = TorchFSDPPlugin(
|
||||||
|
mixed_precision=MixedPrecision(
|
||||||
|
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif args.plugin == "fsdp_cpu":
|
||||||
|
if use_empty_init:
|
||||||
|
plugin = TorchFSDPPlugin(
|
||||||
|
mixed_precision=MixedPrecision(
|
||||||
|
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
|
||||||
|
),
|
||||||
|
cpu_offload=CPUOffload(offload_params=True),
|
||||||
|
param_init_fn=empty_init(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
plugin = TorchFSDPPlugin(
|
||||||
|
mixed_precision=MixedPrecision(
|
||||||
|
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
|
||||||
|
),
|
||||||
|
cpu_offload=CPUOffload(offload_params=True),
|
||||||
|
)
|
||||||
|
elif args.plugin == "3d":
|
||||||
|
plugin = HybridParallelPlugin(
|
||||||
|
tp_size=args.tp,
|
||||||
|
pp_size=args.pp,
|
||||||
|
pp_style=args.pp_style,
|
||||||
|
zero_stage=args.zero,
|
||||||
|
num_model_chunks=args.num_model_chunks,
|
||||||
|
enable_all_optimization=True,
|
||||||
|
num_microbatches=args.mbs,
|
||||||
|
cpu_offload=args.cpu_offload,
|
||||||
|
precision="bf16",
|
||||||
|
)
|
||||||
|
elif args.plugin == "3d_cpu":
|
||||||
|
plugin = HybridParallelPlugin(
|
||||||
|
tp_size=args.tp,
|
||||||
|
pp_size=args.pp,
|
||||||
|
zero_stage=args.zero,
|
||||||
|
cpu_offload=True,
|
||||||
|
enable_fused_normalization=torch.cuda.is_available(),
|
||||||
|
num_microbatches=args.mbs,
|
||||||
|
initial_scale=2**8,
|
||||||
|
precision="bf16",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||||
|
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Initialize Dataset and Dataloader
|
||||||
|
# ==============================
|
||||||
|
dp_size = plugin.dp_size if isinstance(plugin, HybridParallelPlugin) else coordinator.world_size
|
||||||
|
|
||||||
|
config = MODEL_CONFIGS[args.config]
|
||||||
|
dataset = RandomDataset(
|
||||||
|
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
|
||||||
|
)
|
||||||
|
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Initialize Model and Optimizer
|
||||||
|
# ==============================
|
||||||
|
init_ctx = (
|
||||||
|
LazyInitContext(default_device=get_current_device())
|
||||||
|
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
|
||||||
|
else nullcontext()
|
||||||
|
)
|
||||||
|
|
||||||
|
with init_ctx:
|
||||||
|
model = GPT2LMHeadModel(config)
|
||||||
|
|
||||||
|
if args.grad_checkpoint:
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
model_numel = get_model_numel(model)
|
||||||
|
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||||
|
performance_evaluator = PerformanceEvaluator(
|
||||||
|
model_numel,
|
||||||
|
model.config.n_layer,
|
||||||
|
model.config.n_embd,
|
||||||
|
model.config.vocab_size,
|
||||||
|
args.grad_checkpoint,
|
||||||
|
args.ignore_steps,
|
||||||
|
dp_world_size=dp_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = Adam(model.parameters())
|
||||||
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
|
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
|
||||||
|
torch.set_default_dtype(torch.float)
|
||||||
|
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||||
|
coordinator.print_on_master(
|
||||||
|
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
|
||||||
|
data_iter = iter(dataloader)
|
||||||
|
for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()):
|
||||||
|
performance_evaluator.on_step_start(step)
|
||||||
|
booster.execute_pipeline(
|
||||||
|
data_iter, model, criterion=lambda outputs, inputs: outputs[0], optimizer=optimizer, return_loss=False
|
||||||
|
)
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
|
||||||
|
else:
|
||||||
|
for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())):
|
||||||
|
performance_evaluator.on_step_start(step)
|
||||||
|
outputs = model(**batch)
|
||||||
|
loss = outputs[0]
|
||||||
|
booster.backward(loss, optimizer)
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
performance_evaluator.on_step_end(**batch)
|
||||||
|
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||||
|
|
||||||
|
performance_evaluator.on_fit_end()
|
||||||
|
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -19,6 +19,9 @@ from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchF
|
|||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.lazy import LazyInitContext
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from examples.language.data_utils import RandomDataset
|
||||||
|
from examples.language.model_utils import format_numel_str, get_model_numel
|
||||||
|
from examples.language.performance_evaluator import PerformanceEvaluator
|
||||||
|
|
||||||
# ==============================
|
# ==============================
|
||||||
# Constants
|
# Constants
|
||||||
|
@ -102,4 +102,4 @@ class ModelZooRegistry(dict):
|
|||||||
return new_dict
|
return new_dict
|
||||||
|
|
||||||
|
|
||||||
model_zoo = ModelZooRegistry()
|
model_zoo = ModelZooRegistry()
|
@ -276,4 +276,4 @@ def test_gemini_plugin(early_stop: bool = True):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_gemini_plugin(early_stop=False)
|
test_gemini_plugin(early_stop=False)
|
@ -185,4 +185,4 @@ def test_gemini_plugin_3d(early_stop: bool = True):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_gemini_plugin(early_stop=False)
|
test_gemini_plugin(early_stop=False)
|
@ -186,4 +186,4 @@ def test_gemini_ckpIO_3d():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_gemini_ckpIO()
|
test_gemini_ckpIO()
|
@ -24,4 +24,4 @@ def test_torchvision_models_lazy_init(subset, default_device):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_torchvision_models_lazy_init("transformers", "cpu")
|
test_torchvision_models_lazy_init("transformers", "cpu")
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -11,8 +12,10 @@ from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLin
|
|||||||
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
|
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
# This code is copied from https://github.com/huggingface/transformers
|
# This code is copied from https://github.com/huggingface/transformers
|
||||||
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
||||||
|
|
||||||
|
|
||||||
class Conv1D(nn.Module):
|
class Conv1D(nn.Module):
|
||||||
"""
|
"""
|
||||||
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
|
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -11,6 +12,8 @@ from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
|
|||||||
from colossalai.tensor.d_tensor import is_distributed_tensor
|
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
||||||
|
|
||||||
|
|
||||||
def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
|
def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
|
||||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -11,8 +12,10 @@ from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLin
|
|||||||
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
|
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
# This code is copied from https://github.com/huggingface/transformers
|
# This code is copied from https://github.com/huggingface/transformers
|
||||||
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
||||||
|
|
||||||
|
|
||||||
class Conv1D(nn.Module):
|
class Conv1D(nn.Module):
|
||||||
"""
|
"""
|
||||||
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
|
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
|
||||||
|
Loading…
Reference in New Issue
Block a user