mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-18 21:29:57 +00:00
[fix] fix mixtral policy;
This commit is contained in:
parent
f4d023ca6e
commit
292a504bea
@ -268,9 +268,11 @@ class MixtralPolicy(Policy):
|
|||||||
held_layers.append(module.embed_tokens)
|
held_layers.append(module.embed_tokens)
|
||||||
for start_idx, end_idx in stage_indices:
|
for start_idx, end_idx in stage_indices:
|
||||||
held_layers.extend(module.layers[start_idx:end_idx])
|
held_layers.extend(module.layers[start_idx:end_idx])
|
||||||
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
|
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||||
held_layers.append(module.norm)
|
stage_manager.is_last_stage(ignore_chunk=True)
|
||||||
elif stage_manager.is_last_stage(ignore_chunk=True):
|
):
|
||||||
|
# for zbv, when is_first_stage (last fwd), we append norm
|
||||||
|
# for interleaved, when is_last_stage (last fwd), we also append norm
|
||||||
held_layers.append(module.norm)
|
held_layers.append(module.norm)
|
||||||
else:
|
else:
|
||||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
|
Loading…
Reference in New Issue
Block a user