mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-30 10:23:30 +00:00
nvidia-trt[patch]: add TritonTensorRTLLM(verbose_client=False) (#16848)
- **Description:** adding verbose flag to TritonTensorRTLLM, - **Issue:** nope, - **Dependencies:** not any, - **Twitter handle:**
This commit is contained in:
parent
1569b19191
commit
d039dcb6ba
@ -40,6 +40,7 @@ class TritonTensorRTLLM(BaseLLM):
|
||||
length_penalty: (float) The penalty to apply repeated tokens
|
||||
tokens: (int) The maximum number of tokens to generate.
|
||||
client: The client object used to communicate with the inference server
|
||||
verbose_client: flag to pass to the client on creation
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
@ -73,6 +74,7 @@ class TritonTensorRTLLM(BaseLLM):
|
||||
description="Request the inference server to load the specified model.\
|
||||
Certain Triton configurations do not allow for this operation.",
|
||||
)
|
||||
verbose_client: bool = False
|
||||
|
||||
def __del__(self):
|
||||
"""Ensure the client streaming connection is properly shutdown"""
|
||||
@ -82,7 +84,9 @@ class TritonTensorRTLLM(BaseLLM):
|
||||
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate that python package exists in environment."""
|
||||
if not values.get("client"):
|
||||
values["client"] = grpcclient.InferenceServerClient(values["server_url"])
|
||||
values["client"] = grpcclient.InferenceServerClient(
|
||||
values["server_url"], verbose=values.get("verbose_client", False)
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
|
@ -1,7 +1,33 @@
|
||||
"""Test TritonTensorRT Chat API wrapper."""
|
||||
import sys
|
||||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
|
||||
from langchain_nvidia_trt import TritonTensorRTLLM
|
||||
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test integration initialization."""
|
||||
TritonTensorRTLLM(model_name="ensemble", server_url="http://localhost:8001")
|
||||
|
||||
|
||||
@patch("tritonclient.grpc.service_pb2_grpc.GRPCInferenceServiceStub")
|
||||
def test_default_verbose(ignore) -> None:
|
||||
llm = TritonTensorRTLLM(server_url="http://localhost:8001", model_name="ensemble")
|
||||
captured = StringIO()
|
||||
sys.stdout = captured
|
||||
llm.client.is_server_live()
|
||||
sys.stdout = sys.__stdout__
|
||||
assert "is_server_live" not in captured.getvalue()
|
||||
|
||||
|
||||
@patch("tritonclient.grpc.service_pb2_grpc.GRPCInferenceServiceStub")
|
||||
def test_verbose(ignore) -> None:
|
||||
llm = TritonTensorRTLLM(
|
||||
server_url="http://localhost:8001", model_name="ensemble", verbose_client=True
|
||||
)
|
||||
captured = StringIO()
|
||||
sys.stdout = captured
|
||||
llm.client.is_server_live()
|
||||
sys.stdout = sys.__stdout__
|
||||
assert "is_server_live" in captured.getvalue()
|
||||
|
Loading…
Reference in New Issue
Block a user