fix: retrieval

This commit is contained in:
Zach Nussbaum 2023-05-17 20:41:53 +00:00
parent db90a15911
commit cfddd78eb4
3 changed files with 13 additions and 9 deletions

View File

@ -2,6 +2,8 @@ from .gpt_jr.configuration_gpt_jr import GPTJRConfig
from .gpt_jr.modeling_gpt_jr import GPTJRForCausalLM from .gpt_jr.modeling_gpt_jr import GPTJRForCausalLM
from .pythiaseek import PythiaSeekForCausalLM, PythiaSeekConfig from .pythiaseek import PythiaSeekForCausalLM, PythiaSeekConfig
from .pythia_retro import PythiaRetroForCausalLM, PythiaRetroConfig
from .lethe import LetheConfig, LetheForCausalLM
__all__ = [ __all__ = [
@ -9,4 +11,8 @@ __all__ = [
"GPTJRForCausalLM", "GPTJRForCausalLM",
"PythiaSeekConfig", "PythiaSeekConfig",
"PythiaSeekForCausalLM", "PythiaSeekForCausalLM",
"PythiaRetroConfig",
"PythiaRetroForCausalLM",
"LetheConfig",
"LetheForCausalLM"
] ]

View File

@ -246,7 +246,7 @@ class PythiaSeekCrossAttention(PythiaSeekAttention):
f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and" f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
f" `num_attention_heads`: {self.num_attention_heads})." f" `num_attention_heads`: {self.num_attention_heads})."
) )
self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()) self.norm_factor = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
@ -282,7 +282,7 @@ class PythiaSeekCrossAttention(PythiaSeekAttention):
# query -> (bs, num_attention_heads, seq_len, head_dim) # query -> (bs, num_attention_heads, seq_len, head_dim)
# key -> (bs, num_attention_heads, head_dim, neighbors) # key -> (bs, num_attention_heads, head_dim, neighbors)
# attn_weights -> (bs, num_attention_heads, seq_len, neighbors) # attn_weights -> (bs, num_attention_heads, seq_len, neighbors)
attn_weights = torch.matmul(query, key) attn_weights = torch.matmul(query, key) / self.norm_factor
attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype) attn_weights = attn_weights.to(value.dtype)
@ -498,11 +498,10 @@ class PythiaSeekLayer(nn.Module):
cross_attn_output[0] cross_attn_output[0]
) )
if step is not None: if self.learnable_alpha:
if self.learnable_alpha: alpha = F.sigmoid(self.alpha)
alpha = F.sigmoid(self.alpha) elif step is not None:
else: alpha = self._update_alpha(step)
alpha = self._update_alpha(step)
else: else:
alpha = 0.5 alpha = 0.5

View File

@ -183,8 +183,7 @@ def train(accelerator, config):
outputs = model(input_ids=batch["input_ids"], outputs = model(input_ids=batch["input_ids"],
labels=batch["labels"], labels=batch["labels"],
encoder_hidden_states=batch["encoder_hidden_states"], encoder_hidden_states=batch["encoder_hidden_states"],
) step=curr_step)
#step=curr_step)
loss = outputs.loss loss = outputs.loss
if config["debug"]: if config["debug"]: