mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -2,22 +2,20 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
class GPTLMModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50257):
|
||||
class GPTLMModel(nn.Module):
|
||||
def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257):
|
||||
super().__init__()
|
||||
self.model = GPT2LMHeadModel(
|
||||
GPT2Config(n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size))
|
||||
GPT2Config(
|
||||
n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
# Only return lm_logits
|
||||
@@ -25,7 +23,6 @@ class GPTLMModel(nn.Module):
|
||||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
@@ -36,6 +33,7 @@ class GPTLMLoss(nn.Module):
|
||||
# Flatten the tokens
|
||||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
|
||||
def get_gpt2_components(model_type: str, batch_size: int):
|
||||
vocab_size = 1024
|
||||
seq_len = 8
|
||||
@@ -62,4 +60,4 @@ def get_gpt2_components(model_type: str, batch_size: int):
|
||||
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
return kwargs
|
||||
|
||||
return gpt2_model_builder, gpt2_data_gen
|
||||
return gpt2_model_builder, gpt2_data_gen
|
||||
|
Reference in New Issue
Block a user