mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-18 11:48:53 +00:00
[fix] fix llama modeling policy;
This commit is contained in:
parent
cc0dfddcbc
commit
03fa79a55c
@ -96,7 +96,8 @@ class LlamaPolicy(Policy):
|
|||||||
target_key=attn_cls,
|
target_key=attn_cls,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.pipeline_stage_manager is not None:
|
# if self.pipeline_stage_manager is not None:
|
||||||
|
if self.pipeline_stage_manager is None:
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description={
|
description={
|
||||||
"forward": partial(
|
"forward": partial(
|
||||||
|
@ -325,6 +325,7 @@ def run_llama_test(test_config):
|
|||||||
).get_v_schedule()
|
).get_v_schedule()
|
||||||
test_config["scheduler_nodes"] = scheduler_nodes
|
test_config["scheduler_nodes"] = scheduler_nodes
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
print(f"name {name}")
|
||||||
if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name:
|
if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name:
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user