[Partner] NVIDIA TRT Package (#14733)

Simplify #13976 and add as a separate package.

- [] Add README
- [X] Add doc notebook
- [X] Add simple LLM integration

---------

Co-authored-by: Jeremy Dyer <jdye64@gmail.com>
This commit is contained in:
William FH 2023-12-18 19:08:25 -08:00 committed by GitHub
parent 0d4cbbcc85
commit 583696732c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 2993 additions and 0 deletions

1
libs/partners/nvidia-trt/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
__pycache__

View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 LangChain, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,59 @@
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
# Default target executed when no arguments are given to make.
all: help
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/
test:
poetry run pytest $(TEST_FILE)
tests:
poetry run pytest $(TEST_FILE)
######################
# LINTING AND FORMATTING
######################
# Define a variable for Python and notebook files.
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/nvidia-trt --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=langchain_nvidia_trt
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
poetry run ruff .
poetry run ruff format $(PYTHON_FILES) --diff
poetry run ruff --select I $(PYTHON_FILES)
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
poetry run ruff format $(PYTHON_FILES)
poetry run ruff --select I --fix $(PYTHON_FILES)
spell_check:
poetry run codespell --toml pyproject.toml
spell_fix:
poetry run codespell --toml pyproject.toml -w
check_imports: $(shell find langchain_nvidia_trt -name '*.py')
poetry run python ./scripts/check_imports.py $^
######################
# HELP
######################
help:
@echo '----'
@echo 'check_imports - check imports'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'test - run unit tests'
@echo 'tests - run unit tests'
@echo 'test TEST_FILE=<test_file> - run all tests in file'

View File

@ -0,0 +1 @@
# langchain-nvidia-trt

View File

@ -0,0 +1,106 @@
{
"cells": [
{
"cell_type": "raw",
"id": "67db2992",
"metadata": {},
"source": [
"---\n",
"sidebar_label: TritonTensorRT\n",
"---"
]
},
{
"cell_type": "markdown",
"id": "b56b221d",
"metadata": {},
"source": [
"# Nvidia Triton+TRT-LLM\n",
"\n",
"Nvidia's Triton is an inference server that provides an API style access to hosted LLM models. Likewise, Nvidia TensorRT-LLM, often abbreviated as TRT-LLM, is a GPU accelerated SDK for running optimizations and inference on LLM models. This connector allows for Langchain to remotely interact with a Triton inference server over GRPC or HTTP to performance accelerated inference operations.\n",
"\n",
"[Triton Inference Server Github](https://github.com/triton-inference-server/server)\n",
"\n",
"\n",
"## TritonTensorRTLLM\n",
"\n",
"This example goes over how to use LangChain to interact with `TritonTensorRT` LLMs. To install, run the following command:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "59c710c4",
"metadata": {},
"outputs": [],
"source": [
"# install package\n",
"%pip install -U langchain-nvidia-trt"
]
},
{
"cell_type": "markdown",
"id": "0ee90032",
"metadata": {},
"source": [
"## Create the Triton+TRT-LLM instance\n",
"\n",
"Remember that a Triton instance represents a running server instance therefore you should ensure you have a valid server configuration running and change the `localhost:8001` to the correct IP/hostname:port combination for your server.\n",
"\n",
"An example of setting up this environment can be found at Nvidia's (GenerativeAIExamples Github Repo)[https://github.com/NVIDIA/GenerativeAIExamples/tree/main/RetrievalAugmentedGeneration]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "035dea0f",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain_core.prompts import PromptTemplate\n",
"from langchain_nvidia_trt.llms import TritonTensorRTLLM\n",
"\n",
"template = \"\"\"Question: {question}\n",
"\n",
"Answer: Let's think step by step.\"\"\"\n",
"\n",
"prompt = PromptTemplate.from_template(template)\n",
"\n",
"# Connect to the TRT-LLM Llama-2 model running on the Triton server at the url below\n",
"triton_llm = TritonTensorRTLLM(server_url =\"localhost:8001\", model_name=\"ensemble\", tokens=500)\n",
"\n",
"chain = prompt | triton_llm \n",
"\n",
"chain.invoke({\"question\": \"What is LangChain?\"})"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
},
"vscode": {
"interpreter": {
"hash": "e971737741ff4ec9aff7dc6155a1060a59a8a6d52c757dbbe66bf8ee389494b1"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,3 @@
from langchain_nvidia_trt.llms import TritonTensorRTLLM
__all__ = ["TritonTensorRTLLM"]

View File

@ -0,0 +1,404 @@
from __future__ import annotations
import json
import queue
import random
import time
from functools import partial
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
import google.protobuf.json_format
import numpy as np
import tritonclient.grpc as grpcclient
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import Field, root_validator
from tritonclient.grpc.service_pb2 import ModelInferResponse
from tritonclient.utils import np_to_triton_dtype
class TritonTensorRTError(Exception):
"""Base exception for TritonTensorRT."""
class TritonTensorRTRuntimeError(TritonTensorRTError, RuntimeError):
"""Runtime error for TritonTensorRT."""
class TritonTensorRTLLM(BaseLLM):
"""TRTLLM triton models.
Arguments:
server_url: (str) The URL of the Triton inference server to use.
model_name: (str) The name of the Triton TRT model to use.
temperature: (str) Temperature to use for sampling
top_p: (float) The top-p value to use for sampling
top_k: (float) The top k values use for sampling
beam_width: (int) Last n number of tokens to penalize
repetition_penalty: (int) Last n number of tokens to penalize
length_penalty: (float) The penalty to apply repeated tokens
tokens: (int) The maximum number of tokens to generate.
client: The client object used to communicate with the inference server
Example:
.. code-block:: python
from langchain_nvidia_trt import TritonTensorRTLLM
model = TritonTensorRTLLM()
"""
server_url: Optional[str] = Field(None, alias="server_url")
model_name: str = Field(
..., description="The name of the model to use, such as 'ensemble'."
)
## Optional args for the model
temperature: float = 1.0
top_p: float = 0
top_k: int = 1
tokens: int = 100
beam_width: int = 1
repetition_penalty: float = 1.0
length_penalty: float = 1.0
client: grpcclient.InferenceServerClient
stop: List[str] = Field(
default_factory=lambda: ["</s>"], description="Stop tokens."
)
seed: int = Field(42, description="The seed to use for random generation.")
load_model: bool = Field(
True,
description="Request the inference server to load the specified model.\
Certain Triton configurations do not allow for this operation.",
)
def __del__(self):
"""Ensure the client streaming connection is properly shutdown"""
self.client.close()
@root_validator(pre=True, allow_reuse=True)
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate that python package exists in environment."""
if not values.get("client"):
values["client"] = grpcclient.InferenceServerClient(values["server_url"])
return values
@property
def _llm_type(self) -> str:
"""Return type of LLM."""
return "nvidia-trt-llm"
@property
def _model_default_parameters(self) -> Dict[str, Any]:
return {
"tokens": self.tokens,
"top_k": self.top_k,
"top_p": self.top_p,
"temperature": self.temperature,
"repetition_penalty": self.repetition_penalty,
"length_penalty": self.length_penalty,
"beam_width": self.beam_width,
}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get all the identifying parameters."""
return {
"server_url": self.server_url,
"model_name": self.model_name,
**self._model_default_parameters,
}
def _get_invocation_params(self, **kwargs: Any) -> Dict[str, Any]:
return {**self._model_default_parameters, **kwargs}
def get_model_list(self) -> List[str]:
"""Get a list of models loaded in the triton server."""
res = self.client.get_model_repository_index(as_json=True)
return [model["name"] for model in res["models"]]
def _load_model(self, model_name: str, timeout: int = 1000) -> None:
"""Load a model into the server."""
if self.client.is_model_ready(model_name):
return
self.client.load_model(model_name)
t0 = time.perf_counter()
t1 = t0
while not self.client.is_model_ready(model_name) and t1 - t0 < timeout:
t1 = time.perf_counter()
if not self.client.is_model_ready(model_name):
raise TritonTensorRTRuntimeError(
f"Failed to load {model_name} on Triton in {timeout}s"
)
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
self._load_model(self.model_name)
invocation_params = self._get_invocation_params(**kwargs)
stop_words = stop if stop is not None else self.stop
generations = []
# TODO: We should handle the native batching instead.
for prompt in prompts:
invoc_params = {**invocation_params, "prompt": [[prompt]]}
result: str = self._request(
self.model_name,
stop=stop_words,
**invoc_params,
)
generations.append([Generation(text=result, generation_info={})])
return LLMResult(generations=generations)
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
self._load_model(self.model_name)
invocation_params = self._get_invocation_params(**kwargs, prompt=[[prompt]])
stop_words = stop if stop is not None else self.stop
inputs = self._generate_inputs(stream=True, **invocation_params)
outputs = self._generate_outputs()
result_queue = self._invoke_triton(self.model_name, inputs, outputs, stop_words)
for token in result_queue:
yield GenerationChunk(text=token)
if run_manager:
run_manager.on_llm_new_token(token)
self.client.stop_stream()
##### BELOW ARE METHODS PREVIOUSLY ONLY IN THE GRPC CLIENT
def _request(
self,
model_name: str,
prompt: Sequence[Sequence[str]],
stop: Optional[List[str]] = None,
**params: Any,
) -> str:
"""Request inferencing from the triton server."""
# create model inputs and outputs
inputs = self._generate_inputs(stream=False, prompt=prompt, **params)
outputs = self._generate_outputs()
result_queue = self._invoke_triton(self.model_name, inputs, outputs, stop)
result_str = ""
for token in result_queue:
result_str += token
self.client.stop_stream()
return result_str
def _invoke_triton(self, model_name, inputs, outputs, stop_words):
if not self.client.is_model_ready(model_name):
raise RuntimeError("Cannot request streaming, model is not loaded")
request_id = str(random.randint(1, 9999999)) # nosec
result_queue = StreamingResponseGenerator(
self,
request_id,
force_batch=False,
stop_words=stop_words,
)
self.client.start_stream(
callback=partial(
self._stream_callback,
result_queue,
stop_words=stop_words,
)
)
# Even though this request may not be a streaming request certain configurations
# in Triton prevent the GRPC server from accepting none streaming connections.
# Therefore we call the streaming API and combine the streamed results.
self.client.async_stream_infer(
model_name=model_name,
inputs=inputs,
outputs=outputs,
request_id=request_id,
)
return result_queue
def _generate_outputs(
self,
) -> List[grpcclient.InferRequestedOutput]:
"""Generate the expected output structure."""
return [grpcclient.InferRequestedOutput("text_output")]
def _prepare_tensor(
self, name: str, input_data: np.ndarray
) -> grpcclient.InferInput:
"""Prepare an input data structure."""
t = grpcclient.InferInput(
name, input_data.shape, np_to_triton_dtype(input_data.dtype)
)
t.set_data_from_numpy(input_data)
return t
def _generate_inputs(
self,
prompt: Sequence[Sequence[str]],
tokens: int = 300,
temperature: float = 1.0,
top_k: float = 1,
top_p: float = 0,
beam_width: int = 1,
repetition_penalty: float = 1,
length_penalty: float = 1.0,
stream: bool = True,
) -> List[grpcclient.InferRequestedOutput]:
"""Create the input for the triton inference server."""
query = np.array(prompt).astype(object)
request_output_len = np.array([tokens]).astype(np.uint32).reshape((1, -1))
runtime_top_k = np.array([top_k]).astype(np.uint32).reshape((1, -1))
runtime_top_p = np.array([top_p]).astype(np.float32).reshape((1, -1))
temperature_array = np.array([temperature]).astype(np.float32).reshape((1, -1))
len_penalty = np.array([length_penalty]).astype(np.float32).reshape((1, -1))
repetition_penalty_array = (
np.array([repetition_penalty]).astype(np.float32).reshape((1, -1))
)
random_seed = np.array([self.seed]).astype(np.uint64).reshape((1, -1))
beam_width_array = np.array([beam_width]).astype(np.uint32).reshape((1, -1))
streaming_data = np.array([[stream]], dtype=bool)
inputs = [
self._prepare_tensor("text_input", query),
self._prepare_tensor("max_tokens", request_output_len),
self._prepare_tensor("top_k", runtime_top_k),
self._prepare_tensor("top_p", runtime_top_p),
self._prepare_tensor("temperature", temperature_array),
self._prepare_tensor("length_penalty", len_penalty),
self._prepare_tensor("repetition_penalty", repetition_penalty_array),
self._prepare_tensor("random_seed", random_seed),
self._prepare_tensor("beam_width", beam_width_array),
self._prepare_tensor("stream", streaming_data),
]
return inputs
def _send_stop_signals(self, model_name: str, request_id: str) -> None:
"""Send the stop signal to the Triton Inference server."""
stop_inputs = self._generate_stop_signals()
self.client.async_stream_infer(
model_name,
stop_inputs,
request_id=request_id,
parameters={"Streaming": True},
)
def _generate_stop_signals(
self,
) -> List[grpcclient.InferInput]:
"""Generate the signal to stop the stream."""
inputs = [
grpcclient.InferInput("input_ids", [1, 1], "INT32"),
grpcclient.InferInput("input_lengths", [1, 1], "INT32"),
grpcclient.InferInput("request_output_len", [1, 1], "UINT32"),
grpcclient.InferInput("stop", [1, 1], "BOOL"),
]
inputs[0].set_data_from_numpy(np.empty([1, 1], dtype=np.int32))
inputs[1].set_data_from_numpy(np.zeros([1, 1], dtype=np.int32))
inputs[2].set_data_from_numpy(np.array([[0]], dtype=np.uint32))
inputs[3].set_data_from_numpy(np.array([[True]], dtype="bool"))
return inputs
@staticmethod
def _process_result(result: Dict[str, str]) -> str:
"""Post-process the result from the server."""
message = ModelInferResponse()
google.protobuf.json_format.Parse(json.dumps(result), message)
infer_result = grpcclient.InferResult(message)
np_res = infer_result.as_numpy("text_output")
generated_text = ""
if np_res is not None:
generated_text = "".join([token.decode() for token in np_res])
return generated_text
def _stream_callback(
self,
result_queue: queue.Queue[Union[Optional[Dict[str, str]], str]],
result: grpcclient.InferResult,
error: str,
stop_words: List[str],
) -> None:
"""Add streamed result to queue."""
if error:
result_queue.put(error)
else:
response_raw: dict = result.get_response(as_json=True)
# TODO: Check the response is a map rather than a string
if "outputs" in response_raw:
# the very last response might have no output, just the final flag
response = self._process_result(response_raw)
if response in stop_words:
result_queue.put(None)
else:
result_queue.put(response)
if response_raw["parameters"]["triton_final_response"]["bool_param"]:
# end of the generation
result_queue.put(None)
def stop_stream(
self, model_name: str, request_id: str, signal: bool = True
) -> None:
"""Close the streaming connection."""
if signal:
self._send_stop_signals(model_name, request_id)
self.client.stop_stream()
class StreamingResponseGenerator(queue.Queue):
"""A Generator that provides the inference results from an LLM."""
def __init__(
self,
client: grpcclient.InferenceServerClient,
request_id: str,
force_batch: bool,
stop_words: Sequence[str],
) -> None:
"""Instantiate the generator class."""
super().__init__()
self.client = client
self.request_id = request_id
self._batch = force_batch
self._stop_words = stop_words
def __iter__(self) -> StreamingResponseGenerator:
"""Return self as a generator."""
return self
def __next__(self) -> str:
"""Return the next retrieved token."""
val = self.get()
if val is None or val in self._stop_words:
self.client.stop_stream(
"tensorrt_llm", self.request_id, signal=not self._batch
)
raise StopIteration()
return val

View File

@ -0,0 +1,4 @@
[mypy]
# Empty global config
[mypy-tritonclient.*]
ignore_missing_imports = True

2148
libs/partners/nvidia-trt/poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,90 @@
[tool.poetry]
name = "langchain-nvidia-trt"
version = "0.0.1"
description = "An integration package connecting TritonTensorRT and LangChain"
authors = []
readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = ">=0.0.12"
tritonclient = { extras = ["all"], version = "^2.40.0" }
lint = "^1.2.1"
types-protobuf = "^4.24.0.4"
protobuf = "^3.5.0"
[tool.poetry.group.test]
optional = true
[tool.poetry.group.test.dependencies]
pytest = "^7.3.0"
freezegun = "^1.2.2"
pytest-mock = "^3.10.0"
syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
langchain-core = { path = "../../core", develop = true }
[tool.poetry.group.codespell]
optional = true
[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0"
[tool.poetry.group.test_integration]
optional = true
[tool.poetry.group.test_integration.dependencies]
[tool.poetry.group.lint]
optional = true
[tool.poetry.group.lint.dependencies]
ruff = "^0.1.5"
[tool.poetry.group.typing.dependencies]
mypy = "^0.991"
langchain-core = { path = "../../core", develop = true }
[tool.poetry.group.dev]
optional = true
[tool.poetry.group.dev.dependencies]
langchain-core = { path = "../../core", develop = true }
[tool.ruff]
select = [
"E", # pycodestyle
"F", # pyflakes
"I", # isort
]
[tool.mypy]
disallow_untyped_defs = "True"
[tool.coverage.run]
omit = ["tests/*"]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
# --strict-markers will raise errors on unknown marks.
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
#
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
#
# https://github.com/tophat/syrupy
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
# Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [
"requires: mark tests as requiring a specific library",
"asyncio: mark tests as requiring asyncio",
"compile: mark placeholder test used to compile integration tests without running them",
]
asyncio_mode = "auto"

View File

@ -0,0 +1,17 @@
import sys
import traceback
from importlib.machinery import SourceFileLoader
if __name__ == "__main__":
files = sys.argv[1:]
has_failure = False
for file in files:
try:
SourceFileLoader("x", file).load_module()
except Exception:
has_faillure = True
print(file)
traceback.print_exc()
print()
sys.exit(1 if has_failure else 0)

View File

@ -0,0 +1,27 @@
#!/bin/bash
#
# This script searches for lines starting with "import pydantic" or "from pydantic"
# in tracked files within a Git repository.
#
# Usage: ./scripts/check_pydantic.sh /path/to/repository
# Check if a path argument is provided
if [ $# -ne 1 ]; then
echo "Usage: $0 /path/to/repository"
exit 1
fi
repository_path="$1"
# Search for lines matching the pattern within the specified repository
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
# Check if any matching lines were found
if [ -n "$result" ]; then
echo "ERROR: The following lines need to be updated:"
echo "$result"
echo "Please replace the code with an import from langchain_core.pydantic_v1."
echo "For example, replace 'from pydantic import BaseModel'"
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
exit 1
fi

View File

@ -0,0 +1,17 @@
#!/bin/bash
set -eu
# Initialize a variable to keep track of errors
errors=0
# make sure not importing from langchain or langchain_experimental
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
# Decide on an exit status based on the errors
if [ "$errors" -gt 0 ]; then
exit 1
else
exit 0
fi

View File

@ -0,0 +1,7 @@
import pytest
@pytest.mark.compile
def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests."""
pass

View File

@ -0,0 +1,74 @@
"""Test TritonTensorRTLLM llm."""
import pytest
from langchain_nvidia_trt.llms import TritonTensorRTLLM
_MODEL_NAME = "ensemble"
@pytest.mark.skip(reason="Need a working Triton server")
def test_stream() -> None:
"""Test streaming tokens from OpenAI."""
llm = TritonTensorRTLLM(model_name=_MODEL_NAME)
for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token, str)
@pytest.mark.skip(reason="Need a working Triton server")
async def test_astream() -> None:
"""Test streaming tokens from OpenAI."""
llm = TritonTensorRTLLM(model_name=_MODEL_NAME)
async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token, str)
@pytest.mark.skip(reason="Need a working Triton server")
async def test_abatch() -> None:
"""Test streaming tokens from TritonTensorRTLLM."""
llm = TritonTensorRTLLM(model_name=_MODEL_NAME)
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token, str)
@pytest.mark.skip(reason="Need a working Triton server")
async def test_abatch_tags() -> None:
"""Test batch tokens from TritonTensorRTLLM."""
llm = TritonTensorRTLLM(model_name=_MODEL_NAME)
result = await llm.abatch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
)
for token in result:
assert isinstance(token, str)
@pytest.mark.skip(reason="Need a working Triton server")
def test_batch() -> None:
"""Test batch tokens from TritonTensorRTLLM."""
llm = TritonTensorRTLLM(model_name=_MODEL_NAME)
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token, str)
@pytest.mark.skip(reason="Need a working Triton server")
async def test_ainvoke() -> None:
"""Test invoke tokens from TritonTensorRTLLM."""
llm = TritonTensorRTLLM(model_name=_MODEL_NAME)
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result, str)
@pytest.mark.skip(reason="Need a working Triton server")
def test_invoke() -> None:
"""Test invoke tokens from TritonTensorRTLLM."""
llm = TritonTensorRTLLM(model_name=_MODEL_NAME)
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result, str)

View File

@ -0,0 +1,7 @@
from langchain_nvidia_trt import __all__
EXPECTED_ALL = ["TritonTensorRTLLM"]
def test_all_imports() -> None:
assert sorted(EXPECTED_ALL) == sorted(__all__)

View File

@ -0,0 +1,7 @@
"""Test TritonTensorRT Chat API wrapper."""
from langchain_nvidia_trt import TritonTensorRTLLM
def test_initialization() -> None:
"""Test integration initialization."""
TritonTensorRTLLM(model_name="ensemble", server_url="http://localhost:8001")