[Gemini] add unitests to check gemini correctness (#2015)

This commit is contained in:
Jiarui Fang
2022-11-24 16:51:45 +08:00
committed by GitHub
parent 0b0d8f9e17
commit 2e9cbfca12
13 changed files with 135 additions and 54 deletions

View File

@@ -1 +1,2 @@
from . import bert, gpt, inline_op_model, nested_model, no_leaf_module, repeated_computed_layer, resnet, simple_net
from .utils import run_fwd_bwd

View File

@@ -1,10 +1,12 @@
import torch
import torch.nn as nn
from .registry import non_distributed_component_funcs
from transformers import GPT2Config, GPT2LMHeadModel
from .utils.dummy_data_generator import DummyDataGenerator
from colossalai.utils.cuda import get_current_device
from .registry import non_distributed_component_funcs
from .utils.dummy_data_generator import DummyDataGenerator
class DummyDataLoader(DummyDataGenerator):
vocab_size = 128
@@ -15,8 +17,7 @@ class DummyDataLoader(DummyDataGenerator):
input_ids = torch.randint(0,
DummyDataLoader.vocab_size, (DummyDataLoader.batch_size, DummyDataLoader.seq_len),
device=get_current_device())
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask
return input_ids, input_ids
class GPTLMModel(nn.Module):
@@ -43,8 +44,9 @@ class GPTLMModel(nn.Module):
if checkpoint:
self.model.gradient_checkpointing_enable()
def forward(self, input_ids, attention_mask):
def forward(self, input_ids):
# Only return lm_logits
attention_mask = torch.ones_like(input_ids)
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]

View File

@@ -38,7 +38,7 @@ class DummyDataLoader(DummyDataGenerator):
return data, label
@non_distributed_component_funcs.register(name='inline_op_module')
@non_distributed_component_funcs.register(name='inline_op_model')
def get_training_components():
def model_builder(checkpoint=True):

View File

@@ -1 +1,2 @@
from .dummy_data_generator import DummyDataGenerator
from .dummy_data_generator import DummyDataGenerator
from .executor import run_fwd_bwd

View File

@@ -0,0 +1,15 @@
import torch
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False, use_init_ctx=False):
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if use_init_ctx:
model.backward(loss)
else:
loss.backward()