mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 15:04:13 +00:00
community: add new parameter default_headers (#28700)
Thank you for contributing to LangChain! - [x] **PR title**: "package: description" - "community: 1. add new parameter `default_headers` for oci model deployments and oci chat model deployments. 2. updated k parameter in OCIModelDeploymentLLM class." - [x] **PR message**: - **Description:** 1. add new parameters `default_headers` for oci model deployments and oci chat model deployments. 2. updated k parameter in OCIModelDeploymentLLM class. - [x] **Add tests and docs**: 1. unit tests 2. notebook --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -19,6 +19,7 @@ CONST_MODEL_NAME = "odsc-vllm"
|
||||
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
|
||||
CONST_PROMPT = "This is a prompt."
|
||||
CONST_COMPLETION = "This is a completion."
|
||||
CONST_COMPLETION_ROUTE = "/v1/chat/completions"
|
||||
CONST_COMPLETION_RESPONSE = {
|
||||
"id": "chat-123456789",
|
||||
"object": "chat.completion",
|
||||
@@ -120,6 +121,7 @@ def mocked_requests_post(url: str, **kwargs: Any) -> MockResponse:
|
||||
def test_invoke_vllm(*args: Any) -> None:
|
||||
"""Tests invoking vLLM endpoint."""
|
||||
llm = ChatOCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
|
||||
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
|
||||
output = llm.invoke(CONST_PROMPT)
|
||||
assert isinstance(output, AIMessage)
|
||||
assert output.content == CONST_COMPLETION
|
||||
@@ -132,6 +134,7 @@ def test_invoke_vllm(*args: Any) -> None:
|
||||
def test_invoke_tgi(*args: Any) -> None:
|
||||
"""Tests invoking TGI endpoint using OpenAI Spec."""
|
||||
llm = ChatOCIModelDeploymentTGI(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
|
||||
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
|
||||
output = llm.invoke(CONST_PROMPT)
|
||||
assert isinstance(output, AIMessage)
|
||||
assert output.content == CONST_COMPLETION
|
||||
@@ -146,6 +149,7 @@ def test_stream_vllm(*args: Any) -> None:
|
||||
llm = ChatOCIModelDeploymentVLLM(
|
||||
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
|
||||
)
|
||||
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
|
||||
output = None
|
||||
count = 0
|
||||
for chunk in llm.stream(CONST_PROMPT):
|
||||
@@ -184,6 +188,7 @@ async def test_stream_async(*args: Any) -> None:
|
||||
llm = ChatOCIModelDeploymentVLLM(
|
||||
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
|
||||
)
|
||||
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
|
||||
with mock.patch.object(
|
||||
llm,
|
||||
"_aiter_sse",
|
||||
|
@@ -18,6 +18,7 @@ CONST_MODEL_NAME = "odsc-vllm"
|
||||
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
|
||||
CONST_PROMPT = "This is a prompt."
|
||||
CONST_COMPLETION = "This is a completion."
|
||||
CONST_COMPLETION_ROUTE = "/v1/completions"
|
||||
CONST_COMPLETION_RESPONSE = {
|
||||
"choices": [
|
||||
{
|
||||
@@ -114,6 +115,7 @@ async def mocked_async_streaming_response(
|
||||
def test_invoke_vllm(*args: Any) -> None:
|
||||
"""Tests invoking vLLM endpoint."""
|
||||
llm = OCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
|
||||
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
|
||||
output = llm.invoke(CONST_PROMPT)
|
||||
assert output == CONST_COMPLETION
|
||||
|
||||
@@ -126,6 +128,7 @@ def test_stream_tgi(*args: Any) -> None:
|
||||
llm = OCIModelDeploymentTGI(
|
||||
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
|
||||
)
|
||||
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
|
||||
output = ""
|
||||
count = 0
|
||||
for chunk in llm.stream(CONST_PROMPT):
|
||||
@@ -143,6 +146,7 @@ def test_generate_tgi(*args: Any) -> None:
|
||||
llm = OCIModelDeploymentTGI(
|
||||
endpoint=CONST_ENDPOINT, api="/generate", model=CONST_MODEL_NAME
|
||||
)
|
||||
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
|
||||
output = llm.invoke(CONST_PROMPT)
|
||||
assert output == CONST_COMPLETION
|
||||
|
||||
@@ -161,6 +165,7 @@ async def test_stream_async(*args: Any) -> None:
|
||||
llm = OCIModelDeploymentTGI(
|
||||
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
|
||||
)
|
||||
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
|
||||
with mock.patch.object(
|
||||
llm,
|
||||
"_aiter_sse",
|
||||
|
Reference in New Issue
Block a user