nvidia-ai-endpoints[patch]: model arguments (e.g. temperature) on construction bug (#17290)

- **Issue:** Issue with model argument support (been there for a while
actually):
- Non-specially-handled arguments like temperature don't work when
passed through constructor.
- Such arguments DO work quite well with `bind`, but also do not abide
by field requirements.
- Since initial push, server-side error messages have gotten better and
v0.0.2 raises better exceptions. So maybe it's better to let server-side
handle such issues?
- **Description:**
- Removed ChatNVIDIA's argument fields in favor of
`model_kwargs`/`model_kws` arguments which aggregates constructor kwargs
(from constructor pathway) and merges them with call kwargs (bind
pathway).
- Shuffled a few functions from `_NVIDIAClient` to `ChatNVIDIA` to
streamline construction for future integrations.
- Minor/Optional: Old services didn't have stop support, so client-side
stopping was implemented. Now do both.
- **Any Breaking Changes:** Minor breaking changes if you strongly rely
on chat_model.temperature, etc. This is captured by
chat_model.model_kwargs.

PR passes tests and example notebooks and example testing. Still gonna
chat with some people, so leaving as draft for now.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Vadim Kudlay 2024-02-09 15:46:02 -06:00 committed by GitHub
parent 932c52c333
commit 5f9ac6986e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 313 additions and 242 deletions

File diff suppressed because one or more lines are too long

View File

@ -28,9 +28,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install --upgrade --quiet langchain-nvidia-ai-endpoints"
]
@ -56,7 +64,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -64,7 +72,15 @@
"id": "hoF41-tNczS3",
"outputId": "7f2833dc-191c-4d73-b823-7b2745a93a2f"
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Valid NVIDIA_API_KEY already in environment. Delete to reset\n"
]
}
],
"source": [
"import getpass\n",
"import os\n",
@ -105,7 +121,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {
"id": "hbXmJssPdIPX"
},
@ -180,7 +196,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -194,15 +210,15 @@
"output_type": "stream",
"text": [
"Single Query Embedding: \n",
"\u001b[1mExecuted in 1.39 seconds.\u001b[0m\n",
"\u001b[1mExecuted in 2.19 seconds.\u001b[0m\n",
"Shape: (1024,)\n",
"\n",
"Sequential Embedding: \n",
"\u001b[1mExecuted in 3.20 seconds.\u001b[0m\n",
"\u001b[1mExecuted in 3.16 seconds.\u001b[0m\n",
"Shape: (5, 1024)\n",
"\n",
"Batch Query Embedding: \n",
"\u001b[1mExecuted in 1.52 seconds.\u001b[0m\n",
"\u001b[1mExecuted in 1.23 seconds.\u001b[0m\n",
"Shape: (5, 1024)\n"
]
}
@ -260,7 +276,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -274,11 +290,11 @@
"output_type": "stream",
"text": [
"Single Document Embedding: \n",
"\u001b[1mExecuted in 0.76 seconds.\u001b[0m\n",
"\u001b[1mExecuted in 0.52 seconds.\u001b[0m\n",
"Shape: (1024,)\n",
"\n",
"Batch Document Embedding: \n",
"\u001b[1mExecuted in 0.86 seconds.\u001b[0m\n",
"\u001b[1mExecuted in 0.89 seconds.\u001b[0m\n",
"Shape: (5, 1024)\n"
]
}
@ -324,7 +340,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [
{
@ -341,7 +357,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",

View File

@ -20,7 +20,6 @@ from typing import (
import aiohttp
import requests
from langchain_core.messages import BaseMessage
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
@ -440,10 +439,6 @@ class _NVIDIAClient(BaseModel):
model: str = Field(..., description="Name of the model to invoke")
temperature: float = Field(0.2, le=1.0, gt=0.0)
top_p: float = Field(0.7, le=1.0, ge=0.0)
max_tokens: int = Field(1024, le=1024, ge=32)
####################################################################################
@root_validator(pre=True)
@ -485,67 +480,3 @@ class _NVIDIAClient(BaseModel):
known_fns = self.client.available_functions
fn_spec = [f for f in known_fns if f.get("id") == model_key][0]
return fn_spec
def get_generation(
self,
inputs: Sequence[Dict],
labels: Optional[dict] = None,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> dict:
"""Call to client generate method with call scope"""
payload = self.get_payload(inputs=inputs, stream=False, labels=labels, **kwargs)
out = self.client.get_req_generation(self.model, stop=stop, payload=payload)
return out
def get_stream(
self,
inputs: Sequence[Dict],
labels: Optional[dict] = None,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> Iterator:
"""Call to client stream method with call scope"""
payload = self.get_payload(inputs=inputs, stream=True, labels=labels, **kwargs)
return self.client.get_req_stream(self.model, stop=stop, payload=payload)
def get_astream(
self,
inputs: Sequence[Dict],
labels: Optional[dict] = None,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> AsyncIterator:
"""Call to client astream methods with call scope"""
payload = self.get_payload(inputs=inputs, stream=True, labels=labels, **kwargs)
return self.client.get_req_astream(self.model, stop=stop, payload=payload)
def get_payload(
self, inputs: Sequence[Dict], labels: Optional[dict] = None, **kwargs: Any
) -> dict:
"""Generates payload for the _NVIDIAClient API to send to service."""
return {
**self.preprocess(inputs=inputs, labels=labels),
**kwargs,
}
def preprocess(self, inputs: Sequence[Dict], labels: Optional[dict] = None) -> dict:
"""Prepares a message or list of messages for the payload"""
messages = [self.prep_msg(m) for m in inputs]
if labels:
# (WFH) Labels are currently (?) always passed as an assistant
# suffix message, but this API seems less stable.
messages += [{"labels": labels, "role": "assistant"}]
return {"messages": messages}
def prep_msg(self, msg: Union[str, dict, BaseMessage]) -> dict:
"""Helper Method: Ensures a message is a dictionary with a role and content."""
if isinstance(msg, str):
# (WFH) this shouldn't ever be reached but leaving this here bcs
# it's a Chesterton's fence I'm unwilling to touch
return dict(role="user", content=msg)
if isinstance(msg, dict):
if msg.get("content", None) is None:
raise ValueError(f"Message {msg} has no content")
return msg
raise ValueError(f"Unknown message received: {msg} of type {type(msg)}")

View File

@ -27,6 +27,7 @@ from langchain_core.callbacks.manager import (
from langchain_core.language_models.chat_models import SimpleChatModel
from langchain_core.messages import BaseMessage, ChatMessage, ChatMessageChunk
from langchain_core.outputs import ChatGenerationChunk
from langchain_core.pydantic_v1 import Field
from langchain_nvidia_ai_endpoints import _common as nvidia_ai_endpoints
@ -116,6 +117,14 @@ class ChatNVIDIA(nvidia_ai_endpoints._NVIDIAClient, SimpleChatModel):
response = model.invoke("Hello")
"""
temperature: Optional[float] = Field(description="Sampling temperature in [0, 1]")
max_tokens: Optional[int] = Field(description="Maximum # of tokens to generate")
top_p: Optional[float] = Field(description="Top-p for distribution sampling")
seed: Optional[int] = Field(description="The seed for deterministic results")
bad: Optional[Sequence[str]] = Field(description="Bad words to avoid (cased)")
stop: Optional[Sequence[str]] = Field(description="Stop words (cased)")
labels: Optional[Dict[str, float]] = Field(description="Steering parameters")
@property
def _llm_type(self) -> str:
"""Return type of NVIDIA AI Foundation Model Interface."""
@ -126,14 +135,11 @@ class ChatNVIDIA(nvidia_ai_endpoints._NVIDIAClient, SimpleChatModel):
messages: List[BaseMessage],
stop: Optional[Sequence[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
labels: Optional[dict] = None,
**kwargs: Any,
) -> str:
"""Invoke on a single list of chat messages."""
inputs = self.custom_preprocess(messages)
responses = self.get_generation(
inputs=inputs, stop=stop, labels=labels, **kwargs
)
responses = self.get_generation(inputs=inputs, stop=stop, **kwargs)
outputs = self.custom_postprocess(responses)
return outputs
@ -148,14 +154,11 @@ class ChatNVIDIA(nvidia_ai_endpoints._NVIDIAClient, SimpleChatModel):
messages: List[BaseMessage],
stop: Optional[Sequence[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
labels: Optional[dict] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Allows streaming to model!"""
inputs = self.custom_preprocess(messages)
for response in self.get_stream(
inputs=inputs, stop=stop, labels=labels, **kwargs
):
for response in self.get_stream(inputs=inputs, stop=stop, **kwargs):
chunk = self._get_filled_chunk(self.custom_postprocess(response))
yield chunk
if run_manager:
@ -166,13 +169,10 @@ class ChatNVIDIA(nvidia_ai_endpoints._NVIDIAClient, SimpleChatModel):
messages: List[BaseMessage],
stop: Optional[Sequence[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
labels: Optional[dict] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
inputs = self.custom_preprocess(messages)
async for response in self.get_astream(
inputs=inputs, stop=stop, labels=labels, **kwargs
):
async for response in self.get_astream(inputs=inputs, stop=stop, **kwargs):
chunk = self._get_filled_chunk(self.custom_postprocess(response))
yield chunk
if run_manager:
@ -229,7 +229,78 @@ class ChatNVIDIA(nvidia_ai_endpoints._NVIDIAClient, SimpleChatModel):
def custom_postprocess(self, msg: dict) -> str:
if "content" in msg:
return msg["content"]
logger.warning(
f"Got ambiguous message in postprocessing; returning as-is: msg = {msg}"
)
elif "b64_json" in msg:
return msg["b64_json"]
return str(msg)
######################################################################################
## Core client-side interfaces
def get_generation(
self,
inputs: Sequence[Dict],
**kwargs: Any,
) -> dict:
"""Call to client generate method with call scope"""
stop = kwargs.get("stop", None)
payload = self.get_payload(inputs=inputs, stream=False, **kwargs)
out = self.client.get_req_generation(self.model, stop=stop, payload=payload)
return out
def get_stream(
self,
inputs: Sequence[Dict],
**kwargs: Any,
) -> Iterator:
"""Call to client stream method with call scope"""
stop = kwargs.get("stop", None)
payload = self.get_payload(inputs=inputs, stream=True, **kwargs)
return self.client.get_req_stream(self.model, stop=stop, payload=payload)
def get_astream(
self,
inputs: Sequence[Dict],
**kwargs: Any,
) -> AsyncIterator:
"""Call to client astream methods with call scope"""
stop = kwargs.get("stop", None)
payload = self.get_payload(inputs=inputs, stream=True, **kwargs)
return self.client.get_req_astream(self.model, stop=stop, payload=payload)
def get_payload(self, inputs: Sequence[Dict], **kwargs: Any) -> dict:
"""Generates payload for the _NVIDIAClient API to send to service."""
attr_kwargs = {
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"seed": self.seed,
"bad": self.bad,
"stop": self.stop,
"labels": self.labels,
}
attr_kwargs = {k: v for k, v in attr_kwargs.items() if v is not None}
new_kwargs = {**attr_kwargs, **kwargs}
return self.prep_payload(inputs=inputs, **new_kwargs)
def prep_payload(self, inputs: Sequence[Dict], **kwargs: Any) -> dict:
"""Prepares a message or list of messages for the payload"""
messages = [self.prep_msg(m) for m in inputs]
if kwargs.get("labels"):
# (WFH) Labels are currently (?) always passed as an assistant
# suffix message, but this API seems less stable.
messages += [{"labels": kwargs.pop("labels"), "role": "assistant"}]
if kwargs.get("stop") is None:
kwargs.pop("stop")
return {"messages": messages, **kwargs}
def prep_msg(self, msg: Union[str, dict, BaseMessage]) -> dict:
"""Helper Method: Ensures a message is a dictionary with a role and content."""
if isinstance(msg, str):
# (WFH) this shouldn't ever be reached but leaving this here bcs
# it's a Chesterton's fence I'm unwilling to touch
return dict(role="user", content=msg)
if isinstance(msg, dict):
if msg.get("content", None) is None:
raise ValueError(f"Message {msg} has no content")
return msg
raise ValueError(f"Unknown message received: {msg} of type {type(msg)}")

View File

@ -1,51 +1,21 @@
"""Embeddings Components Derived from NVEModel/Embeddings"""
from typing import Any, List, Literal, Optional
from typing import List, Literal, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.pydantic_v1 import Field
from langchain_nvidia_ai_endpoints._common import NVEModel
from langchain_nvidia_ai_endpoints._common import _NVIDIAClient
class NVIDIAEmbeddings(BaseModel, Embeddings):
class NVIDIAEmbeddings(_NVIDIAClient, Embeddings):
"""NVIDIA's AI Foundation Retriever Question-Answering Asymmetric Model."""
client: NVEModel = Field(NVEModel)
model: str = Field(
..., description="The embedding model to use. Example: nvolveqa_40k"
)
max_length: int = Field(2048, ge=1, le=2048)
max_batch_size: int = Field(default=50)
model_type: Optional[Literal["passage", "query"]] = Field(
"passage", description="The type of text to be embedded."
)
@root_validator(pre=True)
def _validate_client(cls, values: Any) -> Any:
if "client" not in values:
values["client"] = NVEModel(**values)
return values
@property
def available_functions(self) -> List[dict]:
"""Map the available functions that can be invoked."""
return self.client.available_functions
@property
def available_models(self) -> dict:
"""Map the available models that can be invoked."""
return self.client.available_models
@staticmethod
def get_available_functions(**kwargs: Any) -> List[dict]:
"""Map the available functions that can be invoked. Callable from class"""
return NVEModel(**kwargs).available_functions
@staticmethod
def get_available_models(**kwargs: Any) -> dict:
"""Map the available models that can be invoked. Callable from class"""
return NVEModel(**kwargs).available_models
def _embed(
self, texts: List[str], model_type: Literal["passage", "query"]
) -> List[List[float]]:

View File

@ -11,7 +11,7 @@ def test_chat_ai_endpoints() -> None:
temperature=0.7,
)
message = HumanMessage(content="Hello")
response = chat([message])
response = chat.invoke([message])
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)