mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-02 01:23:07 +00:00
nvidia-ai-endpoints[patch]: model arguments (e.g. temperature) on construction bug (#17290)
- **Issue:** Issue with model argument support (been there for a while actually): - Non-specially-handled arguments like temperature don't work when passed through constructor. - Such arguments DO work quite well with `bind`, but also do not abide by field requirements. - Since initial push, server-side error messages have gotten better and v0.0.2 raises better exceptions. So maybe it's better to let server-side handle such issues? - **Description:** - Removed ChatNVIDIA's argument fields in favor of `model_kwargs`/`model_kws` arguments which aggregates constructor kwargs (from constructor pathway) and merges them with call kwargs (bind pathway). - Shuffled a few functions from `_NVIDIAClient` to `ChatNVIDIA` to streamline construction for future integrations. - Minor/Optional: Old services didn't have stop support, so client-side stopping was implemented. Now do both. - **Any Breaking Changes:** Minor breaking changes if you strongly rely on chat_model.temperature, etc. This is captured by chat_model.model_kwargs. PR passes tests and example notebooks and example testing. Still gonna chat with some people, so leaving as draft for now. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
932c52c333
commit
5f9ac6986e
File diff suppressed because one or more lines are too long
@ -28,9 +28,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Note: you may need to restart the kernel to use updated packages.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%pip install --upgrade --quiet langchain-nvidia-ai-endpoints"
|
||||
]
|
||||
@ -56,7 +64,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
@ -64,7 +72,15 @@
|
||||
"id": "hoF41-tNczS3",
|
||||
"outputId": "7f2833dc-191c-4d73-b823-7b2745a93a2f"
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Valid NVIDIA_API_KEY already in environment. Delete to reset\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"import os\n",
|
||||
@ -105,7 +121,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"id": "hbXmJssPdIPX"
|
||||
},
|
||||
@ -180,7 +196,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
@ -194,15 +210,15 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Single Query Embedding: \n",
|
||||
"\u001b[1mExecuted in 1.39 seconds.\u001b[0m\n",
|
||||
"\u001b[1mExecuted in 2.19 seconds.\u001b[0m\n",
|
||||
"Shape: (1024,)\n",
|
||||
"\n",
|
||||
"Sequential Embedding: \n",
|
||||
"\u001b[1mExecuted in 3.20 seconds.\u001b[0m\n",
|
||||
"\u001b[1mExecuted in 3.16 seconds.\u001b[0m\n",
|
||||
"Shape: (5, 1024)\n",
|
||||
"\n",
|
||||
"Batch Query Embedding: \n",
|
||||
"\u001b[1mExecuted in 1.52 seconds.\u001b[0m\n",
|
||||
"\u001b[1mExecuted in 1.23 seconds.\u001b[0m\n",
|
||||
"Shape: (5, 1024)\n"
|
||||
]
|
||||
}
|
||||
@ -260,7 +276,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
@ -274,11 +290,11 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Single Document Embedding: \n",
|
||||
"\u001b[1mExecuted in 0.76 seconds.\u001b[0m\n",
|
||||
"\u001b[1mExecuted in 0.52 seconds.\u001b[0m\n",
|
||||
"Shape: (1024,)\n",
|
||||
"\n",
|
||||
"Batch Document Embedding: \n",
|
||||
"\u001b[1mExecuted in 0.86 seconds.\u001b[0m\n",
|
||||
"\u001b[1mExecuted in 0.89 seconds.\u001b[0m\n",
|
||||
"Shape: (5, 1024)\n"
|
||||
]
|
||||
}
|
||||
@ -324,7 +340,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -341,7 +357,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
|
@ -20,7 +20,6 @@ from typing import (
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
@ -440,10 +439,6 @@ class _NVIDIAClient(BaseModel):
|
||||
|
||||
model: str = Field(..., description="Name of the model to invoke")
|
||||
|
||||
temperature: float = Field(0.2, le=1.0, gt=0.0)
|
||||
top_p: float = Field(0.7, le=1.0, ge=0.0)
|
||||
max_tokens: int = Field(1024, le=1024, ge=32)
|
||||
|
||||
####################################################################################
|
||||
|
||||
@root_validator(pre=True)
|
||||
@ -485,67 +480,3 @@ class _NVIDIAClient(BaseModel):
|
||||
known_fns = self.client.available_functions
|
||||
fn_spec = [f for f in known_fns if f.get("id") == model_key][0]
|
||||
return fn_spec
|
||||
|
||||
def get_generation(
|
||||
self,
|
||||
inputs: Sequence[Dict],
|
||||
labels: Optional[dict] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Call to client generate method with call scope"""
|
||||
payload = self.get_payload(inputs=inputs, stream=False, labels=labels, **kwargs)
|
||||
out = self.client.get_req_generation(self.model, stop=stop, payload=payload)
|
||||
return out
|
||||
|
||||
def get_stream(
|
||||
self,
|
||||
inputs: Sequence[Dict],
|
||||
labels: Optional[dict] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator:
|
||||
"""Call to client stream method with call scope"""
|
||||
payload = self.get_payload(inputs=inputs, stream=True, labels=labels, **kwargs)
|
||||
return self.client.get_req_stream(self.model, stop=stop, payload=payload)
|
||||
|
||||
def get_astream(
|
||||
self,
|
||||
inputs: Sequence[Dict],
|
||||
labels: Optional[dict] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator:
|
||||
"""Call to client astream methods with call scope"""
|
||||
payload = self.get_payload(inputs=inputs, stream=True, labels=labels, **kwargs)
|
||||
return self.client.get_req_astream(self.model, stop=stop, payload=payload)
|
||||
|
||||
def get_payload(
|
||||
self, inputs: Sequence[Dict], labels: Optional[dict] = None, **kwargs: Any
|
||||
) -> dict:
|
||||
"""Generates payload for the _NVIDIAClient API to send to service."""
|
||||
return {
|
||||
**self.preprocess(inputs=inputs, labels=labels),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
def preprocess(self, inputs: Sequence[Dict], labels: Optional[dict] = None) -> dict:
|
||||
"""Prepares a message or list of messages for the payload"""
|
||||
messages = [self.prep_msg(m) for m in inputs]
|
||||
if labels:
|
||||
# (WFH) Labels are currently (?) always passed as an assistant
|
||||
# suffix message, but this API seems less stable.
|
||||
messages += [{"labels": labels, "role": "assistant"}]
|
||||
return {"messages": messages}
|
||||
|
||||
def prep_msg(self, msg: Union[str, dict, BaseMessage]) -> dict:
|
||||
"""Helper Method: Ensures a message is a dictionary with a role and content."""
|
||||
if isinstance(msg, str):
|
||||
# (WFH) this shouldn't ever be reached but leaving this here bcs
|
||||
# it's a Chesterton's fence I'm unwilling to touch
|
||||
return dict(role="user", content=msg)
|
||||
if isinstance(msg, dict):
|
||||
if msg.get("content", None) is None:
|
||||
raise ValueError(f"Message {msg} has no content")
|
||||
return msg
|
||||
raise ValueError(f"Unknown message received: {msg} of type {type(msg)}")
|
||||
|
@ -27,6 +27,7 @@ from langchain_core.callbacks.manager import (
|
||||
from langchain_core.language_models.chat_models import SimpleChatModel
|
||||
from langchain_core.messages import BaseMessage, ChatMessage, ChatMessageChunk
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain_nvidia_ai_endpoints import _common as nvidia_ai_endpoints
|
||||
|
||||
@ -116,6 +117,14 @@ class ChatNVIDIA(nvidia_ai_endpoints._NVIDIAClient, SimpleChatModel):
|
||||
response = model.invoke("Hello")
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = Field(description="Sampling temperature in [0, 1]")
|
||||
max_tokens: Optional[int] = Field(description="Maximum # of tokens to generate")
|
||||
top_p: Optional[float] = Field(description="Top-p for distribution sampling")
|
||||
seed: Optional[int] = Field(description="The seed for deterministic results")
|
||||
bad: Optional[Sequence[str]] = Field(description="Bad words to avoid (cased)")
|
||||
stop: Optional[Sequence[str]] = Field(description="Stop words (cased)")
|
||||
labels: Optional[Dict[str, float]] = Field(description="Steering parameters")
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of NVIDIA AI Foundation Model Interface."""
|
||||
@ -126,14 +135,11 @@ class ChatNVIDIA(nvidia_ai_endpoints._NVIDIAClient, SimpleChatModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
labels: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Invoke on a single list of chat messages."""
|
||||
inputs = self.custom_preprocess(messages)
|
||||
responses = self.get_generation(
|
||||
inputs=inputs, stop=stop, labels=labels, **kwargs
|
||||
)
|
||||
responses = self.get_generation(inputs=inputs, stop=stop, **kwargs)
|
||||
outputs = self.custom_postprocess(responses)
|
||||
return outputs
|
||||
|
||||
@ -148,14 +154,11 @@ class ChatNVIDIA(nvidia_ai_endpoints._NVIDIAClient, SimpleChatModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
labels: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
"""Allows streaming to model!"""
|
||||
inputs = self.custom_preprocess(messages)
|
||||
for response in self.get_stream(
|
||||
inputs=inputs, stop=stop, labels=labels, **kwargs
|
||||
):
|
||||
for response in self.get_stream(inputs=inputs, stop=stop, **kwargs):
|
||||
chunk = self._get_filled_chunk(self.custom_postprocess(response))
|
||||
yield chunk
|
||||
if run_manager:
|
||||
@ -166,13 +169,10 @@ class ChatNVIDIA(nvidia_ai_endpoints._NVIDIAClient, SimpleChatModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
labels: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
inputs = self.custom_preprocess(messages)
|
||||
async for response in self.get_astream(
|
||||
inputs=inputs, stop=stop, labels=labels, **kwargs
|
||||
):
|
||||
async for response in self.get_astream(inputs=inputs, stop=stop, **kwargs):
|
||||
chunk = self._get_filled_chunk(self.custom_postprocess(response))
|
||||
yield chunk
|
||||
if run_manager:
|
||||
@ -229,7 +229,78 @@ class ChatNVIDIA(nvidia_ai_endpoints._NVIDIAClient, SimpleChatModel):
|
||||
def custom_postprocess(self, msg: dict) -> str:
|
||||
if "content" in msg:
|
||||
return msg["content"]
|
||||
logger.warning(
|
||||
f"Got ambiguous message in postprocessing; returning as-is: msg = {msg}"
|
||||
)
|
||||
elif "b64_json" in msg:
|
||||
return msg["b64_json"]
|
||||
return str(msg)
|
||||
|
||||
######################################################################################
|
||||
## Core client-side interfaces
|
||||
|
||||
def get_generation(
|
||||
self,
|
||||
inputs: Sequence[Dict],
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Call to client generate method with call scope"""
|
||||
stop = kwargs.get("stop", None)
|
||||
payload = self.get_payload(inputs=inputs, stream=False, **kwargs)
|
||||
out = self.client.get_req_generation(self.model, stop=stop, payload=payload)
|
||||
return out
|
||||
|
||||
def get_stream(
|
||||
self,
|
||||
inputs: Sequence[Dict],
|
||||
**kwargs: Any,
|
||||
) -> Iterator:
|
||||
"""Call to client stream method with call scope"""
|
||||
stop = kwargs.get("stop", None)
|
||||
payload = self.get_payload(inputs=inputs, stream=True, **kwargs)
|
||||
return self.client.get_req_stream(self.model, stop=stop, payload=payload)
|
||||
|
||||
def get_astream(
|
||||
self,
|
||||
inputs: Sequence[Dict],
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator:
|
||||
"""Call to client astream methods with call scope"""
|
||||
stop = kwargs.get("stop", None)
|
||||
payload = self.get_payload(inputs=inputs, stream=True, **kwargs)
|
||||
return self.client.get_req_astream(self.model, stop=stop, payload=payload)
|
||||
|
||||
def get_payload(self, inputs: Sequence[Dict], **kwargs: Any) -> dict:
|
||||
"""Generates payload for the _NVIDIAClient API to send to service."""
|
||||
attr_kwargs = {
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"seed": self.seed,
|
||||
"bad": self.bad,
|
||||
"stop": self.stop,
|
||||
"labels": self.labels,
|
||||
}
|
||||
attr_kwargs = {k: v for k, v in attr_kwargs.items() if v is not None}
|
||||
new_kwargs = {**attr_kwargs, **kwargs}
|
||||
return self.prep_payload(inputs=inputs, **new_kwargs)
|
||||
|
||||
def prep_payload(self, inputs: Sequence[Dict], **kwargs: Any) -> dict:
|
||||
"""Prepares a message or list of messages for the payload"""
|
||||
messages = [self.prep_msg(m) for m in inputs]
|
||||
if kwargs.get("labels"):
|
||||
# (WFH) Labels are currently (?) always passed as an assistant
|
||||
# suffix message, but this API seems less stable.
|
||||
messages += [{"labels": kwargs.pop("labels"), "role": "assistant"}]
|
||||
if kwargs.get("stop") is None:
|
||||
kwargs.pop("stop")
|
||||
return {"messages": messages, **kwargs}
|
||||
|
||||
def prep_msg(self, msg: Union[str, dict, BaseMessage]) -> dict:
|
||||
"""Helper Method: Ensures a message is a dictionary with a role and content."""
|
||||
if isinstance(msg, str):
|
||||
# (WFH) this shouldn't ever be reached but leaving this here bcs
|
||||
# it's a Chesterton's fence I'm unwilling to touch
|
||||
return dict(role="user", content=msg)
|
||||
if isinstance(msg, dict):
|
||||
if msg.get("content", None) is None:
|
||||
raise ValueError(f"Message {msg} has no content")
|
||||
return msg
|
||||
raise ValueError(f"Unknown message received: {msg} of type {type(msg)}")
|
||||
|
@ -1,51 +1,21 @@
|
||||
"""Embeddings Components Derived from NVEModel/Embeddings"""
|
||||
from typing import Any, List, Literal, Optional
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain_nvidia_ai_endpoints._common import NVEModel
|
||||
from langchain_nvidia_ai_endpoints._common import _NVIDIAClient
|
||||
|
||||
|
||||
class NVIDIAEmbeddings(BaseModel, Embeddings):
|
||||
class NVIDIAEmbeddings(_NVIDIAClient, Embeddings):
|
||||
"""NVIDIA's AI Foundation Retriever Question-Answering Asymmetric Model."""
|
||||
|
||||
client: NVEModel = Field(NVEModel)
|
||||
model: str = Field(
|
||||
..., description="The embedding model to use. Example: nvolveqa_40k"
|
||||
)
|
||||
max_length: int = Field(2048, ge=1, le=2048)
|
||||
max_batch_size: int = Field(default=50)
|
||||
model_type: Optional[Literal["passage", "query"]] = Field(
|
||||
"passage", description="The type of text to be embedded."
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def _validate_client(cls, values: Any) -> Any:
|
||||
if "client" not in values:
|
||||
values["client"] = NVEModel(**values)
|
||||
return values
|
||||
|
||||
@property
|
||||
def available_functions(self) -> List[dict]:
|
||||
"""Map the available functions that can be invoked."""
|
||||
return self.client.available_functions
|
||||
|
||||
@property
|
||||
def available_models(self) -> dict:
|
||||
"""Map the available models that can be invoked."""
|
||||
return self.client.available_models
|
||||
|
||||
@staticmethod
|
||||
def get_available_functions(**kwargs: Any) -> List[dict]:
|
||||
"""Map the available functions that can be invoked. Callable from class"""
|
||||
return NVEModel(**kwargs).available_functions
|
||||
|
||||
@staticmethod
|
||||
def get_available_models(**kwargs: Any) -> dict:
|
||||
"""Map the available models that can be invoked. Callable from class"""
|
||||
return NVEModel(**kwargs).available_models
|
||||
|
||||
def _embed(
|
||||
self, texts: List[str], model_type: Literal["passage", "query"]
|
||||
) -> List[List[float]]:
|
||||
|
@ -11,7 +11,7 @@ def test_chat_ai_endpoints() -> None:
|
||||
temperature=0.7,
|
||||
)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat([message])
|
||||
response = chat.invoke([message])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user