[autochunk] support vit (#3084)

support vit for autochunk
* support some new ops for vit
* fix some bugs
* add test for vit
This commit is contained in:
Xuanlei Zhao
2023-03-10 10:23:26 +08:00
committed by GitHub
parent e58a3c804c
commit 10c61de2f7
8 changed files with 445 additions and 57 deletions

View File

@@ -39,7 +39,7 @@ def get_data(shape: tuple) -> Tuple[List, List]:
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("shape", [LATENTS_SHAPE])
@pytest.mark.parametrize("max_memory", [None])
@pytest.mark.parametrize("max_memory", [None, 150, 300])
def test_evoformer_block(model, shape, max_memory):
run_func = partial(
run_test,
@@ -57,7 +57,7 @@ if __name__ == "__main__":
max_memory=None,
model=UNet2DModel,
print_code=False,
print_mem=False,
print_mem=True,
print_est_mem=False,
print_progress=False,
)