diff --git a/gpt4all/models/__init__.py b/gpt4all/models/__init__.py index b0ce79f9..575d8e2f 100644 --- a/gpt4all/models/__init__.py +++ b/gpt4all/models/__init__.py @@ -2,6 +2,8 @@ from .gpt_jr.configuration_gpt_jr import GPTJRConfig from .gpt_jr.modeling_gpt_jr import GPTJRForCausalLM from .pythiaseek import PythiaSeekForCausalLM, PythiaSeekConfig +from .pythia_retro import PythiaRetroForCausalLM, PythiaRetroConfig +from .lethe import LetheConfig, LetheForCausalLM __all__ = [ @@ -9,4 +11,8 @@ __all__ = [ "GPTJRForCausalLM", "PythiaSeekConfig", "PythiaSeekForCausalLM", + "PythiaRetroConfig", + "PythiaRetroForCausalLM", + "LetheConfig", + "LetheForCausalLM" ] \ No newline at end of file diff --git a/gpt4all/models/pythiaseek/modeling_pythiaseek.py b/gpt4all/models/pythiaseek/modeling_pythiaseek.py index ef95a3f6..f4e8fcd1 100644 --- a/gpt4all/models/pythiaseek/modeling_pythiaseek.py +++ b/gpt4all/models/pythiaseek/modeling_pythiaseek.py @@ -246,7 +246,7 @@ class PythiaSeekCrossAttention(PythiaSeekAttention): f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and" f" `num_attention_heads`: {self.num_attention_heads})." ) - self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()) + self.norm_factor = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()) self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) @@ -282,7 +282,7 @@ class PythiaSeekCrossAttention(PythiaSeekAttention): # query -> (bs, num_attention_heads, seq_len, head_dim) # key -> (bs, num_attention_heads, head_dim, neighbors) # attn_weights -> (bs, num_attention_heads, seq_len, neighbors) - attn_weights = torch.matmul(query, key) + attn_weights = torch.matmul(query, key) / self.norm_factor attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = attn_weights.to(value.dtype) @@ -498,11 +498,10 @@ class PythiaSeekLayer(nn.Module): cross_attn_output[0] ) - if step is not None: - if self.learnable_alpha: - alpha = F.sigmoid(self.alpha) - else: - alpha = self._update_alpha(step) + if self.learnable_alpha: + alpha = F.sigmoid(self.alpha) + elif step is not None: + alpha = self._update_alpha(step) else: alpha = 0.5 diff --git a/gpt4all/train/train_retrieval.py b/gpt4all/train/train_retrieval.py index 0c6baaf2..b689920d 100644 --- a/gpt4all/train/train_retrieval.py +++ b/gpt4all/train/train_retrieval.py @@ -183,8 +183,7 @@ def train(accelerator, config): outputs = model(input_ids=batch["input_ids"], labels=batch["labels"], encoder_hidden_states=batch["encoder_hidden_states"], - ) - #step=curr_step) + step=curr_step) loss = outputs.loss if config["debug"]: