mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 13:11:27 +00:00
init openfold
This commit is contained in:
parent
69af93107f
commit
fff493c202
@ -1,59 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .msa import MSAStack
|
||||
from .ops import OutProductMean
|
||||
from .triangle import PairStack
|
||||
|
||||
|
||||
def print_memory(init_mem, text=None):
|
||||
now_mem = torch.cuda.memory_allocated() / 1024 ** 2 - init_mem
|
||||
max_mem = torch.cuda.max_memory_allocated() / 1024 ** 2 - init_mem
|
||||
print("%s now:%.2f max:%.2f" % ("" if text is None else text, now_mem, max_mem))
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
|
||||
class EvoformerBlock(nn.Module):
|
||||
|
||||
def __init__(self, d_node, d_pair):
|
||||
super(EvoformerBlock, self).__init__()
|
||||
|
||||
self.msa_stack = MSAStack(d_node, d_pair, p_drop=0.15)
|
||||
self.communication = OutProductMean(n_feat=d_node, n_feat_out=d_pair, n_feat_proj=32)
|
||||
self.pair_stack = PairStack(d_pair=d_pair)
|
||||
|
||||
def forward(self, node, pair):
|
||||
node = self.msa_stack(node, pair)
|
||||
pair = pair + self.communication(node)
|
||||
pair = self.pair_stack(pair)
|
||||
return node, pair
|
||||
|
||||
|
||||
class Evoformer(nn.Module):
|
||||
|
||||
def __init__(self, d_node, d_pair):
|
||||
super(Evoformer, self).__init__()
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
for _ in range(1):
|
||||
self.blocks.append(EvoformerBlock(d_node, d_pair))
|
||||
|
||||
def forward(self, node, pair):
|
||||
for b in self.blocks:
|
||||
node, pair = b(node, pair)
|
||||
return node, pair
|
||||
|
||||
|
||||
def evoformer_tiny():
|
||||
return Evoformer(d_node=64, d_pair=32)
|
||||
|
||||
|
||||
def evoformer_base():
|
||||
return Evoformer(d_node=256, d_pair=128)
|
||||
|
||||
|
||||
def evoformer_large():
|
||||
return Evoformer(d_node=512, d_pair=256)
|
||||
|
||||
|
||||
__all__ = ['Evoformer', 'evoformer_base', 'evoformer_large']
|
@ -1,29 +0,0 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def glorot_uniform_af(x, gain=1.0):
|
||||
"""
|
||||
initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different:
|
||||
In PyTorch:
|
||||
[feature_out, feature_in, n_head ...]
|
||||
In Jax:
|
||||
[... n_head, feature_in, feature_out]
|
||||
However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like:
|
||||
[feature_in, n_head, feature_out]
|
||||
|
||||
In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors
|
||||
"""
|
||||
fan_in, fan_out = x.shape[-2:]
|
||||
if len(x.shape) > 2:
|
||||
receptive_field_size = np.prod(x.shape[:-2])
|
||||
fan_in *= receptive_field_size
|
||||
fan_out *= receptive_field_size
|
||||
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
|
||||
dev = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
||||
|
||||
nn.init.uniform_(x, -dev, dev)
|
||||
|
||||
return x
|
@ -1,19 +0,0 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def bias_sigmod_ele(y, bias, z):
|
||||
return torch.sigmoid(y + bias) * z
|
||||
|
||||
|
||||
def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor,
|
||||
residual: torch.Tensor, prob: float) -> torch.Tensor:
|
||||
out = (x + bias) * F.dropout(dropmask, p=prob, training=False)
|
||||
out = residual + out
|
||||
return out
|
||||
|
||||
|
||||
def bias_ele_dropout_residual(ab: torch.Tensor, b: torch.Tensor, g: torch.Tensor,
|
||||
dropout_mask: torch.Tensor, Z_raw: torch.Tensor,
|
||||
prob: float) -> torch.Tensor:
|
||||
return Z_raw + F.dropout(dropout_mask, p=prob, training=True) * (g * (ab + b))
|
@ -1,95 +0,0 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from .kernel import bias_dropout_add
|
||||
from .ops import SelfAttention, Transition
|
||||
|
||||
|
||||
class MSARowAttentionWithPairBias(nn.Module):
|
||||
|
||||
def __init__(self, d_node, d_pair, c=32, n_head=8, p_drop=0.15):
|
||||
super(MSARowAttentionWithPairBias, self).__init__()
|
||||
self.d_node = d_node
|
||||
self.d_pair = d_pair
|
||||
self.c = c
|
||||
self.n_head = n_head
|
||||
self.p_drop = p_drop
|
||||
|
||||
self.layernormM = LayerNorm(d_node)
|
||||
self.layernormZ = LayerNorm(d_pair)
|
||||
|
||||
_init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]),
|
||||
std=1.0 / math.sqrt(d_pair))
|
||||
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights, requires_grad=True)
|
||||
|
||||
self.attention = SelfAttention(qkv_dim=d_node,
|
||||
c=c,
|
||||
n_head=n_head,
|
||||
out_dim=d_node,
|
||||
gating=True,
|
||||
last_bias_fuse=True)
|
||||
|
||||
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_node,)), requires_grad=True)
|
||||
|
||||
def forward(self, M_raw, Z):
|
||||
## Input projections
|
||||
M = self.layernormM(M_raw)
|
||||
Z = self.layernormZ(Z)
|
||||
b = F.linear(Z, self.linear_b_weights)
|
||||
b = b.permute(0, 3, 1, 2)
|
||||
# b = rearrange(b, 'b q k h -> b h q k')
|
||||
|
||||
M = self.attention(M, b)
|
||||
dropout_mask = torch.ones_like(M[:, 0:1, :, :]).to(M.device).to(M.dtype)
|
||||
|
||||
return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop)
|
||||
|
||||
|
||||
class MSAColumnAttention(nn.Module):
|
||||
|
||||
def __init__(self, d_node, c=32, n_head=8):
|
||||
super(MSAColumnAttention, self).__init__()
|
||||
self.d_node = d_node
|
||||
self.c = c
|
||||
self.n_head = n_head
|
||||
|
||||
self.layernormM = LayerNorm(d_node)
|
||||
self.attention = SelfAttention(qkv_dim=d_node,
|
||||
c=c,
|
||||
n_head=n_head,
|
||||
out_dim=d_node,
|
||||
gating=True)
|
||||
|
||||
def forward(self, M_raw):
|
||||
M = M_raw.transpose(-2, -3)
|
||||
M = self.layernormM(M)
|
||||
|
||||
M = self.attention(M)
|
||||
|
||||
M = M.transpose(-2, -3)
|
||||
return M_raw + M
|
||||
|
||||
|
||||
class MSAStack(nn.Module):
|
||||
|
||||
def __init__(self, d_node, d_pair, p_drop=0.15):
|
||||
super(MSAStack, self).__init__()
|
||||
|
||||
self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node,
|
||||
d_pair=d_pair,
|
||||
p_drop=p_drop)
|
||||
|
||||
self.MSAColumnAttention = MSAColumnAttention(d_node=d_node)
|
||||
self.MSATransition = Transition(d=d_node)
|
||||
|
||||
def forward(self, node, pair):
|
||||
node = self.MSARowAttentionWithPairBias(node, pair)
|
||||
node = self.MSAColumnAttention(node)
|
||||
node = self.MSATransition(node)
|
||||
|
||||
return node
|
@ -1,176 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from .initializer import glorot_uniform_af
|
||||
from .kernel import bias_sigmod_ele
|
||||
|
||||
|
||||
class DropoutRowwise(nn.Module):
|
||||
|
||||
def __init__(self, p):
|
||||
super(DropoutRowwise, self).__init__()
|
||||
self.p = p
|
||||
self.dropout = nn.Dropout(p=p)
|
||||
|
||||
def forward(self, x):
|
||||
dropout_mask = torch.ones_like(x[:, 0:1, :, :])
|
||||
dropout_mask = self.dropout(dropout_mask)
|
||||
return dropout_mask * x
|
||||
|
||||
|
||||
class DropoutColumnwise(nn.Module):
|
||||
|
||||
def __init__(self, p):
|
||||
super(DropoutColumnwise, self).__init__()
|
||||
self.p = p
|
||||
self.dropout = nn.Dropout(p=p)
|
||||
|
||||
def forward(self, x):
|
||||
dropout_mask = torch.ones_like(x[:, :, 0:1, :])
|
||||
dropout_mask = self.dropout(dropout_mask)
|
||||
return dropout_mask * x
|
||||
|
||||
|
||||
class Transition(nn.Module):
|
||||
|
||||
def __init__(self, d, n=4):
|
||||
super(Transition, self).__init__()
|
||||
self.norm = LayerNorm(d)
|
||||
self.linear1 = Linear(d, n * d, initializer='relu')
|
||||
self.linear2 = Linear(n * d, d, initializer='zeros')
|
||||
|
||||
def forward(self, src):
|
||||
x = self.norm(src)
|
||||
x = self.linear2(F.relu(self.linear1(x)))
|
||||
return src + x
|
||||
|
||||
|
||||
class OutProductMean(nn.Module):
|
||||
|
||||
def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32):
|
||||
super(OutProductMean, self).__init__()
|
||||
|
||||
self.layernormM = LayerNorm(n_feat)
|
||||
self.linear_a = Linear(n_feat, n_feat_proj)
|
||||
self.linear_b = Linear(n_feat, n_feat_proj)
|
||||
|
||||
self.o_linear = Linear(n_feat_proj * n_feat_proj,
|
||||
n_feat_out,
|
||||
initializer='zero',
|
||||
use_bias=True)
|
||||
|
||||
def forward(self, M):
|
||||
M = self.layernormM(M)
|
||||
left_act = self.linear_a(M)
|
||||
right_act = self.linear_b(M)
|
||||
|
||||
O = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous()
|
||||
# O = rearrange(O, 'b i j d e -> b i j (d e)')
|
||||
O = O.reshape(O.shape[0], O.shape[1], O.shape[2], -1)
|
||||
Z = self.o_linear(O)
|
||||
|
||||
return Z
|
||||
|
||||
|
||||
class Linear(nn.Linear):
|
||||
"""
|
||||
A Linear layer with built-in nonstandard initializations. Called just
|
||||
like torch.nn.Linear.
|
||||
Implements the initializers in 1.11.4, plus some additional ones found
|
||||
in the code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_in: int,
|
||||
feature_out: int,
|
||||
initializer: str = 'linear',
|
||||
use_bias: bool = True,
|
||||
bias_init: float = 0.,
|
||||
):
|
||||
super(Linear, self).__init__(feature_in, feature_out, bias=use_bias)
|
||||
|
||||
self.use_bias = use_bias
|
||||
if initializer == 'linear':
|
||||
glorot_uniform_af(self.weight, gain=1.0)
|
||||
elif initializer == 'relu':
|
||||
glorot_uniform_af(self.weight, gain=2.0)
|
||||
elif initializer == 'zeros':
|
||||
nn.init.zeros_(self.weight)
|
||||
if self.use_bias:
|
||||
with torch.no_grad():
|
||||
self.bias.fill_(bias_init)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
"""
|
||||
Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors
|
||||
"""
|
||||
|
||||
def __init__(self, qkv_dim, c, n_head, out_dim, gating=True, last_bias_fuse=False):
|
||||
super(SelfAttention, self).__init__()
|
||||
self.qkv_dim = qkv_dim
|
||||
self.c = c
|
||||
self.n_head = n_head
|
||||
self.out_dim = out_dim
|
||||
self.gating = gating
|
||||
self.last_bias_fuse = last_bias_fuse
|
||||
|
||||
self.scaling = self.c**(-0.5)
|
||||
|
||||
# self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear')
|
||||
self.to_q = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
|
||||
self.to_k = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
|
||||
self.to_v = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
|
||||
|
||||
if gating:
|
||||
self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,)))
|
||||
self.gating_linear = Linear(qkv_dim, n_head * c, initializer='zero', use_bias=False)
|
||||
|
||||
self.o_linear = Linear(n_head * c,
|
||||
out_dim,
|
||||
initializer='zero',
|
||||
use_bias=(not last_bias_fuse))
|
||||
|
||||
def forward(self, in_data, nonbatched_bias=None):
|
||||
"""
|
||||
:param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim]
|
||||
:param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv]
|
||||
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
|
||||
"""
|
||||
|
||||
# qkv = self.to_qkv(in_data).chunk(3, dim=-1)
|
||||
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv)
|
||||
|
||||
q = self.to_q(in_data)
|
||||
k = self.to_k(in_data)
|
||||
v = self.to_v(in_data)
|
||||
|
||||
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head),
|
||||
# [q, k, v])
|
||||
q, k, v = map(lambda t: t.view(t.shape[0], t.shape[1], t.shape[2], self.n_head, -1).permute(0, 1, 3, 2, 4),
|
||||
[q, k, v])
|
||||
|
||||
q = q * self.scaling
|
||||
|
||||
logits = torch.matmul(q, k.transpose(-1, -2))
|
||||
|
||||
if nonbatched_bias is not None:
|
||||
logits += nonbatched_bias.unsqueeze(1)
|
||||
weights = torch.softmax(logits, dim=-1)
|
||||
# weights = softmax(logits)
|
||||
|
||||
weighted_avg = torch.matmul(weights, v)
|
||||
# weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
|
||||
weighted_avg = weighted_avg.permute(0, 1, 3, 2, 4)
|
||||
weighted_avg = weighted_avg.reshape(weighted_avg.shape[0], weighted_avg.shape[1], weighted_avg.shape[2], -1)
|
||||
|
||||
if self.gating:
|
||||
gate_values = self.gating_linear(in_data)
|
||||
weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg)
|
||||
|
||||
output = self.o_linear(weighted_avg)
|
||||
return output
|
@ -1,192 +0,0 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from .kernel import bias_dropout_add, bias_ele_dropout_residual
|
||||
from .ops import Linear, SelfAttention, Transition
|
||||
|
||||
|
||||
def permute_final_dims(tensor, inds):
|
||||
zero_index = -1 * len(inds)
|
||||
first_inds = list(range(len(tensor.shape[:zero_index])))
|
||||
return tensor.permute(first_inds + [zero_index + i for i in inds])
|
||||
|
||||
|
||||
class TriangleMultiplicationOutgoing(nn.Module):
|
||||
|
||||
def __init__(self, d_pair, p_drop, c=128):
|
||||
super(TriangleMultiplicationOutgoing, self).__init__()
|
||||
self.d_pair = d_pair
|
||||
self.c = c
|
||||
|
||||
self.layernorm1 = LayerNorm(d_pair)
|
||||
self.left_projection = Linear(d_pair, c)
|
||||
self.right_projection = Linear(d_pair, c)
|
||||
self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
|
||||
self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
|
||||
|
||||
self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.)
|
||||
self.layernorm2 = LayerNorm(c)
|
||||
self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False)
|
||||
self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
|
||||
|
||||
self.p_drop = p_drop
|
||||
|
||||
def forward(self, Z_raw):
|
||||
Z = self.layernorm1(Z_raw)
|
||||
left_proj_act = self.left_projection(Z)
|
||||
right_proj_act = self.right_projection(Z)
|
||||
|
||||
left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z))
|
||||
right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z))
|
||||
|
||||
g = torch.sigmoid(self.output_gate(Z))
|
||||
# p = torch.matmul(
|
||||
# permute_final_dims(left_proj_act, (2, 0, 1)),
|
||||
# permute_final_dims(right_proj_act, (2, 1, 0)),
|
||||
# )
|
||||
# ab = permute_final_dims(p, (1, 2, 0))
|
||||
|
||||
ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act)
|
||||
ab = self.output_projection(self.layernorm2(ab))
|
||||
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
|
||||
return bias_ele_dropout_residual(ab,
|
||||
self.output_bias,
|
||||
g,
|
||||
dropout_mask,
|
||||
Z_raw,
|
||||
prob=self.p_drop)
|
||||
|
||||
|
||||
class TriangleMultiplicationIncoming(nn.Module):
|
||||
|
||||
def __init__(self, d_pair, p_drop, c=128):
|
||||
super(TriangleMultiplicationIncoming, self).__init__()
|
||||
self.d_pair = d_pair
|
||||
self.c = c
|
||||
|
||||
self.layernorm1 = LayerNorm(d_pair)
|
||||
self.left_projection = Linear(d_pair, c)
|
||||
self.right_projection = Linear(d_pair, c)
|
||||
self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
|
||||
self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
|
||||
|
||||
self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.)
|
||||
self.layernorm2 = LayerNorm(c)
|
||||
self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False)
|
||||
self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
|
||||
|
||||
self.p_drop = p_drop
|
||||
|
||||
def forward(self, Z_raw):
|
||||
Z = self.layernorm1(Z_raw)
|
||||
left_proj_act = self.left_projection(Z)
|
||||
right_proj_act = self.right_projection(Z)
|
||||
|
||||
left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z))
|
||||
right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z))
|
||||
|
||||
g = torch.sigmoid(self.output_gate(Z))
|
||||
# p = torch.matmul(
|
||||
# permute_final_dims(left_proj_act, (2, 1, 0)),
|
||||
# permute_final_dims(right_proj_act, (2, 0, 1)),
|
||||
# )
|
||||
# ab = permute_final_dims(p, (1, 2, 0))
|
||||
|
||||
ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act)
|
||||
ab = self.output_projection(self.layernorm2(ab))
|
||||
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
|
||||
return bias_ele_dropout_residual(ab,
|
||||
self.output_bias,
|
||||
g,
|
||||
dropout_mask,
|
||||
Z_raw,
|
||||
prob=self.p_drop)
|
||||
|
||||
|
||||
class TriangleAttentionStartingNode(nn.Module):
|
||||
|
||||
def __init__(self, d_pair, p_drop, c=32, n_head=4):
|
||||
super(TriangleAttentionStartingNode, self).__init__()
|
||||
self.d_pair = d_pair
|
||||
self.c = c
|
||||
self.n_head = n_head
|
||||
self.p_drop = p_drop
|
||||
|
||||
self.layernorm1 = LayerNorm(d_pair)
|
||||
_init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]),
|
||||
std=1.0 / math.sqrt(d_pair))
|
||||
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights)
|
||||
self.attention = SelfAttention(qkv_dim=d_pair,
|
||||
c=c,
|
||||
n_head=n_head,
|
||||
out_dim=d_pair,
|
||||
gating=True,
|
||||
last_bias_fuse=True)
|
||||
|
||||
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
|
||||
|
||||
def forward(self, Z_raw):
|
||||
Z = self.layernorm1(Z_raw)
|
||||
b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights)
|
||||
|
||||
Z = self.attention(Z, b)
|
||||
|
||||
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
|
||||
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)
|
||||
|
||||
|
||||
class TriangleAttentionEndingNode(nn.Module):
|
||||
|
||||
def __init__(self, d_pair, p_drop, c=32, n_head=4):
|
||||
super(TriangleAttentionEndingNode, self).__init__()
|
||||
self.d_pair = d_pair
|
||||
self.c = c
|
||||
self.n_head = n_head
|
||||
self.p_drop = p_drop
|
||||
|
||||
self.layernorm1 = LayerNorm(d_pair)
|
||||
_init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]),
|
||||
std=1.0 / math.sqrt(d_pair))
|
||||
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights)
|
||||
self.attention = SelfAttention(qkv_dim=d_pair,
|
||||
c=c,
|
||||
n_head=n_head,
|
||||
out_dim=d_pair,
|
||||
gating=True,
|
||||
last_bias_fuse=True)
|
||||
|
||||
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
|
||||
|
||||
def forward(self, Z_raw):
|
||||
Z = Z_raw.transpose(-2, -3)
|
||||
Z = self.layernorm1(Z)
|
||||
b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights)
|
||||
|
||||
Z = self.attention(Z, b)
|
||||
|
||||
Z = Z.transpose(-2, -3)
|
||||
dropout_mask = torch.ones_like(Z[:, :, 0:1, :]).to(Z.device).to(Z.dtype)
|
||||
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)
|
||||
|
||||
|
||||
class PairStack(nn.Module):
|
||||
|
||||
def __init__(self, d_pair, p_drop=0.25):
|
||||
super(PairStack, self).__init__()
|
||||
|
||||
self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing(d_pair, p_drop=p_drop)
|
||||
self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming(d_pair, p_drop=p_drop)
|
||||
self.TriangleAttentionStartingNode = TriangleAttentionStartingNode(d_pair, p_drop=p_drop)
|
||||
self.TriangleAttentionEndingNode = TriangleAttentionEndingNode(d_pair, p_drop=p_drop)
|
||||
self.PairTransition = Transition(d=d_pair)
|
||||
|
||||
def forward(self, pair):
|
||||
pair = self.TriangleMultiplicationOutgoing(pair)
|
||||
pair = self.TriangleMultiplicationIncoming(pair)
|
||||
pair = self.TriangleAttentionStartingNode(pair)
|
||||
pair = self.TriangleAttentionEndingNode(pair)
|
||||
pair = self.PairTransition(pair)
|
||||
return pair
|
84
openfold/checkpointing.py
Normal file
84
openfold/checkpointing.py
Normal file
@ -0,0 +1,84 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from typing import Any, Tuple, List, Callable, Optional
|
||||
|
||||
|
||||
BLOCK_ARG = Any
|
||||
BLOCK_ARGS = List[BLOCK_ARG]
|
||||
|
||||
|
||||
def get_checkpoint_fn():
|
||||
checkpoint = torch.utils.checkpoint.checkpoint
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def checkpoint_blocks(
|
||||
blocks: List[Callable],
|
||||
args: BLOCK_ARGS,
|
||||
blocks_per_ckpt: Optional[int],
|
||||
) -> BLOCK_ARGS:
|
||||
"""
|
||||
Chunk a list of blocks and run each chunk with activation
|
||||
checkpointing. We define a "block" as a callable whose only inputs are
|
||||
the outputs of the previous block.
|
||||
|
||||
Implements Subsection 1.11.8
|
||||
|
||||
Args:
|
||||
blocks:
|
||||
List of blocks
|
||||
args:
|
||||
Tuple of arguments for the first block.
|
||||
blocks_per_ckpt:
|
||||
Size of each chunk. A higher value corresponds to fewer
|
||||
checkpoints, and trades memory for speed. If None, no checkpointing
|
||||
is performed.
|
||||
Returns:
|
||||
The output of the final block
|
||||
"""
|
||||
def wrap(a):
|
||||
return (a,) if type(a) is not tuple else a
|
||||
|
||||
def exec(b, a):
|
||||
for block in b:
|
||||
a = wrap(block(*a))
|
||||
return a
|
||||
|
||||
def chunker(s, e):
|
||||
def exec_sliced(*a):
|
||||
return exec(blocks[s:e], a)
|
||||
|
||||
return exec_sliced
|
||||
|
||||
# Avoids mishaps when the blocks take just one argument
|
||||
args = wrap(args)
|
||||
|
||||
if blocks_per_ckpt is None:
|
||||
return exec(blocks, args)
|
||||
elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks):
|
||||
raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)")
|
||||
|
||||
checkpoint = get_checkpoint_fn()
|
||||
|
||||
for s in range(0, len(blocks), blocks_per_ckpt):
|
||||
e = s + blocks_per_ckpt
|
||||
args = checkpoint(chunker(s, e), *args)
|
||||
args = wrap(args)
|
||||
|
||||
return args
|
78
openfold/dropout.py
Normal file
78
openfold/dropout.py
Normal file
@ -0,0 +1,78 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partialmethod
|
||||
from typing import Union, List
|
||||
|
||||
|
||||
class Dropout(nn.Module):
|
||||
"""
|
||||
Implementation of dropout with the ability to share the dropout mask
|
||||
along a particular dimension.
|
||||
|
||||
If not in training mode, this module computes the identity function.
|
||||
"""
|
||||
|
||||
def __init__(self, r: float, batch_dim: Union[int, List[int]]):
|
||||
"""
|
||||
Args:
|
||||
r:
|
||||
Dropout rate
|
||||
batch_dim:
|
||||
Dimension(s) along which the dropout mask is shared
|
||||
"""
|
||||
super(Dropout, self).__init__()
|
||||
|
||||
self.r = r
|
||||
if type(batch_dim) == int:
|
||||
batch_dim = [batch_dim]
|
||||
self.batch_dim = batch_dim
|
||||
self.dropout = nn.Dropout(self.r)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
Tensor to which dropout is applied. Can have any shape
|
||||
compatible with self.batch_dim
|
||||
"""
|
||||
shape = list(x.shape)
|
||||
if self.batch_dim is not None:
|
||||
for bd in self.batch_dim:
|
||||
shape[bd] = 1
|
||||
mask = x.new_ones(shape)
|
||||
mask = self.dropout(mask)
|
||||
x *= mask
|
||||
return x
|
||||
|
||||
|
||||
class DropoutRowwise(Dropout):
|
||||
"""
|
||||
Convenience class for rowwise dropout as described in subsection
|
||||
1.11.6.
|
||||
"""
|
||||
|
||||
__init__ = partialmethod(Dropout.__init__, batch_dim=-3)
|
||||
|
||||
|
||||
class DropoutColumnwise(Dropout):
|
||||
"""
|
||||
Convenience class for columnwise dropout as described in subsection
|
||||
1.11.6.
|
||||
"""
|
||||
|
||||
__init__ = partialmethod(Dropout.__init__, batch_dim=-2)
|
636
openfold/evoformer.py
Normal file
636
openfold/evoformer.py
Normal file
@ -0,0 +1,636 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Tuple, Optional
|
||||
from functools import partial
|
||||
|
||||
from openfold.primitives import Linear, LayerNorm
|
||||
from openfold.dropout import DropoutRowwise, DropoutColumnwise
|
||||
from openfold.msa import (
|
||||
MSARowAttentionWithPairBias,
|
||||
MSAColumnAttention,
|
||||
MSAColumnGlobalAttention,
|
||||
)
|
||||
from openfold.outer_product_mean import OuterProductMean
|
||||
from openfold.pair_transition import PairTransition
|
||||
from openfold.triangular_attention import (
|
||||
TriangleAttentionStartingNode,
|
||||
TriangleAttentionEndingNode,
|
||||
)
|
||||
from openfold.triangular_multiplicative_update import (
|
||||
TriangleMultiplicationOutgoing,
|
||||
TriangleMultiplicationIncoming,
|
||||
)
|
||||
from openfold.checkpointing import checkpoint_blocks, get_checkpoint_fn
|
||||
from openfold.tensor_utils import chunk_layer
|
||||
|
||||
|
||||
class MSATransition(nn.Module):
|
||||
"""
|
||||
Feed-forward network applied to MSA activations after attention.
|
||||
|
||||
Implements Algorithm 9
|
||||
"""
|
||||
def __init__(self, c_m, n):
|
||||
"""
|
||||
Args:
|
||||
c_m:
|
||||
MSA channel dimension
|
||||
n:
|
||||
Factor multiplied to c_m to obtain the hidden channel
|
||||
dimension
|
||||
"""
|
||||
super(MSATransition, self).__init__()
|
||||
|
||||
self.c_m = c_m
|
||||
self.n = n
|
||||
|
||||
self.layer_norm = LayerNorm(self.c_m)
|
||||
self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu")
|
||||
self.relu = nn.ReLU()
|
||||
self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final")
|
||||
|
||||
def _transition(self, m, mask):
|
||||
m = self.linear_1(m)
|
||||
m = self.relu(m)
|
||||
m = self.linear_2(m) * mask
|
||||
return m
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk(self,
|
||||
m: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
chunk_size: int,
|
||||
) -> torch.Tensor:
|
||||
return chunk_layer(
|
||||
self._transition,
|
||||
{"m": m, "mask": mask},
|
||||
chunk_size=chunk_size,
|
||||
no_batch_dims=len(m.shape[:-2]),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
m: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
m:
|
||||
[*, N_seq, N_res, C_m] MSA activation
|
||||
mask:
|
||||
[*, N_seq, N_res, C_m] MSA mask
|
||||
Returns:
|
||||
m:
|
||||
[*, N_seq, N_res, C_m] MSA activation update
|
||||
"""
|
||||
|
||||
# DISCREPANCY: DeepMind forgets to apply the MSA mask here.
|
||||
if mask is None:
|
||||
mask = m.new_ones(m.shape[:-1])
|
||||
|
||||
# [*, N_seq, N_res, 1]
|
||||
mask = mask.unsqueeze(-1)
|
||||
|
||||
m = self.layer_norm(m)
|
||||
|
||||
if chunk_size is not None:
|
||||
m = self._chunk(m, mask, chunk_size)
|
||||
else:
|
||||
m = self._transition(m, mask)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
class EvoformerBlockCore(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
c_m: int,
|
||||
c_z: int,
|
||||
c_hidden_opm: int,
|
||||
c_hidden_mul: int,
|
||||
c_hidden_pair_att: int,
|
||||
no_heads_msa: int,
|
||||
no_heads_pair: int,
|
||||
transition_n: int,
|
||||
pair_dropout: float,
|
||||
inf: float,
|
||||
eps: float,
|
||||
_is_extra_msa_stack: bool = False,
|
||||
is_multimer: bool = False,
|
||||
):
|
||||
super(EvoformerBlockCore, self).__init__()
|
||||
self.is_multimer = is_multimer
|
||||
self.msa_transition = MSATransition(
|
||||
c_m=c_m,
|
||||
n=transition_n,
|
||||
)
|
||||
|
||||
self.outer_product_mean = OuterProductMean(
|
||||
c_m,
|
||||
c_z,
|
||||
c_hidden_opm,
|
||||
)
|
||||
|
||||
self.tri_mul_out = TriangleMultiplicationOutgoing(
|
||||
c_z,
|
||||
c_hidden_mul,
|
||||
)
|
||||
self.tri_mul_in = TriangleMultiplicationIncoming(
|
||||
c_z,
|
||||
c_hidden_mul,
|
||||
)
|
||||
|
||||
self.tri_att_start = TriangleAttentionStartingNode(
|
||||
c_z,
|
||||
c_hidden_pair_att,
|
||||
no_heads_pair,
|
||||
inf=inf,
|
||||
)
|
||||
self.tri_att_end = TriangleAttentionEndingNode(
|
||||
c_z,
|
||||
c_hidden_pair_att,
|
||||
no_heads_pair,
|
||||
inf=inf,
|
||||
)
|
||||
|
||||
self.pair_transition = PairTransition(
|
||||
c_z,
|
||||
transition_n,
|
||||
)
|
||||
|
||||
self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
|
||||
self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
m: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
msa_mask: torch.Tensor,
|
||||
pair_mask: torch.Tensor,
|
||||
chunk_size: Optional[int] = None,
|
||||
_mask_trans: bool = True,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# DeepMind doesn't mask these transitions in the source, so _mask_trans
|
||||
# should be disabled to better approximate the exact activations of
|
||||
# the original.
|
||||
msa_trans_mask = msa_mask if _mask_trans else None
|
||||
pair_trans_mask = pair_mask if _mask_trans else None
|
||||
|
||||
m = m + self.msa_transition(
|
||||
m, mask=msa_trans_mask, chunk_size=chunk_size
|
||||
)
|
||||
z = z + self.outer_product_mean(
|
||||
m, mask=msa_mask, chunk_size=chunk_size
|
||||
)
|
||||
z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask))
|
||||
z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask))
|
||||
z = z + self.ps_dropout_row_layer(
|
||||
self.tri_att_start(z, mask=pair_mask, chunk_size=chunk_size)
|
||||
)
|
||||
z = z + self.ps_dropout_col_layer(
|
||||
self.tri_att_end(z, mask=pair_mask, chunk_size=chunk_size)
|
||||
)
|
||||
z = z + self.pair_transition(
|
||||
z, mask=pair_trans_mask, chunk_size=chunk_size
|
||||
)
|
||||
|
||||
return m, z
|
||||
|
||||
|
||||
class EvoformerBlock(nn.Module):
|
||||
def __init__(self,
|
||||
c_m: int,
|
||||
c_z: int,
|
||||
c_hidden_msa_att: int,
|
||||
c_hidden_opm: int,
|
||||
c_hidden_mul: int,
|
||||
c_hidden_pair_att: int,
|
||||
no_heads_msa: int,
|
||||
no_heads_pair: int,
|
||||
transition_n: int,
|
||||
msa_dropout: float,
|
||||
pair_dropout: float,
|
||||
inf: float,
|
||||
eps: float,
|
||||
is_multimer: bool,
|
||||
):
|
||||
super(EvoformerBlock, self).__init__()
|
||||
|
||||
self.msa_att_row = MSARowAttentionWithPairBias(
|
||||
c_m=c_m,
|
||||
c_z=c_z,
|
||||
c_hidden=c_hidden_msa_att,
|
||||
no_heads=no_heads_msa,
|
||||
inf=inf,
|
||||
)
|
||||
|
||||
self.msa_att_col = MSAColumnAttention(
|
||||
c_m,
|
||||
c_hidden_msa_att,
|
||||
no_heads_msa,
|
||||
inf=inf,
|
||||
)
|
||||
|
||||
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
|
||||
|
||||
self.core = EvoformerBlockCore(
|
||||
c_m=c_m,
|
||||
c_z=c_z,
|
||||
c_hidden_opm=c_hidden_opm,
|
||||
c_hidden_mul=c_hidden_mul,
|
||||
c_hidden_pair_att=c_hidden_pair_att,
|
||||
no_heads_msa=no_heads_msa,
|
||||
no_heads_pair=no_heads_pair,
|
||||
transition_n=transition_n,
|
||||
pair_dropout=pair_dropout,
|
||||
inf=inf,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
self.outer_product_mean = OuterProductMean(
|
||||
c_m,
|
||||
c_z,
|
||||
c_hidden_opm,
|
||||
)
|
||||
self.is_multimer = is_multimer
|
||||
|
||||
def forward(self,
|
||||
m: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
msa_mask: torch.Tensor,
|
||||
pair_mask: torch.Tensor,
|
||||
chunk_size: Optional[int] = None,
|
||||
_mask_trans: bool = True,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
m = m + self.msa_dropout_layer(
|
||||
self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size)
|
||||
)
|
||||
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
|
||||
m, z = self.core(
|
||||
m,
|
||||
z,
|
||||
msa_mask=msa_mask,
|
||||
pair_mask=pair_mask,
|
||||
chunk_size=chunk_size,
|
||||
_mask_trans=_mask_trans,
|
||||
)
|
||||
|
||||
return m, z
|
||||
|
||||
|
||||
class ExtraMSABlock(nn.Module):
|
||||
"""
|
||||
Almost identical to the standard EvoformerBlock, except in that the
|
||||
ExtraMSABlock uses GlobalAttention for MSA column attention and
|
||||
requires more fine-grained control over checkpointing. Separated from
|
||||
its twin to preserve the TorchScript-ability of the latter.
|
||||
"""
|
||||
def __init__(self,
|
||||
c_m: int,
|
||||
c_z: int,
|
||||
c_hidden_msa_att: int,
|
||||
c_hidden_opm: int,
|
||||
c_hidden_mul: int,
|
||||
c_hidden_pair_att: int,
|
||||
no_heads_msa: int,
|
||||
no_heads_pair: int,
|
||||
transition_n: int,
|
||||
msa_dropout: float,
|
||||
pair_dropout: float,
|
||||
inf: float,
|
||||
eps: float,
|
||||
ckpt: bool,
|
||||
is_multimer: bool,
|
||||
):
|
||||
super(ExtraMSABlock, self).__init__()
|
||||
|
||||
self.ckpt = ckpt
|
||||
|
||||
self.msa_att_row = MSARowAttentionWithPairBias(
|
||||
c_m=c_m,
|
||||
c_z=c_z,
|
||||
c_hidden=c_hidden_msa_att,
|
||||
no_heads=no_heads_msa,
|
||||
inf=inf,
|
||||
)
|
||||
|
||||
self.msa_att_col = MSAColumnGlobalAttention(
|
||||
c_in=c_m,
|
||||
c_hidden=c_hidden_msa_att,
|
||||
no_heads=no_heads_msa,
|
||||
inf=inf,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
|
||||
|
||||
self.core = EvoformerBlockCore(
|
||||
c_m=c_m,
|
||||
c_z=c_z,
|
||||
c_hidden_opm=c_hidden_opm,
|
||||
c_hidden_mul=c_hidden_mul,
|
||||
c_hidden_pair_att=c_hidden_pair_att,
|
||||
no_heads_msa=no_heads_msa,
|
||||
no_heads_pair=no_heads_pair,
|
||||
transition_n=transition_n,
|
||||
pair_dropout=pair_dropout,
|
||||
inf=inf,
|
||||
eps=eps,
|
||||
)
|
||||
self.is_multimer = is_multimer
|
||||
|
||||
def forward(self,
|
||||
m: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
msa_mask: torch.Tensor,
|
||||
pair_mask: torch.Tensor,
|
||||
chunk_size: Optional[int] = None,
|
||||
_chunk_logits: Optional[int] = 1024,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
m = m + self.msa_dropout_layer(
|
||||
self.msa_att_row(
|
||||
m.clone(),
|
||||
z=z.clone(),
|
||||
mask=msa_mask,
|
||||
chunk_size=chunk_size,
|
||||
_chunk_logits=_chunk_logits if torch.is_grad_enabled() else None,
|
||||
_checkpoint_chunks=
|
||||
self.ckpt if torch.is_grad_enabled() else False,
|
||||
)
|
||||
)
|
||||
|
||||
def fn(m, z):
|
||||
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
|
||||
m, z = self.core(
|
||||
m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size
|
||||
)
|
||||
|
||||
return m, z
|
||||
|
||||
if(torch.is_grad_enabled() and self.ckpt):
|
||||
checkpoint_fn = get_checkpoint_fn()
|
||||
m, z = checkpoint_fn(fn, m, z)
|
||||
else:
|
||||
m, z = fn(m, z)
|
||||
|
||||
return m, z
|
||||
|
||||
|
||||
class EvoformerStack(nn.Module):
|
||||
"""
|
||||
Main Evoformer trunk.
|
||||
|
||||
Implements Algorithm 6.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
c_m: int,
|
||||
c_z: int,
|
||||
c_hidden_msa_att: int,
|
||||
c_hidden_opm: int,
|
||||
c_hidden_mul: int,
|
||||
c_hidden_pair_att: int,
|
||||
c_s: int,
|
||||
no_heads_msa: int,
|
||||
no_heads_pair: int,
|
||||
no_blocks: int,
|
||||
transition_n: int,
|
||||
msa_dropout: float,
|
||||
pair_dropout: float,
|
||||
blocks_per_ckpt: int,
|
||||
inf: float,
|
||||
eps: float,
|
||||
clear_cache_between_blocks: bool = False,
|
||||
is_multimer: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
c_m:
|
||||
MSA channel dimension
|
||||
c_z:
|
||||
Pair channel dimension
|
||||
c_hidden_msa_att:
|
||||
Hidden dimension in MSA attention
|
||||
c_hidden_opm:
|
||||
Hidden dimension in outer product mean module
|
||||
c_hidden_mul:
|
||||
Hidden dimension in multiplicative updates
|
||||
c_hidden_pair_att:
|
||||
Hidden dimension in triangular attention
|
||||
c_s:
|
||||
Channel dimension of the output "single" embedding
|
||||
no_heads_msa:
|
||||
Number of heads used for MSA attention
|
||||
no_heads_pair:
|
||||
Number of heads used for pair attention
|
||||
no_blocks:
|
||||
Number of Evoformer blocks in the stack
|
||||
transition_n:
|
||||
Factor by which to multiply c_m to obtain the MSATransition
|
||||
hidden dimension
|
||||
msa_dropout:
|
||||
Dropout rate for MSA activations
|
||||
pair_dropout:
|
||||
Dropout used for pair activations
|
||||
blocks_per_ckpt:
|
||||
Number of Evoformer blocks in each activation checkpoint
|
||||
clear_cache_between_blocks:
|
||||
Whether to clear CUDA's GPU memory cache between blocks of the
|
||||
stack. Slows down each block but can reduce fragmentation
|
||||
"""
|
||||
super(EvoformerStack, self).__init__()
|
||||
|
||||
self.blocks_per_ckpt = blocks_per_ckpt
|
||||
self.clear_cache_between_blocks = clear_cache_between_blocks
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
|
||||
for _ in range(no_blocks):
|
||||
block = EvoformerBlock(
|
||||
c_m=c_m,
|
||||
c_z=c_z,
|
||||
c_hidden_msa_att=c_hidden_msa_att,
|
||||
c_hidden_opm=c_hidden_opm,
|
||||
c_hidden_mul=c_hidden_mul,
|
||||
c_hidden_pair_att=c_hidden_pair_att,
|
||||
no_heads_msa=no_heads_msa,
|
||||
no_heads_pair=no_heads_pair,
|
||||
transition_n=transition_n,
|
||||
msa_dropout=msa_dropout,
|
||||
pair_dropout=pair_dropout,
|
||||
inf=inf,
|
||||
eps=eps,
|
||||
is_multimer=is_multimer,
|
||||
)
|
||||
self.blocks.append(block)
|
||||
|
||||
self.linear = Linear(c_m, c_s)
|
||||
|
||||
def forward(self,
|
||||
m: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
msa_mask: torch.Tensor,
|
||||
pair_mask: torch.Tensor,
|
||||
chunk_size: int,
|
||||
_mask_trans: bool = True,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
m:
|
||||
[*, N_seq, N_res, C_m] MSA embedding
|
||||
z:
|
||||
[*, N_res, N_res, C_z] pair embedding
|
||||
msa_mask:
|
||||
[*, N_seq, N_res] MSA mask
|
||||
pair_mask:
|
||||
[*, N_res, N_res] pair mask
|
||||
Returns:
|
||||
m:
|
||||
[*, N_seq, N_res, C_m] MSA embedding
|
||||
z:
|
||||
[*, N_res, N_res, C_z] pair embedding
|
||||
s:
|
||||
[*, N_res, C_s] single embedding (or None if extra MSA stack)
|
||||
"""
|
||||
blocks = [
|
||||
partial(
|
||||
b,
|
||||
msa_mask=msa_mask,
|
||||
pair_mask=pair_mask,
|
||||
chunk_size=chunk_size,
|
||||
_mask_trans=_mask_trans,
|
||||
)
|
||||
for b in self.blocks
|
||||
]
|
||||
|
||||
if(self.clear_cache_between_blocks):
|
||||
def block_with_cache_clear(block, *args):
|
||||
torch.cuda.empty_cache()
|
||||
return block(*args)
|
||||
|
||||
blocks = [partial(block_with_cache_clear, b) for b in blocks]
|
||||
|
||||
m, z = checkpoint_blocks(
|
||||
blocks,
|
||||
args=(m, z),
|
||||
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
|
||||
)
|
||||
|
||||
s = self.linear(m[..., 0, :, :])
|
||||
|
||||
return m, z, s
|
||||
|
||||
|
||||
class ExtraMSAStack(nn.Module):
|
||||
"""
|
||||
Implements Algorithm 18.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
c_m: int,
|
||||
c_z: int,
|
||||
c_hidden_msa_att: int,
|
||||
c_hidden_opm: int,
|
||||
c_hidden_mul: int,
|
||||
c_hidden_pair_att: int,
|
||||
no_heads_msa: int,
|
||||
no_heads_pair: int,
|
||||
no_blocks: int,
|
||||
transition_n: int,
|
||||
msa_dropout: float,
|
||||
pair_dropout: float,
|
||||
inf: float,
|
||||
eps: float,
|
||||
ckpt: bool,
|
||||
clear_cache_between_blocks: bool = False,
|
||||
is_multimer: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super(ExtraMSAStack, self).__init__()
|
||||
|
||||
self.clear_cache_between_blocks = clear_cache_between_blocks
|
||||
self.blocks = nn.ModuleList()
|
||||
for _ in range(no_blocks):
|
||||
block = ExtraMSABlock(
|
||||
c_m=c_m,
|
||||
c_z=c_z,
|
||||
c_hidden_msa_att=c_hidden_msa_att,
|
||||
c_hidden_opm=c_hidden_opm,
|
||||
c_hidden_mul=c_hidden_mul,
|
||||
c_hidden_pair_att=c_hidden_pair_att,
|
||||
no_heads_msa=no_heads_msa,
|
||||
no_heads_pair=no_heads_pair,
|
||||
transition_n=transition_n,
|
||||
msa_dropout=msa_dropout,
|
||||
pair_dropout=pair_dropout,
|
||||
inf=inf,
|
||||
eps=eps,
|
||||
ckpt=ckpt,
|
||||
is_multimer=is_multimer,
|
||||
)
|
||||
self.blocks.append(block)
|
||||
|
||||
def forward(self,
|
||||
m: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
chunk_size: int,
|
||||
msa_mask: Optional[torch.Tensor] = None,
|
||||
pair_mask: Optional[torch.Tensor] = None,
|
||||
_mask_trans: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
m:
|
||||
[*, N_extra, N_res, C_m] extra MSA embedding
|
||||
z:
|
||||
[*, N_res, N_res, C_z] pair embedding
|
||||
msa_mask:
|
||||
Optional [*, N_extra, N_res] MSA mask
|
||||
pair_mask:
|
||||
Optional [*, N_res, N_res] pair mask
|
||||
Returns:
|
||||
[*, N_res, N_res, C_z] pair update
|
||||
"""
|
||||
#checkpoint_fn = get_checkpoint_fn()
|
||||
#blocks = [
|
||||
# partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
|
||||
#]
|
||||
|
||||
#def dodo(b, *args):
|
||||
# torch.cuda.empty_cache()
|
||||
# return b(*args)
|
||||
|
||||
#blocks = [partial(dodo, b) for b in blocks]
|
||||
|
||||
#for b in blocks:
|
||||
# if(torch.is_grad_enabled()):
|
||||
# m, z = checkpoint_fn(b, *(m, z))
|
||||
# else:
|
||||
# m, z = b(m, z)
|
||||
|
||||
for b in self.blocks:
|
||||
m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
|
||||
|
||||
if(self.clear_cache_between_blocks):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return z
|
392
openfold/msa.py
Normal file
392
openfold/msa.py
Normal file
@ -0,0 +1,392 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from openfold.primitives import (
|
||||
Linear,
|
||||
LayerNorm,
|
||||
Attention,
|
||||
GlobalAttention,
|
||||
_attention_chunked_trainable,
|
||||
)
|
||||
from openfold.checkpointing import get_checkpoint_fn
|
||||
from openfold.tensor_utils import (
|
||||
chunk_layer,
|
||||
permute_final_dims,
|
||||
flatten_final_dims,
|
||||
)
|
||||
|
||||
|
||||
class MSAAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
c_in,
|
||||
c_hidden,
|
||||
no_heads,
|
||||
pair_bias=False,
|
||||
c_z=None,
|
||||
inf=1e9,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
c_in:
|
||||
Input channel dimension
|
||||
c_hidden:
|
||||
Per-head hidden channel dimension
|
||||
no_heads:
|
||||
Number of attention heads
|
||||
pair_bias:
|
||||
Whether to use pair embedding bias
|
||||
c_z:
|
||||
Pair embedding channel dimension. Ignored unless pair_bias
|
||||
is true
|
||||
inf:
|
||||
A large number to be used in computing the attention mask
|
||||
"""
|
||||
super(MSAAttention, self).__init__()
|
||||
|
||||
self.c_in = c_in
|
||||
self.c_hidden = c_hidden
|
||||
self.no_heads = no_heads
|
||||
self.pair_bias = pair_bias
|
||||
self.c_z = c_z
|
||||
self.inf = inf
|
||||
|
||||
self.layer_norm_m = LayerNorm(self.c_in)
|
||||
|
||||
self.layer_norm_z = None
|
||||
self.linear_z = None
|
||||
if self.pair_bias:
|
||||
self.layer_norm_z = LayerNorm(self.c_z)
|
||||
self.linear_z = Linear(
|
||||
self.c_z, self.no_heads, bias=False, init="normal"
|
||||
)
|
||||
|
||||
self.mha = Attention(
|
||||
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk(self,
|
||||
m: torch.Tensor,
|
||||
biases: List[torch.Tensor],
|
||||
chunk_size: int,
|
||||
) -> torch.Tensor:
|
||||
return chunk_layer(
|
||||
self.mha,
|
||||
{"q_x": m, "kv_x": m, "biases": biases},
|
||||
chunk_size=chunk_size,
|
||||
no_batch_dims=len(m.shape[:-2]),
|
||||
)
|
||||
|
||||
def _prep_inputs(self,
|
||||
m: torch.Tensor,
|
||||
z: Optional[torch.Tensor],
|
||||
mask: Optional[torch.Tensor]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# [*, N_seq, N_res, C_m]
|
||||
m = self.layer_norm_m(m)
|
||||
|
||||
n_seq, n_res = m.shape[-3:-1]
|
||||
if mask is None:
|
||||
# [*, N_seq, N_res]
|
||||
mask = m.new_ones(
|
||||
m.shape[:-3] + (n_seq, n_res),
|
||||
)
|
||||
|
||||
# [*, N_seq, 1, 1, N_res]
|
||||
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
|
||||
|
||||
# This step simply returns a larger view of the bias, and does not
|
||||
# consume additional memory.
|
||||
# [*, N_seq, no_heads, N_res, N_res]
|
||||
#bias = bias.expand(
|
||||
# ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
|
||||
#)
|
||||
|
||||
if (self.pair_bias and
|
||||
z is not None and # For the
|
||||
self.layer_norm_z is not None and # benefit of
|
||||
self.linear_z is not None # TorchScript
|
||||
):
|
||||
# [*, N_res, N_res, C_z]
|
||||
z = self.layer_norm_z(z)
|
||||
|
||||
# [*, N_res, N_res, no_heads]
|
||||
z = self.linear_z(z)
|
||||
|
||||
# [*, 1, no_heads, N_res, N_res]
|
||||
z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
|
||||
|
||||
return m, mask_bias, z
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunked_msa_attn(self,
|
||||
m: torch.Tensor,
|
||||
z: Optional[torch.Tensor],
|
||||
mask: Optional[torch.Tensor],
|
||||
chunk_logits: int,
|
||||
checkpoint: bool,
|
||||
) -> torch.Tensor:
|
||||
MSA_DIM = -4
|
||||
|
||||
def _get_qkv(m, z):
|
||||
m, mask_bias, z = self._prep_inputs(m, z, mask)
|
||||
q, k, v = self.mha._prep_qkv(m, m)
|
||||
return m, q, k, v, mask_bias, z
|
||||
|
||||
checkpoint_fn = get_checkpoint_fn()
|
||||
|
||||
if(torch.is_grad_enabled() and checkpoint):
|
||||
m, q, k, v, mask_bias, z = checkpoint_fn(_get_qkv, m, z)
|
||||
else:
|
||||
m, q, k, v, mask_bias, z = _get_qkv(m, z)
|
||||
|
||||
o = _attention_chunked_trainable(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
biases=[mask_bias, z],
|
||||
chunk_size=chunk_logits,
|
||||
chunk_dim=MSA_DIM,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
if(torch.is_grad_enabled() and checkpoint):
|
||||
# Storing an additional m here is far from ideal
|
||||
m = checkpoint_fn(self.mha._wrap_up, o, m)
|
||||
else:
|
||||
m = self.mha._wrap_up(o, m)
|
||||
|
||||
return m
|
||||
|
||||
def forward(self,
|
||||
m: torch.Tensor,
|
||||
z: Optional[torch.Tensor] = None,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
_chunk_logits: Optional[int] = None,
|
||||
_checkpoint_chunks: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
m:
|
||||
[*, N_seq, N_res, C_m] MSA embedding
|
||||
z:
|
||||
[*, N_res, N_res, C_z] pair embedding. Required only if
|
||||
pair_bias is True
|
||||
mask:
|
||||
[*, N_seq, N_res] MSA mask
|
||||
chunk_size:
|
||||
Size of chunks into which the inputs are split along their
|
||||
batch dimensions. A low value decreases memory overhead at the
|
||||
cost of slower execution. Chunking is not performed by default.
|
||||
|
||||
"""
|
||||
if(_chunk_logits is not None):
|
||||
return self._chunked_msa_attn(
|
||||
m=m, z=z, mask=mask,
|
||||
chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks
|
||||
)
|
||||
|
||||
m, mask_bias, z = self._prep_inputs(m, z, mask)
|
||||
|
||||
biases = [mask_bias]
|
||||
if(z is not None):
|
||||
biases.append(z)
|
||||
|
||||
if chunk_size is not None:
|
||||
m = self._chunk(m, biases, chunk_size)
|
||||
else:
|
||||
m = self.mha(
|
||||
q_x=m,
|
||||
kv_x=m,
|
||||
biases=biases
|
||||
)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
class MSARowAttentionWithPairBias(MSAAttention):
|
||||
"""
|
||||
Implements Algorithm 7.
|
||||
"""
|
||||
|
||||
def __init__(self, c_m, c_z, c_hidden, no_heads, inf=1e9):
|
||||
"""
|
||||
Args:
|
||||
c_m:
|
||||
Input channel dimension
|
||||
c_z:
|
||||
Pair embedding channel dimension
|
||||
c_hidden:
|
||||
Per-head hidden channel dimension
|
||||
no_heads:
|
||||
Number of attention heads
|
||||
inf:
|
||||
Large number used to construct attention masks
|
||||
"""
|
||||
super(MSARowAttentionWithPairBias, self).__init__(
|
||||
c_m,
|
||||
c_hidden,
|
||||
no_heads,
|
||||
pair_bias=True,
|
||||
c_z=c_z,
|
||||
inf=inf,
|
||||
)
|
||||
|
||||
|
||||
class MSAColumnAttention(nn.Module):
|
||||
"""
|
||||
Implements Algorithm 8.
|
||||
|
||||
By rights, this should also be a subclass of MSAAttention. Alas,
|
||||
most inheritance isn't supported by TorchScript.
|
||||
"""
|
||||
|
||||
def __init__(self, c_m, c_hidden, no_heads, inf=1e9):
|
||||
"""
|
||||
Args:
|
||||
c_m:
|
||||
MSA channel dimension
|
||||
c_hidden:
|
||||
Per-head hidden channel dimension
|
||||
no_heads:
|
||||
Number of attention heads
|
||||
inf:
|
||||
Large number used to construct attention masks
|
||||
"""
|
||||
super(MSAColumnAttention, self).__init__()
|
||||
|
||||
self.c_m = c_m
|
||||
self.c_hidden = c_hidden
|
||||
self.no_heads = no_heads
|
||||
self.inf = inf
|
||||
|
||||
self._msa_att = MSAAttention(
|
||||
c_in=c_m,
|
||||
c_hidden=c_hidden,
|
||||
no_heads=no_heads,
|
||||
pair_bias=False,
|
||||
c_z=None,
|
||||
inf=inf,
|
||||
)
|
||||
|
||||
def forward(self,
|
||||
m: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
m:
|
||||
[*, N_seq, N_res, C_m] MSA embedding
|
||||
mask:
|
||||
[*, N_seq, N_res] MSA mask
|
||||
chunk_size:
|
||||
Size of chunks into which the inputs are split along their
|
||||
batch dimensions. A low value decreases memory overhead at the
|
||||
cost of slower execution. Chunking is not performed by default.
|
||||
"""
|
||||
# [*, N_res, N_seq, C_in]
|
||||
m = m.transpose(-2, -3)
|
||||
if mask is not None:
|
||||
mask = mask.transpose(-1, -2)
|
||||
|
||||
m = self._msa_att(m, mask=mask, chunk_size=chunk_size)
|
||||
|
||||
# [*, N_seq, N_res, C_in]
|
||||
m = m.transpose(-2, -3)
|
||||
if mask is not None:
|
||||
mask = mask.transpose(-1, -2)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
class MSAColumnGlobalAttention(nn.Module):
|
||||
def __init__(
|
||||
self, c_in, c_hidden, no_heads, inf=1e9, eps=1e-10,
|
||||
):
|
||||
super(MSAColumnGlobalAttention, self).__init__()
|
||||
|
||||
self.c_in = c_in
|
||||
self.c_hidden = c_hidden
|
||||
self.no_heads = no_heads
|
||||
self.inf = inf
|
||||
self.eps = eps
|
||||
|
||||
self.layer_norm_m = nn.LayerNorm(c_in)
|
||||
|
||||
self.global_attention = GlobalAttention(
|
||||
c_in=c_in,
|
||||
c_hidden=c_hidden,
|
||||
no_heads=no_heads,
|
||||
inf=inf,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk(self,
|
||||
m: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
chunk_size: int,
|
||||
) -> torch.Tensor:
|
||||
mha_input = {
|
||||
"m": m,
|
||||
"mask": mask,
|
||||
}
|
||||
return chunk_layer(
|
||||
self.global_attention,
|
||||
mha_input,
|
||||
chunk_size=chunk_size,
|
||||
no_batch_dims=len(m.shape[:-2]),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
m: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
n_seq, n_res, c_in = m.shape[-3:]
|
||||
|
||||
if mask is None:
|
||||
# [*, N_seq, N_res]
|
||||
mask = torch.ones(
|
||||
m.shape[:-1],
|
||||
dtype=m.dtype,
|
||||
device=m.device,
|
||||
).detach()
|
||||
|
||||
# [*, N_res, N_seq, C_in]
|
||||
m = m.transpose(-2, -3)
|
||||
mask = mask.transpose(-1, -2)
|
||||
|
||||
# [*, N_res, N_seq, C_in]
|
||||
m = self.layer_norm_m(m)
|
||||
|
||||
if chunk_size is not None:
|
||||
m = self._chunk(m, mask, chunk_size)
|
||||
else:
|
||||
m = self.global_attention(m=m, mask=mask)
|
||||
|
||||
# [*, N_seq, N_res, C_in]
|
||||
m = m.transpose(-2, -3)
|
||||
|
||||
return m
|
129
openfold/outer_product_mean.py
Normal file
129
openfold/outer_product_mean.py
Normal file
@ -0,0 +1,129 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from openfold.primitives import Linear
|
||||
from openfold.tensor_utils import chunk_layer
|
||||
|
||||
|
||||
class OuterProductMean(nn.Module):
|
||||
"""
|
||||
Implements Algorithm 10.
|
||||
"""
|
||||
|
||||
def __init__(self, c_m, c_z, c_hidden, eps=1e-3):
|
||||
"""
|
||||
Args:
|
||||
c_m:
|
||||
MSA embedding channel dimension
|
||||
c_z:
|
||||
Pair embedding channel dimension
|
||||
c_hidden:
|
||||
Hidden channel dimension
|
||||
"""
|
||||
super(OuterProductMean, self).__init__()
|
||||
|
||||
self.c_m = c_m
|
||||
self.c_z = c_z
|
||||
self.c_hidden = c_hidden
|
||||
self.eps = eps
|
||||
|
||||
self.layer_norm = nn.LayerNorm(c_m)
|
||||
self.linear_1 = Linear(c_m, c_hidden)
|
||||
self.linear_2 = Linear(c_m, c_hidden)
|
||||
self.linear_out = Linear(c_hidden ** 2, c_z, init="final")
|
||||
|
||||
def _opm(self, a, b):
|
||||
# [*, N_res, N_res, C, C]
|
||||
outer = torch.einsum("...bac,...dae->...bdce", a, b)
|
||||
|
||||
# [*, N_res, N_res, C * C]
|
||||
outer = outer.reshape(outer.shape[:-2] + (-1,))
|
||||
|
||||
# [*, N_res, N_res, C_z]
|
||||
outer = self.linear_out(outer)
|
||||
|
||||
return outer
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk(self,
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
chunk_size: int
|
||||
) -> torch.Tensor:
|
||||
# Since the "batch dim" in this case is not a true batch dimension
|
||||
# (in that the shape of the output depends on it), we need to
|
||||
# iterate over it ourselves
|
||||
a_reshape = a.reshape((-1,) + a.shape[-3:])
|
||||
b_reshape = b.reshape((-1,) + b.shape[-3:])
|
||||
out = []
|
||||
for a_prime, b_prime in zip(a_reshape, b_reshape):
|
||||
outer = chunk_layer(
|
||||
partial(self._opm, b=b_prime),
|
||||
{"a": a_prime},
|
||||
chunk_size=chunk_size,
|
||||
no_batch_dims=1,
|
||||
)
|
||||
out.append(outer)
|
||||
outer = torch.stack(out, dim=0)
|
||||
outer = outer.reshape(a.shape[:-3] + outer.shape[1:])
|
||||
|
||||
return outer
|
||||
|
||||
def forward(self,
|
||||
m: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
m:
|
||||
[*, N_seq, N_res, C_m] MSA embedding
|
||||
mask:
|
||||
[*, N_seq, N_res] MSA mask
|
||||
Returns:
|
||||
[*, N_res, N_res, C_z] pair embedding update
|
||||
"""
|
||||
if mask is None:
|
||||
mask = m.new_ones(m.shape[:-1])
|
||||
|
||||
# [*, N_seq, N_res, C_m]
|
||||
m = self.layer_norm(m)
|
||||
|
||||
# [*, N_seq, N_res, C]
|
||||
mask = mask.unsqueeze(-1)
|
||||
a = self.linear_1(m) * mask
|
||||
b = self.linear_2(m) * mask
|
||||
|
||||
a = a.transpose(-2, -3)
|
||||
b = b.transpose(-2, -3)
|
||||
|
||||
if chunk_size is not None:
|
||||
outer = self._chunk(a, b, chunk_size)
|
||||
else:
|
||||
outer = self._opm(a, b)
|
||||
|
||||
# [*, N_res, N_res, 1]
|
||||
norm = torch.einsum("...abc,...adc->...bdc", mask, mask)
|
||||
|
||||
# [*, N_res, N_res, C_z]
|
||||
outer = outer / (self.eps + norm)
|
||||
|
||||
return outer
|
99
openfold/pair_transition.py
Normal file
99
openfold/pair_transition.py
Normal file
@ -0,0 +1,99 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from openfold.primitives import Linear, LayerNorm
|
||||
from openfold.tensor_utils import chunk_layer
|
||||
|
||||
|
||||
class PairTransition(nn.Module):
|
||||
"""
|
||||
Implements Algorithm 15.
|
||||
"""
|
||||
|
||||
def __init__(self, c_z, n):
|
||||
"""
|
||||
Args:
|
||||
c_z:
|
||||
Pair transition channel dimension
|
||||
n:
|
||||
Factor by which c_z is multiplied to obtain hidden channel
|
||||
dimension
|
||||
"""
|
||||
super(PairTransition, self).__init__()
|
||||
|
||||
self.c_z = c_z
|
||||
self.n = n
|
||||
|
||||
self.layer_norm = LayerNorm(self.c_z)
|
||||
self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu")
|
||||
self.relu = nn.ReLU()
|
||||
self.linear_2 = Linear(self.n * self.c_z, c_z, init="final")
|
||||
|
||||
def _transition(self, z, mask):
|
||||
# [*, N_res, N_res, C_hidden]
|
||||
z = self.linear_1(z)
|
||||
z = self.relu(z)
|
||||
|
||||
# [*, N_res, N_res, C_z]
|
||||
z = self.linear_2(z) * mask
|
||||
|
||||
return z
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk(self,
|
||||
z: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
chunk_size: int,
|
||||
) -> torch.Tensor:
|
||||
return chunk_layer(
|
||||
self._transition,
|
||||
{"z": z, "mask": mask},
|
||||
chunk_size=chunk_size,
|
||||
no_batch_dims=len(z.shape[:-2]),
|
||||
)
|
||||
|
||||
|
||||
def forward(self,
|
||||
z: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
z:
|
||||
[*, N_res, N_res, C_z] pair embedding
|
||||
Returns:
|
||||
[*, N_res, N_res, C_z] pair embedding update
|
||||
"""
|
||||
# DISCREPANCY: DeepMind forgets to apply the mask in this module.
|
||||
if mask is None:
|
||||
mask = z.new_ones(z.shape[:-1])
|
||||
|
||||
# [*, N_res, N_res, 1]
|
||||
mask = mask.unsqueeze(-1)
|
||||
|
||||
# [*, N_res, N_res, C_z]
|
||||
z = self.layer_norm(z)
|
||||
|
||||
if chunk_size is not None:
|
||||
z = self._chunk(z, mask, chunk_size)
|
||||
else:
|
||||
z = self._transition(z=z, mask=mask)
|
||||
|
||||
return z
|
529
openfold/primitives.py
Normal file
529
openfold/primitives.py
Normal file
@ -0,0 +1,529 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import math
|
||||
from typing import Optional, Callable, List, Tuple, Sequence
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from openfold.checkpointing import get_checkpoint_fn
|
||||
from openfold.tensor_utils import (
|
||||
permute_final_dims,
|
||||
flatten_final_dims,
|
||||
_chunk_slice,
|
||||
)
|
||||
|
||||
|
||||
def _prod(nums):
|
||||
out = 1
|
||||
for n in nums:
|
||||
out = out * n
|
||||
return out
|
||||
|
||||
|
||||
def _calculate_fan(linear_weight_shape, fan="fan_in"):
|
||||
fan_out, fan_in = linear_weight_shape
|
||||
|
||||
if fan == "fan_in":
|
||||
f = fan_in
|
||||
elif fan == "fan_out":
|
||||
f = fan_out
|
||||
elif fan == "fan_avg":
|
||||
f = (fan_in + fan_out) / 2
|
||||
else:
|
||||
raise ValueError("Invalid fan option")
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def glorot_uniform_init_(weights):
|
||||
nn.init.xavier_uniform_(weights, gain=1)
|
||||
|
||||
|
||||
def final_init_(weights):
|
||||
with torch.no_grad():
|
||||
weights.fill_(0.0)
|
||||
|
||||
|
||||
def gating_init_(weights):
|
||||
with torch.no_grad():
|
||||
weights.fill_(0.0)
|
||||
|
||||
|
||||
def normal_init_(weights):
|
||||
torch.nn.init.kaiming_normal_(weights, nonlinearity="linear")
|
||||
|
||||
|
||||
def ipa_point_weights_init_(weights):
|
||||
with torch.no_grad():
|
||||
softplus_inverse_1 = 0.541324854612918
|
||||
weights.fill_(softplus_inverse_1)
|
||||
|
||||
|
||||
class Linear(nn.Linear):
|
||||
"""
|
||||
A Linear layer with built-in nonstandard initializations. Called just
|
||||
like torch.nn.Linear.
|
||||
|
||||
Implements the initializers in 1.11.4, plus some additional ones found
|
||||
in the code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: int,
|
||||
bias: bool = True,
|
||||
init: str = "default",
|
||||
init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
in_dim:
|
||||
The final dimension of inputs to the layer
|
||||
out_dim:
|
||||
The final dimension of layer outputs
|
||||
bias:
|
||||
Whether to learn an additive bias. True by default
|
||||
init:
|
||||
The initializer to use. Choose from:
|
||||
|
||||
"default": LeCun fan-in truncated normal initialization
|
||||
"relu": He initialization w/ truncated normal distribution
|
||||
"glorot": Fan-average Glorot uniform initialization
|
||||
"gating": Weights=0, Bias=1
|
||||
"normal": Normal initialization with std=1/sqrt(fan_in)
|
||||
"final": Weights=0, Bias=0
|
||||
|
||||
Overridden by init_fn if the latter is not None.
|
||||
init_fn:
|
||||
A custom initializer taking weight and bias as inputs.
|
||||
Overrides init if not None.
|
||||
"""
|
||||
super(Linear, self).__init__(in_dim, out_dim, bias=bias)
|
||||
|
||||
if bias:
|
||||
with torch.no_grad():
|
||||
self.bias.fill_(0)
|
||||
|
||||
if init_fn is not None:
|
||||
init_fn(self.weight, self.bias)
|
||||
else:
|
||||
if init == "default":
|
||||
normal_init_(self.weight)
|
||||
elif init == "relu":
|
||||
normal_init_(self.weight)
|
||||
elif init == "glorot":
|
||||
glorot_uniform_init_(self.weight)
|
||||
elif init == "gating":
|
||||
gating_init_(self.weight)
|
||||
if bias:
|
||||
with torch.no_grad():
|
||||
self.bias.fill_(1.0)
|
||||
elif init == "normal":
|
||||
normal_init_(self.weight)
|
||||
elif init == "final":
|
||||
final_init_(self.weight)
|
||||
else:
|
||||
raise ValueError("Invalid init string.")
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
|
||||
def __init__(self, c_in, eps=1e-5):
|
||||
super(LayerNorm, self).__init__()
|
||||
|
||||
self.c_in = (c_in,)
|
||||
self.eps = eps
|
||||
|
||||
self.weight = nn.Parameter(torch.ones(c_in))
|
||||
self.bias = nn.Parameter(torch.zeros(c_in))
|
||||
|
||||
def forward(self, x):
|
||||
out = nn.functional.layer_norm(
|
||||
x,
|
||||
self.c_in,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.eps,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
"""
|
||||
Softmax, but without automatic casting to fp32 when the input is of
|
||||
type bfloat16
|
||||
"""
|
||||
s = torch.nn.functional.softmax(t, dim=dim)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
#@torch.jit.script
|
||||
def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
|
||||
biases: List[torch.Tensor]) -> torch.Tensor:
|
||||
# [*, H, Q, C_hidden]
|
||||
query = permute_final_dims(query, (1, 0, 2))
|
||||
|
||||
# [*, H, C_hidden, K]
|
||||
key = permute_final_dims(key, (1, 2, 0))
|
||||
|
||||
# [*, H, V, C_hidden]
|
||||
value = permute_final_dims(value, (1, 0, 2))
|
||||
|
||||
# [*, H, Q, K]
|
||||
a = torch.matmul(query, key)
|
||||
|
||||
for b in biases:
|
||||
a += b
|
||||
|
||||
a = softmax(a, -1)
|
||||
|
||||
# [*, H, Q, C_hidden]
|
||||
a = torch.matmul(a, value)
|
||||
|
||||
# [*, Q, H, C_hidden]
|
||||
a = a.transpose(-2, -3)
|
||||
|
||||
return a
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def _attention_chunked_trainable(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
biases,
|
||||
chunk_size,
|
||||
chunk_dim,
|
||||
checkpoint,
|
||||
):
|
||||
if (checkpoint and len(biases) > 2):
|
||||
raise ValueError("Checkpointed version permits only permits two bias terms")
|
||||
|
||||
def _checkpointable_attention(q, k, v, b1, b2):
|
||||
bs = [b for b in [b1, b2] if b is not None]
|
||||
return _attention(q, k, v, bs)
|
||||
|
||||
o_chunks = []
|
||||
checkpoint_fn = get_checkpoint_fn()
|
||||
count = query.shape[chunk_dim]
|
||||
for start in range(0, count, chunk_size):
|
||||
end = start + chunk_size
|
||||
idx = [slice(None)] * len(query.shape)
|
||||
idx[chunk_dim] = slice(start, end)
|
||||
idx_tup = tuple(idx)
|
||||
q_chunk = query[idx_tup]
|
||||
k_chunk = key[idx_tup]
|
||||
v_chunk = value[idx_tup]
|
||||
|
||||
def _slice_bias(b):
|
||||
idx[chunk_dim] = (slice(start, end) if b.shape[chunk_dim] != 1 else slice(None))
|
||||
return b[tuple(idx)]
|
||||
|
||||
if (checkpoint):
|
||||
bias_1_chunk, bias_2_chunk = [
|
||||
_slice_bias(b) if b is not None else None for b in (biases + [None, None])[:2]
|
||||
]
|
||||
|
||||
o_chunk = checkpoint_fn(_checkpointable_attention, q_chunk, k_chunk, v_chunk,
|
||||
bias_1_chunk, bias_2_chunk)
|
||||
else:
|
||||
bias_chunks = [_slice_bias(b) for b in biases]
|
||||
|
||||
o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks)
|
||||
|
||||
o_chunks.append(o_chunk)
|
||||
|
||||
o = torch.cat(o_chunks, dim=chunk_dim)
|
||||
return o
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
Standard multi-head attention using AlphaFold's default layer
|
||||
initialization. Allows multiple bias vectors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
c_q: int,
|
||||
c_k: int,
|
||||
c_v: int,
|
||||
c_hidden: int,
|
||||
no_heads: int,
|
||||
gating: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
c_q:
|
||||
Input dimension of query data
|
||||
c_k:
|
||||
Input dimension of key data
|
||||
c_v:
|
||||
Input dimension of value data
|
||||
c_hidden:
|
||||
Per-head hidden dimension
|
||||
no_heads:
|
||||
Number of attention heads
|
||||
gating:
|
||||
Whether the output should be gated using query data
|
||||
"""
|
||||
super(Attention, self).__init__()
|
||||
|
||||
self.c_q = c_q
|
||||
self.c_k = c_k
|
||||
self.c_v = c_v
|
||||
self.c_hidden = c_hidden
|
||||
self.no_heads = no_heads
|
||||
self.gating = gating
|
||||
|
||||
# DISCREPANCY: c_hidden is not the per-head channel dimension, as
|
||||
# stated in the supplement, but the overall channel dimension.
|
||||
|
||||
self.linear_q = Linear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot")
|
||||
self.linear_k = Linear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot")
|
||||
self.linear_v = Linear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot")
|
||||
self.linear_o = Linear(self.c_hidden * self.no_heads, self.c_q, init="final")
|
||||
|
||||
self.linear_g = None
|
||||
if self.gating:
|
||||
self.linear_g = Linear(self.c_q, self.c_hidden * self.no_heads, init="gating")
|
||||
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def _prep_qkv(self, q_x: torch.Tensor,
|
||||
kv_x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# [*, Q/K/V, H * C_hidden]
|
||||
q = self.linear_q(q_x)
|
||||
k = self.linear_k(kv_x)
|
||||
v = self.linear_v(kv_x)
|
||||
|
||||
# [*, Q/K, H, C_hidden]
|
||||
q = q.view(q.shape[:-1] + (self.no_heads, -1))
|
||||
k = k.view(k.shape[:-1] + (self.no_heads, -1))
|
||||
v = v.view(v.shape[:-1] + (self.no_heads, -1))
|
||||
|
||||
q /= math.sqrt(self.c_hidden)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor:
|
||||
if (self.linear_g is not None):
|
||||
g = self.sigmoid(self.linear_g(q_x))
|
||||
|
||||
# [*, Q, H, C_hidden]
|
||||
g = g.view(g.shape[:-1] + (self.no_heads, -1))
|
||||
o = o * g
|
||||
|
||||
# [*, Q, H * C_hidden]
|
||||
o = flatten_final_dims(o, 2)
|
||||
|
||||
# [*, Q, C_q]
|
||||
o = self.linear_o(o)
|
||||
|
||||
return o
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q_x: torch.Tensor,
|
||||
kv_x: torch.Tensor,
|
||||
biases: Optional[List[torch.Tensor]] = None,
|
||||
use_lma: bool = False,
|
||||
q_chunk_size: Optional[int] = None,
|
||||
kv_chunk_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
q_x:
|
||||
[*, Q, C_q] query data
|
||||
kv_x:
|
||||
[*, K, C_k] key data
|
||||
biases:
|
||||
List of biases that broadcast to [*, H, Q, K]
|
||||
use_lma:
|
||||
Whether to use low-memory attention
|
||||
q_chunk_size:
|
||||
Query chunk size (for LMA)
|
||||
kv_chunk_size:
|
||||
Key/Value chunk size (for LMA)
|
||||
Returns
|
||||
[*, Q, C_q] attention update
|
||||
"""
|
||||
if (biases is None):
|
||||
biases = []
|
||||
if (use_lma and (q_chunk_size is None or kv_chunk_size is None)):
|
||||
raise ValueError("If use_lma is specified, q_chunk_size and kv_chunk_size must "
|
||||
"be provided")
|
||||
|
||||
q, k, v = self._prep_qkv(q_x, kv_x)
|
||||
|
||||
if (use_lma):
|
||||
biases = [b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) for b in biases]
|
||||
|
||||
o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)
|
||||
else:
|
||||
o = _attention(q, k, v, biases)
|
||||
|
||||
o = self._wrap_up(o, q_x)
|
||||
|
||||
return o
|
||||
|
||||
|
||||
class GlobalAttention(nn.Module):
|
||||
|
||||
def __init__(self, c_in, c_hidden, no_heads, inf, eps):
|
||||
super(GlobalAttention, self).__init__()
|
||||
|
||||
self.c_in = c_in
|
||||
self.c_hidden = c_hidden
|
||||
self.no_heads = no_heads
|
||||
self.inf = inf
|
||||
self.eps = eps
|
||||
|
||||
self.linear_q = Linear(c_in, c_hidden * no_heads, bias=False, init="glorot")
|
||||
|
||||
self.linear_k = Linear(
|
||||
c_in,
|
||||
c_hidden,
|
||||
bias=False,
|
||||
init="glorot",
|
||||
)
|
||||
self.linear_v = Linear(
|
||||
c_in,
|
||||
c_hidden,
|
||||
bias=False,
|
||||
init="glorot",
|
||||
)
|
||||
self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating")
|
||||
self.linear_o = Linear(c_hidden * no_heads, c_in, init="final")
|
||||
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
# [*, N_res, C_in]
|
||||
q = torch.sum(m * mask.unsqueeze(-1),
|
||||
dim=-2) / (torch.sum(mask, dim=-1)[..., None] + self.eps)
|
||||
|
||||
# [*, N_res, H * C_hidden]
|
||||
q = self.linear_q(q)
|
||||
q *= (self.c_hidden**(-0.5))
|
||||
|
||||
# [*, N_res, H, C_hidden]
|
||||
q = q.view(q.shape[:-1] + (self.no_heads, -1))
|
||||
|
||||
# [*, N_res, N_seq, C_hidden]
|
||||
k = self.linear_k(m)
|
||||
v = self.linear_v(m)
|
||||
|
||||
# [*, N_res, H, N_seq]
|
||||
a = torch.matmul(
|
||||
q,
|
||||
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
|
||||
)
|
||||
bias = (self.inf * (mask - 1))[..., :, None, :]
|
||||
a += bias
|
||||
a = softmax(a)
|
||||
|
||||
# [*, N_res, H, C_hidden]
|
||||
o = torch.matmul(
|
||||
a,
|
||||
v,
|
||||
)
|
||||
|
||||
# [*, N_res, N_seq, C_hidden]
|
||||
g = self.sigmoid(self.linear_g(m))
|
||||
|
||||
# [*, N_res, N_seq, H, C_hidden]
|
||||
g = g.view(g.shape[:-1] + (self.no_heads, -1))
|
||||
|
||||
# [*, N_res, N_seq, H, C_hidden]
|
||||
o = o.unsqueeze(-3) * g
|
||||
|
||||
# [*, N_res, N_seq, H * C_hidden]
|
||||
o = o.reshape(o.shape[:-2] + (-1,))
|
||||
|
||||
# [*, N_res, N_seq, C_in]
|
||||
m = self.linear_o(o)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
def _lma(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
biases: List[torch.Tensor],
|
||||
q_chunk_size: int,
|
||||
kv_chunk_size: int,
|
||||
):
|
||||
no_q, no_kv = q.shape[-3], k.shape[-3]
|
||||
|
||||
# [*, Q, H, C_hidden]
|
||||
o = q.new_zeros(q.shape)
|
||||
for q_s in range(0, no_q, q_chunk_size):
|
||||
q_chunk = q[..., q_s:q_s + q_chunk_size, :, :]
|
||||
large_bias_chunks = [b[..., q_s:q_s + q_chunk_size, :] for b in biases]
|
||||
|
||||
maxes = []
|
||||
weights = []
|
||||
values = []
|
||||
for kv_s in range(0, no_kv, kv_chunk_size):
|
||||
k_chunk = k[..., kv_s:kv_s + kv_chunk_size, :, :]
|
||||
v_chunk = v[..., kv_s:kv_s + kv_chunk_size, :, :]
|
||||
small_bias_chunks = [b[..., kv_s:kv_s + kv_chunk_size] for b in large_bias_chunks]
|
||||
|
||||
a = torch.einsum(
|
||||
"...qhd,...khd->...hqk",
|
||||
q_chunk,
|
||||
k_chunk,
|
||||
)
|
||||
|
||||
for b in small_bias_chunks:
|
||||
a += b
|
||||
|
||||
a = a.transpose(-2, -3)
|
||||
|
||||
max_a = torch.max(a, dim=-1, keepdim=True)[0]
|
||||
exp_a = torch.exp(a - max_a)
|
||||
exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a)
|
||||
|
||||
maxes.append(max_a.detach().squeeze(-1))
|
||||
weights.append(torch.sum(exp_a, dim=-1))
|
||||
values.append(exp_v)
|
||||
|
||||
chunk_max = torch.stack(maxes, dim=-3)
|
||||
chunk_weights = torch.stack(weights, dim=-3)
|
||||
chunk_values = torch.stack(values, dim=-4)
|
||||
|
||||
global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0]
|
||||
max_diffs = torch.exp(chunk_max - global_max)
|
||||
chunk_values *= max_diffs.unsqueeze(-1)
|
||||
chunk_weights *= max_diffs
|
||||
|
||||
all_values = torch.sum(chunk_values, dim=-4)
|
||||
all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4)
|
||||
|
||||
q_chunk_out = all_values / all_weights
|
||||
|
||||
o[..., q_s:q_s + q_chunk_size, :, :] = q_chunk_out
|
||||
|
||||
return o
|
408
openfold/tensor_utils.py
Normal file
408
openfold/tensor_utils.py
Normal file
@ -0,0 +1,408 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional
|
||||
|
||||
|
||||
def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
|
||||
zero_index = -1 * len(inds)
|
||||
first_inds = list(range(len(tensor.shape[:zero_index])))
|
||||
return tensor.permute(first_inds + [zero_index + i for i in inds])
|
||||
|
||||
|
||||
def flatten_final_dims(t: torch.Tensor, no_dims: int):
|
||||
return t.reshape(t.shape[:-no_dims] + (-1,))
|
||||
|
||||
|
||||
def masked_mean(mask, value, dim, eps=1e-4):
|
||||
mask = mask.expand(*value.shape)
|
||||
return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))
|
||||
|
||||
|
||||
def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64):
|
||||
boundaries = torch.linspace(
|
||||
min_bin, max_bin, no_bins - 1, device=pts.device
|
||||
)
|
||||
dists = torch.sqrt(
|
||||
torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1)
|
||||
)
|
||||
return torch.bucketize(dists, boundaries)
|
||||
|
||||
|
||||
def dict_multimap(fn, dicts):
|
||||
first = dicts[0]
|
||||
new_dict = {}
|
||||
for k, v in first.items():
|
||||
all_v = [d[k] for d in dicts]
|
||||
if type(v) is dict:
|
||||
new_dict[k] = dict_multimap(fn, all_v)
|
||||
else:
|
||||
new_dict[k] = fn(all_v)
|
||||
|
||||
return new_dict
|
||||
|
||||
|
||||
def one_hot(x, v_bins):
|
||||
reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
|
||||
diffs = x[..., None] - reshaped_bins
|
||||
am = torch.argmin(torch.abs(diffs), dim=-1)
|
||||
return nn.functional.one_hot(am, num_classes=len(v_bins)).float()
|
||||
|
||||
|
||||
def batched_gather(data, inds, dim=0, no_batch_dims=0):
|
||||
ranges = []
|
||||
for i, s in enumerate(data.shape[:no_batch_dims]):
|
||||
r = torch.arange(s)
|
||||
r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
|
||||
ranges.append(r)
|
||||
|
||||
remaining_dims = [
|
||||
slice(None) for _ in range(len(data.shape) - no_batch_dims)
|
||||
]
|
||||
remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
|
||||
ranges.extend(remaining_dims)
|
||||
return data[ranges]
|
||||
|
||||
|
||||
# With tree_map, a poor man's JAX tree_map
|
||||
def dict_map(fn, dic, leaf_type):
|
||||
new_dict = {}
|
||||
for k, v in dic.items():
|
||||
if type(v) is dict:
|
||||
new_dict[k] = dict_map(fn, v, leaf_type)
|
||||
else:
|
||||
new_dict[k] = tree_map(fn, v, leaf_type)
|
||||
|
||||
return new_dict
|
||||
|
||||
|
||||
def tree_map(fn, tree, leaf_type):
|
||||
if isinstance(tree, dict):
|
||||
return dict_map(fn, tree, leaf_type)
|
||||
elif isinstance(tree, list):
|
||||
return [tree_map(fn, x, leaf_type) for x in tree]
|
||||
elif isinstance(tree, tuple):
|
||||
return tuple([tree_map(fn, x, leaf_type) for x in tree])
|
||||
elif isinstance(tree, leaf_type):
|
||||
return fn(tree)
|
||||
else:
|
||||
print(type(tree))
|
||||
raise ValueError("Not supported")
|
||||
|
||||
|
||||
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
|
||||
|
||||
def _fetch_dims(tree):
|
||||
shapes = []
|
||||
tree_type = type(tree)
|
||||
if tree_type is dict:
|
||||
for v in tree.values():
|
||||
shapes.extend(_fetch_dims(v))
|
||||
elif tree_type is list or tree_type is tuple:
|
||||
for t in tree:
|
||||
shapes.extend(_fetch_dims(t))
|
||||
elif tree_type is torch.Tensor:
|
||||
shapes.append(tree.shape)
|
||||
else:
|
||||
raise ValueError("Not supported")
|
||||
|
||||
return shapes
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def _flat_idx_to_idx(
|
||||
flat_idx: int,
|
||||
dims: Tuple[int],
|
||||
) -> Tuple[int]:
|
||||
idx = []
|
||||
for d in reversed(dims):
|
||||
idx.append(flat_idx % d)
|
||||
flat_idx = flat_idx // d
|
||||
|
||||
return tuple(reversed(idx))
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def _get_minimal_slice_set(
|
||||
start: Sequence[int],
|
||||
end: Sequence[int],
|
||||
dims: int,
|
||||
start_edges: Optional[Sequence[bool]] = None,
|
||||
end_edges: Optional[Sequence[bool]] = None,
|
||||
) -> Sequence[Tuple[int]]:
|
||||
"""
|
||||
Produces an ordered sequence of tensor slices that, when used in
|
||||
sequence on a tensor with shape dims, yields tensors that contain every
|
||||
leaf in the contiguous range [start, end]. Care is taken to yield a
|
||||
short sequence of slices, and perhaps even the shortest possible (I'm
|
||||
pretty sure it's the latter).
|
||||
|
||||
end is INCLUSIVE.
|
||||
"""
|
||||
# start_edges and end_edges both indicate whether, starting from any given
|
||||
# dimension, the start/end index is at the top/bottom edge of the
|
||||
# corresponding tensor, modeled as a tree
|
||||
def reduce_edge_list(l):
|
||||
tally = 1
|
||||
for i in range(len(l)):
|
||||
reversed_idx = -1 * (i + 1)
|
||||
l[reversed_idx] *= tally
|
||||
tally = l[reversed_idx]
|
||||
|
||||
if(start_edges is None):
|
||||
start_edges = [s == 0 for s in start]
|
||||
reduce_edge_list(start_edges)
|
||||
if(end_edges is None):
|
||||
end_edges = [e == (d - 1) for e,d in zip(end, dims)]
|
||||
reduce_edge_list(end_edges)
|
||||
|
||||
# Base cases. Either start/end are empty and we're done, or the final,
|
||||
# one-dimensional tensor can be simply sliced
|
||||
if(len(start) == 0):
|
||||
return [tuple()]
|
||||
elif(len(start) == 1):
|
||||
return [(slice(start[0], end[0] + 1),)]
|
||||
|
||||
slices = []
|
||||
path = []
|
||||
|
||||
# Dimensions common to start and end can be selected directly
|
||||
for s,e in zip(start, end):
|
||||
if(s == e):
|
||||
path.append(slice(s, s + 1))
|
||||
else:
|
||||
break
|
||||
|
||||
path = tuple(path)
|
||||
divergence_idx = len(path)
|
||||
|
||||
# start == end, and we're done
|
||||
if(divergence_idx == len(dims)):
|
||||
return [tuple(path)]
|
||||
|
||||
def upper():
|
||||
sdi = start[divergence_idx]
|
||||
return [
|
||||
path + (slice(sdi, sdi + 1),) + s for s in
|
||||
_get_minimal_slice_set(
|
||||
start[divergence_idx + 1:],
|
||||
[d - 1 for d in dims[divergence_idx + 1:]],
|
||||
dims[divergence_idx + 1:],
|
||||
start_edges=start_edges[divergence_idx + 1:],
|
||||
end_edges=[1 for _ in end_edges[divergence_idx + 1:]]
|
||||
)
|
||||
]
|
||||
|
||||
def lower():
|
||||
edi = end[divergence_idx]
|
||||
return [
|
||||
path + (slice(edi, edi + 1),) + s for s in
|
||||
_get_minimal_slice_set(
|
||||
[0 for _ in start[divergence_idx + 1:]],
|
||||
end[divergence_idx + 1:],
|
||||
dims[divergence_idx + 1:],
|
||||
start_edges=[1 for _ in start_edges[divergence_idx + 1:]],
|
||||
end_edges=end_edges[divergence_idx + 1:],
|
||||
)
|
||||
]
|
||||
|
||||
# If both start and end are at the edges of the subtree rooted at
|
||||
# divergence_idx, we can just select the whole subtree at once
|
||||
if(start_edges[divergence_idx] and end_edges[divergence_idx]):
|
||||
slices.append(
|
||||
path + (slice(start[divergence_idx], end[divergence_idx] + 1),)
|
||||
)
|
||||
# If just start is at the edge, we can grab almost all of the subtree,
|
||||
# treating only the ragged bottom edge as an edge case
|
||||
elif(start_edges[divergence_idx]):
|
||||
slices.append(
|
||||
path + (slice(start[divergence_idx], end[divergence_idx]),)
|
||||
)
|
||||
slices.extend(lower())
|
||||
# Analogous to the previous case, but the top is ragged this time
|
||||
elif(end_edges[divergence_idx]):
|
||||
slices.extend(upper())
|
||||
slices.append(
|
||||
path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),)
|
||||
)
|
||||
# If both sides of the range are ragged, we need to handle both sides
|
||||
# separately. If there's contiguous meat in between them, we can index it
|
||||
# in one big chunk
|
||||
else:
|
||||
slices.extend(upper())
|
||||
middle_ground = end[divergence_idx] - start[divergence_idx]
|
||||
if(middle_ground > 1):
|
||||
slices.append(
|
||||
path + (slice(start[divergence_idx] + 1, end[divergence_idx]),)
|
||||
)
|
||||
slices.extend(lower())
|
||||
|
||||
return [tuple(s) for s in slices]
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk_slice(
|
||||
t: torch.Tensor,
|
||||
flat_start: int,
|
||||
flat_end: int,
|
||||
no_batch_dims: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Equivalent to
|
||||
|
||||
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
|
||||
|
||||
but without the need for the initial reshape call, which can be
|
||||
memory-intensive in certain situations. The only reshape operations
|
||||
in this function are performed on sub-tensors that scale with
|
||||
(flat_end - flat_start), the chunk size.
|
||||
"""
|
||||
|
||||
batch_dims = t.shape[:no_batch_dims]
|
||||
start_idx = list(_flat_idx_to_idx(flat_start, batch_dims))
|
||||
# _get_minimal_slice_set is inclusive
|
||||
end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims))
|
||||
|
||||
# Get an ordered list of slices to perform
|
||||
slices = _get_minimal_slice_set(
|
||||
start_idx,
|
||||
end_idx,
|
||||
batch_dims,
|
||||
)
|
||||
|
||||
sliced_tensors = [t[s] for s in slices]
|
||||
|
||||
return torch.cat(
|
||||
[s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors]
|
||||
)
|
||||
|
||||
|
||||
def chunk_layer(
|
||||
layer: Callable,
|
||||
inputs: Dict[str, Any],
|
||||
chunk_size: int,
|
||||
no_batch_dims: int,
|
||||
low_mem: bool = False,
|
||||
) -> Any:
|
||||
"""
|
||||
Implements the "chunking" procedure described in section 1.11.8.
|
||||
|
||||
Layer outputs and inputs are assumed to be simple "pytrees,"
|
||||
consisting only of (arbitrarily nested) lists, tuples, and dicts with
|
||||
torch.Tensor leaves.
|
||||
|
||||
Args:
|
||||
layer:
|
||||
The layer to be applied chunk-wise
|
||||
inputs:
|
||||
A (non-nested) dictionary of keyworded inputs. All leaves must
|
||||
be tensors and must share the same batch dimensions.
|
||||
chunk_size:
|
||||
The number of sub-batches per chunk. If multiple batch
|
||||
dimensions are specified, a "sub-batch" is defined as a single
|
||||
indexing of all batch dimensions simultaneously (s.t. the
|
||||
number of sub-batches is the product of the batch dimensions).
|
||||
no_batch_dims:
|
||||
How many of the initial dimensions of each input tensor can
|
||||
be considered batch dimensions.
|
||||
low_mem:
|
||||
Avoids flattening potentially large input tensors. Unnecessary
|
||||
in most cases, and is ever so slightly slower than the default
|
||||
setting.
|
||||
Returns:
|
||||
The reassembled output of the layer on the inputs.
|
||||
"""
|
||||
if not (len(inputs) > 0):
|
||||
raise ValueError("Must provide at least one input")
|
||||
|
||||
initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
|
||||
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
|
||||
|
||||
def _prep_inputs(t):
|
||||
# TODO: make this more memory efficient. This sucks
|
||||
if(not low_mem):
|
||||
if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
|
||||
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
|
||||
t = t.reshape(-1, *t.shape[no_batch_dims:])
|
||||
else:
|
||||
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
|
||||
return t
|
||||
|
||||
prepped_inputs = tensor_tree_map(_prep_inputs, inputs)
|
||||
|
||||
flat_batch_dim = 1
|
||||
for d in orig_batch_dims:
|
||||
flat_batch_dim *= d
|
||||
|
||||
no_chunks = flat_batch_dim // chunk_size + (
|
||||
flat_batch_dim % chunk_size != 0
|
||||
)
|
||||
|
||||
i = 0
|
||||
out = None
|
||||
for _ in range(no_chunks):
|
||||
# Chunk the input
|
||||
if(not low_mem):
|
||||
select_chunk = (
|
||||
lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t
|
||||
)
|
||||
else:
|
||||
select_chunk = (
|
||||
partial(
|
||||
_chunk_slice,
|
||||
flat_start=i,
|
||||
flat_end=min(flat_batch_dim, i + chunk_size),
|
||||
no_batch_dims=len(orig_batch_dims)
|
||||
)
|
||||
)
|
||||
|
||||
chunks = tensor_tree_map(select_chunk, prepped_inputs)
|
||||
|
||||
# Run the layer on the chunk
|
||||
output_chunk = layer(**chunks)
|
||||
|
||||
# Allocate space for the output
|
||||
if out is None:
|
||||
allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:])
|
||||
out = tensor_tree_map(allocate, output_chunk)
|
||||
|
||||
# Put the chunk in its pre-allocated space
|
||||
out_type = type(output_chunk)
|
||||
if out_type is dict:
|
||||
def assign(d1, d2):
|
||||
for k, v in d1.items():
|
||||
if type(v) is dict:
|
||||
assign(v, d2[k])
|
||||
else:
|
||||
v[i : i + chunk_size] = d2[k]
|
||||
|
||||
assign(out, output_chunk)
|
||||
elif out_type is tuple:
|
||||
for x1, x2 in zip(out, output_chunk):
|
||||
x1[i : i + chunk_size] = x2
|
||||
elif out_type is torch.Tensor:
|
||||
out[i : i + chunk_size] = output_chunk
|
||||
else:
|
||||
raise ValueError("Not supported")
|
||||
|
||||
i += chunk_size
|
||||
|
||||
reshape = lambda t: t.view(orig_batch_dims + t.shape[1:])
|
||||
out = tensor_tree_map(reshape, out)
|
||||
|
||||
return out
|
139
openfold/triangular_attention.py
Normal file
139
openfold/triangular_attention.py
Normal file
@ -0,0 +1,139 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partialmethod, partial
|
||||
import math
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from openfold.primitives import Linear, LayerNorm, Attention
|
||||
from openfold.tensor_utils import (
|
||||
chunk_layer,
|
||||
permute_final_dims,
|
||||
flatten_final_dims,
|
||||
)
|
||||
|
||||
|
||||
class TriangleAttention(nn.Module):
|
||||
def __init__(
|
||||
self, c_in, c_hidden, no_heads, starting, inf=1e9
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
c_in:
|
||||
Input channel dimension
|
||||
c_hidden:
|
||||
Overall hidden channel dimension (not per-head)
|
||||
no_heads:
|
||||
Number of attention heads
|
||||
"""
|
||||
super(TriangleAttention, self).__init__()
|
||||
|
||||
self.c_in = c_in
|
||||
self.c_hidden = c_hidden
|
||||
self.no_heads = no_heads
|
||||
self.starting = starting
|
||||
self.inf = inf
|
||||
|
||||
self.layer_norm = LayerNorm(self.c_in)
|
||||
|
||||
self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
|
||||
|
||||
self.mha = Attention(
|
||||
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk(self,
|
||||
x: torch.Tensor,
|
||||
biases: List[torch.Tensor],
|
||||
chunk_size: int,
|
||||
) -> torch.Tensor:
|
||||
mha_inputs = {
|
||||
"q_x": x,
|
||||
"kv_x": x,
|
||||
"biases": biases,
|
||||
}
|
||||
return chunk_layer(
|
||||
partial(self.mha),
|
||||
mha_inputs,
|
||||
chunk_size=chunk_size,
|
||||
no_batch_dims=len(x.shape[:-2]),
|
||||
)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
[*, I, J, C_in] input tensor (e.g. the pair representation)
|
||||
Returns:
|
||||
[*, I, J, C_in] output tensor
|
||||
"""
|
||||
if mask is None:
|
||||
# [*, I, J]
|
||||
mask = x.new_ones(
|
||||
x.shape[:-1],
|
||||
)
|
||||
|
||||
# Shape annotations assume self.starting. Else, I and J are flipped
|
||||
if not self.starting:
|
||||
x = x.transpose(-2, -3)
|
||||
mask = mask.transpose(-1, -2)
|
||||
|
||||
# [*, I, J, C_in]
|
||||
x = self.layer_norm(x)
|
||||
|
||||
# [*, I, 1, 1, J]
|
||||
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
|
||||
|
||||
# [*, H, I, J]
|
||||
triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
|
||||
|
||||
# [*, 1, H, I, J]
|
||||
triangle_bias = triangle_bias.unsqueeze(-4)
|
||||
|
||||
biases = [mask_bias, triangle_bias]
|
||||
|
||||
if chunk_size is not None:
|
||||
x = self._chunk(x, biases, chunk_size)
|
||||
else:
|
||||
x = self.mha(q_x=x, kv_x=x, biases=biases)
|
||||
|
||||
if not self.starting:
|
||||
x = x.transpose(-2, -3)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TriangleAttentionStartingNode(TriangleAttention):
|
||||
"""
|
||||
Implements Algorithm 13.
|
||||
"""
|
||||
|
||||
__init__ = partialmethod(TriangleAttention.__init__, starting=True)
|
||||
|
||||
|
||||
class TriangleAttentionEndingNode(TriangleAttention):
|
||||
"""
|
||||
Implements Algorithm 14.
|
||||
"""
|
||||
|
||||
__init__ = partialmethod(TriangleAttention.__init__, starting=False)
|
127
openfold/triangular_multiplicative_update.py
Normal file
127
openfold/triangular_multiplicative_update.py
Normal file
@ -0,0 +1,127 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partialmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from openfold.primitives import Linear, LayerNorm
|
||||
from openfold.tensor_utils import permute_final_dims
|
||||
|
||||
|
||||
class TriangleMultiplicativeUpdate(nn.Module):
|
||||
"""
|
||||
Implements Algorithms 11 and 12.
|
||||
"""
|
||||
def __init__(self, c_z, c_hidden, _outgoing=True):
|
||||
"""
|
||||
Args:
|
||||
c_z:
|
||||
Input channel dimension
|
||||
c:
|
||||
Hidden channel dimension
|
||||
"""
|
||||
super(TriangleMultiplicativeUpdate, self).__init__()
|
||||
self.c_z = c_z
|
||||
self.c_hidden = c_hidden
|
||||
self._outgoing = _outgoing
|
||||
|
||||
self.linear_a_p = Linear(self.c_z, self.c_hidden)
|
||||
self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating")
|
||||
self.linear_b_p = Linear(self.c_z, self.c_hidden)
|
||||
self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating")
|
||||
self.linear_g = Linear(self.c_z, self.c_z, init="gating")
|
||||
self.linear_z = Linear(self.c_hidden, self.c_z, init="final")
|
||||
|
||||
self.layer_norm_in = LayerNorm(self.c_z)
|
||||
self.layer_norm_out = LayerNorm(self.c_hidden)
|
||||
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def _combine_projections(self,
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError("This method needs to be overridden")
|
||||
|
||||
def forward(self,
|
||||
z: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
[*, N_res, N_res, C_z] input tensor
|
||||
mask:
|
||||
[*, N_res, N_res] input mask
|
||||
Returns:
|
||||
[*, N_res, N_res, C_z] output tensor
|
||||
"""
|
||||
if mask is None:
|
||||
mask = z.new_ones(z.shape[:-1])
|
||||
|
||||
mask = mask.unsqueeze(-1)
|
||||
|
||||
z = self.layer_norm_in(z)
|
||||
a = self.linear_a_p(z) * self.sigmoid(self.linear_a_g(z))
|
||||
a = a * mask
|
||||
b = self.linear_b_p(z) * self.sigmoid(self.linear_b_g(z))
|
||||
b = b * mask
|
||||
x = self._combine_projections(a, b)
|
||||
x = self.layer_norm_out(x)
|
||||
x = self.linear_z(x)
|
||||
g = self.sigmoid(self.linear_g(z))
|
||||
z = x * g
|
||||
|
||||
return z
|
||||
|
||||
|
||||
class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate):
|
||||
"""
|
||||
Implements Algorithm 11.
|
||||
"""
|
||||
def _combine_projections(self,
|
||||
a: torch.Tensor, # [*, N_i, N_k, C]
|
||||
b: torch.Tensor, # [*, N_j, N_k, C]
|
||||
):
|
||||
# [*, C, N_i, N_j]
|
||||
p = torch.matmul(
|
||||
permute_final_dims(a, (2, 0, 1)),
|
||||
permute_final_dims(b, (2, 1, 0)),
|
||||
)
|
||||
|
||||
# [*, N_i, N_j, C]
|
||||
return permute_final_dims(p, (1, 2, 0))
|
||||
|
||||
|
||||
class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
|
||||
"""
|
||||
Implements Algorithm 12.
|
||||
"""
|
||||
def _combine_projections(self,
|
||||
a: torch.Tensor, # [*, N_k, N_i, C]
|
||||
b: torch.Tensor, # [*, N_k, N_j, C]
|
||||
):
|
||||
# [*, C, N_i, N_j]
|
||||
p = torch.matmul(
|
||||
permute_final_dims(a, (2, 1, 0)),
|
||||
permute_final_dims(b, (2, 0, 1)),
|
||||
)
|
||||
|
||||
# [*, N_i, N_j, C]
|
||||
return permute_final_dims(p, (1, 2, 0))
|
||||
|
Loading…
Reference in New Issue
Block a user