mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-23 14:10:29 +00:00
[NFC] polish test component gpt code style (#1567)
This commit is contained in:
parent
6159d45417
commit
e615cfc3a8
@ -47,8 +47,15 @@ class GPTLMModel(nn.Module):
|
||||
# Only return lm_logits
|
||||
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
|
||||
|
||||
|
||||
def gpt2_micro(checkpoint=True):
|
||||
return GPTLMModel(checkpoint=checkpoint, hidden_size=32, num_layers=2, num_attention_heads=4, max_seq_len=64, vocab_size=128)
|
||||
return GPTLMModel(checkpoint=checkpoint,
|
||||
hidden_size=32,
|
||||
num_layers=2,
|
||||
num_attention_heads=4,
|
||||
max_seq_len=64,
|
||||
vocab_size=128)
|
||||
|
||||
|
||||
def gpt2_s(checkpoint=True):
|
||||
return GPTLMModel(checkpoint=checkpoint)
|
||||
|
Loading…
Reference in New Issue
Block a user