[test] Fix/fix testcase (#5770)

* [fix] branch for fix testcase;

* [fix] fix test_analyzer & test_auto_parallel;

* [fix] remove local change about moe;

* [fix] rm local change moe;
This commit is contained in:
duanjunwen
2024-06-03 15:26:01 +08:00
committed by GitHub
parent 3f2be80530
commit 1b76564e16
4 changed files with 10 additions and 7 deletions

View File

@@ -2,7 +2,7 @@ import torch
import torch.nn as nn
from transformers import BertConfig, BertLMHeadModel, GPT2Config, GPT2LMHeadModel
from tests.components_to_test.registry import non_distributed_component_funcs
# from tests.components_to_test.registry import non_distributed_component_funcs
class GPTLMModel(nn.Module):
@@ -55,7 +55,7 @@ class BertLMModel(nn.Module):
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0]
@non_distributed_component_funcs.register(name="bert_")
# @non_distributed_component_funcs.register(name="bert_")
def get_bert_components():
vocab_size = 1024
seq_len = 64
@@ -74,7 +74,7 @@ def get_bert_components():
return bert_model_builder, bert_data_gen
@non_distributed_component_funcs.register(name="gpt2_")
# @non_distributed_component_funcs.register(name="gpt2_")
def get_gpt2_components():
vocab_size = 1024
seq_len = 8