ColossalAI/tests/test_elixir/utils/resnet.py
Haichen Huang 206280408a
[elixir] add elixir and its unit tests (#3835)
* [elixir] add elixir

* [elixir] add unit tests

* remove useless code

* fix python 3.8 issue

* fix typo

* add test skip

* add docstrings

* add docstrings

* add readme

* fix typo
2023-05-29 09:32:37 +08:00

24 lines
466 B
Python

import torch
import torch.nn as nn
from torchvision.models import resnet18
from tests.test_elixir.utils.registry import TEST_MODELS
def resnet_data_fn():
return dict(x=torch.randn(4, 3, 32, 32))
class ResNetModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.r = resnet18()
def forward(self, x):
output = self.r(x)
return output.sum()
TEST_MODELS.register('resnet', ResNetModel, resnet_data_fn)