restruct dir

This commit is contained in:
oahzxl 2023-01-06 11:39:26 +08:00
parent 27ab524096
commit efb1c64c30
19 changed files with 31 additions and 31 deletions

View File

@ -3,13 +3,13 @@ import time
import torch import torch
import torch.fx import torch.fx
from autochunk.chunk_codegen import ChunkCodeGen from colossalai.autochunk.chunk_codegen import ChunkCodeGen
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import MetaTensor from colossalai.fx.profiler import MetaTensor
from autochunk.evoformer.evoformer import evoformer_base from tests.test_autochunk.evoformer.evoformer import evoformer_base
from autochunk.openfold.evoformer import EvoformerBlock from tests.test_autochunk.openfold.evoformer import EvoformerBlock
def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None): def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None):
@ -94,7 +94,7 @@ def _build_openfold():
def benchmark_evoformer(): def benchmark_evoformer():
# init data and model # init data and model
msa_len = 256 msa_len = 256
pair_len = 1024 pair_len = 256
node = torch.randn(1, msa_len, pair_len, 256).cuda() node = torch.randn(1, msa_len, pair_len, 256).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda()
model = evoformer_base().cuda() model = evoformer_base().cuda()
@ -106,11 +106,11 @@ def benchmark_evoformer():
# build openfold # build openfold
chunk_size = 64 chunk_size = 64
# openfold = _build_openfold() openfold = _build_openfold()
# benchmark # benchmark
# _benchmark_evoformer(model, node, pair, "base") _benchmark_evoformer(model, node, pair, "base")
# _benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size) _benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size)
_benchmark_evoformer(autochunk, node, pair, "autochunk") _benchmark_evoformer(autochunk, node, pair, "autochunk")

View File

@ -12,8 +12,8 @@ from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
from colossalai.fx.profiler import MetaTensor from colossalai.fx.profiler import MetaTensor
from autochunk.evoformer.evoformer import evoformer_base from tests.test_autochunk.evoformer.evoformer import evoformer_base
from autochunk.chunk_codegen import ChunkCodeGen from ...colossalai.autochunk.chunk_codegen import ChunkCodeGen
with_codegen = True with_codegen = True

View File

@ -19,25 +19,25 @@ import torch.nn as nn
from typing import Tuple, Optional from typing import Tuple, Optional
from functools import partial from functools import partial
from openfold.primitives import Linear, LayerNorm from .primitives import Linear, LayerNorm
from openfold.dropout import DropoutRowwise, DropoutColumnwise from .dropout import DropoutRowwise, DropoutColumnwise
from openfold.msa import ( from .msa import (
MSARowAttentionWithPairBias, MSARowAttentionWithPairBias,
MSAColumnAttention, MSAColumnAttention,
MSAColumnGlobalAttention, MSAColumnGlobalAttention,
) )
from openfold.outer_product_mean import OuterProductMean from .outer_product_mean import OuterProductMean
from openfold.pair_transition import PairTransition from .pair_transition import PairTransition
from openfold.triangular_attention import ( from .triangular_attention import (
TriangleAttentionStartingNode, TriangleAttentionStartingNode,
TriangleAttentionEndingNode, TriangleAttentionEndingNode,
) )
from openfold.triangular_multiplicative_update import ( from .triangular_multiplicative_update import (
TriangleMultiplicationOutgoing, TriangleMultiplicationOutgoing,
TriangleMultiplicationIncoming, TriangleMultiplicationIncoming,
) )
from openfold.checkpointing import checkpoint_blocks, get_checkpoint_fn from .checkpointing import checkpoint_blocks, get_checkpoint_fn
from openfold.tensor_utils import chunk_layer from .tensor_utils import chunk_layer
class MSATransition(nn.Module): class MSATransition(nn.Module):

View File

@ -18,15 +18,15 @@ import torch
import torch.nn as nn import torch.nn as nn
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from openfold.primitives import ( from .primitives import (
Linear, Linear,
LayerNorm, LayerNorm,
Attention, Attention,
GlobalAttention, GlobalAttention,
_attention_chunked_trainable, _attention_chunked_trainable,
) )
from openfold.checkpointing import get_checkpoint_fn from .checkpointing import get_checkpoint_fn
from openfold.tensor_utils import ( from .tensor_utils import (
chunk_layer, chunk_layer,
permute_final_dims, permute_final_dims,
flatten_final_dims, flatten_final_dims,

View File

@ -19,8 +19,8 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.primitives import Linear from .primitives import Linear
from openfold.tensor_utils import chunk_layer from .tensor_utils import chunk_layer
class OuterProductMean(nn.Module): class OuterProductMean(nn.Module):

View File

@ -17,8 +17,8 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.primitives import Linear, LayerNorm from .primitives import Linear, LayerNorm
from openfold.tensor_utils import chunk_layer from .tensor_utils import chunk_layer
class PairTransition(nn.Module): class PairTransition(nn.Module):

View File

@ -21,8 +21,8 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.checkpointing import get_checkpoint_fn from .checkpointing import get_checkpoint_fn
from openfold.tensor_utils import ( from .tensor_utils import (
permute_final_dims, permute_final_dims,
flatten_final_dims, flatten_final_dims,
_chunk_slice, _chunk_slice,

View File

@ -20,8 +20,8 @@ from typing import Optional, List
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.primitives import Linear, LayerNorm, Attention from .primitives import Linear, LayerNorm, Attention
from openfold.tensor_utils import ( from .tensor_utils import (
chunk_layer, chunk_layer,
permute_final_dims, permute_final_dims,
flatten_final_dims, flatten_final_dims,

View File

@ -19,8 +19,8 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.primitives import Linear, LayerNorm from .primitives import Linear, LayerNorm
from openfold.tensor_utils import permute_final_dims from .tensor_utils import permute_final_dims
class TriangleMultiplicativeUpdate(nn.Module): class TriangleMultiplicativeUpdate(nn.Module):