mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 07:35:18 +00:00
community: add new parameter default_headers (#28700)
Thank you for contributing to LangChain! - [x] **PR title**: "package: description" - "community: 1. add new parameter `default_headers` for oci model deployments and oci chat model deployments. 2. updated k parameter in OCIModelDeploymentLLM class." - [x] **PR message**: - **Description:** 1. add new parameters `default_headers` for oci model deployments and oci chat model deployments. 2. updated k parameter in OCIModelDeploymentLLM class. - [x] **Add tests and docs**: 1. unit tests 2. notebook --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
1e88adaca7
commit
50afa7c4e7
@ -137,7 +137,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -156,6 +156,10 @@
|
|||||||
" \"temperature\": 0.2,\n",
|
" \"temperature\": 0.2,\n",
|
||||||
" \"max_tokens\": 512,\n",
|
" \"max_tokens\": 512,\n",
|
||||||
" }, # other model params...\n",
|
" }, # other model params...\n",
|
||||||
|
" default_headers={\n",
|
||||||
|
" \"route\": \"/v1/chat/completions\",\n",
|
||||||
|
" # other request headers ...\n",
|
||||||
|
" },\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -47,6 +47,7 @@ from langchain_community.llms.oci_data_science_model_deployment_endpoint import
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
DEFAULT_INFERENCE_ENDPOINT_CHAT = "/v1/chat/completions"
|
||||||
|
|
||||||
|
|
||||||
def _is_pydantic_class(obj: Any) -> bool:
|
def _is_pydantic_class(obj: Any) -> bool:
|
||||||
@ -56,6 +57,13 @@ def _is_pydantic_class(obj: Any) -> bool:
|
|||||||
class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
||||||
"""OCI Data Science Model Deployment chat model integration.
|
"""OCI Data Science Model Deployment chat model integration.
|
||||||
|
|
||||||
|
Prerequisite
|
||||||
|
The OCI Model Deployment plugins are installable only on
|
||||||
|
python version 3.9 and above. If you're working inside the notebook,
|
||||||
|
try installing the python 3.10 based conda pack and running the
|
||||||
|
following setup.
|
||||||
|
|
||||||
|
|
||||||
Setup:
|
Setup:
|
||||||
Install ``oracle-ads`` and ``langchain-openai``.
|
Install ``oracle-ads`` and ``langchain-openai``.
|
||||||
|
|
||||||
@ -90,6 +98,8 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
|||||||
Key init args — client params:
|
Key init args — client params:
|
||||||
auth: dict
|
auth: dict
|
||||||
ADS auth dictionary for OCI authentication.
|
ADS auth dictionary for OCI authentication.
|
||||||
|
default_headers: Optional[Dict]
|
||||||
|
The headers to be added to the Model Deployment request.
|
||||||
|
|
||||||
Instantiate:
|
Instantiate:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@ -98,7 +108,7 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
|||||||
|
|
||||||
chat = ChatOCIModelDeployment(
|
chat = ChatOCIModelDeployment(
|
||||||
endpoint="https://modeldeployment.<region>.oci.customer-oci.com/<ocid>/predict",
|
endpoint="https://modeldeployment.<region>.oci.customer-oci.com/<ocid>/predict",
|
||||||
model="odsc-llm",
|
model="odsc-llm", # this is the default model name if deployed with AQUA
|
||||||
streaming=True,
|
streaming=True,
|
||||||
max_retries=3,
|
max_retries=3,
|
||||||
model_kwargs={
|
model_kwargs={
|
||||||
@ -106,6 +116,10 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
|||||||
"temperature": 0.2,
|
"temperature": 0.2,
|
||||||
# other model parameters ...
|
# other model parameters ...
|
||||||
},
|
},
|
||||||
|
default_headers={
|
||||||
|
"route": "/v1/chat/completions",
|
||||||
|
# other request headers ...
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
Invocation:
|
Invocation:
|
||||||
@ -288,6 +302,25 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
|||||||
"stream": self.streaming,
|
"stream": self.streaming,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _headers(
|
||||||
|
self, is_async: Optional[bool] = False, body: Optional[dict] = None
|
||||||
|
) -> Dict:
|
||||||
|
"""Construct and return the headers for a request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
is_async (bool, optional): Indicates if the request is asynchronous.
|
||||||
|
Defaults to `False`.
|
||||||
|
body (optional): The request body to be included in the headers if
|
||||||
|
the request is asynchronous.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: A dictionary containing the appropriate headers for the request.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"route": DEFAULT_INFERENCE_ENDPOINT_CHAT,
|
||||||
|
**super()._headers(is_async=is_async, body=body),
|
||||||
|
}
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
@ -701,7 +734,7 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
|||||||
|
|
||||||
for choice in choices:
|
for choice in choices:
|
||||||
message = _convert_dict_to_message(choice["message"])
|
message = _convert_dict_to_message(choice["message"])
|
||||||
generation_info = dict(finish_reason=choice.get("finish_reason"))
|
generation_info = {"finish_reason": choice.get("finish_reason")}
|
||||||
if "logprobs" in choice:
|
if "logprobs" in choice:
|
||||||
generation_info["logprobs"] = choice["logprobs"]
|
generation_info["logprobs"] = choice["logprobs"]
|
||||||
|
|
||||||
|
@ -32,6 +32,7 @@ from pydantic import Field, model_validator
|
|||||||
from langchain_community.utilities.requests import Requests
|
from langchain_community.utilities.requests import Requests
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
DEFAULT_INFERENCE_ENDPOINT = "/v1/completions"
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_TIME_OUT = 300
|
DEFAULT_TIME_OUT = 300
|
||||||
@ -81,6 +82,9 @@ class BaseOCIModelDeployment(Serializable):
|
|||||||
max_retries: int = 3
|
max_retries: int = 3
|
||||||
"""Maximum number of retries to make when generating."""
|
"""Maximum number of retries to make when generating."""
|
||||||
|
|
||||||
|
default_headers: Optional[Dict[str, Any]] = None
|
||||||
|
"""The headers to be added to the Model Deployment request."""
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
@ -120,12 +124,12 @@ class BaseOCIModelDeployment(Serializable):
|
|||||||
Returns:
|
Returns:
|
||||||
Dict: A dictionary containing the appropriate headers for the request.
|
Dict: A dictionary containing the appropriate headers for the request.
|
||||||
"""
|
"""
|
||||||
|
headers = self.default_headers or {}
|
||||||
if is_async:
|
if is_async:
|
||||||
signer = self.auth["signer"]
|
signer = self.auth["signer"]
|
||||||
_req = requests.Request("POST", self.endpoint, json=body)
|
_req = requests.Request("POST", self.endpoint, json=body)
|
||||||
req = _req.prepare()
|
req = _req.prepare()
|
||||||
req = signer(req)
|
req = signer(req)
|
||||||
headers = {}
|
|
||||||
for key, value in req.headers.items():
|
for key, value in req.headers.items():
|
||||||
headers[key] = value
|
headers[key] = value
|
||||||
|
|
||||||
@ -135,7 +139,7 @@ class BaseOCIModelDeployment(Serializable):
|
|||||||
)
|
)
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
return (
|
headers.update(
|
||||||
{
|
{
|
||||||
"Content-Type": DEFAULT_CONTENT_TYPE_JSON,
|
"Content-Type": DEFAULT_CONTENT_TYPE_JSON,
|
||||||
"enable-streaming": "true",
|
"enable-streaming": "true",
|
||||||
@ -147,6 +151,8 @@ class BaseOCIModelDeployment(Serializable):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
def completion_with_retry(
|
def completion_with_retry(
|
||||||
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
|
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@ -383,6 +389,10 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
|
|||||||
model="odsc-llm",
|
model="odsc-llm",
|
||||||
streaming=True,
|
streaming=True,
|
||||||
model_kwargs={"frequency_penalty": 1.0},
|
model_kwargs={"frequency_penalty": 1.0},
|
||||||
|
headers={
|
||||||
|
"route": "/v1/completions",
|
||||||
|
# other request headers ...
|
||||||
|
}
|
||||||
)
|
)
|
||||||
llm.invoke("tell me a joke.")
|
llm.invoke("tell me a joke.")
|
||||||
|
|
||||||
@ -426,7 +436,7 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
|
|||||||
temperature: float = 0.2
|
temperature: float = 0.2
|
||||||
"""A non-negative float that tunes the degree of randomness in generation."""
|
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||||
|
|
||||||
k: int = -1
|
k: int = 50
|
||||||
"""Number of most likely tokens to consider at each step."""
|
"""Number of most likely tokens to consider at each step."""
|
||||||
|
|
||||||
p: float = 0.75
|
p: float = 0.75
|
||||||
@ -472,6 +482,25 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
|
|||||||
**self._default_params,
|
**self._default_params,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _headers(
|
||||||
|
self, is_async: Optional[bool] = False, body: Optional[dict] = None
|
||||||
|
) -> Dict:
|
||||||
|
"""Construct and return the headers for a request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
is_async (bool, optional): Indicates if the request is asynchronous.
|
||||||
|
Defaults to `False`.
|
||||||
|
body (optional): The request body to be included in the headers if
|
||||||
|
the request is asynchronous.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: A dictionary containing the appropriate headers for the request.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"route": DEFAULT_INFERENCE_ENDPOINT,
|
||||||
|
**super()._headers(is_async=is_async, body=body),
|
||||||
|
}
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
|
@ -19,6 +19,7 @@ CONST_MODEL_NAME = "odsc-vllm"
|
|||||||
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
|
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
|
||||||
CONST_PROMPT = "This is a prompt."
|
CONST_PROMPT = "This is a prompt."
|
||||||
CONST_COMPLETION = "This is a completion."
|
CONST_COMPLETION = "This is a completion."
|
||||||
|
CONST_COMPLETION_ROUTE = "/v1/chat/completions"
|
||||||
CONST_COMPLETION_RESPONSE = {
|
CONST_COMPLETION_RESPONSE = {
|
||||||
"id": "chat-123456789",
|
"id": "chat-123456789",
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
@ -120,6 +121,7 @@ def mocked_requests_post(url: str, **kwargs: Any) -> MockResponse:
|
|||||||
def test_invoke_vllm(*args: Any) -> None:
|
def test_invoke_vllm(*args: Any) -> None:
|
||||||
"""Tests invoking vLLM endpoint."""
|
"""Tests invoking vLLM endpoint."""
|
||||||
llm = ChatOCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
|
llm = ChatOCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
|
||||||
|
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
|
||||||
output = llm.invoke(CONST_PROMPT)
|
output = llm.invoke(CONST_PROMPT)
|
||||||
assert isinstance(output, AIMessage)
|
assert isinstance(output, AIMessage)
|
||||||
assert output.content == CONST_COMPLETION
|
assert output.content == CONST_COMPLETION
|
||||||
@ -132,6 +134,7 @@ def test_invoke_vllm(*args: Any) -> None:
|
|||||||
def test_invoke_tgi(*args: Any) -> None:
|
def test_invoke_tgi(*args: Any) -> None:
|
||||||
"""Tests invoking TGI endpoint using OpenAI Spec."""
|
"""Tests invoking TGI endpoint using OpenAI Spec."""
|
||||||
llm = ChatOCIModelDeploymentTGI(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
|
llm = ChatOCIModelDeploymentTGI(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
|
||||||
|
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
|
||||||
output = llm.invoke(CONST_PROMPT)
|
output = llm.invoke(CONST_PROMPT)
|
||||||
assert isinstance(output, AIMessage)
|
assert isinstance(output, AIMessage)
|
||||||
assert output.content == CONST_COMPLETION
|
assert output.content == CONST_COMPLETION
|
||||||
@ -146,6 +149,7 @@ def test_stream_vllm(*args: Any) -> None:
|
|||||||
llm = ChatOCIModelDeploymentVLLM(
|
llm = ChatOCIModelDeploymentVLLM(
|
||||||
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
|
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
|
||||||
)
|
)
|
||||||
|
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
|
||||||
output = None
|
output = None
|
||||||
count = 0
|
count = 0
|
||||||
for chunk in llm.stream(CONST_PROMPT):
|
for chunk in llm.stream(CONST_PROMPT):
|
||||||
@ -184,6 +188,7 @@ async def test_stream_async(*args: Any) -> None:
|
|||||||
llm = ChatOCIModelDeploymentVLLM(
|
llm = ChatOCIModelDeploymentVLLM(
|
||||||
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
|
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
|
||||||
)
|
)
|
||||||
|
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
|
||||||
with mock.patch.object(
|
with mock.patch.object(
|
||||||
llm,
|
llm,
|
||||||
"_aiter_sse",
|
"_aiter_sse",
|
||||||
|
@ -18,6 +18,7 @@ CONST_MODEL_NAME = "odsc-vllm"
|
|||||||
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
|
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
|
||||||
CONST_PROMPT = "This is a prompt."
|
CONST_PROMPT = "This is a prompt."
|
||||||
CONST_COMPLETION = "This is a completion."
|
CONST_COMPLETION = "This is a completion."
|
||||||
|
CONST_COMPLETION_ROUTE = "/v1/completions"
|
||||||
CONST_COMPLETION_RESPONSE = {
|
CONST_COMPLETION_RESPONSE = {
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
@ -114,6 +115,7 @@ async def mocked_async_streaming_response(
|
|||||||
def test_invoke_vllm(*args: Any) -> None:
|
def test_invoke_vllm(*args: Any) -> None:
|
||||||
"""Tests invoking vLLM endpoint."""
|
"""Tests invoking vLLM endpoint."""
|
||||||
llm = OCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
|
llm = OCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
|
||||||
|
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
|
||||||
output = llm.invoke(CONST_PROMPT)
|
output = llm.invoke(CONST_PROMPT)
|
||||||
assert output == CONST_COMPLETION
|
assert output == CONST_COMPLETION
|
||||||
|
|
||||||
@ -126,6 +128,7 @@ def test_stream_tgi(*args: Any) -> None:
|
|||||||
llm = OCIModelDeploymentTGI(
|
llm = OCIModelDeploymentTGI(
|
||||||
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
|
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
|
||||||
)
|
)
|
||||||
|
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
|
||||||
output = ""
|
output = ""
|
||||||
count = 0
|
count = 0
|
||||||
for chunk in llm.stream(CONST_PROMPT):
|
for chunk in llm.stream(CONST_PROMPT):
|
||||||
@ -143,6 +146,7 @@ def test_generate_tgi(*args: Any) -> None:
|
|||||||
llm = OCIModelDeploymentTGI(
|
llm = OCIModelDeploymentTGI(
|
||||||
endpoint=CONST_ENDPOINT, api="/generate", model=CONST_MODEL_NAME
|
endpoint=CONST_ENDPOINT, api="/generate", model=CONST_MODEL_NAME
|
||||||
)
|
)
|
||||||
|
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
|
||||||
output = llm.invoke(CONST_PROMPT)
|
output = llm.invoke(CONST_PROMPT)
|
||||||
assert output == CONST_COMPLETION
|
assert output == CONST_COMPLETION
|
||||||
|
|
||||||
@ -161,6 +165,7 @@ async def test_stream_async(*args: Any) -> None:
|
|||||||
llm = OCIModelDeploymentTGI(
|
llm = OCIModelDeploymentTGI(
|
||||||
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
|
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
|
||||||
)
|
)
|
||||||
|
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
|
||||||
with mock.patch.object(
|
with mock.patch.object(
|
||||||
llm,
|
llm,
|
||||||
"_aiter_sse",
|
"_aiter_sse",
|
||||||
|
Loading…
Reference in New Issue
Block a user