mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-27 15:57:16 +00:00
fix
This commit is contained in:
parent
840b9f3266
commit
519c2d0eab
@ -85,8 +85,6 @@ def _get_attention_mask(
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
is_causal=True,
|
is_causal=True,
|
||||||
)
|
)
|
||||||
elif self._attn_implementation == "flash_attention_2":
|
|
||||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
|
||||||
elif attention_mask is not None:
|
elif attention_mask is not None:
|
||||||
if batch_size <= 0:
|
if batch_size <= 0:
|
||||||
raise ValueError("batch_size has to be defined and > 0")
|
raise ValueError("batch_size has to be defined and > 0")
|
||||||
|
@ -47,7 +47,10 @@ class GPT2Policy(Policy):
|
|||||||
if self.tie_weight:
|
if self.tie_weight:
|
||||||
embedding_cls = col_nn.PaddingEmbedding
|
embedding_cls = col_nn.PaddingEmbedding
|
||||||
|
|
||||||
|
<<<<<<< Updated upstream
|
||||||
print("embedding_cls", embedding_cls)
|
print("embedding_cls", embedding_cls)
|
||||||
|
=======
|
||||||
|
>>>>>>> Stashed changes
|
||||||
|
|
||||||
if self.shard_config.enable_fused_normalization:
|
if self.shard_config.enable_fused_normalization:
|
||||||
norm_cls = col_nn.FusedLayerNorm
|
norm_cls = col_nn.FusedLayerNorm
|
||||||
@ -220,7 +223,6 @@ class GPT2Policy(Policy):
|
|||||||
|
|
||||||
if embedding_cls is not None:
|
if embedding_cls is not None:
|
||||||
# padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by
|
# padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by
|
||||||
print("embedding_cls", embedding_cls)
|
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
description=SubModuleReplacementDescription(
|
description=SubModuleReplacementDescription(
|
||||||
suffix="wte",
|
suffix="wte",
|
||||||
@ -391,7 +393,6 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
|||||||
module_policy = super().module_policy()
|
module_policy = super().module_policy()
|
||||||
module_policy[GPT2LMHeadModel] = ModulePolicyDescription()
|
module_policy[GPT2LMHeadModel] = ModulePolicyDescription()
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
print("self.shard_config.enable_tensor_parallelism", self.shard_config.enable_tensor_parallelism)
|
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
description=SubModuleReplacementDescription(
|
description=SubModuleReplacementDescription(
|
||||||
suffix="lm_head",
|
suffix="lm_head",
|
||||||
@ -428,7 +429,7 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
|||||||
self.set_pipeline_forward(
|
self.set_pipeline_forward(
|
||||||
model_cls=GPT2LMHeadModel,
|
model_cls=GPT2LMHeadModel,
|
||||||
new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward,
|
new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward,
|
||||||
# shard_config=self.shard_config,
|
shard_config=self.shard_config,
|
||||||
policy=module_policy,
|
policy=module_policy,
|
||||||
)
|
)
|
||||||
return module_policy
|
return module_policy
|
||||||
|
@ -165,11 +165,11 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
{
|
{
|
||||||
"tp_size": 4,
|
"tp_size": 4,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
"num_microbatches": 2,
|
"num_microbatches": 1,
|
||||||
"enable_sequence_parallelism": True,
|
"enable_sequence_parallelism": True,
|
||||||
"sequence_parallelism_mode": "split_gather",
|
"sequence_parallelism_mode": "split_gather",
|
||||||
"enable_flash_attention": True,
|
"enable_flash_attention": True,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": True,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
|
Loading…
Reference in New Issue
Block a user