community: add new parameter default_headers (#28700)

Thank you for contributing to LangChain!

- [x] **PR title**: "package: description"
- "community: 1. add new parameter `default_headers` for oci model
deployments and oci chat model deployments. 2. updated k parameter in
OCIModelDeploymentLLM class."


- [x] **PR message**:
- **Description:** 1. add new parameters `default_headers` for oci model
deployments and oci chat model deployments. 2. updated k parameter in
OCIModelDeploymentLLM class.


- [x] **Add tests and docs**:
  1. unit tests
  2. notebook

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Lu Peng 2024-12-18 17:33:23 -05:00 committed by GitHub
parent 1e88adaca7
commit 50afa7c4e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 82 additions and 6 deletions

View File

@ -137,7 +137,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -156,6 +156,10 @@
" \"temperature\": 0.2,\n",
" \"max_tokens\": 512,\n",
" }, # other model params...\n",
" default_headers={\n",
" \"route\": \"/v1/chat/completions\",\n",
" # other request headers ...\n",
" },\n",
")"
]
},

View File

@ -47,6 +47,7 @@ from langchain_community.llms.oci_data_science_model_deployment_endpoint import
)
logger = logging.getLogger(__name__)
DEFAULT_INFERENCE_ENDPOINT_CHAT = "/v1/chat/completions"
def _is_pydantic_class(obj: Any) -> bool:
@ -56,6 +57,13 @@ def _is_pydantic_class(obj: Any) -> bool:
class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
"""OCI Data Science Model Deployment chat model integration.
Prerequisite
The OCI Model Deployment plugins are installable only on
python version 3.9 and above. If you're working inside the notebook,
try installing the python 3.10 based conda pack and running the
following setup.
Setup:
Install ``oracle-ads`` and ``langchain-openai``.
@ -90,6 +98,8 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
Key init args client params:
auth: dict
ADS auth dictionary for OCI authentication.
default_headers: Optional[Dict]
The headers to be added to the Model Deployment request.
Instantiate:
.. code-block:: python
@ -98,7 +108,7 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
chat = ChatOCIModelDeployment(
endpoint="https://modeldeployment.<region>.oci.customer-oci.com/<ocid>/predict",
model="odsc-llm",
model="odsc-llm", # this is the default model name if deployed with AQUA
streaming=True,
max_retries=3,
model_kwargs={
@ -106,6 +116,10 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
"temperature": 0.2,
# other model parameters ...
},
default_headers={
"route": "/v1/chat/completions",
# other request headers ...
},
)
Invocation:
@ -288,6 +302,25 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
"stream": self.streaming,
}
def _headers(
self, is_async: Optional[bool] = False, body: Optional[dict] = None
) -> Dict:
"""Construct and return the headers for a request.
Args:
is_async (bool, optional): Indicates if the request is asynchronous.
Defaults to `False`.
body (optional): The request body to be included in the headers if
the request is asynchronous.
Returns:
Dict: A dictionary containing the appropriate headers for the request.
"""
return {
"route": DEFAULT_INFERENCE_ENDPOINT_CHAT,
**super()._headers(is_async=is_async, body=body),
}
def _generate(
self,
messages: List[BaseMessage],
@ -701,7 +734,7 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
for choice in choices:
message = _convert_dict_to_message(choice["message"])
generation_info = dict(finish_reason=choice.get("finish_reason"))
generation_info = {"finish_reason": choice.get("finish_reason")}
if "logprobs" in choice:
generation_info["logprobs"] = choice["logprobs"]

View File

@ -32,6 +32,7 @@ from pydantic import Field, model_validator
from langchain_community.utilities.requests import Requests
logger = logging.getLogger(__name__)
DEFAULT_INFERENCE_ENDPOINT = "/v1/completions"
DEFAULT_TIME_OUT = 300
@ -81,6 +82,9 @@ class BaseOCIModelDeployment(Serializable):
max_retries: int = 3
"""Maximum number of retries to make when generating."""
default_headers: Optional[Dict[str, Any]] = None
"""The headers to be added to the Model Deployment request."""
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Dict:
@ -120,12 +124,12 @@ class BaseOCIModelDeployment(Serializable):
Returns:
Dict: A dictionary containing the appropriate headers for the request.
"""
headers = self.default_headers or {}
if is_async:
signer = self.auth["signer"]
_req = requests.Request("POST", self.endpoint, json=body)
req = _req.prepare()
req = signer(req)
headers = {}
for key, value in req.headers.items():
headers[key] = value
@ -135,7 +139,7 @@ class BaseOCIModelDeployment(Serializable):
)
return headers
return (
headers.update(
{
"Content-Type": DEFAULT_CONTENT_TYPE_JSON,
"enable-streaming": "true",
@ -147,6 +151,8 @@ class BaseOCIModelDeployment(Serializable):
}
)
return headers
def completion_with_retry(
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
) -> Any:
@ -383,6 +389,10 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
model="odsc-llm",
streaming=True,
model_kwargs={"frequency_penalty": 1.0},
headers={
"route": "/v1/completions",
# other request headers ...
}
)
llm.invoke("tell me a joke.")
@ -426,7 +436,7 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
temperature: float = 0.2
"""A non-negative float that tunes the degree of randomness in generation."""
k: int = -1
k: int = 50
"""Number of most likely tokens to consider at each step."""
p: float = 0.75
@ -472,6 +482,25 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
**self._default_params,
}
def _headers(
self, is_async: Optional[bool] = False, body: Optional[dict] = None
) -> Dict:
"""Construct and return the headers for a request.
Args:
is_async (bool, optional): Indicates if the request is asynchronous.
Defaults to `False`.
body (optional): The request body to be included in the headers if
the request is asynchronous.
Returns:
Dict: A dictionary containing the appropriate headers for the request.
"""
return {
"route": DEFAULT_INFERENCE_ENDPOINT,
**super()._headers(is_async=is_async, body=body),
}
def _generate(
self,
prompts: List[str],

View File

@ -19,6 +19,7 @@ CONST_MODEL_NAME = "odsc-vllm"
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
CONST_PROMPT = "This is a prompt."
CONST_COMPLETION = "This is a completion."
CONST_COMPLETION_ROUTE = "/v1/chat/completions"
CONST_COMPLETION_RESPONSE = {
"id": "chat-123456789",
"object": "chat.completion",
@ -120,6 +121,7 @@ def mocked_requests_post(url: str, **kwargs: Any) -> MockResponse:
def test_invoke_vllm(*args: Any) -> None:
"""Tests invoking vLLM endpoint."""
llm = ChatOCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = llm.invoke(CONST_PROMPT)
assert isinstance(output, AIMessage)
assert output.content == CONST_COMPLETION
@ -132,6 +134,7 @@ def test_invoke_vllm(*args: Any) -> None:
def test_invoke_tgi(*args: Any) -> None:
"""Tests invoking TGI endpoint using OpenAI Spec."""
llm = ChatOCIModelDeploymentTGI(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = llm.invoke(CONST_PROMPT)
assert isinstance(output, AIMessage)
assert output.content == CONST_COMPLETION
@ -146,6 +149,7 @@ def test_stream_vllm(*args: Any) -> None:
llm = ChatOCIModelDeploymentVLLM(
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = None
count = 0
for chunk in llm.stream(CONST_PROMPT):
@ -184,6 +188,7 @@ async def test_stream_async(*args: Any) -> None:
llm = ChatOCIModelDeploymentVLLM(
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
with mock.patch.object(
llm,
"_aiter_sse",

View File

@ -18,6 +18,7 @@ CONST_MODEL_NAME = "odsc-vllm"
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
CONST_PROMPT = "This is a prompt."
CONST_COMPLETION = "This is a completion."
CONST_COMPLETION_ROUTE = "/v1/completions"
CONST_COMPLETION_RESPONSE = {
"choices": [
{
@ -114,6 +115,7 @@ async def mocked_async_streaming_response(
def test_invoke_vllm(*args: Any) -> None:
"""Tests invoking vLLM endpoint."""
llm = OCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = llm.invoke(CONST_PROMPT)
assert output == CONST_COMPLETION
@ -126,6 +128,7 @@ def test_stream_tgi(*args: Any) -> None:
llm = OCIModelDeploymentTGI(
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = ""
count = 0
for chunk in llm.stream(CONST_PROMPT):
@ -143,6 +146,7 @@ def test_generate_tgi(*args: Any) -> None:
llm = OCIModelDeploymentTGI(
endpoint=CONST_ENDPOINT, api="/generate", model=CONST_MODEL_NAME
)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = llm.invoke(CONST_PROMPT)
assert output == CONST_COMPLETION
@ -161,6 +165,7 @@ async def test_stream_async(*args: Any) -> None:
llm = OCIModelDeploymentTGI(
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
with mock.patch.object(
llm,
"_aiter_sse",