mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-06 12:06:54 +00:00
fix: retrieval
This commit is contained in:
parent
db90a15911
commit
cfddd78eb4
@ -2,6 +2,8 @@ from .gpt_jr.configuration_gpt_jr import GPTJRConfig
|
|||||||
from .gpt_jr.modeling_gpt_jr import GPTJRForCausalLM
|
from .gpt_jr.modeling_gpt_jr import GPTJRForCausalLM
|
||||||
|
|
||||||
from .pythiaseek import PythiaSeekForCausalLM, PythiaSeekConfig
|
from .pythiaseek import PythiaSeekForCausalLM, PythiaSeekConfig
|
||||||
|
from .pythia_retro import PythiaRetroForCausalLM, PythiaRetroConfig
|
||||||
|
from .lethe import LetheConfig, LetheForCausalLM
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -9,4 +11,8 @@ __all__ = [
|
|||||||
"GPTJRForCausalLM",
|
"GPTJRForCausalLM",
|
||||||
"PythiaSeekConfig",
|
"PythiaSeekConfig",
|
||||||
"PythiaSeekForCausalLM",
|
"PythiaSeekForCausalLM",
|
||||||
|
"PythiaRetroConfig",
|
||||||
|
"PythiaRetroForCausalLM",
|
||||||
|
"LetheConfig",
|
||||||
|
"LetheForCausalLM"
|
||||||
]
|
]
|
@ -246,7 +246,7 @@ class PythiaSeekCrossAttention(PythiaSeekAttention):
|
|||||||
f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
|
f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
|
||||||
f" `num_attention_heads`: {self.num_attention_heads})."
|
f" `num_attention_heads`: {self.num_attention_heads})."
|
||||||
)
|
)
|
||||||
self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
|
self.norm_factor = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
|
||||||
|
|
||||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||||
@ -282,7 +282,7 @@ class PythiaSeekCrossAttention(PythiaSeekAttention):
|
|||||||
# query -> (bs, num_attention_heads, seq_len, head_dim)
|
# query -> (bs, num_attention_heads, seq_len, head_dim)
|
||||||
# key -> (bs, num_attention_heads, head_dim, neighbors)
|
# key -> (bs, num_attention_heads, head_dim, neighbors)
|
||||||
# attn_weights -> (bs, num_attention_heads, seq_len, neighbors)
|
# attn_weights -> (bs, num_attention_heads, seq_len, neighbors)
|
||||||
attn_weights = torch.matmul(query, key)
|
attn_weights = torch.matmul(query, key) / self.norm_factor
|
||||||
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
attn_weights = attn_weights.to(value.dtype)
|
attn_weights = attn_weights.to(value.dtype)
|
||||||
@ -498,11 +498,10 @@ class PythiaSeekLayer(nn.Module):
|
|||||||
cross_attn_output[0]
|
cross_attn_output[0]
|
||||||
)
|
)
|
||||||
|
|
||||||
if step is not None:
|
if self.learnable_alpha:
|
||||||
if self.learnable_alpha:
|
alpha = F.sigmoid(self.alpha)
|
||||||
alpha = F.sigmoid(self.alpha)
|
elif step is not None:
|
||||||
else:
|
alpha = self._update_alpha(step)
|
||||||
alpha = self._update_alpha(step)
|
|
||||||
else:
|
else:
|
||||||
alpha = 0.5
|
alpha = 0.5
|
||||||
|
|
||||||
|
@ -183,8 +183,7 @@ def train(accelerator, config):
|
|||||||
outputs = model(input_ids=batch["input_ids"],
|
outputs = model(input_ids=batch["input_ids"],
|
||||||
labels=batch["labels"],
|
labels=batch["labels"],
|
||||||
encoder_hidden_states=batch["encoder_hidden_states"],
|
encoder_hidden_states=batch["encoder_hidden_states"],
|
||||||
)
|
step=curr_step)
|
||||||
#step=curr_step)
|
|
||||||
loss = outputs.loss
|
loss = outputs.loss
|
||||||
|
|
||||||
if config["debug"]:
|
if config["debug"]:
|
||||||
|
Loading…
Reference in New Issue
Block a user