fix: retrieval

This commit is contained in:
Zach Nussbaum 2023-05-17 20:41:53 +00:00
parent db90a15911
commit cfddd78eb4
3 changed files with 13 additions and 9 deletions

View File

@ -2,6 +2,8 @@ from .gpt_jr.configuration_gpt_jr import GPTJRConfig
from .gpt_jr.modeling_gpt_jr import GPTJRForCausalLM
from .pythiaseek import PythiaSeekForCausalLM, PythiaSeekConfig
from .pythia_retro import PythiaRetroForCausalLM, PythiaRetroConfig
from .lethe import LetheConfig, LetheForCausalLM
__all__ = [
@ -9,4 +11,8 @@ __all__ = [
"GPTJRForCausalLM",
"PythiaSeekConfig",
"PythiaSeekForCausalLM",
"PythiaRetroConfig",
"PythiaRetroForCausalLM",
"LetheConfig",
"LetheForCausalLM"
]

View File

@ -246,7 +246,7 @@ class PythiaSeekCrossAttention(PythiaSeekAttention):
f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
f" `num_attention_heads`: {self.num_attention_heads})."
)
self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
self.norm_factor = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
@ -282,7 +282,7 @@ class PythiaSeekCrossAttention(PythiaSeekAttention):
# query -> (bs, num_attention_heads, seq_len, head_dim)
# key -> (bs, num_attention_heads, head_dim, neighbors)
# attn_weights -> (bs, num_attention_heads, seq_len, neighbors)
attn_weights = torch.matmul(query, key)
attn_weights = torch.matmul(query, key) / self.norm_factor
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
@ -498,11 +498,10 @@ class PythiaSeekLayer(nn.Module):
cross_attn_output[0]
)
if step is not None:
if self.learnable_alpha:
alpha = F.sigmoid(self.alpha)
else:
alpha = self._update_alpha(step)
if self.learnable_alpha:
alpha = F.sigmoid(self.alpha)
elif step is not None:
alpha = self._update_alpha(step)
else:
alpha = 0.5

View File

@ -183,8 +183,7 @@ def train(accelerator, config):
outputs = model(input_ids=batch["input_ids"],
labels=batch["labels"],
encoder_hidden_states=batch["encoder_hidden_states"],
)
#step=curr_step)
step=curr_step)
loss = outputs.loss
if config["debug"]: