mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 04:55:25 +00:00
[shardformer] Pipeline/whisper (#4456)
* add some base tests and policies * finish whisper base model * add conditional generation * finish basic tests * whisper * finish whisper * finish whisper * del useless whisper test * fix * add argmin to replace * finish revision
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from torch import Tensor, nn
|
||||
|
||||
from colossalai.shardformer.layer import (
|
||||
@@ -228,13 +229,7 @@ class T5BasePolicy(Policy):
|
||||
def objective(num_encoder_stages):
|
||||
return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages))
|
||||
|
||||
num_encoder_stages = 0
|
||||
optimal_diff = 2**31 - 1
|
||||
for i in range(1, num_stages):
|
||||
attempt = objective(i)
|
||||
if attempt < optimal_diff:
|
||||
num_encoder_stages = i
|
||||
optimal_diff = attempt
|
||||
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
|
||||
num_decoder_stages = num_stages - num_encoder_stages
|
||||
|
||||
encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages)
|
||||
|
Reference in New Issue
Block a user