mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 15:01:43 +00:00
update openfold
This commit is contained in:
parent
289f3a45c2
commit
5c4df01af3
@ -182,33 +182,28 @@ class EvoformerBlockCore(nn.Module):
|
|||||||
self,
|
self,
|
||||||
m: torch.Tensor,
|
m: torch.Tensor,
|
||||||
z: torch.Tensor,
|
z: torch.Tensor,
|
||||||
msa_mask: torch.Tensor,
|
|
||||||
pair_mask: torch.Tensor,
|
|
||||||
chunk_size: Optional[int] = None,
|
chunk_size: Optional[int] = None,
|
||||||
_mask_trans: bool = True,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# DeepMind doesn't mask these transitions in the source, so _mask_trans
|
# DeepMind doesn't mask these transitions in the source, so _mask_trans
|
||||||
# should be disabled to better approximate the exact activations of
|
# should be disabled to better approximate the exact activations of
|
||||||
# the original.
|
# the original.
|
||||||
msa_trans_mask = msa_mask if _mask_trans else None
|
|
||||||
pair_trans_mask = pair_mask if _mask_trans else None
|
|
||||||
|
|
||||||
m = m + self.msa_transition(
|
m = m + self.msa_transition(
|
||||||
m, mask=msa_trans_mask, chunk_size=chunk_size
|
m, chunk_size=chunk_size
|
||||||
)
|
)
|
||||||
z = z + self.outer_product_mean(
|
z = z + self.outer_product_mean(
|
||||||
m, mask=msa_mask, chunk_size=chunk_size
|
m, chunk_size=chunk_size
|
||||||
)
|
)
|
||||||
z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask))
|
z = z + self.ps_dropout_row_layer(self.tri_mul_out(z))
|
||||||
z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask))
|
z = z + self.ps_dropout_row_layer(self.tri_mul_in(z))
|
||||||
z = z + self.ps_dropout_row_layer(
|
z = z + self.ps_dropout_row_layer(
|
||||||
self.tri_att_start(z, mask=pair_mask, chunk_size=chunk_size)
|
self.tri_att_start(z, chunk_size=chunk_size)
|
||||||
)
|
)
|
||||||
z = z + self.ps_dropout_col_layer(
|
z = z + self.ps_dropout_col_layer(
|
||||||
self.tri_att_end(z, mask=pair_mask, chunk_size=chunk_size)
|
self.tri_att_end(z, chunk_size=chunk_size)
|
||||||
)
|
)
|
||||||
z = z + self.pair_transition(
|
z = z + self.pair_transition(
|
||||||
z, mask=pair_trans_mask, chunk_size=chunk_size
|
z, chunk_size=chunk_size
|
||||||
)
|
)
|
||||||
|
|
||||||
return m, z
|
return m, z
|
||||||
@ -274,22 +269,16 @@ class EvoformerBlock(nn.Module):
|
|||||||
def forward(self,
|
def forward(self,
|
||||||
m: torch.Tensor,
|
m: torch.Tensor,
|
||||||
z: torch.Tensor,
|
z: torch.Tensor,
|
||||||
msa_mask: torch.Tensor,
|
|
||||||
pair_mask: torch.Tensor,
|
|
||||||
chunk_size: Optional[int] = None,
|
chunk_size: Optional[int] = None,
|
||||||
_mask_trans: bool = True,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
m = m + self.msa_dropout_layer(
|
m = m + self.msa_dropout_layer(
|
||||||
self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size)
|
self.msa_att_row(m, z=z, chunk_size=chunk_size)
|
||||||
)
|
)
|
||||||
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
|
m = m + self.msa_att_col(m, chunk_size=chunk_size)
|
||||||
m, z = self.core(
|
m, z = self.core(
|
||||||
m,
|
m,
|
||||||
z,
|
z,
|
||||||
msa_mask=msa_mask,
|
|
||||||
pair_mask=pair_mask,
|
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
_mask_trans=_mask_trans,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return m, z
|
return m, z
|
||||||
|
@ -136,45 +136,6 @@ class MSAAttention(nn.Module):
|
|||||||
|
|
||||||
return m, mask_bias, z
|
return m, mask_bias, z
|
||||||
|
|
||||||
@torch.jit.ignore
|
|
||||||
def _chunked_msa_attn(self,
|
|
||||||
m: torch.Tensor,
|
|
||||||
z: Optional[torch.Tensor],
|
|
||||||
mask: Optional[torch.Tensor],
|
|
||||||
chunk_logits: int,
|
|
||||||
checkpoint: bool,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
MSA_DIM = -4
|
|
||||||
|
|
||||||
def _get_qkv(m, z):
|
|
||||||
m, mask_bias, z = self._prep_inputs(m, z, mask)
|
|
||||||
q, k, v = self.mha._prep_qkv(m, m)
|
|
||||||
return m, q, k, v, mask_bias, z
|
|
||||||
|
|
||||||
checkpoint_fn = get_checkpoint_fn()
|
|
||||||
|
|
||||||
if(torch.is_grad_enabled() and checkpoint):
|
|
||||||
m, q, k, v, mask_bias, z = checkpoint_fn(_get_qkv, m, z)
|
|
||||||
else:
|
|
||||||
m, q, k, v, mask_bias, z = _get_qkv(m, z)
|
|
||||||
|
|
||||||
o = _attention_chunked_trainable(
|
|
||||||
query=q,
|
|
||||||
key=k,
|
|
||||||
value=v,
|
|
||||||
biases=[mask_bias, z],
|
|
||||||
chunk_size=chunk_logits,
|
|
||||||
chunk_dim=MSA_DIM,
|
|
||||||
checkpoint=checkpoint,
|
|
||||||
)
|
|
||||||
|
|
||||||
if(torch.is_grad_enabled() and checkpoint):
|
|
||||||
# Storing an additional m here is far from ideal
|
|
||||||
m = checkpoint_fn(self.mha._wrap_up, o, m)
|
|
||||||
else:
|
|
||||||
m = self.mha._wrap_up(o, m)
|
|
||||||
|
|
||||||
return m
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
m: torch.Tensor,
|
m: torch.Tensor,
|
||||||
@ -199,12 +160,6 @@ class MSAAttention(nn.Module):
|
|||||||
cost of slower execution. Chunking is not performed by default.
|
cost of slower execution. Chunking is not performed by default.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if(_chunk_logits is not None):
|
|
||||||
return self._chunked_msa_attn(
|
|
||||||
m=m, z=z, mask=mask,
|
|
||||||
chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks
|
|
||||||
)
|
|
||||||
|
|
||||||
m, mask_bias, z = self._prep_inputs(m, z, mask)
|
m, mask_bias, z = self._prep_inputs(m, z, mask)
|
||||||
|
|
||||||
biases = [mask_bias]
|
biases = [mask_bias]
|
||||||
@ -306,15 +261,11 @@ class MSAColumnAttention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
# [*, N_res, N_seq, C_in]
|
# [*, N_res, N_seq, C_in]
|
||||||
m = m.transpose(-2, -3)
|
m = m.transpose(-2, -3)
|
||||||
if mask is not None:
|
|
||||||
mask = mask.transpose(-1, -2)
|
|
||||||
|
|
||||||
m = self._msa_att(m, mask=mask, chunk_size=chunk_size)
|
m = self._msa_att(m, chunk_size=chunk_size)
|
||||||
|
|
||||||
# [*, N_seq, N_res, C_in]
|
# [*, N_seq, N_res, C_in]
|
||||||
m = m.transpose(-2, -3)
|
m = m.transpose(-2, -3)
|
||||||
if mask is not None:
|
|
||||||
mask = mask.transpose(-1, -2)
|
|
||||||
|
|
||||||
return m
|
return m
|
||||||
|
|
||||||
@ -344,12 +295,10 @@ class MSAColumnGlobalAttention(nn.Module):
|
|||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def _chunk(self,
|
def _chunk(self,
|
||||||
m: torch.Tensor,
|
m: torch.Tensor,
|
||||||
mask: torch.Tensor,
|
|
||||||
chunk_size: int,
|
chunk_size: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
mha_input = {
|
mha_input = {
|
||||||
"m": m,
|
"m": m,
|
||||||
"mask": mask,
|
|
||||||
}
|
}
|
||||||
return chunk_layer(
|
return chunk_layer(
|
||||||
self.global_attention,
|
self.global_attention,
|
||||||
@ -361,30 +310,20 @@ class MSAColumnGlobalAttention(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
m: torch.Tensor,
|
m: torch.Tensor,
|
||||||
mask: Optional[torch.Tensor] = None,
|
|
||||||
chunk_size: Optional[int] = None,
|
chunk_size: Optional[int] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
n_seq, n_res, c_in = m.shape[-3:]
|
n_seq, n_res, c_in = m.shape[-3:]
|
||||||
|
|
||||||
if mask is None:
|
|
||||||
# [*, N_seq, N_res]
|
|
||||||
mask = torch.ones(
|
|
||||||
m.shape[:-1],
|
|
||||||
dtype=m.dtype,
|
|
||||||
device=m.device,
|
|
||||||
).detach()
|
|
||||||
|
|
||||||
# [*, N_res, N_seq, C_in]
|
# [*, N_res, N_seq, C_in]
|
||||||
m = m.transpose(-2, -3)
|
m = m.transpose(-2, -3)
|
||||||
mask = mask.transpose(-1, -2)
|
|
||||||
|
|
||||||
# [*, N_res, N_seq, C_in]
|
# [*, N_res, N_seq, C_in]
|
||||||
m = self.layer_norm_m(m)
|
m = self.layer_norm_m(m)
|
||||||
|
|
||||||
if chunk_size is not None:
|
if chunk_size is not None:
|
||||||
m = self._chunk(m, mask, chunk_size)
|
m = self._chunk(m, chunk_size)
|
||||||
else:
|
else:
|
||||||
m = self.global_attention(m=m, mask=mask)
|
m = self.global_attention(m=m)
|
||||||
|
|
||||||
# [*, N_seq, N_res, C_in]
|
# [*, N_seq, N_res, C_in]
|
||||||
m = m.transpose(-2, -3)
|
m = m.transpose(-2, -3)
|
||||||
|
Loading…
Reference in New Issue
Block a user