mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
community: add new parameter default_headers (#28700)
Thank you for contributing to LangChain! - [x] **PR title**: "package: description" - "community: 1. add new parameter `default_headers` for oci model deployments and oci chat model deployments. 2. updated k parameter in OCIModelDeploymentLLM class." - [x] **PR message**: - **Description:** 1. add new parameters `default_headers` for oci model deployments and oci chat model deployments. 2. updated k parameter in OCIModelDeploymentLLM class. - [x] **Add tests and docs**: 1. unit tests 2. notebook --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
1e88adaca7
commit
50afa7c4e7
@ -137,7 +137,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -156,6 +156,10 @@
|
||||
" \"temperature\": 0.2,\n",
|
||||
" \"max_tokens\": 512,\n",
|
||||
" }, # other model params...\n",
|
||||
" default_headers={\n",
|
||||
" \"route\": \"/v1/chat/completions\",\n",
|
||||
" # other request headers ...\n",
|
||||
" },\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
|
@ -47,6 +47,7 @@ from langchain_community.llms.oci_data_science_model_deployment_endpoint import
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_INFERENCE_ENDPOINT_CHAT = "/v1/chat/completions"
|
||||
|
||||
|
||||
def _is_pydantic_class(obj: Any) -> bool:
|
||||
@ -56,6 +57,13 @@ def _is_pydantic_class(obj: Any) -> bool:
|
||||
class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
||||
"""OCI Data Science Model Deployment chat model integration.
|
||||
|
||||
Prerequisite
|
||||
The OCI Model Deployment plugins are installable only on
|
||||
python version 3.9 and above. If you're working inside the notebook,
|
||||
try installing the python 3.10 based conda pack and running the
|
||||
following setup.
|
||||
|
||||
|
||||
Setup:
|
||||
Install ``oracle-ads`` and ``langchain-openai``.
|
||||
|
||||
@ -90,6 +98,8 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
||||
Key init args — client params:
|
||||
auth: dict
|
||||
ADS auth dictionary for OCI authentication.
|
||||
default_headers: Optional[Dict]
|
||||
The headers to be added to the Model Deployment request.
|
||||
|
||||
Instantiate:
|
||||
.. code-block:: python
|
||||
@ -98,7 +108,7 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
||||
|
||||
chat = ChatOCIModelDeployment(
|
||||
endpoint="https://modeldeployment.<region>.oci.customer-oci.com/<ocid>/predict",
|
||||
model="odsc-llm",
|
||||
model="odsc-llm", # this is the default model name if deployed with AQUA
|
||||
streaming=True,
|
||||
max_retries=3,
|
||||
model_kwargs={
|
||||
@ -106,6 +116,10 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
||||
"temperature": 0.2,
|
||||
# other model parameters ...
|
||||
},
|
||||
default_headers={
|
||||
"route": "/v1/chat/completions",
|
||||
# other request headers ...
|
||||
},
|
||||
)
|
||||
|
||||
Invocation:
|
||||
@ -288,6 +302,25 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
||||
"stream": self.streaming,
|
||||
}
|
||||
|
||||
def _headers(
|
||||
self, is_async: Optional[bool] = False, body: Optional[dict] = None
|
||||
) -> Dict:
|
||||
"""Construct and return the headers for a request.
|
||||
|
||||
Args:
|
||||
is_async (bool, optional): Indicates if the request is asynchronous.
|
||||
Defaults to `False`.
|
||||
body (optional): The request body to be included in the headers if
|
||||
the request is asynchronous.
|
||||
|
||||
Returns:
|
||||
Dict: A dictionary containing the appropriate headers for the request.
|
||||
"""
|
||||
return {
|
||||
"route": DEFAULT_INFERENCE_ENDPOINT_CHAT,
|
||||
**super()._headers(is_async=is_async, body=body),
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@ -701,7 +734,7 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
||||
|
||||
for choice in choices:
|
||||
message = _convert_dict_to_message(choice["message"])
|
||||
generation_info = dict(finish_reason=choice.get("finish_reason"))
|
||||
generation_info = {"finish_reason": choice.get("finish_reason")}
|
||||
if "logprobs" in choice:
|
||||
generation_info["logprobs"] = choice["logprobs"]
|
||||
|
||||
|
@ -32,6 +32,7 @@ from pydantic import Field, model_validator
|
||||
from langchain_community.utilities.requests import Requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_INFERENCE_ENDPOINT = "/v1/completions"
|
||||
|
||||
|
||||
DEFAULT_TIME_OUT = 300
|
||||
@ -81,6 +82,9 @@ class BaseOCIModelDeployment(Serializable):
|
||||
max_retries: int = 3
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
default_headers: Optional[Dict[str, Any]] = None
|
||||
"""The headers to be added to the Model Deployment request."""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@ -120,12 +124,12 @@ class BaseOCIModelDeployment(Serializable):
|
||||
Returns:
|
||||
Dict: A dictionary containing the appropriate headers for the request.
|
||||
"""
|
||||
headers = self.default_headers or {}
|
||||
if is_async:
|
||||
signer = self.auth["signer"]
|
||||
_req = requests.Request("POST", self.endpoint, json=body)
|
||||
req = _req.prepare()
|
||||
req = signer(req)
|
||||
headers = {}
|
||||
for key, value in req.headers.items():
|
||||
headers[key] = value
|
||||
|
||||
@ -135,7 +139,7 @@ class BaseOCIModelDeployment(Serializable):
|
||||
)
|
||||
return headers
|
||||
|
||||
return (
|
||||
headers.update(
|
||||
{
|
||||
"Content-Type": DEFAULT_CONTENT_TYPE_JSON,
|
||||
"enable-streaming": "true",
|
||||
@ -147,6 +151,8 @@ class BaseOCIModelDeployment(Serializable):
|
||||
}
|
||||
)
|
||||
|
||||
return headers
|
||||
|
||||
def completion_with_retry(
|
||||
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
@ -383,6 +389,10 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
|
||||
model="odsc-llm",
|
||||
streaming=True,
|
||||
model_kwargs={"frequency_penalty": 1.0},
|
||||
headers={
|
||||
"route": "/v1/completions",
|
||||
# other request headers ...
|
||||
}
|
||||
)
|
||||
llm.invoke("tell me a joke.")
|
||||
|
||||
@ -426,7 +436,7 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
|
||||
temperature: float = 0.2
|
||||
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||
|
||||
k: int = -1
|
||||
k: int = 50
|
||||
"""Number of most likely tokens to consider at each step."""
|
||||
|
||||
p: float = 0.75
|
||||
@ -472,6 +482,25 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
|
||||
**self._default_params,
|
||||
}
|
||||
|
||||
def _headers(
|
||||
self, is_async: Optional[bool] = False, body: Optional[dict] = None
|
||||
) -> Dict:
|
||||
"""Construct and return the headers for a request.
|
||||
|
||||
Args:
|
||||
is_async (bool, optional): Indicates if the request is asynchronous.
|
||||
Defaults to `False`.
|
||||
body (optional): The request body to be included in the headers if
|
||||
the request is asynchronous.
|
||||
|
||||
Returns:
|
||||
Dict: A dictionary containing the appropriate headers for the request.
|
||||
"""
|
||||
return {
|
||||
"route": DEFAULT_INFERENCE_ENDPOINT,
|
||||
**super()._headers(is_async=is_async, body=body),
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
|
@ -19,6 +19,7 @@ CONST_MODEL_NAME = "odsc-vllm"
|
||||
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
|
||||
CONST_PROMPT = "This is a prompt."
|
||||
CONST_COMPLETION = "This is a completion."
|
||||
CONST_COMPLETION_ROUTE = "/v1/chat/completions"
|
||||
CONST_COMPLETION_RESPONSE = {
|
||||
"id": "chat-123456789",
|
||||
"object": "chat.completion",
|
||||
@ -120,6 +121,7 @@ def mocked_requests_post(url: str, **kwargs: Any) -> MockResponse:
|
||||
def test_invoke_vllm(*args: Any) -> None:
|
||||
"""Tests invoking vLLM endpoint."""
|
||||
llm = ChatOCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
|
||||
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
|
||||
output = llm.invoke(CONST_PROMPT)
|
||||
assert isinstance(output, AIMessage)
|
||||
assert output.content == CONST_COMPLETION
|
||||
@ -132,6 +134,7 @@ def test_invoke_vllm(*args: Any) -> None:
|
||||
def test_invoke_tgi(*args: Any) -> None:
|
||||
"""Tests invoking TGI endpoint using OpenAI Spec."""
|
||||
llm = ChatOCIModelDeploymentTGI(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
|
||||
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
|
||||
output = llm.invoke(CONST_PROMPT)
|
||||
assert isinstance(output, AIMessage)
|
||||
assert output.content == CONST_COMPLETION
|
||||
@ -146,6 +149,7 @@ def test_stream_vllm(*args: Any) -> None:
|
||||
llm = ChatOCIModelDeploymentVLLM(
|
||||
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
|
||||
)
|
||||
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
|
||||
output = None
|
||||
count = 0
|
||||
for chunk in llm.stream(CONST_PROMPT):
|
||||
@ -184,6 +188,7 @@ async def test_stream_async(*args: Any) -> None:
|
||||
llm = ChatOCIModelDeploymentVLLM(
|
||||
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
|
||||
)
|
||||
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
|
||||
with mock.patch.object(
|
||||
llm,
|
||||
"_aiter_sse",
|
||||
|
@ -18,6 +18,7 @@ CONST_MODEL_NAME = "odsc-vllm"
|
||||
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
|
||||
CONST_PROMPT = "This is a prompt."
|
||||
CONST_COMPLETION = "This is a completion."
|
||||
CONST_COMPLETION_ROUTE = "/v1/completions"
|
||||
CONST_COMPLETION_RESPONSE = {
|
||||
"choices": [
|
||||
{
|
||||
@ -114,6 +115,7 @@ async def mocked_async_streaming_response(
|
||||
def test_invoke_vllm(*args: Any) -> None:
|
||||
"""Tests invoking vLLM endpoint."""
|
||||
llm = OCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
|
||||
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
|
||||
output = llm.invoke(CONST_PROMPT)
|
||||
assert output == CONST_COMPLETION
|
||||
|
||||
@ -126,6 +128,7 @@ def test_stream_tgi(*args: Any) -> None:
|
||||
llm = OCIModelDeploymentTGI(
|
||||
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
|
||||
)
|
||||
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
|
||||
output = ""
|
||||
count = 0
|
||||
for chunk in llm.stream(CONST_PROMPT):
|
||||
@ -143,6 +146,7 @@ def test_generate_tgi(*args: Any) -> None:
|
||||
llm = OCIModelDeploymentTGI(
|
||||
endpoint=CONST_ENDPOINT, api="/generate", model=CONST_MODEL_NAME
|
||||
)
|
||||
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
|
||||
output = llm.invoke(CONST_PROMPT)
|
||||
assert output == CONST_COMPLETION
|
||||
|
||||
@ -161,6 +165,7 @@ async def test_stream_async(*args: Any) -> None:
|
||||
llm = OCIModelDeploymentTGI(
|
||||
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
|
||||
)
|
||||
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
|
||||
with mock.patch.object(
|
||||
llm,
|
||||
"_aiter_sse",
|
||||
|
Loading…
Reference in New Issue
Block a user