code style

This commit is contained in:
oahzxl 2022-12-29 14:47:47 +08:00
parent 7a23deb584
commit efe6fe3a33

View File

@ -33,23 +33,6 @@ def _benchmark_evoformer(model: torch.nn.Module, node, pair, title):
)
def benchmark_evoformer():
# init data and model
msa_len = 300
pair_len = 800
node = torch.randn(1, msa_len, pair_len, 256).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
model = evoformer_base().cuda()
# build autochunk model
max_memory = 3000 # MB
autochunk = _build_autochunk(model, max_memory, node, pair)
# benchmark
_benchmark_evoformer(model, node, pair, "openfold")
_benchmark_evoformer(autochunk, node, pair, "autochunk")
def _build_autochunk(model, max_memory, node, pair):
# trace the module and replace codegen
graph = ColoTracer().trace(
@ -81,5 +64,22 @@ def _build_autochunk(model, max_memory, node, pair):
return gm
def benchmark_evoformer():
# init data and model
msa_len = 300
pair_len = 800
node = torch.randn(1, msa_len, pair_len, 256).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
model = evoformer_base().cuda()
# build autochunk model
max_memory = 3000 # MB
autochunk = _build_autochunk(model, max_memory, node, pair)
# benchmark
_benchmark_evoformer(model, node, pair, "openfold")
_benchmark_evoformer(autochunk, node, pair, "autochunk")
if __name__ == "__main__":
benchmark_evoformer()