mirror of
https://github.com/hwchase17/langchain.git
synced 2025-11-13 08:13:44 +00:00
… (#14723) - **Description:** Minor updates per marketing requests. Namely, name decisions (AI Foundation Models / AI Playground) - **Tag maintainer:** @hinthornw Do want to pass around the PR for a bit and ask a few more marketing questions before merge, but just want to make sure I'm not working in a vacuum. No major changes to code functionality intended; the PR should be for documentation and only minor tweaks. Note: QA model is a bit borked across staging/prod right now. Relevant teams have been informed and are looking into it, and I'm placeholdered the response to that of a working version in the notebook. Co-authored-by: Vadim Kudlay <32310964+VKudlay@users.noreply.github.com>
75 lines
3.0 KiB
Python
75 lines
3.0 KiB
Python
"""Embeddings Components Derived from NVEModel/Embeddings"""
|
|
from typing import Any, List, Literal, Optional
|
|
|
|
from langchain_core.embeddings import Embeddings
|
|
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
|
|
|
import langchain_nvidia_ai_endpoints._common as nvai_common
|
|
|
|
|
|
class NVIDIAEmbeddings(BaseModel, Embeddings):
|
|
"""NVIDIA's AI Foundation Retriever Question-Answering Asymmetric Model."""
|
|
|
|
client: nvai_common.NVEModel = Field(nvai_common.NVEModel)
|
|
model: str = Field(
|
|
..., description="The embedding model to use. Example: nvolveqa_40k"
|
|
)
|
|
max_length: int = Field(2048, ge=1, le=2048)
|
|
max_batch_size: int = Field(default=50)
|
|
model_type: Optional[Literal["passage", "query"]] = Field(
|
|
"passage", description="The type of text to be embedded."
|
|
)
|
|
|
|
@root_validator(pre=True)
|
|
def _validate_client(cls, values: Any) -> Any:
|
|
if "client" not in values:
|
|
values["client"] = nvai_common.NVEModel()
|
|
return values
|
|
|
|
@property
|
|
def available_models(self) -> dict:
|
|
"""Map the available models that can be invoked."""
|
|
return self.client.available_models
|
|
|
|
def _embed(
|
|
self, texts: List[str], model_type: Literal["passage", "query"]
|
|
) -> List[List[float]]:
|
|
"""Embed a single text entry to either passage or query type"""
|
|
response = self.client.get_req(
|
|
model_name=self.model,
|
|
payload={
|
|
"input": texts,
|
|
"model": model_type,
|
|
"encoding_format": "float",
|
|
},
|
|
)
|
|
response.raise_for_status()
|
|
result = response.json()
|
|
data = result["data"]
|
|
if not isinstance(data, list):
|
|
raise ValueError(f"Expected a list of embeddings. Got: {data}")
|
|
embedding_list = [(res["embedding"], res["index"]) for res in data]
|
|
return [x[0] for x in sorted(embedding_list, key=lambda x: x[1])]
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
"""Input pathway for query embeddings."""
|
|
return self._embed([text], model_type=self.model_type or "query")[0]
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
"""Input pathway for document embeddings."""
|
|
# From https://catalog.ngc.nvidia.com/orgs/nvidia/teams/ai-foundation/models/nvolve-40k/documentation
|
|
# The input must not exceed the 2048 max input characters and inputs above 512
|
|
# model tokens will be truncated. The input array must not exceed 50 input
|
|
# strings.
|
|
all_embeddings = []
|
|
for i in range(0, len(texts), self.max_batch_size):
|
|
batch = texts[i : i + self.max_batch_size]
|
|
truncated = [
|
|
text[: self.max_length] if len(text) > self.max_length else text
|
|
for text in batch
|
|
]
|
|
all_embeddings.extend(
|
|
self._embed(truncated, model_type=self.model_type or "passage")
|
|
)
|
|
return all_embeddings
|