mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[inference] simplified config verification (#5346)
* [inference] simplified config verification * polish * polish
This commit is contained in:
@@ -21,11 +21,15 @@ def setup_seed(seed):
|
||||
def check_inference_engine(test_cai=False):
|
||||
setup_seed(20)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
model = LlamaForCausalLM(
|
||||
LlamaConfig(
|
||||
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
|
||||
model = (
|
||||
LlamaForCausalLM(
|
||||
LlamaConfig(
|
||||
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
|
||||
)
|
||||
)
|
||||
).cuda()
|
||||
.cuda()
|
||||
.half()
|
||||
)
|
||||
|
||||
model = model.eval()
|
||||
|
||||
@@ -70,7 +74,7 @@ def run_dist(rank, world_size, port):
|
||||
transformer_outputs = check_inference_engine(False)
|
||||
|
||||
for s1, s2 in zip(cai_outputs, transformer_outputs):
|
||||
assert s1 == s2
|
||||
assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
Reference in New Issue
Block a user