mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-09 11:58:06 +00:00
restruct dir
This commit is contained in:
parent
27ab524096
commit
efb1c64c30
@ -3,13 +3,13 @@ import time
|
|||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
|
|
||||||
from autochunk.chunk_codegen import ChunkCodeGen
|
from colossalai.autochunk.chunk_codegen import ChunkCodeGen
|
||||||
from colossalai.fx import ColoTracer
|
from colossalai.fx import ColoTracer
|
||||||
from colossalai.fx.graph_module import ColoGraphModule
|
from colossalai.fx.graph_module import ColoGraphModule
|
||||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
from colossalai.fx.profiler import MetaTensor
|
from colossalai.fx.profiler import MetaTensor
|
||||||
from autochunk.evoformer.evoformer import evoformer_base
|
from tests.test_autochunk.evoformer.evoformer import evoformer_base
|
||||||
from autochunk.openfold.evoformer import EvoformerBlock
|
from tests.test_autochunk.openfold.evoformer import EvoformerBlock
|
||||||
|
|
||||||
|
|
||||||
def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None):
|
def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None):
|
||||||
@ -94,7 +94,7 @@ def _build_openfold():
|
|||||||
def benchmark_evoformer():
|
def benchmark_evoformer():
|
||||||
# init data and model
|
# init data and model
|
||||||
msa_len = 256
|
msa_len = 256
|
||||||
pair_len = 1024
|
pair_len = 256
|
||||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||||
model = evoformer_base().cuda()
|
model = evoformer_base().cuda()
|
||||||
@ -106,11 +106,11 @@ def benchmark_evoformer():
|
|||||||
|
|
||||||
# build openfold
|
# build openfold
|
||||||
chunk_size = 64
|
chunk_size = 64
|
||||||
# openfold = _build_openfold()
|
openfold = _build_openfold()
|
||||||
|
|
||||||
# benchmark
|
# benchmark
|
||||||
# _benchmark_evoformer(model, node, pair, "base")
|
_benchmark_evoformer(model, node, pair, "base")
|
||||||
# _benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size)
|
_benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size)
|
||||||
_benchmark_evoformer(autochunk, node, pair, "autochunk")
|
_benchmark_evoformer(autochunk, node, pair, "autochunk")
|
||||||
|
|
||||||
|
|
@ -12,8 +12,8 @@ from colossalai.core import global_context as gpc
|
|||||||
from colossalai.fx.graph_module import ColoGraphModule
|
from colossalai.fx.graph_module import ColoGraphModule
|
||||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
|
||||||
from colossalai.fx.profiler import MetaTensor
|
from colossalai.fx.profiler import MetaTensor
|
||||||
from autochunk.evoformer.evoformer import evoformer_base
|
from tests.test_autochunk.evoformer.evoformer import evoformer_base
|
||||||
from autochunk.chunk_codegen import ChunkCodeGen
|
from ...colossalai.autochunk.chunk_codegen import ChunkCodeGen
|
||||||
with_codegen = True
|
with_codegen = True
|
||||||
|
|
||||||
|
|
@ -19,25 +19,25 @@ import torch.nn as nn
|
|||||||
from typing import Tuple, Optional
|
from typing import Tuple, Optional
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from openfold.primitives import Linear, LayerNorm
|
from .primitives import Linear, LayerNorm
|
||||||
from openfold.dropout import DropoutRowwise, DropoutColumnwise
|
from .dropout import DropoutRowwise, DropoutColumnwise
|
||||||
from openfold.msa import (
|
from .msa import (
|
||||||
MSARowAttentionWithPairBias,
|
MSARowAttentionWithPairBias,
|
||||||
MSAColumnAttention,
|
MSAColumnAttention,
|
||||||
MSAColumnGlobalAttention,
|
MSAColumnGlobalAttention,
|
||||||
)
|
)
|
||||||
from openfold.outer_product_mean import OuterProductMean
|
from .outer_product_mean import OuterProductMean
|
||||||
from openfold.pair_transition import PairTransition
|
from .pair_transition import PairTransition
|
||||||
from openfold.triangular_attention import (
|
from .triangular_attention import (
|
||||||
TriangleAttentionStartingNode,
|
TriangleAttentionStartingNode,
|
||||||
TriangleAttentionEndingNode,
|
TriangleAttentionEndingNode,
|
||||||
)
|
)
|
||||||
from openfold.triangular_multiplicative_update import (
|
from .triangular_multiplicative_update import (
|
||||||
TriangleMultiplicationOutgoing,
|
TriangleMultiplicationOutgoing,
|
||||||
TriangleMultiplicationIncoming,
|
TriangleMultiplicationIncoming,
|
||||||
)
|
)
|
||||||
from openfold.checkpointing import checkpoint_blocks, get_checkpoint_fn
|
from .checkpointing import checkpoint_blocks, get_checkpoint_fn
|
||||||
from openfold.tensor_utils import chunk_layer
|
from .tensor_utils import chunk_layer
|
||||||
|
|
||||||
|
|
||||||
class MSATransition(nn.Module):
|
class MSATransition(nn.Module):
|
@ -18,15 +18,15 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from openfold.primitives import (
|
from .primitives import (
|
||||||
Linear,
|
Linear,
|
||||||
LayerNorm,
|
LayerNorm,
|
||||||
Attention,
|
Attention,
|
||||||
GlobalAttention,
|
GlobalAttention,
|
||||||
_attention_chunked_trainable,
|
_attention_chunked_trainable,
|
||||||
)
|
)
|
||||||
from openfold.checkpointing import get_checkpoint_fn
|
from .checkpointing import get_checkpoint_fn
|
||||||
from openfold.tensor_utils import (
|
from .tensor_utils import (
|
||||||
chunk_layer,
|
chunk_layer,
|
||||||
permute_final_dims,
|
permute_final_dims,
|
||||||
flatten_final_dims,
|
flatten_final_dims,
|
@ -19,8 +19,8 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from openfold.primitives import Linear
|
from .primitives import Linear
|
||||||
from openfold.tensor_utils import chunk_layer
|
from .tensor_utils import chunk_layer
|
||||||
|
|
||||||
|
|
||||||
class OuterProductMean(nn.Module):
|
class OuterProductMean(nn.Module):
|
@ -17,8 +17,8 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from openfold.primitives import Linear, LayerNorm
|
from .primitives import Linear, LayerNorm
|
||||||
from openfold.tensor_utils import chunk_layer
|
from .tensor_utils import chunk_layer
|
||||||
|
|
||||||
|
|
||||||
class PairTransition(nn.Module):
|
class PairTransition(nn.Module):
|
@ -21,8 +21,8 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from openfold.checkpointing import get_checkpoint_fn
|
from .checkpointing import get_checkpoint_fn
|
||||||
from openfold.tensor_utils import (
|
from .tensor_utils import (
|
||||||
permute_final_dims,
|
permute_final_dims,
|
||||||
flatten_final_dims,
|
flatten_final_dims,
|
||||||
_chunk_slice,
|
_chunk_slice,
|
@ -20,8 +20,8 @@ from typing import Optional, List
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from openfold.primitives import Linear, LayerNorm, Attention
|
from .primitives import Linear, LayerNorm, Attention
|
||||||
from openfold.tensor_utils import (
|
from .tensor_utils import (
|
||||||
chunk_layer,
|
chunk_layer,
|
||||||
permute_final_dims,
|
permute_final_dims,
|
||||||
flatten_final_dims,
|
flatten_final_dims,
|
@ -19,8 +19,8 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from openfold.primitives import Linear, LayerNorm
|
from .primitives import Linear, LayerNorm
|
||||||
from openfold.tensor_utils import permute_final_dims
|
from .tensor_utils import permute_final_dims
|
||||||
|
|
||||||
|
|
||||||
class TriangleMultiplicativeUpdate(nn.Module):
|
class TriangleMultiplicativeUpdate(nn.Module):
|
Loading…
Reference in New Issue
Block a user