mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-29 08:47:17 +00:00
code style
This commit is contained in:
parent
7a23deb584
commit
efe6fe3a33
@ -33,23 +33,6 @@ def _benchmark_evoformer(model: torch.nn.Module, node, pair, title):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def benchmark_evoformer():
|
|
||||||
# init data and model
|
|
||||||
msa_len = 300
|
|
||||||
pair_len = 800
|
|
||||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
|
||||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
|
||||||
model = evoformer_base().cuda()
|
|
||||||
|
|
||||||
# build autochunk model
|
|
||||||
max_memory = 3000 # MB
|
|
||||||
autochunk = _build_autochunk(model, max_memory, node, pair)
|
|
||||||
|
|
||||||
# benchmark
|
|
||||||
_benchmark_evoformer(model, node, pair, "openfold")
|
|
||||||
_benchmark_evoformer(autochunk, node, pair, "autochunk")
|
|
||||||
|
|
||||||
|
|
||||||
def _build_autochunk(model, max_memory, node, pair):
|
def _build_autochunk(model, max_memory, node, pair):
|
||||||
# trace the module and replace codegen
|
# trace the module and replace codegen
|
||||||
graph = ColoTracer().trace(
|
graph = ColoTracer().trace(
|
||||||
@ -81,5 +64,22 @@ def _build_autochunk(model, max_memory, node, pair):
|
|||||||
return gm
|
return gm
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_evoformer():
|
||||||
|
# init data and model
|
||||||
|
msa_len = 300
|
||||||
|
pair_len = 800
|
||||||
|
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||||
|
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||||
|
model = evoformer_base().cuda()
|
||||||
|
|
||||||
|
# build autochunk model
|
||||||
|
max_memory = 3000 # MB
|
||||||
|
autochunk = _build_autochunk(model, max_memory, node, pair)
|
||||||
|
|
||||||
|
# benchmark
|
||||||
|
_benchmark_evoformer(model, node, pair, "openfold")
|
||||||
|
_benchmark_evoformer(autochunk, node, pair, "autochunk")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
benchmark_evoformer()
|
benchmark_evoformer()
|
||||||
|
Loading…
Reference in New Issue
Block a user