mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-06 03:27:55 +00:00
nvidia-ai-endpoints[patch]: model arguments (e.g. temperature) on construction bug (#17290)
- **Issue:** Issue with model argument support (been there for a while actually): - Non-specially-handled arguments like temperature don't work when passed through constructor. - Such arguments DO work quite well with `bind`, but also do not abide by field requirements. - Since initial push, server-side error messages have gotten better and v0.0.2 raises better exceptions. So maybe it's better to let server-side handle such issues? - **Description:** - Removed ChatNVIDIA's argument fields in favor of `model_kwargs`/`model_kws` arguments which aggregates constructor kwargs (from constructor pathway) and merges them with call kwargs (bind pathway). - Shuffled a few functions from `_NVIDIAClient` to `ChatNVIDIA` to streamline construction for future integrations. - Minor/Optional: Old services didn't have stop support, so client-side stopping was implemented. Now do both. - **Any Breaking Changes:** Minor breaking changes if you strongly rely on chat_model.temperature, etc. This is captured by chat_model.model_kwargs. PR passes tests and example notebooks and example testing. Still gonna chat with some people, so leaving as draft for now. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
932c52c333
commit
5f9ac6986e
File diff suppressed because one or more lines are too long
@ -28,9 +28,17 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 1,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Note: you may need to restart the kernel to use updated packages.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"%pip install --upgrade --quiet langchain-nvidia-ai-endpoints"
|
"%pip install --upgrade --quiet langchain-nvidia-ai-endpoints"
|
||||||
]
|
]
|
||||||
@ -56,7 +64,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 2,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
@ -64,7 +72,15 @@
|
|||||||
"id": "hoF41-tNczS3",
|
"id": "hoF41-tNczS3",
|
||||||
"outputId": "7f2833dc-191c-4d73-b823-7b2745a93a2f"
|
"outputId": "7f2833dc-191c-4d73-b823-7b2745a93a2f"
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Valid NVIDIA_API_KEY already in environment. Delete to reset\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import getpass\n",
|
"import getpass\n",
|
||||||
"import os\n",
|
"import os\n",
|
||||||
@ -105,7 +121,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 3,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "hbXmJssPdIPX"
|
"id": "hbXmJssPdIPX"
|
||||||
},
|
},
|
||||||
@ -180,7 +196,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 4,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
@ -194,15 +210,15 @@
|
|||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Single Query Embedding: \n",
|
"Single Query Embedding: \n",
|
||||||
"\u001b[1mExecuted in 1.39 seconds.\u001b[0m\n",
|
"\u001b[1mExecuted in 2.19 seconds.\u001b[0m\n",
|
||||||
"Shape: (1024,)\n",
|
"Shape: (1024,)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Sequential Embedding: \n",
|
"Sequential Embedding: \n",
|
||||||
"\u001b[1mExecuted in 3.20 seconds.\u001b[0m\n",
|
"\u001b[1mExecuted in 3.16 seconds.\u001b[0m\n",
|
||||||
"Shape: (5, 1024)\n",
|
"Shape: (5, 1024)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Batch Query Embedding: \n",
|
"Batch Query Embedding: \n",
|
||||||
"\u001b[1mExecuted in 1.52 seconds.\u001b[0m\n",
|
"\u001b[1mExecuted in 1.23 seconds.\u001b[0m\n",
|
||||||
"Shape: (5, 1024)\n"
|
"Shape: (5, 1024)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@ -260,7 +276,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 5,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
@ -274,11 +290,11 @@
|
|||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Single Document Embedding: \n",
|
"Single Document Embedding: \n",
|
||||||
"\u001b[1mExecuted in 0.76 seconds.\u001b[0m\n",
|
"\u001b[1mExecuted in 0.52 seconds.\u001b[0m\n",
|
||||||
"Shape: (1024,)\n",
|
"Shape: (1024,)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Batch Document Embedding: \n",
|
"Batch Document Embedding: \n",
|
||||||
"\u001b[1mExecuted in 0.86 seconds.\u001b[0m\n",
|
"\u001b[1mExecuted in 0.89 seconds.\u001b[0m\n",
|
||||||
"Shape: (5, 1024)\n"
|
"Shape: (5, 1024)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@ -324,7 +340,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -341,7 +357,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 7,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
"base_uri": "https://localhost:8080/",
|
"base_uri": "https://localhost:8080/",
|
||||||
|
@ -20,7 +20,6 @@ from typing import (
|
|||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import requests
|
import requests
|
||||||
from langchain_core.messages import BaseMessage
|
|
||||||
from langchain_core.pydantic_v1 import (
|
from langchain_core.pydantic_v1 import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
Field,
|
Field,
|
||||||
@ -440,10 +439,6 @@ class _NVIDIAClient(BaseModel):
|
|||||||
|
|
||||||
model: str = Field(..., description="Name of the model to invoke")
|
model: str = Field(..., description="Name of the model to invoke")
|
||||||
|
|
||||||
temperature: float = Field(0.2, le=1.0, gt=0.0)
|
|
||||||
top_p: float = Field(0.7, le=1.0, ge=0.0)
|
|
||||||
max_tokens: int = Field(1024, le=1024, ge=32)
|
|
||||||
|
|
||||||
####################################################################################
|
####################################################################################
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
@ -485,67 +480,3 @@ class _NVIDIAClient(BaseModel):
|
|||||||
known_fns = self.client.available_functions
|
known_fns = self.client.available_functions
|
||||||
fn_spec = [f for f in known_fns if f.get("id") == model_key][0]
|
fn_spec = [f for f in known_fns if f.get("id") == model_key][0]
|
||||||
return fn_spec
|
return fn_spec
|
||||||
|
|
||||||
def get_generation(
|
|
||||||
self,
|
|
||||||
inputs: Sequence[Dict],
|
|
||||||
labels: Optional[dict] = None,
|
|
||||||
stop: Optional[Sequence[str]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> dict:
|
|
||||||
"""Call to client generate method with call scope"""
|
|
||||||
payload = self.get_payload(inputs=inputs, stream=False, labels=labels, **kwargs)
|
|
||||||
out = self.client.get_req_generation(self.model, stop=stop, payload=payload)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def get_stream(
|
|
||||||
self,
|
|
||||||
inputs: Sequence[Dict],
|
|
||||||
labels: Optional[dict] = None,
|
|
||||||
stop: Optional[Sequence[str]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Iterator:
|
|
||||||
"""Call to client stream method with call scope"""
|
|
||||||
payload = self.get_payload(inputs=inputs, stream=True, labels=labels, **kwargs)
|
|
||||||
return self.client.get_req_stream(self.model, stop=stop, payload=payload)
|
|
||||||
|
|
||||||
def get_astream(
|
|
||||||
self,
|
|
||||||
inputs: Sequence[Dict],
|
|
||||||
labels: Optional[dict] = None,
|
|
||||||
stop: Optional[Sequence[str]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> AsyncIterator:
|
|
||||||
"""Call to client astream methods with call scope"""
|
|
||||||
payload = self.get_payload(inputs=inputs, stream=True, labels=labels, **kwargs)
|
|
||||||
return self.client.get_req_astream(self.model, stop=stop, payload=payload)
|
|
||||||
|
|
||||||
def get_payload(
|
|
||||||
self, inputs: Sequence[Dict], labels: Optional[dict] = None, **kwargs: Any
|
|
||||||
) -> dict:
|
|
||||||
"""Generates payload for the _NVIDIAClient API to send to service."""
|
|
||||||
return {
|
|
||||||
**self.preprocess(inputs=inputs, labels=labels),
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
|
|
||||||
def preprocess(self, inputs: Sequence[Dict], labels: Optional[dict] = None) -> dict:
|
|
||||||
"""Prepares a message or list of messages for the payload"""
|
|
||||||
messages = [self.prep_msg(m) for m in inputs]
|
|
||||||
if labels:
|
|
||||||
# (WFH) Labels are currently (?) always passed as an assistant
|
|
||||||
# suffix message, but this API seems less stable.
|
|
||||||
messages += [{"labels": labels, "role": "assistant"}]
|
|
||||||
return {"messages": messages}
|
|
||||||
|
|
||||||
def prep_msg(self, msg: Union[str, dict, BaseMessage]) -> dict:
|
|
||||||
"""Helper Method: Ensures a message is a dictionary with a role and content."""
|
|
||||||
if isinstance(msg, str):
|
|
||||||
# (WFH) this shouldn't ever be reached but leaving this here bcs
|
|
||||||
# it's a Chesterton's fence I'm unwilling to touch
|
|
||||||
return dict(role="user", content=msg)
|
|
||||||
if isinstance(msg, dict):
|
|
||||||
if msg.get("content", None) is None:
|
|
||||||
raise ValueError(f"Message {msg} has no content")
|
|
||||||
return msg
|
|
||||||
raise ValueError(f"Unknown message received: {msg} of type {type(msg)}")
|
|
||||||
|
@ -27,6 +27,7 @@ from langchain_core.callbacks.manager import (
|
|||||||
from langchain_core.language_models.chat_models import SimpleChatModel
|
from langchain_core.language_models.chat_models import SimpleChatModel
|
||||||
from langchain_core.messages import BaseMessage, ChatMessage, ChatMessageChunk
|
from langchain_core.messages import BaseMessage, ChatMessage, ChatMessageChunk
|
||||||
from langchain_core.outputs import ChatGenerationChunk
|
from langchain_core.outputs import ChatGenerationChunk
|
||||||
|
from langchain_core.pydantic_v1 import Field
|
||||||
|
|
||||||
from langchain_nvidia_ai_endpoints import _common as nvidia_ai_endpoints
|
from langchain_nvidia_ai_endpoints import _common as nvidia_ai_endpoints
|
||||||
|
|
||||||
@ -116,6 +117,14 @@ class ChatNVIDIA(nvidia_ai_endpoints._NVIDIAClient, SimpleChatModel):
|
|||||||
response = model.invoke("Hello")
|
response = model.invoke("Hello")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
temperature: Optional[float] = Field(description="Sampling temperature in [0, 1]")
|
||||||
|
max_tokens: Optional[int] = Field(description="Maximum # of tokens to generate")
|
||||||
|
top_p: Optional[float] = Field(description="Top-p for distribution sampling")
|
||||||
|
seed: Optional[int] = Field(description="The seed for deterministic results")
|
||||||
|
bad: Optional[Sequence[str]] = Field(description="Bad words to avoid (cased)")
|
||||||
|
stop: Optional[Sequence[str]] = Field(description="Stop words (cased)")
|
||||||
|
labels: Optional[Dict[str, float]] = Field(description="Steering parameters")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
"""Return type of NVIDIA AI Foundation Model Interface."""
|
"""Return type of NVIDIA AI Foundation Model Interface."""
|
||||||
@ -126,14 +135,11 @@ class ChatNVIDIA(nvidia_ai_endpoints._NVIDIAClient, SimpleChatModel):
|
|||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[Sequence[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
labels: Optional[dict] = None,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Invoke on a single list of chat messages."""
|
"""Invoke on a single list of chat messages."""
|
||||||
inputs = self.custom_preprocess(messages)
|
inputs = self.custom_preprocess(messages)
|
||||||
responses = self.get_generation(
|
responses = self.get_generation(inputs=inputs, stop=stop, **kwargs)
|
||||||
inputs=inputs, stop=stop, labels=labels, **kwargs
|
|
||||||
)
|
|
||||||
outputs = self.custom_postprocess(responses)
|
outputs = self.custom_postprocess(responses)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@ -148,14 +154,11 @@ class ChatNVIDIA(nvidia_ai_endpoints._NVIDIAClient, SimpleChatModel):
|
|||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[Sequence[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
labels: Optional[dict] = None,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
"""Allows streaming to model!"""
|
"""Allows streaming to model!"""
|
||||||
inputs = self.custom_preprocess(messages)
|
inputs = self.custom_preprocess(messages)
|
||||||
for response in self.get_stream(
|
for response in self.get_stream(inputs=inputs, stop=stop, **kwargs):
|
||||||
inputs=inputs, stop=stop, labels=labels, **kwargs
|
|
||||||
):
|
|
||||||
chunk = self._get_filled_chunk(self.custom_postprocess(response))
|
chunk = self._get_filled_chunk(self.custom_postprocess(response))
|
||||||
yield chunk
|
yield chunk
|
||||||
if run_manager:
|
if run_manager:
|
||||||
@ -166,13 +169,10 @@ class ChatNVIDIA(nvidia_ai_endpoints._NVIDIAClient, SimpleChatModel):
|
|||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[Sequence[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
labels: Optional[dict] = None,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[ChatGenerationChunk]:
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
inputs = self.custom_preprocess(messages)
|
inputs = self.custom_preprocess(messages)
|
||||||
async for response in self.get_astream(
|
async for response in self.get_astream(inputs=inputs, stop=stop, **kwargs):
|
||||||
inputs=inputs, stop=stop, labels=labels, **kwargs
|
|
||||||
):
|
|
||||||
chunk = self._get_filled_chunk(self.custom_postprocess(response))
|
chunk = self._get_filled_chunk(self.custom_postprocess(response))
|
||||||
yield chunk
|
yield chunk
|
||||||
if run_manager:
|
if run_manager:
|
||||||
@ -229,7 +229,78 @@ class ChatNVIDIA(nvidia_ai_endpoints._NVIDIAClient, SimpleChatModel):
|
|||||||
def custom_postprocess(self, msg: dict) -> str:
|
def custom_postprocess(self, msg: dict) -> str:
|
||||||
if "content" in msg:
|
if "content" in msg:
|
||||||
return msg["content"]
|
return msg["content"]
|
||||||
logger.warning(
|
elif "b64_json" in msg:
|
||||||
f"Got ambiguous message in postprocessing; returning as-is: msg = {msg}"
|
return msg["b64_json"]
|
||||||
)
|
|
||||||
return str(msg)
|
return str(msg)
|
||||||
|
|
||||||
|
######################################################################################
|
||||||
|
## Core client-side interfaces
|
||||||
|
|
||||||
|
def get_generation(
|
||||||
|
self,
|
||||||
|
inputs: Sequence[Dict],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> dict:
|
||||||
|
"""Call to client generate method with call scope"""
|
||||||
|
stop = kwargs.get("stop", None)
|
||||||
|
payload = self.get_payload(inputs=inputs, stream=False, **kwargs)
|
||||||
|
out = self.client.get_req_generation(self.model, stop=stop, payload=payload)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def get_stream(
|
||||||
|
self,
|
||||||
|
inputs: Sequence[Dict],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator:
|
||||||
|
"""Call to client stream method with call scope"""
|
||||||
|
stop = kwargs.get("stop", None)
|
||||||
|
payload = self.get_payload(inputs=inputs, stream=True, **kwargs)
|
||||||
|
return self.client.get_req_stream(self.model, stop=stop, payload=payload)
|
||||||
|
|
||||||
|
def get_astream(
|
||||||
|
self,
|
||||||
|
inputs: Sequence[Dict],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator:
|
||||||
|
"""Call to client astream methods with call scope"""
|
||||||
|
stop = kwargs.get("stop", None)
|
||||||
|
payload = self.get_payload(inputs=inputs, stream=True, **kwargs)
|
||||||
|
return self.client.get_req_astream(self.model, stop=stop, payload=payload)
|
||||||
|
|
||||||
|
def get_payload(self, inputs: Sequence[Dict], **kwargs: Any) -> dict:
|
||||||
|
"""Generates payload for the _NVIDIAClient API to send to service."""
|
||||||
|
attr_kwargs = {
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"seed": self.seed,
|
||||||
|
"bad": self.bad,
|
||||||
|
"stop": self.stop,
|
||||||
|
"labels": self.labels,
|
||||||
|
}
|
||||||
|
attr_kwargs = {k: v for k, v in attr_kwargs.items() if v is not None}
|
||||||
|
new_kwargs = {**attr_kwargs, **kwargs}
|
||||||
|
return self.prep_payload(inputs=inputs, **new_kwargs)
|
||||||
|
|
||||||
|
def prep_payload(self, inputs: Sequence[Dict], **kwargs: Any) -> dict:
|
||||||
|
"""Prepares a message or list of messages for the payload"""
|
||||||
|
messages = [self.prep_msg(m) for m in inputs]
|
||||||
|
if kwargs.get("labels"):
|
||||||
|
# (WFH) Labels are currently (?) always passed as an assistant
|
||||||
|
# suffix message, but this API seems less stable.
|
||||||
|
messages += [{"labels": kwargs.pop("labels"), "role": "assistant"}]
|
||||||
|
if kwargs.get("stop") is None:
|
||||||
|
kwargs.pop("stop")
|
||||||
|
return {"messages": messages, **kwargs}
|
||||||
|
|
||||||
|
def prep_msg(self, msg: Union[str, dict, BaseMessage]) -> dict:
|
||||||
|
"""Helper Method: Ensures a message is a dictionary with a role and content."""
|
||||||
|
if isinstance(msg, str):
|
||||||
|
# (WFH) this shouldn't ever be reached but leaving this here bcs
|
||||||
|
# it's a Chesterton's fence I'm unwilling to touch
|
||||||
|
return dict(role="user", content=msg)
|
||||||
|
if isinstance(msg, dict):
|
||||||
|
if msg.get("content", None) is None:
|
||||||
|
raise ValueError(f"Message {msg} has no content")
|
||||||
|
return msg
|
||||||
|
raise ValueError(f"Unknown message received: {msg} of type {type(msg)}")
|
||||||
|
@ -1,51 +1,21 @@
|
|||||||
"""Embeddings Components Derived from NVEModel/Embeddings"""
|
"""Embeddings Components Derived from NVEModel/Embeddings"""
|
||||||
from typing import Any, List, Literal, Optional
|
from typing import List, Literal, Optional
|
||||||
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
from langchain_core.pydantic_v1 import Field
|
||||||
|
|
||||||
from langchain_nvidia_ai_endpoints._common import NVEModel
|
from langchain_nvidia_ai_endpoints._common import _NVIDIAClient
|
||||||
|
|
||||||
|
|
||||||
class NVIDIAEmbeddings(BaseModel, Embeddings):
|
class NVIDIAEmbeddings(_NVIDIAClient, Embeddings):
|
||||||
"""NVIDIA's AI Foundation Retriever Question-Answering Asymmetric Model."""
|
"""NVIDIA's AI Foundation Retriever Question-Answering Asymmetric Model."""
|
||||||
|
|
||||||
client: NVEModel = Field(NVEModel)
|
|
||||||
model: str = Field(
|
|
||||||
..., description="The embedding model to use. Example: nvolveqa_40k"
|
|
||||||
)
|
|
||||||
max_length: int = Field(2048, ge=1, le=2048)
|
max_length: int = Field(2048, ge=1, le=2048)
|
||||||
max_batch_size: int = Field(default=50)
|
max_batch_size: int = Field(default=50)
|
||||||
model_type: Optional[Literal["passage", "query"]] = Field(
|
model_type: Optional[Literal["passage", "query"]] = Field(
|
||||||
"passage", description="The type of text to be embedded."
|
"passage", description="The type of text to be embedded."
|
||||||
)
|
)
|
||||||
|
|
||||||
@root_validator(pre=True)
|
|
||||||
def _validate_client(cls, values: Any) -> Any:
|
|
||||||
if "client" not in values:
|
|
||||||
values["client"] = NVEModel(**values)
|
|
||||||
return values
|
|
||||||
|
|
||||||
@property
|
|
||||||
def available_functions(self) -> List[dict]:
|
|
||||||
"""Map the available functions that can be invoked."""
|
|
||||||
return self.client.available_functions
|
|
||||||
|
|
||||||
@property
|
|
||||||
def available_models(self) -> dict:
|
|
||||||
"""Map the available models that can be invoked."""
|
|
||||||
return self.client.available_models
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_available_functions(**kwargs: Any) -> List[dict]:
|
|
||||||
"""Map the available functions that can be invoked. Callable from class"""
|
|
||||||
return NVEModel(**kwargs).available_functions
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_available_models(**kwargs: Any) -> dict:
|
|
||||||
"""Map the available models that can be invoked. Callable from class"""
|
|
||||||
return NVEModel(**kwargs).available_models
|
|
||||||
|
|
||||||
def _embed(
|
def _embed(
|
||||||
self, texts: List[str], model_type: Literal["passage", "query"]
|
self, texts: List[str], model_type: Literal["passage", "query"]
|
||||||
) -> List[List[float]]:
|
) -> List[List[float]]:
|
||||||
|
@ -11,7 +11,7 @@ def test_chat_ai_endpoints() -> None:
|
|||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
)
|
)
|
||||||
message = HumanMessage(content="Hello")
|
message = HumanMessage(content="Hello")
|
||||||
response = chat([message])
|
response = chat.invoke([message])
|
||||||
assert isinstance(response, BaseMessage)
|
assert isinstance(response, BaseMessage)
|
||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user