diff --git a/tests/components_to_test/gpt.py b/tests/components_to_test/gpt.py index a0d70a2fd..3123211ad 100644 --- a/tests/components_to_test/gpt.py +++ b/tests/components_to_test/gpt.py @@ -47,8 +47,15 @@ class GPTLMModel(nn.Module): # Only return lm_logits return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] + def gpt2_micro(checkpoint=True): - return GPTLMModel(checkpoint=checkpoint, hidden_size=32, num_layers=2, num_attention_heads=4, max_seq_len=64, vocab_size=128) + return GPTLMModel(checkpoint=checkpoint, + hidden_size=32, + num_layers=2, + num_attention_heads=4, + max_seq_len=64, + vocab_size=128) + def gpt2_s(checkpoint=True): return GPTLMModel(checkpoint=checkpoint)