[test] fixed tests failed due to dtensor change (#4082)

* [test] fixed tests failed due to dtensor change

* polish code
This commit is contained in:
Frank Lee
2023-06-26 15:50:07 +08:00
parent 92f6791095
commit c4b1b65931
37 changed files with 233 additions and 289 deletions

View File

@@ -58,13 +58,4 @@ def test_evoformer_block(model, shape, max_memory):
if __name__ == "__main__":
run_test(
rank=0,
data=get_data(LATENTS_SHAPE),
max_memory=None,
model=UNet2DModel,
print_code=False,
print_mem=True,
print_est_mem=False,
print_progress=False,
)
test_evoformer_block()