[inference] simplified config verification (#5346)

* [inference] simplified config verification

* polish

* polish
This commit is contained in:
Frank Lee
2024-02-01 15:31:01 +08:00
committed by GitHub
parent df0aa49585
commit f8e456d202
2 changed files with 40 additions and 60 deletions

View File

@@ -21,11 +21,15 @@ def setup_seed(seed):
def check_inference_engine(test_cai=False):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = LlamaForCausalLM(
LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
model = (
LlamaForCausalLM(
LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
)
)
).cuda()
.cuda()
.half()
)
model = model.eval()
@@ -70,7 +74,7 @@ def run_dist(rank, world_size, port):
transformer_outputs = check_inference_engine(False)
for s1, s2 in zip(cai_outputs, transformer_outputs):
assert s1 == s2
assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"
@pytest.mark.dist