mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-06 12:06:54 +00:00
fix: retrieval
This commit is contained in:
parent
db90a15911
commit
cfddd78eb4
@ -2,6 +2,8 @@ from .gpt_jr.configuration_gpt_jr import GPTJRConfig
|
||||
from .gpt_jr.modeling_gpt_jr import GPTJRForCausalLM
|
||||
|
||||
from .pythiaseek import PythiaSeekForCausalLM, PythiaSeekConfig
|
||||
from .pythia_retro import PythiaRetroForCausalLM, PythiaRetroConfig
|
||||
from .lethe import LetheConfig, LetheForCausalLM
|
||||
|
||||
|
||||
__all__ = [
|
||||
@ -9,4 +11,8 @@ __all__ = [
|
||||
"GPTJRForCausalLM",
|
||||
"PythiaSeekConfig",
|
||||
"PythiaSeekForCausalLM",
|
||||
"PythiaRetroConfig",
|
||||
"PythiaRetroForCausalLM",
|
||||
"LetheConfig",
|
||||
"LetheForCausalLM"
|
||||
]
|
@ -246,7 +246,7 @@ class PythiaSeekCrossAttention(PythiaSeekAttention):
|
||||
f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
|
||||
f" `num_attention_heads`: {self.num_attention_heads})."
|
||||
)
|
||||
self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
|
||||
self.norm_factor = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||
@ -282,7 +282,7 @@ class PythiaSeekCrossAttention(PythiaSeekAttention):
|
||||
# query -> (bs, num_attention_heads, seq_len, head_dim)
|
||||
# key -> (bs, num_attention_heads, head_dim, neighbors)
|
||||
# attn_weights -> (bs, num_attention_heads, seq_len, neighbors)
|
||||
attn_weights = torch.matmul(query, key)
|
||||
attn_weights = torch.matmul(query, key) / self.norm_factor
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = attn_weights.to(value.dtype)
|
||||
@ -498,11 +498,10 @@ class PythiaSeekLayer(nn.Module):
|
||||
cross_attn_output[0]
|
||||
)
|
||||
|
||||
if step is not None:
|
||||
if self.learnable_alpha:
|
||||
alpha = F.sigmoid(self.alpha)
|
||||
else:
|
||||
alpha = self._update_alpha(step)
|
||||
if self.learnable_alpha:
|
||||
alpha = F.sigmoid(self.alpha)
|
||||
elif step is not None:
|
||||
alpha = self._update_alpha(step)
|
||||
else:
|
||||
alpha = 0.5
|
||||
|
||||
|
@ -183,8 +183,7 @@ def train(accelerator, config):
|
||||
outputs = model(input_ids=batch["input_ids"],
|
||||
labels=batch["labels"],
|
||||
encoder_hidden_states=batch["encoder_hidden_states"],
|
||||
)
|
||||
#step=curr_step)
|
||||
step=curr_step)
|
||||
loss = outputs.loss
|
||||
|
||||
if config["debug"]:
|
||||
|
Loading…
Reference in New Issue
Block a user