diff --git a/autochunk_benchmark.py b/autochunk_benchmark.py index 20f615b21..679016438 100644 --- a/autochunk_benchmark.py +++ b/autochunk_benchmark.py @@ -9,20 +9,27 @@ from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.profiler import MetaTensor from evoformer.evoformer import evoformer_base +from openfold.evoformer import EvoformerBlock -def _benchmark_evoformer(model: torch.nn.Module, node, pair, title): +def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None): torch.cuda.reset_peak_memory_stats() now_mem = torch.cuda.memory_allocated() / 1024**2 loop = 16 with torch.no_grad(): for _ in range(loop // 4): - model(node, pair) + if chunk_size: + model(node, pair, chunk_size) + else: + model(node, pair) torch.cuda.synchronize() time1 = time.time() for _ in range(loop): - model(node, pair) + if chunk_size: + model(node, pair, chunk_size) + else: + model(node, pair) torch.cuda.synchronize() time2 = time.time() @@ -64,6 +71,26 @@ def _build_autochunk(model, max_memory, node, pair): return gm +def _build_openfold(): + model = EvoformerBlock( + c_m=256, + c_z=128, + c_hidden_msa_att=32, + c_hidden_opm=32, + c_hidden_mul=128, + c_hidden_pair_att=32, + no_heads_msa=8, + no_heads_pair=4, + transition_n=4, + msa_dropout=0.15, + pair_dropout=0.15, + inf=1e4, + eps=1e-4, + is_multimer=False, + ).cuda() + return model + + def benchmark_evoformer(): # init data and model msa_len = 300 @@ -74,10 +101,14 @@ def benchmark_evoformer(): # build autochunk model max_memory = 3000 # MB - autochunk = _build_autochunk(model, max_memory, node, pair) + autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair) + + # build openfold + openfold = _build_openfold() # benchmark - _benchmark_evoformer(model, node, pair, "openfold") + _benchmark_evoformer(model, node, pair, "base") + _benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=4) _benchmark_evoformer(autochunk, node, pair, "autochunk") diff --git a/evoformer_openfold/evoformer.py b/evoformer_openfold/evoformer.py deleted file mode 100644 index cfd2bb2a2..000000000 --- a/evoformer_openfold/evoformer.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch -import torch.nn as nn - -from .msa import MSAStack -from .ops import OutProductMean -from .triangle import PairStack - - -def print_memory(init_mem, text=None): - now_mem = torch.cuda.memory_allocated() / 1024 ** 2 - init_mem - max_mem = torch.cuda.max_memory_allocated() / 1024 ** 2 - init_mem - print("%s now:%.2f max:%.2f" % ("" if text is None else text, now_mem, max_mem)) - torch.cuda.reset_peak_memory_stats() - - -class EvoformerBlock(nn.Module): - - def __init__(self, d_node, d_pair): - super(EvoformerBlock, self).__init__() - - self.msa_stack = MSAStack(d_node, d_pair, p_drop=0.15) - self.communication = OutProductMean(n_feat=d_node, n_feat_out=d_pair, n_feat_proj=32) - self.pair_stack = PairStack(d_pair=d_pair) - - def forward(self, node, pair): - node = self.msa_stack(node, pair) - pair = pair + self.communication(node) - pair = self.pair_stack(pair) - return node, pair - - -class Evoformer(nn.Module): - - def __init__(self, d_node, d_pair): - super(Evoformer, self).__init__() - - self.blocks = nn.ModuleList() - for _ in range(1): - self.blocks.append(EvoformerBlock(d_node, d_pair)) - - def forward(self, node, pair): - for b in self.blocks: - node, pair = b(node, pair) - return node, pair - - -def evoformer_tiny(): - return Evoformer(d_node=64, d_pair=32) - - -def evoformer_base(): - return Evoformer(d_node=256, d_pair=128) - - -def evoformer_large(): - return Evoformer(d_node=512, d_pair=256) - - -__all__ = ['Evoformer', 'evoformer_base', 'evoformer_large'] diff --git a/evoformer_openfold/initializer.py b/evoformer_openfold/initializer.py deleted file mode 100755 index c6ce0659e..000000000 --- a/evoformer_openfold/initializer.py +++ /dev/null @@ -1,29 +0,0 @@ -import math - -import numpy as np -import torch.nn as nn - - -def glorot_uniform_af(x, gain=1.0): - """ - initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different: - In PyTorch: - [feature_out, feature_in, n_head ...] - In Jax: - [... n_head, feature_in, feature_out] - However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like: - [feature_in, n_head, feature_out] - - In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors - """ - fan_in, fan_out = x.shape[-2:] - if len(x.shape) > 2: - receptive_field_size = np.prod(x.shape[:-2]) - fan_in *= receptive_field_size - fan_out *= receptive_field_size - std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) - dev = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation - - nn.init.uniform_(x, -dev, dev) - - return x diff --git a/evoformer_openfold/kernel.py b/evoformer_openfold/kernel.py deleted file mode 100644 index 26ab5dc53..000000000 --- a/evoformer_openfold/kernel.py +++ /dev/null @@ -1,19 +0,0 @@ -import torch -import torch.nn.functional as F - - -def bias_sigmod_ele(y, bias, z): - return torch.sigmoid(y + bias) * z - - -def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor, - residual: torch.Tensor, prob: float) -> torch.Tensor: - out = (x + bias) * F.dropout(dropmask, p=prob, training=False) - out = residual + out - return out - - -def bias_ele_dropout_residual(ab: torch.Tensor, b: torch.Tensor, g: torch.Tensor, - dropout_mask: torch.Tensor, Z_raw: torch.Tensor, - prob: float) -> torch.Tensor: - return Z_raw + F.dropout(dropout_mask, p=prob, training=True) * (g * (ab + b)) \ No newline at end of file diff --git a/evoformer_openfold/msa.py b/evoformer_openfold/msa.py deleted file mode 100644 index cac456638..000000000 --- a/evoformer_openfold/msa.py +++ /dev/null @@ -1,95 +0,0 @@ -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from torch.nn import LayerNorm - -from .kernel import bias_dropout_add -from .ops import SelfAttention, Transition - - -class MSARowAttentionWithPairBias(nn.Module): - - def __init__(self, d_node, d_pair, c=32, n_head=8, p_drop=0.15): - super(MSARowAttentionWithPairBias, self).__init__() - self.d_node = d_node - self.d_pair = d_pair - self.c = c - self.n_head = n_head - self.p_drop = p_drop - - self.layernormM = LayerNorm(d_node) - self.layernormZ = LayerNorm(d_pair) - - _init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]), - std=1.0 / math.sqrt(d_pair)) - self.linear_b_weights = nn.parameter.Parameter(data=_init_weights, requires_grad=True) - - self.attention = SelfAttention(qkv_dim=d_node, - c=c, - n_head=n_head, - out_dim=d_node, - gating=True, - last_bias_fuse=True) - - self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_node,)), requires_grad=True) - - def forward(self, M_raw, Z): - ## Input projections - M = self.layernormM(M_raw) - Z = self.layernormZ(Z) - b = F.linear(Z, self.linear_b_weights) - b = b.permute(0, 3, 1, 2) - # b = rearrange(b, 'b q k h -> b h q k') - - M = self.attention(M, b) - dropout_mask = torch.ones_like(M[:, 0:1, :, :]).to(M.device).to(M.dtype) - - return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop) - - -class MSAColumnAttention(nn.Module): - - def __init__(self, d_node, c=32, n_head=8): - super(MSAColumnAttention, self).__init__() - self.d_node = d_node - self.c = c - self.n_head = n_head - - self.layernormM = LayerNorm(d_node) - self.attention = SelfAttention(qkv_dim=d_node, - c=c, - n_head=n_head, - out_dim=d_node, - gating=True) - - def forward(self, M_raw): - M = M_raw.transpose(-2, -3) - M = self.layernormM(M) - - M = self.attention(M) - - M = M.transpose(-2, -3) - return M_raw + M - - -class MSAStack(nn.Module): - - def __init__(self, d_node, d_pair, p_drop=0.15): - super(MSAStack, self).__init__() - - self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node, - d_pair=d_pair, - p_drop=p_drop) - - self.MSAColumnAttention = MSAColumnAttention(d_node=d_node) - self.MSATransition = Transition(d=d_node) - - def forward(self, node, pair): - node = self.MSARowAttentionWithPairBias(node, pair) - node = self.MSAColumnAttention(node) - node = self.MSATransition(node) - - return node diff --git a/evoformer_openfold/ops.py b/evoformer_openfold/ops.py deleted file mode 100755 index 611b7b0fe..000000000 --- a/evoformer_openfold/ops.py +++ /dev/null @@ -1,176 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from torch.nn import LayerNorm - -from .initializer import glorot_uniform_af -from .kernel import bias_sigmod_ele - - -class DropoutRowwise(nn.Module): - - def __init__(self, p): - super(DropoutRowwise, self).__init__() - self.p = p - self.dropout = nn.Dropout(p=p) - - def forward(self, x): - dropout_mask = torch.ones_like(x[:, 0:1, :, :]) - dropout_mask = self.dropout(dropout_mask) - return dropout_mask * x - - -class DropoutColumnwise(nn.Module): - - def __init__(self, p): - super(DropoutColumnwise, self).__init__() - self.p = p - self.dropout = nn.Dropout(p=p) - - def forward(self, x): - dropout_mask = torch.ones_like(x[:, :, 0:1, :]) - dropout_mask = self.dropout(dropout_mask) - return dropout_mask * x - - -class Transition(nn.Module): - - def __init__(self, d, n=4): - super(Transition, self).__init__() - self.norm = LayerNorm(d) - self.linear1 = Linear(d, n * d, initializer='relu') - self.linear2 = Linear(n * d, d, initializer='zeros') - - def forward(self, src): - x = self.norm(src) - x = self.linear2(F.relu(self.linear1(x))) - return src + x - - -class OutProductMean(nn.Module): - - def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32): - super(OutProductMean, self).__init__() - - self.layernormM = LayerNorm(n_feat) - self.linear_a = Linear(n_feat, n_feat_proj) - self.linear_b = Linear(n_feat, n_feat_proj) - - self.o_linear = Linear(n_feat_proj * n_feat_proj, - n_feat_out, - initializer='zero', - use_bias=True) - - def forward(self, M): - M = self.layernormM(M) - left_act = self.linear_a(M) - right_act = self.linear_b(M) - - O = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous() - # O = rearrange(O, 'b i j d e -> b i j (d e)') - O = O.reshape(O.shape[0], O.shape[1], O.shape[2], -1) - Z = self.o_linear(O) - - return Z - - -class Linear(nn.Linear): - """ - A Linear layer with built-in nonstandard initializations. Called just - like torch.nn.Linear. - Implements the initializers in 1.11.4, plus some additional ones found - in the code. - """ - - def __init__( - self, - feature_in: int, - feature_out: int, - initializer: str = 'linear', - use_bias: bool = True, - bias_init: float = 0., - ): - super(Linear, self).__init__(feature_in, feature_out, bias=use_bias) - - self.use_bias = use_bias - if initializer == 'linear': - glorot_uniform_af(self.weight, gain=1.0) - elif initializer == 'relu': - glorot_uniform_af(self.weight, gain=2.0) - elif initializer == 'zeros': - nn.init.zeros_(self.weight) - if self.use_bias: - with torch.no_grad(): - self.bias.fill_(bias_init) - - -class SelfAttention(nn.Module): - """ - Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors - """ - - def __init__(self, qkv_dim, c, n_head, out_dim, gating=True, last_bias_fuse=False): - super(SelfAttention, self).__init__() - self.qkv_dim = qkv_dim - self.c = c - self.n_head = n_head - self.out_dim = out_dim - self.gating = gating - self.last_bias_fuse = last_bias_fuse - - self.scaling = self.c**(-0.5) - - # self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear') - self.to_q = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) - self.to_k = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) - self.to_v = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) - - if gating: - self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,))) - self.gating_linear = Linear(qkv_dim, n_head * c, initializer='zero', use_bias=False) - - self.o_linear = Linear(n_head * c, - out_dim, - initializer='zero', - use_bias=(not last_bias_fuse)) - - def forward(self, in_data, nonbatched_bias=None): - """ - :param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim] - :param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv] - :param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv] - """ - - # qkv = self.to_qkv(in_data).chunk(3, dim=-1) - # q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv) - - q = self.to_q(in_data) - k = self.to_k(in_data) - v = self.to_v(in_data) - - # q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), - # [q, k, v]) - q, k, v = map(lambda t: t.view(t.shape[0], t.shape[1], t.shape[2], self.n_head, -1).permute(0, 1, 3, 2, 4), - [q, k, v]) - - q = q * self.scaling - - logits = torch.matmul(q, k.transpose(-1, -2)) - - if nonbatched_bias is not None: - logits += nonbatched_bias.unsqueeze(1) - weights = torch.softmax(logits, dim=-1) - # weights = softmax(logits) - - weighted_avg = torch.matmul(weights, v) - # weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)') - weighted_avg = weighted_avg.permute(0, 1, 3, 2, 4) - weighted_avg = weighted_avg.reshape(weighted_avg.shape[0], weighted_avg.shape[1], weighted_avg.shape[2], -1) - - if self.gating: - gate_values = self.gating_linear(in_data) - weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg) - - output = self.o_linear(weighted_avg) - return output diff --git a/evoformer_openfold/triangle.py b/evoformer_openfold/triangle.py deleted file mode 100644 index f479469c3..000000000 --- a/evoformer_openfold/triangle.py +++ /dev/null @@ -1,192 +0,0 @@ -import math - -import torch -import torch.nn as nn -from torch.nn import LayerNorm - -from .kernel import bias_dropout_add, bias_ele_dropout_residual -from .ops import Linear, SelfAttention, Transition - - -def permute_final_dims(tensor, inds): - zero_index = -1 * len(inds) - first_inds = list(range(len(tensor.shape[:zero_index]))) - return tensor.permute(first_inds + [zero_index + i for i in inds]) - - -class TriangleMultiplicationOutgoing(nn.Module): - - def __init__(self, d_pair, p_drop, c=128): - super(TriangleMultiplicationOutgoing, self).__init__() - self.d_pair = d_pair - self.c = c - - self.layernorm1 = LayerNorm(d_pair) - self.left_projection = Linear(d_pair, c) - self.right_projection = Linear(d_pair, c) - self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) - self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) - - self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.) - self.layernorm2 = LayerNorm(c) - self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False) - self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) - - self.p_drop = p_drop - - def forward(self, Z_raw): - Z = self.layernorm1(Z_raw) - left_proj_act = self.left_projection(Z) - right_proj_act = self.right_projection(Z) - - left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z)) - right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z)) - - g = torch.sigmoid(self.output_gate(Z)) - # p = torch.matmul( - # permute_final_dims(left_proj_act, (2, 0, 1)), - # permute_final_dims(right_proj_act, (2, 1, 0)), - # ) - # ab = permute_final_dims(p, (1, 2, 0)) - - ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act) - ab = self.output_projection(self.layernorm2(ab)) - dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) - return bias_ele_dropout_residual(ab, - self.output_bias, - g, - dropout_mask, - Z_raw, - prob=self.p_drop) - - -class TriangleMultiplicationIncoming(nn.Module): - - def __init__(self, d_pair, p_drop, c=128): - super(TriangleMultiplicationIncoming, self).__init__() - self.d_pair = d_pair - self.c = c - - self.layernorm1 = LayerNorm(d_pair) - self.left_projection = Linear(d_pair, c) - self.right_projection = Linear(d_pair, c) - self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) - self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) - - self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.) - self.layernorm2 = LayerNorm(c) - self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False) - self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) - - self.p_drop = p_drop - - def forward(self, Z_raw): - Z = self.layernorm1(Z_raw) - left_proj_act = self.left_projection(Z) - right_proj_act = self.right_projection(Z) - - left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z)) - right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z)) - - g = torch.sigmoid(self.output_gate(Z)) - # p = torch.matmul( - # permute_final_dims(left_proj_act, (2, 1, 0)), - # permute_final_dims(right_proj_act, (2, 0, 1)), - # ) - # ab = permute_final_dims(p, (1, 2, 0)) - - ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act) - ab = self.output_projection(self.layernorm2(ab)) - dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) - return bias_ele_dropout_residual(ab, - self.output_bias, - g, - dropout_mask, - Z_raw, - prob=self.p_drop) - - -class TriangleAttentionStartingNode(nn.Module): - - def __init__(self, d_pair, p_drop, c=32, n_head=4): - super(TriangleAttentionStartingNode, self).__init__() - self.d_pair = d_pair - self.c = c - self.n_head = n_head - self.p_drop = p_drop - - self.layernorm1 = LayerNorm(d_pair) - _init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]), - std=1.0 / math.sqrt(d_pair)) - self.linear_b_weights = nn.parameter.Parameter(data=_init_weights) - self.attention = SelfAttention(qkv_dim=d_pair, - c=c, - n_head=n_head, - out_dim=d_pair, - gating=True, - last_bias_fuse=True) - - self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) - - def forward(self, Z_raw): - Z = self.layernorm1(Z_raw) - b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights) - - Z = self.attention(Z, b) - - dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) - return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop) - - -class TriangleAttentionEndingNode(nn.Module): - - def __init__(self, d_pair, p_drop, c=32, n_head=4): - super(TriangleAttentionEndingNode, self).__init__() - self.d_pair = d_pair - self.c = c - self.n_head = n_head - self.p_drop = p_drop - - self.layernorm1 = LayerNorm(d_pair) - _init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]), - std=1.0 / math.sqrt(d_pair)) - self.linear_b_weights = nn.parameter.Parameter(data=_init_weights) - self.attention = SelfAttention(qkv_dim=d_pair, - c=c, - n_head=n_head, - out_dim=d_pair, - gating=True, - last_bias_fuse=True) - - self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) - - def forward(self, Z_raw): - Z = Z_raw.transpose(-2, -3) - Z = self.layernorm1(Z) - b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights) - - Z = self.attention(Z, b) - - Z = Z.transpose(-2, -3) - dropout_mask = torch.ones_like(Z[:, :, 0:1, :]).to(Z.device).to(Z.dtype) - return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop) - - -class PairStack(nn.Module): - - def __init__(self, d_pair, p_drop=0.25): - super(PairStack, self).__init__() - - self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing(d_pair, p_drop=p_drop) - self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming(d_pair, p_drop=p_drop) - self.TriangleAttentionStartingNode = TriangleAttentionStartingNode(d_pair, p_drop=p_drop) - self.TriangleAttentionEndingNode = TriangleAttentionEndingNode(d_pair, p_drop=p_drop) - self.PairTransition = Transition(d=d_pair) - - def forward(self, pair): - pair = self.TriangleMultiplicationOutgoing(pair) - pair = self.TriangleMultiplicationIncoming(pair) - pair = self.TriangleAttentionStartingNode(pair) - pair = self.TriangleAttentionEndingNode(pair) - pair = self.PairTransition(pair) - return pair diff --git a/openfold/evoformer.py b/openfold/evoformer.py index 7fbcd8a76..ffd4c9829 100644 --- a/openfold/evoformer.py +++ b/openfold/evoformer.py @@ -284,104 +284,6 @@ class EvoformerBlock(nn.Module): return m, z -class ExtraMSABlock(nn.Module): - """ - Almost identical to the standard EvoformerBlock, except in that the - ExtraMSABlock uses GlobalAttention for MSA column attention and - requires more fine-grained control over checkpointing. Separated from - its twin to preserve the TorchScript-ability of the latter. - """ - def __init__(self, - c_m: int, - c_z: int, - c_hidden_msa_att: int, - c_hidden_opm: int, - c_hidden_mul: int, - c_hidden_pair_att: int, - no_heads_msa: int, - no_heads_pair: int, - transition_n: int, - msa_dropout: float, - pair_dropout: float, - inf: float, - eps: float, - ckpt: bool, - is_multimer: bool, - ): - super(ExtraMSABlock, self).__init__() - - self.ckpt = ckpt - - self.msa_att_row = MSARowAttentionWithPairBias( - c_m=c_m, - c_z=c_z, - c_hidden=c_hidden_msa_att, - no_heads=no_heads_msa, - inf=inf, - ) - - self.msa_att_col = MSAColumnGlobalAttention( - c_in=c_m, - c_hidden=c_hidden_msa_att, - no_heads=no_heads_msa, - inf=inf, - eps=eps, - ) - - self.msa_dropout_layer = DropoutRowwise(msa_dropout) - - self.core = EvoformerBlockCore( - c_m=c_m, - c_z=c_z, - c_hidden_opm=c_hidden_opm, - c_hidden_mul=c_hidden_mul, - c_hidden_pair_att=c_hidden_pair_att, - no_heads_msa=no_heads_msa, - no_heads_pair=no_heads_pair, - transition_n=transition_n, - pair_dropout=pair_dropout, - inf=inf, - eps=eps, - ) - self.is_multimer = is_multimer - - def forward(self, - m: torch.Tensor, - z: torch.Tensor, - msa_mask: torch.Tensor, - pair_mask: torch.Tensor, - chunk_size: Optional[int] = None, - _chunk_logits: Optional[int] = 1024, - ) -> Tuple[torch.Tensor, torch.Tensor]: - m = m + self.msa_dropout_layer( - self.msa_att_row( - m.clone(), - z=z.clone(), - mask=msa_mask, - chunk_size=chunk_size, - _chunk_logits=_chunk_logits if torch.is_grad_enabled() else None, - _checkpoint_chunks= - self.ckpt if torch.is_grad_enabled() else False, - ) - ) - - def fn(m, z): - m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size) - m, z = self.core( - m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size - ) - - return m, z - - if(torch.is_grad_enabled() and self.ckpt): - checkpoint_fn = get_checkpoint_fn() - m, z = checkpoint_fn(fn, m, z) - else: - m, z = fn(m, z) - - return m, z - - class EvoformerStack(nn.Module): """ Main Evoformer trunk. @@ -527,99 +429,3 @@ class EvoformerStack(nn.Module): s = self.linear(m[..., 0, :, :]) return m, z, s - - -class ExtraMSAStack(nn.Module): - """ - Implements Algorithm 18. - """ - - def __init__(self, - c_m: int, - c_z: int, - c_hidden_msa_att: int, - c_hidden_opm: int, - c_hidden_mul: int, - c_hidden_pair_att: int, - no_heads_msa: int, - no_heads_pair: int, - no_blocks: int, - transition_n: int, - msa_dropout: float, - pair_dropout: float, - inf: float, - eps: float, - ckpt: bool, - clear_cache_between_blocks: bool = False, - is_multimer: bool = False, - **kwargs, - ): - super(ExtraMSAStack, self).__init__() - - self.clear_cache_between_blocks = clear_cache_between_blocks - self.blocks = nn.ModuleList() - for _ in range(no_blocks): - block = ExtraMSABlock( - c_m=c_m, - c_z=c_z, - c_hidden_msa_att=c_hidden_msa_att, - c_hidden_opm=c_hidden_opm, - c_hidden_mul=c_hidden_mul, - c_hidden_pair_att=c_hidden_pair_att, - no_heads_msa=no_heads_msa, - no_heads_pair=no_heads_pair, - transition_n=transition_n, - msa_dropout=msa_dropout, - pair_dropout=pair_dropout, - inf=inf, - eps=eps, - ckpt=ckpt, - is_multimer=is_multimer, - ) - self.blocks.append(block) - - def forward(self, - m: torch.Tensor, - z: torch.Tensor, - chunk_size: int, - msa_mask: Optional[torch.Tensor] = None, - pair_mask: Optional[torch.Tensor] = None, - _mask_trans: bool = True, - ) -> torch.Tensor: - """ - Args: - m: - [*, N_extra, N_res, C_m] extra MSA embedding - z: - [*, N_res, N_res, C_z] pair embedding - msa_mask: - Optional [*, N_extra, N_res] MSA mask - pair_mask: - Optional [*, N_res, N_res] pair mask - Returns: - [*, N_res, N_res, C_z] pair update - """ - #checkpoint_fn = get_checkpoint_fn() - #blocks = [ - # partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks - #] - - #def dodo(b, *args): - # torch.cuda.empty_cache() - # return b(*args) - - #blocks = [partial(dodo, b) for b in blocks] - - #for b in blocks: - # if(torch.is_grad_enabled()): - # m, z = checkpoint_fn(b, *(m, z)) - # else: - # m, z = b(m, z) - - for b in self.blocks: - m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size) - - if(self.clear_cache_between_blocks): - torch.cuda.empty_cache() - - return z \ No newline at end of file