diff --git a/evoformer_openfold/evoformer.py b/evoformer_openfold/evoformer.py deleted file mode 100644 index cfd2bb2a2..000000000 --- a/evoformer_openfold/evoformer.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch -import torch.nn as nn - -from .msa import MSAStack -from .ops import OutProductMean -from .triangle import PairStack - - -def print_memory(init_mem, text=None): - now_mem = torch.cuda.memory_allocated() / 1024 ** 2 - init_mem - max_mem = torch.cuda.max_memory_allocated() / 1024 ** 2 - init_mem - print("%s now:%.2f max:%.2f" % ("" if text is None else text, now_mem, max_mem)) - torch.cuda.reset_peak_memory_stats() - - -class EvoformerBlock(nn.Module): - - def __init__(self, d_node, d_pair): - super(EvoformerBlock, self).__init__() - - self.msa_stack = MSAStack(d_node, d_pair, p_drop=0.15) - self.communication = OutProductMean(n_feat=d_node, n_feat_out=d_pair, n_feat_proj=32) - self.pair_stack = PairStack(d_pair=d_pair) - - def forward(self, node, pair): - node = self.msa_stack(node, pair) - pair = pair + self.communication(node) - pair = self.pair_stack(pair) - return node, pair - - -class Evoformer(nn.Module): - - def __init__(self, d_node, d_pair): - super(Evoformer, self).__init__() - - self.blocks = nn.ModuleList() - for _ in range(1): - self.blocks.append(EvoformerBlock(d_node, d_pair)) - - def forward(self, node, pair): - for b in self.blocks: - node, pair = b(node, pair) - return node, pair - - -def evoformer_tiny(): - return Evoformer(d_node=64, d_pair=32) - - -def evoformer_base(): - return Evoformer(d_node=256, d_pair=128) - - -def evoformer_large(): - return Evoformer(d_node=512, d_pair=256) - - -__all__ = ['Evoformer', 'evoformer_base', 'evoformer_large'] diff --git a/evoformer_openfold/initializer.py b/evoformer_openfold/initializer.py deleted file mode 100755 index c6ce0659e..000000000 --- a/evoformer_openfold/initializer.py +++ /dev/null @@ -1,29 +0,0 @@ -import math - -import numpy as np -import torch.nn as nn - - -def glorot_uniform_af(x, gain=1.0): - """ - initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different: - In PyTorch: - [feature_out, feature_in, n_head ...] - In Jax: - [... n_head, feature_in, feature_out] - However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like: - [feature_in, n_head, feature_out] - - In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors - """ - fan_in, fan_out = x.shape[-2:] - if len(x.shape) > 2: - receptive_field_size = np.prod(x.shape[:-2]) - fan_in *= receptive_field_size - fan_out *= receptive_field_size - std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) - dev = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation - - nn.init.uniform_(x, -dev, dev) - - return x diff --git a/evoformer_openfold/kernel.py b/evoformer_openfold/kernel.py deleted file mode 100644 index 26ab5dc53..000000000 --- a/evoformer_openfold/kernel.py +++ /dev/null @@ -1,19 +0,0 @@ -import torch -import torch.nn.functional as F - - -def bias_sigmod_ele(y, bias, z): - return torch.sigmoid(y + bias) * z - - -def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor, - residual: torch.Tensor, prob: float) -> torch.Tensor: - out = (x + bias) * F.dropout(dropmask, p=prob, training=False) - out = residual + out - return out - - -def bias_ele_dropout_residual(ab: torch.Tensor, b: torch.Tensor, g: torch.Tensor, - dropout_mask: torch.Tensor, Z_raw: torch.Tensor, - prob: float) -> torch.Tensor: - return Z_raw + F.dropout(dropout_mask, p=prob, training=True) * (g * (ab + b)) \ No newline at end of file diff --git a/evoformer_openfold/msa.py b/evoformer_openfold/msa.py deleted file mode 100644 index cac456638..000000000 --- a/evoformer_openfold/msa.py +++ /dev/null @@ -1,95 +0,0 @@ -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from torch.nn import LayerNorm - -from .kernel import bias_dropout_add -from .ops import SelfAttention, Transition - - -class MSARowAttentionWithPairBias(nn.Module): - - def __init__(self, d_node, d_pair, c=32, n_head=8, p_drop=0.15): - super(MSARowAttentionWithPairBias, self).__init__() - self.d_node = d_node - self.d_pair = d_pair - self.c = c - self.n_head = n_head - self.p_drop = p_drop - - self.layernormM = LayerNorm(d_node) - self.layernormZ = LayerNorm(d_pair) - - _init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]), - std=1.0 / math.sqrt(d_pair)) - self.linear_b_weights = nn.parameter.Parameter(data=_init_weights, requires_grad=True) - - self.attention = SelfAttention(qkv_dim=d_node, - c=c, - n_head=n_head, - out_dim=d_node, - gating=True, - last_bias_fuse=True) - - self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_node,)), requires_grad=True) - - def forward(self, M_raw, Z): - ## Input projections - M = self.layernormM(M_raw) - Z = self.layernormZ(Z) - b = F.linear(Z, self.linear_b_weights) - b = b.permute(0, 3, 1, 2) - # b = rearrange(b, 'b q k h -> b h q k') - - M = self.attention(M, b) - dropout_mask = torch.ones_like(M[:, 0:1, :, :]).to(M.device).to(M.dtype) - - return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop) - - -class MSAColumnAttention(nn.Module): - - def __init__(self, d_node, c=32, n_head=8): - super(MSAColumnAttention, self).__init__() - self.d_node = d_node - self.c = c - self.n_head = n_head - - self.layernormM = LayerNorm(d_node) - self.attention = SelfAttention(qkv_dim=d_node, - c=c, - n_head=n_head, - out_dim=d_node, - gating=True) - - def forward(self, M_raw): - M = M_raw.transpose(-2, -3) - M = self.layernormM(M) - - M = self.attention(M) - - M = M.transpose(-2, -3) - return M_raw + M - - -class MSAStack(nn.Module): - - def __init__(self, d_node, d_pair, p_drop=0.15): - super(MSAStack, self).__init__() - - self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node, - d_pair=d_pair, - p_drop=p_drop) - - self.MSAColumnAttention = MSAColumnAttention(d_node=d_node) - self.MSATransition = Transition(d=d_node) - - def forward(self, node, pair): - node = self.MSARowAttentionWithPairBias(node, pair) - node = self.MSAColumnAttention(node) - node = self.MSATransition(node) - - return node diff --git a/evoformer_openfold/ops.py b/evoformer_openfold/ops.py deleted file mode 100755 index 611b7b0fe..000000000 --- a/evoformer_openfold/ops.py +++ /dev/null @@ -1,176 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from torch.nn import LayerNorm - -from .initializer import glorot_uniform_af -from .kernel import bias_sigmod_ele - - -class DropoutRowwise(nn.Module): - - def __init__(self, p): - super(DropoutRowwise, self).__init__() - self.p = p - self.dropout = nn.Dropout(p=p) - - def forward(self, x): - dropout_mask = torch.ones_like(x[:, 0:1, :, :]) - dropout_mask = self.dropout(dropout_mask) - return dropout_mask * x - - -class DropoutColumnwise(nn.Module): - - def __init__(self, p): - super(DropoutColumnwise, self).__init__() - self.p = p - self.dropout = nn.Dropout(p=p) - - def forward(self, x): - dropout_mask = torch.ones_like(x[:, :, 0:1, :]) - dropout_mask = self.dropout(dropout_mask) - return dropout_mask * x - - -class Transition(nn.Module): - - def __init__(self, d, n=4): - super(Transition, self).__init__() - self.norm = LayerNorm(d) - self.linear1 = Linear(d, n * d, initializer='relu') - self.linear2 = Linear(n * d, d, initializer='zeros') - - def forward(self, src): - x = self.norm(src) - x = self.linear2(F.relu(self.linear1(x))) - return src + x - - -class OutProductMean(nn.Module): - - def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32): - super(OutProductMean, self).__init__() - - self.layernormM = LayerNorm(n_feat) - self.linear_a = Linear(n_feat, n_feat_proj) - self.linear_b = Linear(n_feat, n_feat_proj) - - self.o_linear = Linear(n_feat_proj * n_feat_proj, - n_feat_out, - initializer='zero', - use_bias=True) - - def forward(self, M): - M = self.layernormM(M) - left_act = self.linear_a(M) - right_act = self.linear_b(M) - - O = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous() - # O = rearrange(O, 'b i j d e -> b i j (d e)') - O = O.reshape(O.shape[0], O.shape[1], O.shape[2], -1) - Z = self.o_linear(O) - - return Z - - -class Linear(nn.Linear): - """ - A Linear layer with built-in nonstandard initializations. Called just - like torch.nn.Linear. - Implements the initializers in 1.11.4, plus some additional ones found - in the code. - """ - - def __init__( - self, - feature_in: int, - feature_out: int, - initializer: str = 'linear', - use_bias: bool = True, - bias_init: float = 0., - ): - super(Linear, self).__init__(feature_in, feature_out, bias=use_bias) - - self.use_bias = use_bias - if initializer == 'linear': - glorot_uniform_af(self.weight, gain=1.0) - elif initializer == 'relu': - glorot_uniform_af(self.weight, gain=2.0) - elif initializer == 'zeros': - nn.init.zeros_(self.weight) - if self.use_bias: - with torch.no_grad(): - self.bias.fill_(bias_init) - - -class SelfAttention(nn.Module): - """ - Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors - """ - - def __init__(self, qkv_dim, c, n_head, out_dim, gating=True, last_bias_fuse=False): - super(SelfAttention, self).__init__() - self.qkv_dim = qkv_dim - self.c = c - self.n_head = n_head - self.out_dim = out_dim - self.gating = gating - self.last_bias_fuse = last_bias_fuse - - self.scaling = self.c**(-0.5) - - # self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear') - self.to_q = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) - self.to_k = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) - self.to_v = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) - - if gating: - self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,))) - self.gating_linear = Linear(qkv_dim, n_head * c, initializer='zero', use_bias=False) - - self.o_linear = Linear(n_head * c, - out_dim, - initializer='zero', - use_bias=(not last_bias_fuse)) - - def forward(self, in_data, nonbatched_bias=None): - """ - :param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim] - :param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv] - :param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv] - """ - - # qkv = self.to_qkv(in_data).chunk(3, dim=-1) - # q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv) - - q = self.to_q(in_data) - k = self.to_k(in_data) - v = self.to_v(in_data) - - # q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), - # [q, k, v]) - q, k, v = map(lambda t: t.view(t.shape[0], t.shape[1], t.shape[2], self.n_head, -1).permute(0, 1, 3, 2, 4), - [q, k, v]) - - q = q * self.scaling - - logits = torch.matmul(q, k.transpose(-1, -2)) - - if nonbatched_bias is not None: - logits += nonbatched_bias.unsqueeze(1) - weights = torch.softmax(logits, dim=-1) - # weights = softmax(logits) - - weighted_avg = torch.matmul(weights, v) - # weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)') - weighted_avg = weighted_avg.permute(0, 1, 3, 2, 4) - weighted_avg = weighted_avg.reshape(weighted_avg.shape[0], weighted_avg.shape[1], weighted_avg.shape[2], -1) - - if self.gating: - gate_values = self.gating_linear(in_data) - weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg) - - output = self.o_linear(weighted_avg) - return output diff --git a/evoformer_openfold/triangle.py b/evoformer_openfold/triangle.py deleted file mode 100644 index f479469c3..000000000 --- a/evoformer_openfold/triangle.py +++ /dev/null @@ -1,192 +0,0 @@ -import math - -import torch -import torch.nn as nn -from torch.nn import LayerNorm - -from .kernel import bias_dropout_add, bias_ele_dropout_residual -from .ops import Linear, SelfAttention, Transition - - -def permute_final_dims(tensor, inds): - zero_index = -1 * len(inds) - first_inds = list(range(len(tensor.shape[:zero_index]))) - return tensor.permute(first_inds + [zero_index + i for i in inds]) - - -class TriangleMultiplicationOutgoing(nn.Module): - - def __init__(self, d_pair, p_drop, c=128): - super(TriangleMultiplicationOutgoing, self).__init__() - self.d_pair = d_pair - self.c = c - - self.layernorm1 = LayerNorm(d_pair) - self.left_projection = Linear(d_pair, c) - self.right_projection = Linear(d_pair, c) - self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) - self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) - - self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.) - self.layernorm2 = LayerNorm(c) - self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False) - self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) - - self.p_drop = p_drop - - def forward(self, Z_raw): - Z = self.layernorm1(Z_raw) - left_proj_act = self.left_projection(Z) - right_proj_act = self.right_projection(Z) - - left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z)) - right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z)) - - g = torch.sigmoid(self.output_gate(Z)) - # p = torch.matmul( - # permute_final_dims(left_proj_act, (2, 0, 1)), - # permute_final_dims(right_proj_act, (2, 1, 0)), - # ) - # ab = permute_final_dims(p, (1, 2, 0)) - - ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act) - ab = self.output_projection(self.layernorm2(ab)) - dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) - return bias_ele_dropout_residual(ab, - self.output_bias, - g, - dropout_mask, - Z_raw, - prob=self.p_drop) - - -class TriangleMultiplicationIncoming(nn.Module): - - def __init__(self, d_pair, p_drop, c=128): - super(TriangleMultiplicationIncoming, self).__init__() - self.d_pair = d_pair - self.c = c - - self.layernorm1 = LayerNorm(d_pair) - self.left_projection = Linear(d_pair, c) - self.right_projection = Linear(d_pair, c) - self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) - self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) - - self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.) - self.layernorm2 = LayerNorm(c) - self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False) - self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) - - self.p_drop = p_drop - - def forward(self, Z_raw): - Z = self.layernorm1(Z_raw) - left_proj_act = self.left_projection(Z) - right_proj_act = self.right_projection(Z) - - left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z)) - right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z)) - - g = torch.sigmoid(self.output_gate(Z)) - # p = torch.matmul( - # permute_final_dims(left_proj_act, (2, 1, 0)), - # permute_final_dims(right_proj_act, (2, 0, 1)), - # ) - # ab = permute_final_dims(p, (1, 2, 0)) - - ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act) - ab = self.output_projection(self.layernorm2(ab)) - dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) - return bias_ele_dropout_residual(ab, - self.output_bias, - g, - dropout_mask, - Z_raw, - prob=self.p_drop) - - -class TriangleAttentionStartingNode(nn.Module): - - def __init__(self, d_pair, p_drop, c=32, n_head=4): - super(TriangleAttentionStartingNode, self).__init__() - self.d_pair = d_pair - self.c = c - self.n_head = n_head - self.p_drop = p_drop - - self.layernorm1 = LayerNorm(d_pair) - _init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]), - std=1.0 / math.sqrt(d_pair)) - self.linear_b_weights = nn.parameter.Parameter(data=_init_weights) - self.attention = SelfAttention(qkv_dim=d_pair, - c=c, - n_head=n_head, - out_dim=d_pair, - gating=True, - last_bias_fuse=True) - - self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) - - def forward(self, Z_raw): - Z = self.layernorm1(Z_raw) - b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights) - - Z = self.attention(Z, b) - - dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) - return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop) - - -class TriangleAttentionEndingNode(nn.Module): - - def __init__(self, d_pair, p_drop, c=32, n_head=4): - super(TriangleAttentionEndingNode, self).__init__() - self.d_pair = d_pair - self.c = c - self.n_head = n_head - self.p_drop = p_drop - - self.layernorm1 = LayerNorm(d_pair) - _init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]), - std=1.0 / math.sqrt(d_pair)) - self.linear_b_weights = nn.parameter.Parameter(data=_init_weights) - self.attention = SelfAttention(qkv_dim=d_pair, - c=c, - n_head=n_head, - out_dim=d_pair, - gating=True, - last_bias_fuse=True) - - self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) - - def forward(self, Z_raw): - Z = Z_raw.transpose(-2, -3) - Z = self.layernorm1(Z) - b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights) - - Z = self.attention(Z, b) - - Z = Z.transpose(-2, -3) - dropout_mask = torch.ones_like(Z[:, :, 0:1, :]).to(Z.device).to(Z.dtype) - return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop) - - -class PairStack(nn.Module): - - def __init__(self, d_pair, p_drop=0.25): - super(PairStack, self).__init__() - - self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing(d_pair, p_drop=p_drop) - self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming(d_pair, p_drop=p_drop) - self.TriangleAttentionStartingNode = TriangleAttentionStartingNode(d_pair, p_drop=p_drop) - self.TriangleAttentionEndingNode = TriangleAttentionEndingNode(d_pair, p_drop=p_drop) - self.PairTransition = Transition(d=d_pair) - - def forward(self, pair): - pair = self.TriangleMultiplicationOutgoing(pair) - pair = self.TriangleMultiplicationIncoming(pair) - pair = self.TriangleAttentionStartingNode(pair) - pair = self.TriangleAttentionEndingNode(pair) - pair = self.PairTransition(pair) - return pair diff --git a/openfold/checkpointing.py b/openfold/checkpointing.py new file mode 100644 index 000000000..83e77c638 --- /dev/null +++ b/openfold/checkpointing.py @@ -0,0 +1,84 @@ +# Copyright 2021 AlQuraishi Laboratory +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.utils.checkpoint +from typing import Any, Tuple, List, Callable, Optional + + +BLOCK_ARG = Any +BLOCK_ARGS = List[BLOCK_ARG] + + +def get_checkpoint_fn(): + checkpoint = torch.utils.checkpoint.checkpoint + + return checkpoint + + +@torch.jit.ignore +def checkpoint_blocks( + blocks: List[Callable], + args: BLOCK_ARGS, + blocks_per_ckpt: Optional[int], +) -> BLOCK_ARGS: + """ + Chunk a list of blocks and run each chunk with activation + checkpointing. We define a "block" as a callable whose only inputs are + the outputs of the previous block. + + Implements Subsection 1.11.8 + + Args: + blocks: + List of blocks + args: + Tuple of arguments for the first block. + blocks_per_ckpt: + Size of each chunk. A higher value corresponds to fewer + checkpoints, and trades memory for speed. If None, no checkpointing + is performed. + Returns: + The output of the final block + """ + def wrap(a): + return (a,) if type(a) is not tuple else a + + def exec(b, a): + for block in b: + a = wrap(block(*a)) + return a + + def chunker(s, e): + def exec_sliced(*a): + return exec(blocks[s:e], a) + + return exec_sliced + + # Avoids mishaps when the blocks take just one argument + args = wrap(args) + + if blocks_per_ckpt is None: + return exec(blocks, args) + elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks): + raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)") + + checkpoint = get_checkpoint_fn() + + for s in range(0, len(blocks), blocks_per_ckpt): + e = s + blocks_per_ckpt + args = checkpoint(chunker(s, e), *args) + args = wrap(args) + + return args diff --git a/openfold/dropout.py b/openfold/dropout.py new file mode 100644 index 000000000..651b9775e --- /dev/null +++ b/openfold/dropout.py @@ -0,0 +1,78 @@ +# Copyright 2021 AlQuraishi Laboratory +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn +from functools import partialmethod +from typing import Union, List + + +class Dropout(nn.Module): + """ + Implementation of dropout with the ability to share the dropout mask + along a particular dimension. + + If not in training mode, this module computes the identity function. + """ + + def __init__(self, r: float, batch_dim: Union[int, List[int]]): + """ + Args: + r: + Dropout rate + batch_dim: + Dimension(s) along which the dropout mask is shared + """ + super(Dropout, self).__init__() + + self.r = r + if type(batch_dim) == int: + batch_dim = [batch_dim] + self.batch_dim = batch_dim + self.dropout = nn.Dropout(self.r) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + Tensor to which dropout is applied. Can have any shape + compatible with self.batch_dim + """ + shape = list(x.shape) + if self.batch_dim is not None: + for bd in self.batch_dim: + shape[bd] = 1 + mask = x.new_ones(shape) + mask = self.dropout(mask) + x *= mask + return x + + +class DropoutRowwise(Dropout): + """ + Convenience class for rowwise dropout as described in subsection + 1.11.6. + """ + + __init__ = partialmethod(Dropout.__init__, batch_dim=-3) + + +class DropoutColumnwise(Dropout): + """ + Convenience class for columnwise dropout as described in subsection + 1.11.6. + """ + + __init__ = partialmethod(Dropout.__init__, batch_dim=-2) diff --git a/openfold/evoformer.py b/openfold/evoformer.py new file mode 100644 index 000000000..21e422b04 --- /dev/null +++ b/openfold/evoformer.py @@ -0,0 +1,636 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import torch +import torch.nn as nn +from typing import Tuple, Optional +from functools import partial + +from openfold.primitives import Linear, LayerNorm +from openfold.dropout import DropoutRowwise, DropoutColumnwise +from openfold.msa import ( + MSARowAttentionWithPairBias, + MSAColumnAttention, + MSAColumnGlobalAttention, +) +from openfold.outer_product_mean import OuterProductMean +from openfold.pair_transition import PairTransition +from openfold.triangular_attention import ( + TriangleAttentionStartingNode, + TriangleAttentionEndingNode, +) +from openfold.triangular_multiplicative_update import ( + TriangleMultiplicationOutgoing, + TriangleMultiplicationIncoming, +) +from openfold.checkpointing import checkpoint_blocks, get_checkpoint_fn +from openfold.tensor_utils import chunk_layer + + +class MSATransition(nn.Module): + """ + Feed-forward network applied to MSA activations after attention. + + Implements Algorithm 9 + """ + def __init__(self, c_m, n): + """ + Args: + c_m: + MSA channel dimension + n: + Factor multiplied to c_m to obtain the hidden channel + dimension + """ + super(MSATransition, self).__init__() + + self.c_m = c_m + self.n = n + + self.layer_norm = LayerNorm(self.c_m) + self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu") + self.relu = nn.ReLU() + self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final") + + def _transition(self, m, mask): + m = self.linear_1(m) + m = self.relu(m) + m = self.linear_2(m) * mask + return m + + @torch.jit.ignore + def _chunk(self, + m: torch.Tensor, + mask: torch.Tensor, + chunk_size: int, + ) -> torch.Tensor: + return chunk_layer( + self._transition, + {"m": m, "mask": mask}, + chunk_size=chunk_size, + no_batch_dims=len(m.shape[:-2]), + ) + + def forward( + self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + """ + Args: + m: + [*, N_seq, N_res, C_m] MSA activation + mask: + [*, N_seq, N_res, C_m] MSA mask + Returns: + m: + [*, N_seq, N_res, C_m] MSA activation update + """ + + # DISCREPANCY: DeepMind forgets to apply the MSA mask here. + if mask is None: + mask = m.new_ones(m.shape[:-1]) + + # [*, N_seq, N_res, 1] + mask = mask.unsqueeze(-1) + + m = self.layer_norm(m) + + if chunk_size is not None: + m = self._chunk(m, mask, chunk_size) + else: + m = self._transition(m, mask) + + return m + + +class EvoformerBlockCore(nn.Module): + def __init__( + self, + c_m: int, + c_z: int, + c_hidden_opm: int, + c_hidden_mul: int, + c_hidden_pair_att: int, + no_heads_msa: int, + no_heads_pair: int, + transition_n: int, + pair_dropout: float, + inf: float, + eps: float, + _is_extra_msa_stack: bool = False, + is_multimer: bool = False, + ): + super(EvoformerBlockCore, self).__init__() + self.is_multimer = is_multimer + self.msa_transition = MSATransition( + c_m=c_m, + n=transition_n, + ) + + self.outer_product_mean = OuterProductMean( + c_m, + c_z, + c_hidden_opm, + ) + + self.tri_mul_out = TriangleMultiplicationOutgoing( + c_z, + c_hidden_mul, + ) + self.tri_mul_in = TriangleMultiplicationIncoming( + c_z, + c_hidden_mul, + ) + + self.tri_att_start = TriangleAttentionStartingNode( + c_z, + c_hidden_pair_att, + no_heads_pair, + inf=inf, + ) + self.tri_att_end = TriangleAttentionEndingNode( + c_z, + c_hidden_pair_att, + no_heads_pair, + inf=inf, + ) + + self.pair_transition = PairTransition( + c_z, + transition_n, + ) + + self.ps_dropout_row_layer = DropoutRowwise(pair_dropout) + self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout) + + def forward( + self, + m: torch.Tensor, + z: torch.Tensor, + msa_mask: torch.Tensor, + pair_mask: torch.Tensor, + chunk_size: Optional[int] = None, + _mask_trans: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # DeepMind doesn't mask these transitions in the source, so _mask_trans + # should be disabled to better approximate the exact activations of + # the original. + msa_trans_mask = msa_mask if _mask_trans else None + pair_trans_mask = pair_mask if _mask_trans else None + + m = m + self.msa_transition( + m, mask=msa_trans_mask, chunk_size=chunk_size + ) + z = z + self.outer_product_mean( + m, mask=msa_mask, chunk_size=chunk_size + ) + z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask)) + z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask)) + z = z + self.ps_dropout_row_layer( + self.tri_att_start(z, mask=pair_mask, chunk_size=chunk_size) + ) + z = z + self.ps_dropout_col_layer( + self.tri_att_end(z, mask=pair_mask, chunk_size=chunk_size) + ) + z = z + self.pair_transition( + z, mask=pair_trans_mask, chunk_size=chunk_size + ) + + return m, z + + +class EvoformerBlock(nn.Module): + def __init__(self, + c_m: int, + c_z: int, + c_hidden_msa_att: int, + c_hidden_opm: int, + c_hidden_mul: int, + c_hidden_pair_att: int, + no_heads_msa: int, + no_heads_pair: int, + transition_n: int, + msa_dropout: float, + pair_dropout: float, + inf: float, + eps: float, + is_multimer: bool, + ): + super(EvoformerBlock, self).__init__() + + self.msa_att_row = MSARowAttentionWithPairBias( + c_m=c_m, + c_z=c_z, + c_hidden=c_hidden_msa_att, + no_heads=no_heads_msa, + inf=inf, + ) + + self.msa_att_col = MSAColumnAttention( + c_m, + c_hidden_msa_att, + no_heads_msa, + inf=inf, + ) + + self.msa_dropout_layer = DropoutRowwise(msa_dropout) + + self.core = EvoformerBlockCore( + c_m=c_m, + c_z=c_z, + c_hidden_opm=c_hidden_opm, + c_hidden_mul=c_hidden_mul, + c_hidden_pair_att=c_hidden_pair_att, + no_heads_msa=no_heads_msa, + no_heads_pair=no_heads_pair, + transition_n=transition_n, + pair_dropout=pair_dropout, + inf=inf, + eps=eps, + ) + + self.outer_product_mean = OuterProductMean( + c_m, + c_z, + c_hidden_opm, + ) + self.is_multimer = is_multimer + + def forward(self, + m: torch.Tensor, + z: torch.Tensor, + msa_mask: torch.Tensor, + pair_mask: torch.Tensor, + chunk_size: Optional[int] = None, + _mask_trans: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + m = m + self.msa_dropout_layer( + self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size) + ) + m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size) + m, z = self.core( + m, + z, + msa_mask=msa_mask, + pair_mask=pair_mask, + chunk_size=chunk_size, + _mask_trans=_mask_trans, + ) + + return m, z + + +class ExtraMSABlock(nn.Module): + """ + Almost identical to the standard EvoformerBlock, except in that the + ExtraMSABlock uses GlobalAttention for MSA column attention and + requires more fine-grained control over checkpointing. Separated from + its twin to preserve the TorchScript-ability of the latter. + """ + def __init__(self, + c_m: int, + c_z: int, + c_hidden_msa_att: int, + c_hidden_opm: int, + c_hidden_mul: int, + c_hidden_pair_att: int, + no_heads_msa: int, + no_heads_pair: int, + transition_n: int, + msa_dropout: float, + pair_dropout: float, + inf: float, + eps: float, + ckpt: bool, + is_multimer: bool, + ): + super(ExtraMSABlock, self).__init__() + + self.ckpt = ckpt + + self.msa_att_row = MSARowAttentionWithPairBias( + c_m=c_m, + c_z=c_z, + c_hidden=c_hidden_msa_att, + no_heads=no_heads_msa, + inf=inf, + ) + + self.msa_att_col = MSAColumnGlobalAttention( + c_in=c_m, + c_hidden=c_hidden_msa_att, + no_heads=no_heads_msa, + inf=inf, + eps=eps, + ) + + self.msa_dropout_layer = DropoutRowwise(msa_dropout) + + self.core = EvoformerBlockCore( + c_m=c_m, + c_z=c_z, + c_hidden_opm=c_hidden_opm, + c_hidden_mul=c_hidden_mul, + c_hidden_pair_att=c_hidden_pair_att, + no_heads_msa=no_heads_msa, + no_heads_pair=no_heads_pair, + transition_n=transition_n, + pair_dropout=pair_dropout, + inf=inf, + eps=eps, + ) + self.is_multimer = is_multimer + + def forward(self, + m: torch.Tensor, + z: torch.Tensor, + msa_mask: torch.Tensor, + pair_mask: torch.Tensor, + chunk_size: Optional[int] = None, + _chunk_logits: Optional[int] = 1024, + ) -> Tuple[torch.Tensor, torch.Tensor]: + m = m + self.msa_dropout_layer( + self.msa_att_row( + m.clone(), + z=z.clone(), + mask=msa_mask, + chunk_size=chunk_size, + _chunk_logits=_chunk_logits if torch.is_grad_enabled() else None, + _checkpoint_chunks= + self.ckpt if torch.is_grad_enabled() else False, + ) + ) + + def fn(m, z): + m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size) + m, z = self.core( + m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size + ) + + return m, z + + if(torch.is_grad_enabled() and self.ckpt): + checkpoint_fn = get_checkpoint_fn() + m, z = checkpoint_fn(fn, m, z) + else: + m, z = fn(m, z) + + return m, z + + +class EvoformerStack(nn.Module): + """ + Main Evoformer trunk. + + Implements Algorithm 6. + """ + + def __init__( + self, + c_m: int, + c_z: int, + c_hidden_msa_att: int, + c_hidden_opm: int, + c_hidden_mul: int, + c_hidden_pair_att: int, + c_s: int, + no_heads_msa: int, + no_heads_pair: int, + no_blocks: int, + transition_n: int, + msa_dropout: float, + pair_dropout: float, + blocks_per_ckpt: int, + inf: float, + eps: float, + clear_cache_between_blocks: bool = False, + is_multimer: bool = False, + **kwargs, + ): + """ + Args: + c_m: + MSA channel dimension + c_z: + Pair channel dimension + c_hidden_msa_att: + Hidden dimension in MSA attention + c_hidden_opm: + Hidden dimension in outer product mean module + c_hidden_mul: + Hidden dimension in multiplicative updates + c_hidden_pair_att: + Hidden dimension in triangular attention + c_s: + Channel dimension of the output "single" embedding + no_heads_msa: + Number of heads used for MSA attention + no_heads_pair: + Number of heads used for pair attention + no_blocks: + Number of Evoformer blocks in the stack + transition_n: + Factor by which to multiply c_m to obtain the MSATransition + hidden dimension + msa_dropout: + Dropout rate for MSA activations + pair_dropout: + Dropout used for pair activations + blocks_per_ckpt: + Number of Evoformer blocks in each activation checkpoint + clear_cache_between_blocks: + Whether to clear CUDA's GPU memory cache between blocks of the + stack. Slows down each block but can reduce fragmentation + """ + super(EvoformerStack, self).__init__() + + self.blocks_per_ckpt = blocks_per_ckpt + self.clear_cache_between_blocks = clear_cache_between_blocks + + self.blocks = nn.ModuleList() + + for _ in range(no_blocks): + block = EvoformerBlock( + c_m=c_m, + c_z=c_z, + c_hidden_msa_att=c_hidden_msa_att, + c_hidden_opm=c_hidden_opm, + c_hidden_mul=c_hidden_mul, + c_hidden_pair_att=c_hidden_pair_att, + no_heads_msa=no_heads_msa, + no_heads_pair=no_heads_pair, + transition_n=transition_n, + msa_dropout=msa_dropout, + pair_dropout=pair_dropout, + inf=inf, + eps=eps, + is_multimer=is_multimer, + ) + self.blocks.append(block) + + self.linear = Linear(c_m, c_s) + + def forward(self, + m: torch.Tensor, + z: torch.Tensor, + msa_mask: torch.Tensor, + pair_mask: torch.Tensor, + chunk_size: int, + _mask_trans: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + m: + [*, N_seq, N_res, C_m] MSA embedding + z: + [*, N_res, N_res, C_z] pair embedding + msa_mask: + [*, N_seq, N_res] MSA mask + pair_mask: + [*, N_res, N_res] pair mask + Returns: + m: + [*, N_seq, N_res, C_m] MSA embedding + z: + [*, N_res, N_res, C_z] pair embedding + s: + [*, N_res, C_s] single embedding (or None if extra MSA stack) + """ + blocks = [ + partial( + b, + msa_mask=msa_mask, + pair_mask=pair_mask, + chunk_size=chunk_size, + _mask_trans=_mask_trans, + ) + for b in self.blocks + ] + + if(self.clear_cache_between_blocks): + def block_with_cache_clear(block, *args): + torch.cuda.empty_cache() + return block(*args) + + blocks = [partial(block_with_cache_clear, b) for b in blocks] + + m, z = checkpoint_blocks( + blocks, + args=(m, z), + blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, + ) + + s = self.linear(m[..., 0, :, :]) + + return m, z, s + + +class ExtraMSAStack(nn.Module): + """ + Implements Algorithm 18. + """ + + def __init__(self, + c_m: int, + c_z: int, + c_hidden_msa_att: int, + c_hidden_opm: int, + c_hidden_mul: int, + c_hidden_pair_att: int, + no_heads_msa: int, + no_heads_pair: int, + no_blocks: int, + transition_n: int, + msa_dropout: float, + pair_dropout: float, + inf: float, + eps: float, + ckpt: bool, + clear_cache_between_blocks: bool = False, + is_multimer: bool = False, + **kwargs, + ): + super(ExtraMSAStack, self).__init__() + + self.clear_cache_between_blocks = clear_cache_between_blocks + self.blocks = nn.ModuleList() + for _ in range(no_blocks): + block = ExtraMSABlock( + c_m=c_m, + c_z=c_z, + c_hidden_msa_att=c_hidden_msa_att, + c_hidden_opm=c_hidden_opm, + c_hidden_mul=c_hidden_mul, + c_hidden_pair_att=c_hidden_pair_att, + no_heads_msa=no_heads_msa, + no_heads_pair=no_heads_pair, + transition_n=transition_n, + msa_dropout=msa_dropout, + pair_dropout=pair_dropout, + inf=inf, + eps=eps, + ckpt=ckpt, + is_multimer=is_multimer, + ) + self.blocks.append(block) + + def forward(self, + m: torch.Tensor, + z: torch.Tensor, + chunk_size: int, + msa_mask: Optional[torch.Tensor] = None, + pair_mask: Optional[torch.Tensor] = None, + _mask_trans: bool = True, + ) -> torch.Tensor: + """ + Args: + m: + [*, N_extra, N_res, C_m] extra MSA embedding + z: + [*, N_res, N_res, C_z] pair embedding + msa_mask: + Optional [*, N_extra, N_res] MSA mask + pair_mask: + Optional [*, N_res, N_res] pair mask + Returns: + [*, N_res, N_res, C_z] pair update + """ + #checkpoint_fn = get_checkpoint_fn() + #blocks = [ + # partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks + #] + + #def dodo(b, *args): + # torch.cuda.empty_cache() + # return b(*args) + + #blocks = [partial(dodo, b) for b in blocks] + + #for b in blocks: + # if(torch.is_grad_enabled()): + # m, z = checkpoint_fn(b, *(m, z)) + # else: + # m, z = b(m, z) + + for b in self.blocks: + m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size) + + if(self.clear_cache_between_blocks): + torch.cuda.empty_cache() + + return z \ No newline at end of file diff --git a/openfold/msa.py b/openfold/msa.py new file mode 100644 index 000000000..172b26def --- /dev/null +++ b/openfold/msa.py @@ -0,0 +1,392 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import torch +import torch.nn as nn +from typing import Optional, List, Tuple + +from openfold.primitives import ( + Linear, + LayerNorm, + Attention, + GlobalAttention, + _attention_chunked_trainable, +) +from openfold.checkpointing import get_checkpoint_fn +from openfold.tensor_utils import ( + chunk_layer, + permute_final_dims, + flatten_final_dims, +) + + +class MSAAttention(nn.Module): + def __init__( + self, + c_in, + c_hidden, + no_heads, + pair_bias=False, + c_z=None, + inf=1e9, + ): + """ + Args: + c_in: + Input channel dimension + c_hidden: + Per-head hidden channel dimension + no_heads: + Number of attention heads + pair_bias: + Whether to use pair embedding bias + c_z: + Pair embedding channel dimension. Ignored unless pair_bias + is true + inf: + A large number to be used in computing the attention mask + """ + super(MSAAttention, self).__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_heads = no_heads + self.pair_bias = pair_bias + self.c_z = c_z + self.inf = inf + + self.layer_norm_m = LayerNorm(self.c_in) + + self.layer_norm_z = None + self.linear_z = None + if self.pair_bias: + self.layer_norm_z = LayerNorm(self.c_z) + self.linear_z = Linear( + self.c_z, self.no_heads, bias=False, init="normal" + ) + + self.mha = Attention( + self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads + ) + + @torch.jit.ignore + def _chunk(self, + m: torch.Tensor, + biases: List[torch.Tensor], + chunk_size: int, + ) -> torch.Tensor: + return chunk_layer( + self.mha, + {"q_x": m, "kv_x": m, "biases": biases}, + chunk_size=chunk_size, + no_batch_dims=len(m.shape[:-2]), + ) + + def _prep_inputs(self, + m: torch.Tensor, + z: Optional[torch.Tensor], + mask: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # [*, N_seq, N_res, C_m] + m = self.layer_norm_m(m) + + n_seq, n_res = m.shape[-3:-1] + if mask is None: + # [*, N_seq, N_res] + mask = m.new_ones( + m.shape[:-3] + (n_seq, n_res), + ) + + # [*, N_seq, 1, 1, N_res] + mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] + + # This step simply returns a larger view of the bias, and does not + # consume additional memory. + # [*, N_seq, no_heads, N_res, N_res] + #bias = bias.expand( + # ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1) + #) + + if (self.pair_bias and + z is not None and # For the + self.layer_norm_z is not None and # benefit of + self.linear_z is not None # TorchScript + ): + # [*, N_res, N_res, C_z] + z = self.layer_norm_z(z) + + # [*, N_res, N_res, no_heads] + z = self.linear_z(z) + + # [*, 1, no_heads, N_res, N_res] + z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4) + + return m, mask_bias, z + + @torch.jit.ignore + def _chunked_msa_attn(self, + m: torch.Tensor, + z: Optional[torch.Tensor], + mask: Optional[torch.Tensor], + chunk_logits: int, + checkpoint: bool, + ) -> torch.Tensor: + MSA_DIM = -4 + + def _get_qkv(m, z): + m, mask_bias, z = self._prep_inputs(m, z, mask) + q, k, v = self.mha._prep_qkv(m, m) + return m, q, k, v, mask_bias, z + + checkpoint_fn = get_checkpoint_fn() + + if(torch.is_grad_enabled() and checkpoint): + m, q, k, v, mask_bias, z = checkpoint_fn(_get_qkv, m, z) + else: + m, q, k, v, mask_bias, z = _get_qkv(m, z) + + o = _attention_chunked_trainable( + query=q, + key=k, + value=v, + biases=[mask_bias, z], + chunk_size=chunk_logits, + chunk_dim=MSA_DIM, + checkpoint=checkpoint, + ) + + if(torch.is_grad_enabled() and checkpoint): + # Storing an additional m here is far from ideal + m = checkpoint_fn(self.mha._wrap_up, o, m) + else: + m = self.mha._wrap_up(o, m) + + return m + + def forward(self, + m: torch.Tensor, + z: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + _chunk_logits: Optional[int] = None, + _checkpoint_chunks: Optional[bool] = None, + ) -> torch.Tensor: + """ + Args: + m: + [*, N_seq, N_res, C_m] MSA embedding + z: + [*, N_res, N_res, C_z] pair embedding. Required only if + pair_bias is True + mask: + [*, N_seq, N_res] MSA mask + chunk_size: + Size of chunks into which the inputs are split along their + batch dimensions. A low value decreases memory overhead at the + cost of slower execution. Chunking is not performed by default. + + """ + if(_chunk_logits is not None): + return self._chunked_msa_attn( + m=m, z=z, mask=mask, + chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks + ) + + m, mask_bias, z = self._prep_inputs(m, z, mask) + + biases = [mask_bias] + if(z is not None): + biases.append(z) + + if chunk_size is not None: + m = self._chunk(m, biases, chunk_size) + else: + m = self.mha( + q_x=m, + kv_x=m, + biases=biases + ) + + return m + + +class MSARowAttentionWithPairBias(MSAAttention): + """ + Implements Algorithm 7. + """ + + def __init__(self, c_m, c_z, c_hidden, no_heads, inf=1e9): + """ + Args: + c_m: + Input channel dimension + c_z: + Pair embedding channel dimension + c_hidden: + Per-head hidden channel dimension + no_heads: + Number of attention heads + inf: + Large number used to construct attention masks + """ + super(MSARowAttentionWithPairBias, self).__init__( + c_m, + c_hidden, + no_heads, + pair_bias=True, + c_z=c_z, + inf=inf, + ) + + +class MSAColumnAttention(nn.Module): + """ + Implements Algorithm 8. + + By rights, this should also be a subclass of MSAAttention. Alas, + most inheritance isn't supported by TorchScript. + """ + + def __init__(self, c_m, c_hidden, no_heads, inf=1e9): + """ + Args: + c_m: + MSA channel dimension + c_hidden: + Per-head hidden channel dimension + no_heads: + Number of attention heads + inf: + Large number used to construct attention masks + """ + super(MSAColumnAttention, self).__init__() + + self.c_m = c_m + self.c_hidden = c_hidden + self.no_heads = no_heads + self.inf = inf + + self._msa_att = MSAAttention( + c_in=c_m, + c_hidden=c_hidden, + no_heads=no_heads, + pair_bias=False, + c_z=None, + inf=inf, + ) + + def forward(self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None + ) -> torch.Tensor: + """ + Args: + m: + [*, N_seq, N_res, C_m] MSA embedding + mask: + [*, N_seq, N_res] MSA mask + chunk_size: + Size of chunks into which the inputs are split along their + batch dimensions. A low value decreases memory overhead at the + cost of slower execution. Chunking is not performed by default. + """ + # [*, N_res, N_seq, C_in] + m = m.transpose(-2, -3) + if mask is not None: + mask = mask.transpose(-1, -2) + + m = self._msa_att(m, mask=mask, chunk_size=chunk_size) + + # [*, N_seq, N_res, C_in] + m = m.transpose(-2, -3) + if mask is not None: + mask = mask.transpose(-1, -2) + + return m + + +class MSAColumnGlobalAttention(nn.Module): + def __init__( + self, c_in, c_hidden, no_heads, inf=1e9, eps=1e-10, + ): + super(MSAColumnGlobalAttention, self).__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_heads = no_heads + self.inf = inf + self.eps = eps + + self.layer_norm_m = nn.LayerNorm(c_in) + + self.global_attention = GlobalAttention( + c_in=c_in, + c_hidden=c_hidden, + no_heads=no_heads, + inf=inf, + eps=eps, + ) + + @torch.jit.ignore + def _chunk(self, + m: torch.Tensor, + mask: torch.Tensor, + chunk_size: int, + ) -> torch.Tensor: + mha_input = { + "m": m, + "mask": mask, + } + return chunk_layer( + self.global_attention, + mha_input, + chunk_size=chunk_size, + no_batch_dims=len(m.shape[:-2]), + ) + + def forward( + self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + n_seq, n_res, c_in = m.shape[-3:] + + if mask is None: + # [*, N_seq, N_res] + mask = torch.ones( + m.shape[:-1], + dtype=m.dtype, + device=m.device, + ).detach() + + # [*, N_res, N_seq, C_in] + m = m.transpose(-2, -3) + mask = mask.transpose(-1, -2) + + # [*, N_res, N_seq, C_in] + m = self.layer_norm_m(m) + + if chunk_size is not None: + m = self._chunk(m, mask, chunk_size) + else: + m = self.global_attention(m=m, mask=mask) + + # [*, N_seq, N_res, C_in] + m = m.transpose(-2, -3) + + return m diff --git a/openfold/outer_product_mean.py b/openfold/outer_product_mean.py new file mode 100644 index 000000000..43d853833 --- /dev/null +++ b/openfold/outer_product_mean.py @@ -0,0 +1,129 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Optional + +import torch +import torch.nn as nn + +from openfold.primitives import Linear +from openfold.tensor_utils import chunk_layer + + +class OuterProductMean(nn.Module): + """ + Implements Algorithm 10. + """ + + def __init__(self, c_m, c_z, c_hidden, eps=1e-3): + """ + Args: + c_m: + MSA embedding channel dimension + c_z: + Pair embedding channel dimension + c_hidden: + Hidden channel dimension + """ + super(OuterProductMean, self).__init__() + + self.c_m = c_m + self.c_z = c_z + self.c_hidden = c_hidden + self.eps = eps + + self.layer_norm = nn.LayerNorm(c_m) + self.linear_1 = Linear(c_m, c_hidden) + self.linear_2 = Linear(c_m, c_hidden) + self.linear_out = Linear(c_hidden ** 2, c_z, init="final") + + def _opm(self, a, b): + # [*, N_res, N_res, C, C] + outer = torch.einsum("...bac,...dae->...bdce", a, b) + + # [*, N_res, N_res, C * C] + outer = outer.reshape(outer.shape[:-2] + (-1,)) + + # [*, N_res, N_res, C_z] + outer = self.linear_out(outer) + + return outer + + @torch.jit.ignore + def _chunk(self, + a: torch.Tensor, + b: torch.Tensor, + chunk_size: int + ) -> torch.Tensor: + # Since the "batch dim" in this case is not a true batch dimension + # (in that the shape of the output depends on it), we need to + # iterate over it ourselves + a_reshape = a.reshape((-1,) + a.shape[-3:]) + b_reshape = b.reshape((-1,) + b.shape[-3:]) + out = [] + for a_prime, b_prime in zip(a_reshape, b_reshape): + outer = chunk_layer( + partial(self._opm, b=b_prime), + {"a": a_prime}, + chunk_size=chunk_size, + no_batch_dims=1, + ) + out.append(outer) + outer = torch.stack(out, dim=0) + outer = outer.reshape(a.shape[:-3] + outer.shape[1:]) + + return outer + + def forward(self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None + ) -> torch.Tensor: + """ + Args: + m: + [*, N_seq, N_res, C_m] MSA embedding + mask: + [*, N_seq, N_res] MSA mask + Returns: + [*, N_res, N_res, C_z] pair embedding update + """ + if mask is None: + mask = m.new_ones(m.shape[:-1]) + + # [*, N_seq, N_res, C_m] + m = self.layer_norm(m) + + # [*, N_seq, N_res, C] + mask = mask.unsqueeze(-1) + a = self.linear_1(m) * mask + b = self.linear_2(m) * mask + + a = a.transpose(-2, -3) + b = b.transpose(-2, -3) + + if chunk_size is not None: + outer = self._chunk(a, b, chunk_size) + else: + outer = self._opm(a, b) + + # [*, N_res, N_res, 1] + norm = torch.einsum("...abc,...adc->...bdc", mask, mask) + + # [*, N_res, N_res, C_z] + outer = outer / (self.eps + norm) + + return outer diff --git a/openfold/pair_transition.py b/openfold/pair_transition.py new file mode 100644 index 000000000..de7630641 --- /dev/null +++ b/openfold/pair_transition.py @@ -0,0 +1,99 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +import torch.nn as nn + +from openfold.primitives import Linear, LayerNorm +from openfold.tensor_utils import chunk_layer + + +class PairTransition(nn.Module): + """ + Implements Algorithm 15. + """ + + def __init__(self, c_z, n): + """ + Args: + c_z: + Pair transition channel dimension + n: + Factor by which c_z is multiplied to obtain hidden channel + dimension + """ + super(PairTransition, self).__init__() + + self.c_z = c_z + self.n = n + + self.layer_norm = LayerNorm(self.c_z) + self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu") + self.relu = nn.ReLU() + self.linear_2 = Linear(self.n * self.c_z, c_z, init="final") + + def _transition(self, z, mask): + # [*, N_res, N_res, C_hidden] + z = self.linear_1(z) + z = self.relu(z) + + # [*, N_res, N_res, C_z] + z = self.linear_2(z) * mask + + return z + + @torch.jit.ignore + def _chunk(self, + z: torch.Tensor, + mask: torch.Tensor, + chunk_size: int, + ) -> torch.Tensor: + return chunk_layer( + self._transition, + {"z": z, "mask": mask}, + chunk_size=chunk_size, + no_batch_dims=len(z.shape[:-2]), + ) + + + def forward(self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + """ + Args: + z: + [*, N_res, N_res, C_z] pair embedding + Returns: + [*, N_res, N_res, C_z] pair embedding update + """ + # DISCREPANCY: DeepMind forgets to apply the mask in this module. + if mask is None: + mask = z.new_ones(z.shape[:-1]) + + # [*, N_res, N_res, 1] + mask = mask.unsqueeze(-1) + + # [*, N_res, N_res, C_z] + z = self.layer_norm(z) + + if chunk_size is not None: + z = self._chunk(z, mask, chunk_size) + else: + z = self._transition(z=z, mask=mask) + + return z diff --git a/openfold/primitives.py b/openfold/primitives.py new file mode 100644 index 000000000..bbc156f21 --- /dev/null +++ b/openfold/primitives.py @@ -0,0 +1,529 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +import math +from typing import Optional, Callable, List, Tuple, Sequence +import numpy as np + +import torch +import torch.nn as nn + +from openfold.checkpointing import get_checkpoint_fn +from openfold.tensor_utils import ( + permute_final_dims, + flatten_final_dims, + _chunk_slice, +) + + +def _prod(nums): + out = 1 + for n in nums: + out = out * n + return out + + +def _calculate_fan(linear_weight_shape, fan="fan_in"): + fan_out, fan_in = linear_weight_shape + + if fan == "fan_in": + f = fan_in + elif fan == "fan_out": + f = fan_out + elif fan == "fan_avg": + f = (fan_in + fan_out) / 2 + else: + raise ValueError("Invalid fan option") + + return f + + +def glorot_uniform_init_(weights): + nn.init.xavier_uniform_(weights, gain=1) + + +def final_init_(weights): + with torch.no_grad(): + weights.fill_(0.0) + + +def gating_init_(weights): + with torch.no_grad(): + weights.fill_(0.0) + + +def normal_init_(weights): + torch.nn.init.kaiming_normal_(weights, nonlinearity="linear") + + +def ipa_point_weights_init_(weights): + with torch.no_grad(): + softplus_inverse_1 = 0.541324854612918 + weights.fill_(softplus_inverse_1) + + +class Linear(nn.Linear): + """ + A Linear layer with built-in nonstandard initializations. Called just + like torch.nn.Linear. + + Implements the initializers in 1.11.4, plus some additional ones found + in the code. + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + bias: bool = True, + init: str = "default", + init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None, + ): + """ + Args: + in_dim: + The final dimension of inputs to the layer + out_dim: + The final dimension of layer outputs + bias: + Whether to learn an additive bias. True by default + init: + The initializer to use. Choose from: + + "default": LeCun fan-in truncated normal initialization + "relu": He initialization w/ truncated normal distribution + "glorot": Fan-average Glorot uniform initialization + "gating": Weights=0, Bias=1 + "normal": Normal initialization with std=1/sqrt(fan_in) + "final": Weights=0, Bias=0 + + Overridden by init_fn if the latter is not None. + init_fn: + A custom initializer taking weight and bias as inputs. + Overrides init if not None. + """ + super(Linear, self).__init__(in_dim, out_dim, bias=bias) + + if bias: + with torch.no_grad(): + self.bias.fill_(0) + + if init_fn is not None: + init_fn(self.weight, self.bias) + else: + if init == "default": + normal_init_(self.weight) + elif init == "relu": + normal_init_(self.weight) + elif init == "glorot": + glorot_uniform_init_(self.weight) + elif init == "gating": + gating_init_(self.weight) + if bias: + with torch.no_grad(): + self.bias.fill_(1.0) + elif init == "normal": + normal_init_(self.weight) + elif init == "final": + final_init_(self.weight) + else: + raise ValueError("Invalid init string.") + + +class LayerNorm(nn.Module): + + def __init__(self, c_in, eps=1e-5): + super(LayerNorm, self).__init__() + + self.c_in = (c_in,) + self.eps = eps + + self.weight = nn.Parameter(torch.ones(c_in)) + self.bias = nn.Parameter(torch.zeros(c_in)) + + def forward(self, x): + out = nn.functional.layer_norm( + x, + self.c_in, + self.weight, + self.bias, + self.eps, + ) + + return out + + +@torch.jit.ignore +def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor: + """ + Softmax, but without automatic casting to fp32 when the input is of + type bfloat16 + """ + s = torch.nn.functional.softmax(t, dim=dim) + + return s + + +#@torch.jit.script +def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + biases: List[torch.Tensor]) -> torch.Tensor: + # [*, H, Q, C_hidden] + query = permute_final_dims(query, (1, 0, 2)) + + # [*, H, C_hidden, K] + key = permute_final_dims(key, (1, 2, 0)) + + # [*, H, V, C_hidden] + value = permute_final_dims(value, (1, 0, 2)) + + # [*, H, Q, K] + a = torch.matmul(query, key) + + for b in biases: + a += b + + a = softmax(a, -1) + + # [*, H, Q, C_hidden] + a = torch.matmul(a, value) + + # [*, Q, H, C_hidden] + a = a.transpose(-2, -3) + + return a + + +@torch.jit.ignore +def _attention_chunked_trainable( + query, + key, + value, + biases, + chunk_size, + chunk_dim, + checkpoint, +): + if (checkpoint and len(biases) > 2): + raise ValueError("Checkpointed version permits only permits two bias terms") + + def _checkpointable_attention(q, k, v, b1, b2): + bs = [b for b in [b1, b2] if b is not None] + return _attention(q, k, v, bs) + + o_chunks = [] + checkpoint_fn = get_checkpoint_fn() + count = query.shape[chunk_dim] + for start in range(0, count, chunk_size): + end = start + chunk_size + idx = [slice(None)] * len(query.shape) + idx[chunk_dim] = slice(start, end) + idx_tup = tuple(idx) + q_chunk = query[idx_tup] + k_chunk = key[idx_tup] + v_chunk = value[idx_tup] + + def _slice_bias(b): + idx[chunk_dim] = (slice(start, end) if b.shape[chunk_dim] != 1 else slice(None)) + return b[tuple(idx)] + + if (checkpoint): + bias_1_chunk, bias_2_chunk = [ + _slice_bias(b) if b is not None else None for b in (biases + [None, None])[:2] + ] + + o_chunk = checkpoint_fn(_checkpointable_attention, q_chunk, k_chunk, v_chunk, + bias_1_chunk, bias_2_chunk) + else: + bias_chunks = [_slice_bias(b) for b in biases] + + o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks) + + o_chunks.append(o_chunk) + + o = torch.cat(o_chunks, dim=chunk_dim) + return o + + +class Attention(nn.Module): + """ + Standard multi-head attention using AlphaFold's default layer + initialization. Allows multiple bias vectors. + """ + + def __init__( + self, + c_q: int, + c_k: int, + c_v: int, + c_hidden: int, + no_heads: int, + gating: bool = True, + ): + """ + Args: + c_q: + Input dimension of query data + c_k: + Input dimension of key data + c_v: + Input dimension of value data + c_hidden: + Per-head hidden dimension + no_heads: + Number of attention heads + gating: + Whether the output should be gated using query data + """ + super(Attention, self).__init__() + + self.c_q = c_q + self.c_k = c_k + self.c_v = c_v + self.c_hidden = c_hidden + self.no_heads = no_heads + self.gating = gating + + # DISCREPANCY: c_hidden is not the per-head channel dimension, as + # stated in the supplement, but the overall channel dimension. + + self.linear_q = Linear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_k = Linear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_v = Linear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_o = Linear(self.c_hidden * self.no_heads, self.c_q, init="final") + + self.linear_g = None + if self.gating: + self.linear_g = Linear(self.c_q, self.c_hidden * self.no_heads, init="gating") + + self.sigmoid = nn.Sigmoid() + + def _prep_qkv(self, q_x: torch.Tensor, + kv_x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # [*, Q/K/V, H * C_hidden] + q = self.linear_q(q_x) + k = self.linear_k(kv_x) + v = self.linear_v(kv_x) + + # [*, Q/K, H, C_hidden] + q = q.view(q.shape[:-1] + (self.no_heads, -1)) + k = k.view(k.shape[:-1] + (self.no_heads, -1)) + v = v.view(v.shape[:-1] + (self.no_heads, -1)) + + q /= math.sqrt(self.c_hidden) + + return q, k, v + + def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor: + if (self.linear_g is not None): + g = self.sigmoid(self.linear_g(q_x)) + + # [*, Q, H, C_hidden] + g = g.view(g.shape[:-1] + (self.no_heads, -1)) + o = o * g + + # [*, Q, H * C_hidden] + o = flatten_final_dims(o, 2) + + # [*, Q, C_q] + o = self.linear_o(o) + + return o + + def forward( + self, + q_x: torch.Tensor, + kv_x: torch.Tensor, + biases: Optional[List[torch.Tensor]] = None, + use_lma: bool = False, + q_chunk_size: Optional[int] = None, + kv_chunk_size: Optional[int] = None, + ) -> torch.Tensor: + """ + Args: + q_x: + [*, Q, C_q] query data + kv_x: + [*, K, C_k] key data + biases: + List of biases that broadcast to [*, H, Q, K] + use_lma: + Whether to use low-memory attention + q_chunk_size: + Query chunk size (for LMA) + kv_chunk_size: + Key/Value chunk size (for LMA) + Returns + [*, Q, C_q] attention update + """ + if (biases is None): + biases = [] + if (use_lma and (q_chunk_size is None or kv_chunk_size is None)): + raise ValueError("If use_lma is specified, q_chunk_size and kv_chunk_size must " + "be provided") + + q, k, v = self._prep_qkv(q_x, kv_x) + + if (use_lma): + biases = [b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) for b in biases] + + o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size) + else: + o = _attention(q, k, v, biases) + + o = self._wrap_up(o, q_x) + + return o + + +class GlobalAttention(nn.Module): + + def __init__(self, c_in, c_hidden, no_heads, inf, eps): + super(GlobalAttention, self).__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_heads = no_heads + self.inf = inf + self.eps = eps + + self.linear_q = Linear(c_in, c_hidden * no_heads, bias=False, init="glorot") + + self.linear_k = Linear( + c_in, + c_hidden, + bias=False, + init="glorot", + ) + self.linear_v = Linear( + c_in, + c_hidden, + bias=False, + init="glorot", + ) + self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating") + self.linear_o = Linear(c_hidden * no_heads, c_in, init="final") + + self.sigmoid = nn.Sigmoid() + + def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + # [*, N_res, C_in] + q = torch.sum(m * mask.unsqueeze(-1), + dim=-2) / (torch.sum(mask, dim=-1)[..., None] + self.eps) + + # [*, N_res, H * C_hidden] + q = self.linear_q(q) + q *= (self.c_hidden**(-0.5)) + + # [*, N_res, H, C_hidden] + q = q.view(q.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, N_seq, C_hidden] + k = self.linear_k(m) + v = self.linear_v(m) + + # [*, N_res, H, N_seq] + a = torch.matmul( + q, + k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq] + ) + bias = (self.inf * (mask - 1))[..., :, None, :] + a += bias + a = softmax(a) + + # [*, N_res, H, C_hidden] + o = torch.matmul( + a, + v, + ) + + # [*, N_res, N_seq, C_hidden] + g = self.sigmoid(self.linear_g(m)) + + # [*, N_res, N_seq, H, C_hidden] + g = g.view(g.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, N_seq, H, C_hidden] + o = o.unsqueeze(-3) * g + + # [*, N_res, N_seq, H * C_hidden] + o = o.reshape(o.shape[:-2] + (-1,)) + + # [*, N_res, N_seq, C_in] + m = self.linear_o(o) + + return m + + +def _lma( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + biases: List[torch.Tensor], + q_chunk_size: int, + kv_chunk_size: int, +): + no_q, no_kv = q.shape[-3], k.shape[-3] + + # [*, Q, H, C_hidden] + o = q.new_zeros(q.shape) + for q_s in range(0, no_q, q_chunk_size): + q_chunk = q[..., q_s:q_s + q_chunk_size, :, :] + large_bias_chunks = [b[..., q_s:q_s + q_chunk_size, :] for b in biases] + + maxes = [] + weights = [] + values = [] + for kv_s in range(0, no_kv, kv_chunk_size): + k_chunk = k[..., kv_s:kv_s + kv_chunk_size, :, :] + v_chunk = v[..., kv_s:kv_s + kv_chunk_size, :, :] + small_bias_chunks = [b[..., kv_s:kv_s + kv_chunk_size] for b in large_bias_chunks] + + a = torch.einsum( + "...qhd,...khd->...hqk", + q_chunk, + k_chunk, + ) + + for b in small_bias_chunks: + a += b + + a = a.transpose(-2, -3) + + max_a = torch.max(a, dim=-1, keepdim=True)[0] + exp_a = torch.exp(a - max_a) + exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a) + + maxes.append(max_a.detach().squeeze(-1)) + weights.append(torch.sum(exp_a, dim=-1)) + values.append(exp_v) + + chunk_max = torch.stack(maxes, dim=-3) + chunk_weights = torch.stack(weights, dim=-3) + chunk_values = torch.stack(values, dim=-4) + + global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0] + max_diffs = torch.exp(chunk_max - global_max) + chunk_values *= max_diffs.unsqueeze(-1) + chunk_weights *= max_diffs + + all_values = torch.sum(chunk_values, dim=-4) + all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4) + + q_chunk_out = all_values / all_weights + + o[..., q_s:q_s + q_chunk_size, :, :] = q_chunk_out + + return o diff --git a/openfold/tensor_utils.py b/openfold/tensor_utils.py new file mode 100644 index 000000000..7e5e8e4b6 --- /dev/null +++ b/openfold/tensor_utils.py @@ -0,0 +1,408 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +import torch +import torch.nn as nn +from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional + + +def permute_final_dims(tensor: torch.Tensor, inds: List[int]): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + + +def flatten_final_dims(t: torch.Tensor, no_dims: int): + return t.reshape(t.shape[:-no_dims] + (-1,)) + + +def masked_mean(mask, value, dim, eps=1e-4): + mask = mask.expand(*value.shape) + return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) + + +def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): + boundaries = torch.linspace( + min_bin, max_bin, no_bins - 1, device=pts.device + ) + dists = torch.sqrt( + torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1) + ) + return torch.bucketize(dists, boundaries) + + +def dict_multimap(fn, dicts): + first = dicts[0] + new_dict = {} + for k, v in first.items(): + all_v = [d[k] for d in dicts] + if type(v) is dict: + new_dict[k] = dict_multimap(fn, all_v) + else: + new_dict[k] = fn(all_v) + + return new_dict + + +def one_hot(x, v_bins): + reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) + diffs = x[..., None] - reshaped_bins + am = torch.argmin(torch.abs(diffs), dim=-1) + return nn.functional.one_hot(am, num_classes=len(v_bins)).float() + + +def batched_gather(data, inds, dim=0, no_batch_dims=0): + ranges = [] + for i, s in enumerate(data.shape[:no_batch_dims]): + r = torch.arange(s) + r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) + ranges.append(r) + + remaining_dims = [ + slice(None) for _ in range(len(data.shape) - no_batch_dims) + ] + remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds + ranges.extend(remaining_dims) + return data[ranges] + + +# With tree_map, a poor man's JAX tree_map +def dict_map(fn, dic, leaf_type): + new_dict = {} + for k, v in dic.items(): + if type(v) is dict: + new_dict[k] = dict_map(fn, v, leaf_type) + else: + new_dict[k] = tree_map(fn, v, leaf_type) + + return new_dict + + +def tree_map(fn, tree, leaf_type): + if isinstance(tree, dict): + return dict_map(fn, tree, leaf_type) + elif isinstance(tree, list): + return [tree_map(fn, x, leaf_type) for x in tree] + elif isinstance(tree, tuple): + return tuple([tree_map(fn, x, leaf_type) for x in tree]) + elif isinstance(tree, leaf_type): + return fn(tree) + else: + print(type(tree)) + raise ValueError("Not supported") + + +tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) + +def _fetch_dims(tree): + shapes = [] + tree_type = type(tree) + if tree_type is dict: + for v in tree.values(): + shapes.extend(_fetch_dims(v)) + elif tree_type is list or tree_type is tuple: + for t in tree: + shapes.extend(_fetch_dims(t)) + elif tree_type is torch.Tensor: + shapes.append(tree.shape) + else: + raise ValueError("Not supported") + + return shapes + + +@torch.jit.ignore +def _flat_idx_to_idx( + flat_idx: int, + dims: Tuple[int], +) -> Tuple[int]: + idx = [] + for d in reversed(dims): + idx.append(flat_idx % d) + flat_idx = flat_idx // d + + return tuple(reversed(idx)) + + +@torch.jit.ignore +def _get_minimal_slice_set( + start: Sequence[int], + end: Sequence[int], + dims: int, + start_edges: Optional[Sequence[bool]] = None, + end_edges: Optional[Sequence[bool]] = None, +) -> Sequence[Tuple[int]]: + """ + Produces an ordered sequence of tensor slices that, when used in + sequence on a tensor with shape dims, yields tensors that contain every + leaf in the contiguous range [start, end]. Care is taken to yield a + short sequence of slices, and perhaps even the shortest possible (I'm + pretty sure it's the latter). + + end is INCLUSIVE. + """ + # start_edges and end_edges both indicate whether, starting from any given + # dimension, the start/end index is at the top/bottom edge of the + # corresponding tensor, modeled as a tree + def reduce_edge_list(l): + tally = 1 + for i in range(len(l)): + reversed_idx = -1 * (i + 1) + l[reversed_idx] *= tally + tally = l[reversed_idx] + + if(start_edges is None): + start_edges = [s == 0 for s in start] + reduce_edge_list(start_edges) + if(end_edges is None): + end_edges = [e == (d - 1) for e,d in zip(end, dims)] + reduce_edge_list(end_edges) + + # Base cases. Either start/end are empty and we're done, or the final, + # one-dimensional tensor can be simply sliced + if(len(start) == 0): + return [tuple()] + elif(len(start) == 1): + return [(slice(start[0], end[0] + 1),)] + + slices = [] + path = [] + + # Dimensions common to start and end can be selected directly + for s,e in zip(start, end): + if(s == e): + path.append(slice(s, s + 1)) + else: + break + + path = tuple(path) + divergence_idx = len(path) + + # start == end, and we're done + if(divergence_idx == len(dims)): + return [tuple(path)] + + def upper(): + sdi = start[divergence_idx] + return [ + path + (slice(sdi, sdi + 1),) + s for s in + _get_minimal_slice_set( + start[divergence_idx + 1:], + [d - 1 for d in dims[divergence_idx + 1:]], + dims[divergence_idx + 1:], + start_edges=start_edges[divergence_idx + 1:], + end_edges=[1 for _ in end_edges[divergence_idx + 1:]] + ) + ] + + def lower(): + edi = end[divergence_idx] + return [ + path + (slice(edi, edi + 1),) + s for s in + _get_minimal_slice_set( + [0 for _ in start[divergence_idx + 1:]], + end[divergence_idx + 1:], + dims[divergence_idx + 1:], + start_edges=[1 for _ in start_edges[divergence_idx + 1:]], + end_edges=end_edges[divergence_idx + 1:], + ) + ] + + # If both start and end are at the edges of the subtree rooted at + # divergence_idx, we can just select the whole subtree at once + if(start_edges[divergence_idx] and end_edges[divergence_idx]): + slices.append( + path + (slice(start[divergence_idx], end[divergence_idx] + 1),) + ) + # If just start is at the edge, we can grab almost all of the subtree, + # treating only the ragged bottom edge as an edge case + elif(start_edges[divergence_idx]): + slices.append( + path + (slice(start[divergence_idx], end[divergence_idx]),) + ) + slices.extend(lower()) + # Analogous to the previous case, but the top is ragged this time + elif(end_edges[divergence_idx]): + slices.extend(upper()) + slices.append( + path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),) + ) + # If both sides of the range are ragged, we need to handle both sides + # separately. If there's contiguous meat in between them, we can index it + # in one big chunk + else: + slices.extend(upper()) + middle_ground = end[divergence_idx] - start[divergence_idx] + if(middle_ground > 1): + slices.append( + path + (slice(start[divergence_idx] + 1, end[divergence_idx]),) + ) + slices.extend(lower()) + + return [tuple(s) for s in slices] + + +@torch.jit.ignore +def _chunk_slice( + t: torch.Tensor, + flat_start: int, + flat_end: int, + no_batch_dims: int, +) -> torch.Tensor: + """ + Equivalent to + + t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end] + + but without the need for the initial reshape call, which can be + memory-intensive in certain situations. The only reshape operations + in this function are performed on sub-tensors that scale with + (flat_end - flat_start), the chunk size. + """ + + batch_dims = t.shape[:no_batch_dims] + start_idx = list(_flat_idx_to_idx(flat_start, batch_dims)) + # _get_minimal_slice_set is inclusive + end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims)) + + # Get an ordered list of slices to perform + slices = _get_minimal_slice_set( + start_idx, + end_idx, + batch_dims, + ) + + sliced_tensors = [t[s] for s in slices] + + return torch.cat( + [s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors] + ) + + +def chunk_layer( + layer: Callable, + inputs: Dict[str, Any], + chunk_size: int, + no_batch_dims: int, + low_mem: bool = False, +) -> Any: + """ + Implements the "chunking" procedure described in section 1.11.8. + + Layer outputs and inputs are assumed to be simple "pytrees," + consisting only of (arbitrarily nested) lists, tuples, and dicts with + torch.Tensor leaves. + + Args: + layer: + The layer to be applied chunk-wise + inputs: + A (non-nested) dictionary of keyworded inputs. All leaves must + be tensors and must share the same batch dimensions. + chunk_size: + The number of sub-batches per chunk. If multiple batch + dimensions are specified, a "sub-batch" is defined as a single + indexing of all batch dimensions simultaneously (s.t. the + number of sub-batches is the product of the batch dimensions). + no_batch_dims: + How many of the initial dimensions of each input tensor can + be considered batch dimensions. + low_mem: + Avoids flattening potentially large input tensors. Unnecessary + in most cases, and is ever so slightly slower than the default + setting. + Returns: + The reassembled output of the layer on the inputs. + """ + if not (len(inputs) > 0): + raise ValueError("Must provide at least one input") + + initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)] + orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)]) + + def _prep_inputs(t): + # TODO: make this more memory efficient. This sucks + if(not low_mem): + if not sum(t.shape[:no_batch_dims]) == no_batch_dims: + t = t.expand(orig_batch_dims + t.shape[no_batch_dims:]) + t = t.reshape(-1, *t.shape[no_batch_dims:]) + else: + t = t.expand(orig_batch_dims + t.shape[no_batch_dims:]) + return t + + prepped_inputs = tensor_tree_map(_prep_inputs, inputs) + + flat_batch_dim = 1 + for d in orig_batch_dims: + flat_batch_dim *= d + + no_chunks = flat_batch_dim // chunk_size + ( + flat_batch_dim % chunk_size != 0 + ) + + i = 0 + out = None + for _ in range(no_chunks): + # Chunk the input + if(not low_mem): + select_chunk = ( + lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t + ) + else: + select_chunk = ( + partial( + _chunk_slice, + flat_start=i, + flat_end=min(flat_batch_dim, i + chunk_size), + no_batch_dims=len(orig_batch_dims) + ) + ) + + chunks = tensor_tree_map(select_chunk, prepped_inputs) + + # Run the layer on the chunk + output_chunk = layer(**chunks) + + # Allocate space for the output + if out is None: + allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:]) + out = tensor_tree_map(allocate, output_chunk) + + # Put the chunk in its pre-allocated space + out_type = type(output_chunk) + if out_type is dict: + def assign(d1, d2): + for k, v in d1.items(): + if type(v) is dict: + assign(v, d2[k]) + else: + v[i : i + chunk_size] = d2[k] + + assign(out, output_chunk) + elif out_type is tuple: + for x1, x2 in zip(out, output_chunk): + x1[i : i + chunk_size] = x2 + elif out_type is torch.Tensor: + out[i : i + chunk_size] = output_chunk + else: + raise ValueError("Not supported") + + i += chunk_size + + reshape = lambda t: t.view(orig_batch_dims + t.shape[1:]) + out = tensor_tree_map(reshape, out) + + return out diff --git a/openfold/triangular_attention.py b/openfold/triangular_attention.py new file mode 100644 index 000000000..6d3e37f4c --- /dev/null +++ b/openfold/triangular_attention.py @@ -0,0 +1,139 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partialmethod, partial +import math +from typing import Optional, List + +import torch +import torch.nn as nn + +from openfold.primitives import Linear, LayerNorm, Attention +from openfold.tensor_utils import ( + chunk_layer, + permute_final_dims, + flatten_final_dims, +) + + +class TriangleAttention(nn.Module): + def __init__( + self, c_in, c_hidden, no_heads, starting, inf=1e9 + ): + """ + Args: + c_in: + Input channel dimension + c_hidden: + Overall hidden channel dimension (not per-head) + no_heads: + Number of attention heads + """ + super(TriangleAttention, self).__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_heads = no_heads + self.starting = starting + self.inf = inf + + self.layer_norm = LayerNorm(self.c_in) + + self.linear = Linear(c_in, self.no_heads, bias=False, init="normal") + + self.mha = Attention( + self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads + ) + + @torch.jit.ignore + def _chunk(self, + x: torch.Tensor, + biases: List[torch.Tensor], + chunk_size: int, + ) -> torch.Tensor: + mha_inputs = { + "q_x": x, + "kv_x": x, + "biases": biases, + } + return chunk_layer( + partial(self.mha), + mha_inputs, + chunk_size=chunk_size, + no_batch_dims=len(x.shape[:-2]), + ) + + def forward(self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None + ) -> torch.Tensor: + """ + Args: + x: + [*, I, J, C_in] input tensor (e.g. the pair representation) + Returns: + [*, I, J, C_in] output tensor + """ + if mask is None: + # [*, I, J] + mask = x.new_ones( + x.shape[:-1], + ) + + # Shape annotations assume self.starting. Else, I and J are flipped + if not self.starting: + x = x.transpose(-2, -3) + mask = mask.transpose(-1, -2) + + # [*, I, J, C_in] + x = self.layer_norm(x) + + # [*, I, 1, 1, J] + mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] + + # [*, H, I, J] + triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1)) + + # [*, 1, H, I, J] + triangle_bias = triangle_bias.unsqueeze(-4) + + biases = [mask_bias, triangle_bias] + + if chunk_size is not None: + x = self._chunk(x, biases, chunk_size) + else: + x = self.mha(q_x=x, kv_x=x, biases=biases) + + if not self.starting: + x = x.transpose(-2, -3) + + return x + + +class TriangleAttentionStartingNode(TriangleAttention): + """ + Implements Algorithm 13. + """ + + __init__ = partialmethod(TriangleAttention.__init__, starting=True) + + +class TriangleAttentionEndingNode(TriangleAttention): + """ + Implements Algorithm 14. + """ + + __init__ = partialmethod(TriangleAttention.__init__, starting=False) diff --git a/openfold/triangular_multiplicative_update.py b/openfold/triangular_multiplicative_update.py new file mode 100644 index 000000000..2406e2bac --- /dev/null +++ b/openfold/triangular_multiplicative_update.py @@ -0,0 +1,127 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partialmethod +from typing import Optional + +import torch +import torch.nn as nn + +from openfold.primitives import Linear, LayerNorm +from openfold.tensor_utils import permute_final_dims + + +class TriangleMultiplicativeUpdate(nn.Module): + """ + Implements Algorithms 11 and 12. + """ + def __init__(self, c_z, c_hidden, _outgoing=True): + """ + Args: + c_z: + Input channel dimension + c: + Hidden channel dimension + """ + super(TriangleMultiplicativeUpdate, self).__init__() + self.c_z = c_z + self.c_hidden = c_hidden + self._outgoing = _outgoing + + self.linear_a_p = Linear(self.c_z, self.c_hidden) + self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating") + self.linear_b_p = Linear(self.c_z, self.c_hidden) + self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating") + self.linear_g = Linear(self.c_z, self.c_z, init="gating") + self.linear_z = Linear(self.c_hidden, self.c_z, init="final") + + self.layer_norm_in = LayerNorm(self.c_z) + self.layer_norm_out = LayerNorm(self.c_hidden) + + self.sigmoid = nn.Sigmoid() + + def _combine_projections(self, + a: torch.Tensor, + b: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError("This method needs to be overridden") + + def forward(self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Args: + x: + [*, N_res, N_res, C_z] input tensor + mask: + [*, N_res, N_res] input mask + Returns: + [*, N_res, N_res, C_z] output tensor + """ + if mask is None: + mask = z.new_ones(z.shape[:-1]) + + mask = mask.unsqueeze(-1) + + z = self.layer_norm_in(z) + a = self.linear_a_p(z) * self.sigmoid(self.linear_a_g(z)) + a = a * mask + b = self.linear_b_p(z) * self.sigmoid(self.linear_b_g(z)) + b = b * mask + x = self._combine_projections(a, b) + x = self.layer_norm_out(x) + x = self.linear_z(x) + g = self.sigmoid(self.linear_g(z)) + z = x * g + + return z + + +class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate): + """ + Implements Algorithm 11. + """ + def _combine_projections(self, + a: torch.Tensor, # [*, N_i, N_k, C] + b: torch.Tensor, # [*, N_j, N_k, C] + ): + # [*, C, N_i, N_j] + p = torch.matmul( + permute_final_dims(a, (2, 0, 1)), + permute_final_dims(b, (2, 1, 0)), + ) + + # [*, N_i, N_j, C] + return permute_final_dims(p, (1, 2, 0)) + + +class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate): + """ + Implements Algorithm 12. + """ + def _combine_projections(self, + a: torch.Tensor, # [*, N_k, N_i, C] + b: torch.Tensor, # [*, N_k, N_j, C] + ): + # [*, C, N_i, N_j] + p = torch.matmul( + permute_final_dims(a, (2, 1, 0)), + permute_final_dims(b, (2, 0, 1)), + ) + + # [*, N_i, N_j, C] + return permute_final_dims(p, (1, 2, 0)) +