ColossalAI/tests/test_elixir/test_search/test_optimal.py
Haichen Huang 206280408a
[elixir] add elixir and its unit tests (#3835)
* [elixir] add elixir

* [elixir] add unit tests

* remove useless code

* fix python 3.8 issue

* fix typo

* add test skip

* add docstrings

* add docstrings

* add readme

* fix typo
2023-05-29 09:32:37 +08:00

31 lines
821 B
Python

from copy import deepcopy
import torch
from colossalai.elixir.cuda import gpu_device
from colossalai.elixir.search import optimal_search
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.utils import TEST_MODELS
def step_fn(model, inp):
model(**inp).backward()
@run_on_environment_flag('ELX')
def test_optimal_search():
model_fn, data_fn = TEST_MODELS.get('gpt2_small')
model = model_fn()
data = data_fn()
sr = optimal_search(model, 1, unified_dtype=torch.float16, overlap=True, verbose=True, inp=data, step_fn=step_fn)
chunk_plans = deepcopy(sr.param_chunk_plans)
for plan in chunk_plans:
assert plan.chunk_dtype == torch.float16
assert plan.kwargs.get('shard_device') == gpu_device()
if __name__ == '__main__':
test_optimal_search()