fix falcon

This commit is contained in:
wangbluo 2025-05-22 16:50:40 +08:00
parent bafc80c3b0
commit 4a077e5dc3

View File

@ -246,6 +246,7 @@ class FalconPolicy(Policy):
module = self.model.transformer
stage_manager = self.pipeline_stage_manager
held_layers = []
held_layers.append(module.rotary_emb)
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = stage_manager.distribute_layers(len(module.h))