fix: forward works!

This commit is contained in:
Zach Nussbaum 2023-04-21 02:50:27 +00:00
parent 09cddbedc0
commit df79fd64b0
2 changed files with 166 additions and 34 deletions

View File

@ -115,8 +115,7 @@ class GPTJRConfig(PretrainedConfig):
bos_token_id=50256,
eos_token_id=50256,
tie_word_embeddings=False,
encoder_ndim=4096,
alpha=.5,
encoder_dim=4096,
encoder_path=None,
**kwargs
):
@ -139,8 +138,7 @@ class GPTJRConfig(PretrainedConfig):
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.encoder_ndim = encoder_ndim
self.alpha = alpha
self.encoder_dim = encoder_dim
self.encoder_path = encoder_path
super().__init__(

View File

@ -154,11 +154,6 @@ class GPTJRAttention(nn.Module):
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
# Keep the attention weights computation in fp32 to avoid overflow issues
# TODO: do we need to do this with bfloat16??
# query = query.to(torch.float32)
# key = key.to(torch.float32)
attn_weights = torch.matmul(query, key.transpose(-1, -2))
mask_value = torch.finfo(attn_weights.dtype).min
@ -188,7 +183,7 @@ class GPTJRAttention(nn.Module):
def forward(
self,
hidden_states: Optional[torch.FloatTensor],
encoder_hidden_states: Optional[torch.FloatTensor],
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
head_mask: Optional[torch.FloatTensor] = None,
@ -291,13 +286,131 @@ class GPTJRCrossAttention(GPTJRAttention):
)
self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.v_proj = nn.Linear(config.encoder_ndim, self.embed_dim, bias=False)
self.q_proj = nn.Linear(config.encoder_ndim, self.embed_dim, bias=False)
self.k_proj = nn.Linear(config.encoder_dim, self.embed_dim, bias=False)
self.v_proj = nn.Linear(config.encoder_dim, self.embed_dim, bias=False)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.rotary_dim = None
if config.rotary_dim is not None:
self.rotary_dim = config.rotary_dim
def _split_knn_attn_heads(self, tensor, num_attention_heads, attn_head_size):
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1)
def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden dim
"""
# tensor -> (bs, seq_len, num_attention_heads, head_dim)
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
return tensor.view(new_shape)
def _attn(
self,
query,
key,
value,
attention_mask=None,
head_mask=None,
):
# compute causal mask from causal mask buffer
# since key and value don't have seq length, just use causal mask as normal
query_length = query.size(-2)
causal_mask = self.bias[:, :, : query_length, :query_length].to(torch.bool)
# Keep the attention weights computation in fp32 to avoid overflow issues
# TODO: do we need to do this with bfloat16??
# query = query.to(torch.float32)
# key = key.to(torch.float32)
# query -> (bs, seq_len, num_attention_heads, head_dim)
# key -> (bs, num_attention_heads, head_dim)
# attn_weights -> (bs, seq_len, num_attention_heads, num_attention_heads)
attn_weights = torch.matmul(query, key.transpose(-1, -2))
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
attn_weights = attn_weights / self.scale_attn
if attention_mask is not None:
# Apply the attention mask
# attn mask (1, 1, 1, seq_len)
attn_weights = (attn_weights.permute(0, 2, 3, 1) + attention_mask).permute(0, 3, 1, 2)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
# value -> (bs, num_attention_heads, head_dim)
# attn_weights -> (bs, seq_len, num_attention_heads, num_attention_heads)
# attn_output -> (bs, num_attention_heads, seq_len, head_dim)
attn_output = torch.matmul(attn_weights, value.transpose(-1, -2))
return attn_output, attn_weights
def forward(
self,
hidden_states: Optional[torch.FloatTensor],
encoder_hidden_states: Optional[torch.FloatTensor],
attention_mask: Optional[torch.FloatTensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor, Tuple[torch.Tensor]],
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
]:
query = self.q_proj(hidden_states)
# if we are doing cross attention
key = self.k_proj(encoder_hidden_states)
value = self.v_proj(encoder_hidden_states)
# (bs, seq_len, dim) -> (bs, num_attention_heads, seq_len, head_dim)
query = self._split_heads(query, self.num_attention_heads, self.head_dim, False)
# (bs, dim) -> (bs, num_attention_heads, head_dim)
key = self._split_knn_attn_heads(key, self.num_attention_heads, self.head_dim)
value = self._split_knn_attn_heads(value, self.num_attention_heads, self.head_dim)
key = key.permute(0, 2, 1)
query = query.permute(0, 2, 1, 3)
if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
present = (key, value)
else:
present = None
# compute self-attention: V x Softmax(QK^T)
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs # a, present, (attentions)
@ -330,9 +443,13 @@ class GPTJRBlock(nn.Module):
# TODO: fix for n neighbors
# for SBERT this is 384
self.ln_2 = nn.LayerNorm(config.encoder_ndim, eps=config.layer_norm_epsilon)
self.ln_2 = nn.LayerNorm(config.encoder_dim, eps=config.layer_norm_epsilon)
self.cross_attn = GPTJRCrossAttention(config)
self.alpha = config.alpha
self.cross_attn_mlp = GPTJRMLP(inner_dim, config)
self.alpha = nn.Parameter(torch.ones(1), requires_grad=False).to(self.ln_1.weight.dtype)
self.step = 1
self.world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else torch.cuda.device_count() or 1
def forward(
self,
@ -361,31 +478,48 @@ class GPTJRBlock(nn.Module):
feed_forward_hidden_states = self.mlp(hidden_states)
self_attention_residual = attn_output + feed_forward_hidden_states + residual
# encoder_hidden_states (bs, knn, encoder_dim)
# encoder_hidden_states -> (bs, knn, encoder_dim)
encoder_normed = self.ln_2(encoder_hidden_states)
# TODO: how do we handle neighbors
# TODO: we have to make sure we're doing masking right here
# TODO: T5 passes query length to cross attention, do we need that?
cross_attn_outputs = self.cross_attn(
residual,
encoder_hidden_states=encoder_normed,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
num_neighbors = encoder_normed.shape[1]
cross_attn_outputs = []
for k in range(num_neighbors):
# cross_attn_outputs -> (bs, seq_len, num_attention_heads, head_dim)
cross_attn_output = self.cross_attn(
residual,
encoder_hidden_states=encoder_normed[:, k, :],
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
cross_attn_outputs.append(cross_attn_output[0])
cross_attn_output = torch.stack(cross_attn_outputs, dim=1).mean(dim=1)
# gpt-j has parallel ff + attn, can do ff on encoder_normed too I guess?
cross_attn_ff = self.cross_attn_mlp(
cross_attn_output
)
cross_attn_output = cross_attn_outputs[0] # output_attn: a, present, (attentions)
cross_attn_outputs = cross_attn_outputs[1:]
hidden_states = self.alpha * cross_attn_output + (1 - self.alpha) * self_attention_residual
alpha = self.alpha if self.training else 0.5
hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
if use_cache:
outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:]
# if training update alpha
if self.training:
self.step += 1
self._update_alpha(self.step)
return outputs # hidden_states, present, (attentions)
def _update_alpha(self, iteration):
self.alpha.data = torch.clamp(torch.tensor([1 / (iteration * self.world_size) ** 0.05]), min=torch.tensor([0.5]), max=torch.tensor([1.0]))
class GPTJRPreTrainedModel(PreTrainedModel):
@ -769,7 +903,7 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel):
transformer_outputs = self.transformer(
input_ids,
encoder_hidden_states=encoder_outputs[0],
encoder_hidden_states=encoder_outputs,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,