mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-26 15:32:22 +00:00
code style
This commit is contained in:
parent
7a23deb584
commit
efe6fe3a33
@ -33,23 +33,6 @@ def _benchmark_evoformer(model: torch.nn.Module, node, pair, title):
|
||||
)
|
||||
|
||||
|
||||
def benchmark_evoformer():
|
||||
# init data and model
|
||||
msa_len = 300
|
||||
pair_len = 800
|
||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||
model = evoformer_base().cuda()
|
||||
|
||||
# build autochunk model
|
||||
max_memory = 3000 # MB
|
||||
autochunk = _build_autochunk(model, max_memory, node, pair)
|
||||
|
||||
# benchmark
|
||||
_benchmark_evoformer(model, node, pair, "openfold")
|
||||
_benchmark_evoformer(autochunk, node, pair, "autochunk")
|
||||
|
||||
|
||||
def _build_autochunk(model, max_memory, node, pair):
|
||||
# trace the module and replace codegen
|
||||
graph = ColoTracer().trace(
|
||||
@ -81,5 +64,22 @@ def _build_autochunk(model, max_memory, node, pair):
|
||||
return gm
|
||||
|
||||
|
||||
def benchmark_evoformer():
|
||||
# init data and model
|
||||
msa_len = 300
|
||||
pair_len = 800
|
||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||
model = evoformer_base().cuda()
|
||||
|
||||
# build autochunk model
|
||||
max_memory = 3000 # MB
|
||||
autochunk = _build_autochunk(model, max_memory, node, pair)
|
||||
|
||||
# benchmark
|
||||
_benchmark_evoformer(model, node, pair, "openfold")
|
||||
_benchmark_evoformer(autochunk, node, pair, "autochunk")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark_evoformer()
|
||||
|
Loading…
Reference in New Issue
Block a user