mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-03 20:16:52 +00:00
community[minor]: Nemo embeddings(#16206)
This PR is adding support for NVIDIA NeMo embeddings issue #16095. --------- Co-authored-by: Praveen Nakshatrala <pnakshatrala@gmail.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
169
libs/community/langchain_community/embeddings/nemo.py
Normal file
169
libs/community/langchain_community/embeddings/nemo.py
Normal file
@@ -0,0 +1,169 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
|
||||
|
||||
def is_endpoint_live(url: str, headers: Optional[dict], payload: Any) -> bool:
|
||||
"""
|
||||
Check if an endpoint is live by sending a GET request to the specified URL.
|
||||
|
||||
Args:
|
||||
url (str): The URL of the endpoint to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the endpoint is live (status code 200), False otherwise.
|
||||
|
||||
Raises:
|
||||
Exception: If the endpoint returns a non-successful status code or if there is
|
||||
an error querying the endpoint.
|
||||
"""
|
||||
try:
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
|
||||
# Check if the status code is 200 (OK)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
# Raise an exception if the status code is not 200
|
||||
raise Exception(
|
||||
f"Endpoint returned a non-successful status code: "
|
||||
f"{response.status_code}"
|
||||
)
|
||||
except requests.exceptions.RequestException as e:
|
||||
# Handle any exceptions (e.g., connection errors)
|
||||
raise Exception(f"Error querying the endpoint: {e}")
|
||||
|
||||
|
||||
class NeMoEmbeddings(BaseModel, Embeddings):
|
||||
batch_size: int = 16
|
||||
model: str = "NV-Embed-QA-003"
|
||||
api_endpoint_url: str = "http://localhost:8088/v1/embeddings"
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that the end point is alive using the values that are provided."""
|
||||
|
||||
url = values["api_endpoint_url"]
|
||||
model = values["model"]
|
||||
|
||||
# Optional: A minimal test payload and headers required by the endpoint
|
||||
headers = {"Content-Type": "application/json"}
|
||||
payload = json.dumps(
|
||||
{"input": "Hello World", "model": model, "input_type": "query"}
|
||||
)
|
||||
|
||||
is_endpoint_live(url, headers, payload)
|
||||
|
||||
return values
|
||||
|
||||
async def _aembedding_func(
|
||||
self, session: Any, text: str, input_type: str
|
||||
) -> List[float]:
|
||||
"""Async call out to embedding endpoint.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
async with session.post(
|
||||
self.api_endpoint_url,
|
||||
json={"input": text, "model": self.model, "input_type": input_type},
|
||||
headers=headers,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
answer = await response.text()
|
||||
answer = json.loads(answer)
|
||||
return answer["data"][0]["embedding"]
|
||||
|
||||
def _embedding_func(self, text: str, input_type: str) -> List[float]:
|
||||
"""Call out to Cohere's embedding endpoint.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
|
||||
payload = json.dumps(
|
||||
{"input": text, "model": self.model, "input_type": input_type}
|
||||
)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
response = requests.request(
|
||||
"POST", self.api_endpoint_url, headers=headers, data=payload
|
||||
)
|
||||
response_json = json.loads(response.text)
|
||||
embedding = response_json["data"][0]["embedding"]
|
||||
|
||||
return embedding
|
||||
|
||||
def embed_documents(self, documents: List[str]) -> List[List[float]]:
|
||||
"""Embed a list of document texts.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
return [self._embedding_func(text, input_type="passage") for text in documents]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self._embedding_func(text, input_type="query")
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Call out to NeMo's embedding endpoint async for embedding query text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
embedding = await self._aembedding_func(session, text, "passage")
|
||||
return embedding
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to NeMo's embedding endpoint async for embedding search docs.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = []
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for batch in range(0, len(texts), self.batch_size):
|
||||
text_batch = texts[batch : batch + self.batch_size]
|
||||
|
||||
for text in text_batch:
|
||||
# Create tasks for all texts in the batch
|
||||
tasks = [
|
||||
self._aembedding_func(session, text, "passage")
|
||||
for text in text_batch
|
||||
]
|
||||
|
||||
# Run all tasks concurrently
|
||||
batch_results = await asyncio.gather(*tasks)
|
||||
|
||||
# Extend the embeddings list with results from this batch
|
||||
embeddings.extend(batch_results)
|
||||
|
||||
return embeddings
|
Reference in New Issue
Block a user