This commit is contained in:
Erick Friis
2024-01-23 18:30:34 -07:00
parent f3095acf99
commit da049d3538

View File

@@ -1,5 +1,7 @@
from typing import List
import os
from typing import List, Optional
import nomic
from langchain_core.embeddings import Embeddings
@@ -14,21 +16,24 @@ class NomicEmbeddings(Embeddings):
model = NomicEmbeddings()
"""
def __init__(self, *, model: str, nomic_api_key: Optional[str] = None):
"""Initialize NomicEmbeddings model.
Args:
model: model name
"""
_api_key = nomic_api_key or os.environ.get("NOMIC_API_KEY")
nomic.login(api_key=_api_key)
self.model = model
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
raise NotImplementedError
output = nomic.embed.text(
texts=texts,
model=self.model,
)
return output["embeddings"]
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
raise NotImplementedError
# only keep aembed_documents and aembed_query if they're implemented!
# delete them otherwise to use the base class' default
# implementation, which calls the sync version in an executor
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs."""
raise NotImplementedError
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
raise NotImplementedError
return self.embed_documents([text])[0]