mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[test] fixed tests failed due to dtensor change (#4082)
* [test] fixed tests failed due to dtensor change * polish code
This commit is contained in:
@@ -22,7 +22,7 @@ from tests.kit.model_zoo import model_zoo
|
||||
@parameterize('use_safetensors', [False, True])
|
||||
def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: bool):
|
||||
from transformers import BertForSequenceClassification
|
||||
(model_fn, data_gen_fn, output_transform_fn, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
bert_model = model_fn()
|
||||
|
||||
with shared_tempdir() as tempdir:
|
||||
@@ -53,7 +53,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
|
||||
@parameterize('shard', [True, False])
|
||||
@parameterize('model_name', ['transformers_gpt'])
|
||||
def exam_state_dict(placement_policy, shard: bool, model_name: str):
|
||||
(model_fn, data_gen_fn, output_transform_fn, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
criterion = lambda x: x.mean()
|
||||
plugin = GeminiPlugin(placement_policy=placement_policy)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
Reference in New Issue
Block a user