mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-03 03:38:06 +00:00
[Partner] NVIDIA TRT Package (#14733)
Simplify #13976 and add as a separate package. - [] Add README - [X] Add doc notebook - [X] Add simple LLM integration --------- Co-authored-by: Jeremy Dyer <jdye64@gmail.com>
This commit is contained in:
parent
0d4cbbcc85
commit
583696732c
1
libs/partners/nvidia-trt/.gitignore
vendored
Normal file
1
libs/partners/nvidia-trt/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
__pycache__
|
21
libs/partners/nvidia-trt/LICENSE
Normal file
21
libs/partners/nvidia-trt/LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
59
libs/partners/nvidia-trt/Makefile
Normal file
59
libs/partners/nvidia-trt/Makefile
Normal file
@ -0,0 +1,59 @@
|
||||
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
|
||||
test:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
||||
# Define a variable for Python and notebook files.
|
||||
PYTHON_FILES=.
|
||||
MYPY_CACHE=.mypy_cache
|
||||
lint format: PYTHON_FILES=.
|
||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/nvidia-trt --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||
lint_package: PYTHON_FILES=langchain_nvidia_trt
|
||||
lint_tests: PYTHON_FILES=tests
|
||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
|
||||
lint lint_diff lint_package lint_tests:
|
||||
poetry run ruff .
|
||||
poetry run ruff format $(PYTHON_FILES) --diff
|
||||
poetry run ruff --select I $(PYTHON_FILES)
|
||||
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
|
||||
format format_diff:
|
||||
poetry run ruff format $(PYTHON_FILES)
|
||||
poetry run ruff --select I --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
poetry run codespell --toml pyproject.toml
|
||||
|
||||
spell_fix:
|
||||
poetry run codespell --toml pyproject.toml -w
|
||||
|
||||
check_imports: $(shell find langchain_nvidia_trt -name '*.py')
|
||||
poetry run python ./scripts/check_imports.py $^
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
|
||||
help:
|
||||
@echo '----'
|
||||
@echo 'check_imports - check imports'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run linters'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'tests - run unit tests'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
1
libs/partners/nvidia-trt/README.md
Normal file
1
libs/partners/nvidia-trt/README.md
Normal file
@ -0,0 +1 @@
|
||||
# langchain-nvidia-trt
|
106
libs/partners/nvidia-trt/docs/llms.ipynb
Normal file
106
libs/partners/nvidia-trt/docs/llms.ipynb
Normal file
@ -0,0 +1,106 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "raw",
|
||||
"id": "67db2992",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"sidebar_label: TritonTensorRT\n",
|
||||
"---"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b56b221d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Nvidia Triton+TRT-LLM\n",
|
||||
"\n",
|
||||
"Nvidia's Triton is an inference server that provides an API style access to hosted LLM models. Likewise, Nvidia TensorRT-LLM, often abbreviated as TRT-LLM, is a GPU accelerated SDK for running optimizations and inference on LLM models. This connector allows for Langchain to remotely interact with a Triton inference server over GRPC or HTTP to performance accelerated inference operations.\n",
|
||||
"\n",
|
||||
"[Triton Inference Server Github](https://github.com/triton-inference-server/server)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"## TritonTensorRTLLM\n",
|
||||
"\n",
|
||||
"This example goes over how to use LangChain to interact with `TritonTensorRT` LLMs. To install, run the following command:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "59c710c4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# install package\n",
|
||||
"%pip install -U langchain-nvidia-trt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0ee90032",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create the Triton+TRT-LLM instance\n",
|
||||
"\n",
|
||||
"Remember that a Triton instance represents a running server instance therefore you should ensure you have a valid server configuration running and change the `localhost:8001` to the correct IP/hostname:port combination for your server.\n",
|
||||
"\n",
|
||||
"An example of setting up this environment can be found at Nvidia's (GenerativeAIExamples Github Repo)[https://github.com/NVIDIA/GenerativeAIExamples/tree/main/RetrievalAugmentedGeneration]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "035dea0f",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_core.prompts import PromptTemplate\n",
|
||||
"from langchain_nvidia_trt.llms import TritonTensorRTLLM\n",
|
||||
"\n",
|
||||
"template = \"\"\"Question: {question}\n",
|
||||
"\n",
|
||||
"Answer: Let's think step by step.\"\"\"\n",
|
||||
"\n",
|
||||
"prompt = PromptTemplate.from_template(template)\n",
|
||||
"\n",
|
||||
"# Connect to the TRT-LLM Llama-2 model running on the Triton server at the url below\n",
|
||||
"triton_llm = TritonTensorRTLLM(server_url =\"localhost:8001\", model_name=\"ensemble\", tokens=500)\n",
|
||||
"\n",
|
||||
"chain = prompt | triton_llm \n",
|
||||
"\n",
|
||||
"chain.invoke({\"question\": \"What is LangChain?\"})"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "e971737741ff4ec9aff7dc6155a1060a59a8a6d52c757dbbe66bf8ee389494b1"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -0,0 +1,3 @@
|
||||
from langchain_nvidia_trt.llms import TritonTensorRTLLM
|
||||
|
||||
__all__ = ["TritonTensorRTLLM"]
|
404
libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py
Normal file
404
libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py
Normal file
@ -0,0 +1,404 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import queue
|
||||
import random
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
|
||||
|
||||
import google.protobuf.json_format
|
||||
import numpy as np
|
||||
import tritonclient.grpc as grpcclient
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from tritonclient.grpc.service_pb2 import ModelInferResponse
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
|
||||
class TritonTensorRTError(Exception):
|
||||
"""Base exception for TritonTensorRT."""
|
||||
|
||||
|
||||
class TritonTensorRTRuntimeError(TritonTensorRTError, RuntimeError):
|
||||
"""Runtime error for TritonTensorRT."""
|
||||
|
||||
|
||||
class TritonTensorRTLLM(BaseLLM):
|
||||
"""TRTLLM triton models.
|
||||
|
||||
Arguments:
|
||||
server_url: (str) The URL of the Triton inference server to use.
|
||||
model_name: (str) The name of the Triton TRT model to use.
|
||||
temperature: (str) Temperature to use for sampling
|
||||
top_p: (float) The top-p value to use for sampling
|
||||
top_k: (float) The top k values use for sampling
|
||||
beam_width: (int) Last n number of tokens to penalize
|
||||
repetition_penalty: (int) Last n number of tokens to penalize
|
||||
length_penalty: (float) The penalty to apply repeated tokens
|
||||
tokens: (int) The maximum number of tokens to generate.
|
||||
client: The client object used to communicate with the inference server
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_nvidia_trt import TritonTensorRTLLM
|
||||
|
||||
model = TritonTensorRTLLM()
|
||||
|
||||
|
||||
"""
|
||||
|
||||
server_url: Optional[str] = Field(None, alias="server_url")
|
||||
model_name: str = Field(
|
||||
..., description="The name of the model to use, such as 'ensemble'."
|
||||
)
|
||||
## Optional args for the model
|
||||
temperature: float = 1.0
|
||||
top_p: float = 0
|
||||
top_k: int = 1
|
||||
tokens: int = 100
|
||||
beam_width: int = 1
|
||||
repetition_penalty: float = 1.0
|
||||
length_penalty: float = 1.0
|
||||
client: grpcclient.InferenceServerClient
|
||||
stop: List[str] = Field(
|
||||
default_factory=lambda: ["</s>"], description="Stop tokens."
|
||||
)
|
||||
seed: int = Field(42, description="The seed to use for random generation.")
|
||||
load_model: bool = Field(
|
||||
True,
|
||||
description="Request the inference server to load the specified model.\
|
||||
Certain Triton configurations do not allow for this operation.",
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
"""Ensure the client streaming connection is properly shutdown"""
|
||||
self.client.close()
|
||||
|
||||
@root_validator(pre=True, allow_reuse=True)
|
||||
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate that python package exists in environment."""
|
||||
if not values.get("client"):
|
||||
values["client"] = grpcclient.InferenceServerClient(values["server_url"])
|
||||
return values
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of LLM."""
|
||||
return "nvidia-trt-llm"
|
||||
|
||||
@property
|
||||
def _model_default_parameters(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"tokens": self.tokens,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
"length_penalty": self.length_penalty,
|
||||
"beam_width": self.beam_width,
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get all the identifying parameters."""
|
||||
return {
|
||||
"server_url": self.server_url,
|
||||
"model_name": self.model_name,
|
||||
**self._model_default_parameters,
|
||||
}
|
||||
|
||||
def _get_invocation_params(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {**self._model_default_parameters, **kwargs}
|
||||
|
||||
def get_model_list(self) -> List[str]:
|
||||
"""Get a list of models loaded in the triton server."""
|
||||
res = self.client.get_model_repository_index(as_json=True)
|
||||
return [model["name"] for model in res["models"]]
|
||||
|
||||
def _load_model(self, model_name: str, timeout: int = 1000) -> None:
|
||||
"""Load a model into the server."""
|
||||
if self.client.is_model_ready(model_name):
|
||||
return
|
||||
|
||||
self.client.load_model(model_name)
|
||||
t0 = time.perf_counter()
|
||||
t1 = t0
|
||||
while not self.client.is_model_ready(model_name) and t1 - t0 < timeout:
|
||||
t1 = time.perf_counter()
|
||||
|
||||
if not self.client.is_model_ready(model_name):
|
||||
raise TritonTensorRTRuntimeError(
|
||||
f"Failed to load {model_name} on Triton in {timeout}s"
|
||||
)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
self._load_model(self.model_name)
|
||||
|
||||
invocation_params = self._get_invocation_params(**kwargs)
|
||||
stop_words = stop if stop is not None else self.stop
|
||||
generations = []
|
||||
# TODO: We should handle the native batching instead.
|
||||
for prompt in prompts:
|
||||
invoc_params = {**invocation_params, "prompt": [[prompt]]}
|
||||
result: str = self._request(
|
||||
self.model_name,
|
||||
stop=stop_words,
|
||||
**invoc_params,
|
||||
)
|
||||
generations.append([Generation(text=result, generation_info={})])
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
self._load_model(self.model_name)
|
||||
|
||||
invocation_params = self._get_invocation_params(**kwargs, prompt=[[prompt]])
|
||||
stop_words = stop if stop is not None else self.stop
|
||||
|
||||
inputs = self._generate_inputs(stream=True, **invocation_params)
|
||||
outputs = self._generate_outputs()
|
||||
|
||||
result_queue = self._invoke_triton(self.model_name, inputs, outputs, stop_words)
|
||||
|
||||
for token in result_queue:
|
||||
yield GenerationChunk(text=token)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token)
|
||||
|
||||
self.client.stop_stream()
|
||||
|
||||
##### BELOW ARE METHODS PREVIOUSLY ONLY IN THE GRPC CLIENT
|
||||
|
||||
def _request(
|
||||
self,
|
||||
model_name: str,
|
||||
prompt: Sequence[Sequence[str]],
|
||||
stop: Optional[List[str]] = None,
|
||||
**params: Any,
|
||||
) -> str:
|
||||
"""Request inferencing from the triton server."""
|
||||
# create model inputs and outputs
|
||||
inputs = self._generate_inputs(stream=False, prompt=prompt, **params)
|
||||
outputs = self._generate_outputs()
|
||||
|
||||
result_queue = self._invoke_triton(self.model_name, inputs, outputs, stop)
|
||||
|
||||
result_str = ""
|
||||
for token in result_queue:
|
||||
result_str += token
|
||||
|
||||
self.client.stop_stream()
|
||||
|
||||
return result_str
|
||||
|
||||
def _invoke_triton(self, model_name, inputs, outputs, stop_words):
|
||||
if not self.client.is_model_ready(model_name):
|
||||
raise RuntimeError("Cannot request streaming, model is not loaded")
|
||||
|
||||
request_id = str(random.randint(1, 9999999)) # nosec
|
||||
|
||||
result_queue = StreamingResponseGenerator(
|
||||
self,
|
||||
request_id,
|
||||
force_batch=False,
|
||||
stop_words=stop_words,
|
||||
)
|
||||
|
||||
self.client.start_stream(
|
||||
callback=partial(
|
||||
self._stream_callback,
|
||||
result_queue,
|
||||
stop_words=stop_words,
|
||||
)
|
||||
)
|
||||
|
||||
# Even though this request may not be a streaming request certain configurations
|
||||
# in Triton prevent the GRPC server from accepting none streaming connections.
|
||||
# Therefore we call the streaming API and combine the streamed results.
|
||||
self.client.async_stream_infer(
|
||||
model_name=model_name,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
return result_queue
|
||||
|
||||
def _generate_outputs(
|
||||
self,
|
||||
) -> List[grpcclient.InferRequestedOutput]:
|
||||
"""Generate the expected output structure."""
|
||||
return [grpcclient.InferRequestedOutput("text_output")]
|
||||
|
||||
def _prepare_tensor(
|
||||
self, name: str, input_data: np.ndarray
|
||||
) -> grpcclient.InferInput:
|
||||
"""Prepare an input data structure."""
|
||||
|
||||
t = grpcclient.InferInput(
|
||||
name, input_data.shape, np_to_triton_dtype(input_data.dtype)
|
||||
)
|
||||
t.set_data_from_numpy(input_data)
|
||||
return t
|
||||
|
||||
def _generate_inputs(
|
||||
self,
|
||||
prompt: Sequence[Sequence[str]],
|
||||
tokens: int = 300,
|
||||
temperature: float = 1.0,
|
||||
top_k: float = 1,
|
||||
top_p: float = 0,
|
||||
beam_width: int = 1,
|
||||
repetition_penalty: float = 1,
|
||||
length_penalty: float = 1.0,
|
||||
stream: bool = True,
|
||||
) -> List[grpcclient.InferRequestedOutput]:
|
||||
"""Create the input for the triton inference server."""
|
||||
query = np.array(prompt).astype(object)
|
||||
request_output_len = np.array([tokens]).astype(np.uint32).reshape((1, -1))
|
||||
runtime_top_k = np.array([top_k]).astype(np.uint32).reshape((1, -1))
|
||||
runtime_top_p = np.array([top_p]).astype(np.float32).reshape((1, -1))
|
||||
temperature_array = np.array([temperature]).astype(np.float32).reshape((1, -1))
|
||||
len_penalty = np.array([length_penalty]).astype(np.float32).reshape((1, -1))
|
||||
repetition_penalty_array = (
|
||||
np.array([repetition_penalty]).astype(np.float32).reshape((1, -1))
|
||||
)
|
||||
random_seed = np.array([self.seed]).astype(np.uint64).reshape((1, -1))
|
||||
beam_width_array = np.array([beam_width]).astype(np.uint32).reshape((1, -1))
|
||||
streaming_data = np.array([[stream]], dtype=bool)
|
||||
|
||||
inputs = [
|
||||
self._prepare_tensor("text_input", query),
|
||||
self._prepare_tensor("max_tokens", request_output_len),
|
||||
self._prepare_tensor("top_k", runtime_top_k),
|
||||
self._prepare_tensor("top_p", runtime_top_p),
|
||||
self._prepare_tensor("temperature", temperature_array),
|
||||
self._prepare_tensor("length_penalty", len_penalty),
|
||||
self._prepare_tensor("repetition_penalty", repetition_penalty_array),
|
||||
self._prepare_tensor("random_seed", random_seed),
|
||||
self._prepare_tensor("beam_width", beam_width_array),
|
||||
self._prepare_tensor("stream", streaming_data),
|
||||
]
|
||||
return inputs
|
||||
|
||||
def _send_stop_signals(self, model_name: str, request_id: str) -> None:
|
||||
"""Send the stop signal to the Triton Inference server."""
|
||||
stop_inputs = self._generate_stop_signals()
|
||||
self.client.async_stream_infer(
|
||||
model_name,
|
||||
stop_inputs,
|
||||
request_id=request_id,
|
||||
parameters={"Streaming": True},
|
||||
)
|
||||
|
||||
def _generate_stop_signals(
|
||||
self,
|
||||
) -> List[grpcclient.InferInput]:
|
||||
"""Generate the signal to stop the stream."""
|
||||
inputs = [
|
||||
grpcclient.InferInput("input_ids", [1, 1], "INT32"),
|
||||
grpcclient.InferInput("input_lengths", [1, 1], "INT32"),
|
||||
grpcclient.InferInput("request_output_len", [1, 1], "UINT32"),
|
||||
grpcclient.InferInput("stop", [1, 1], "BOOL"),
|
||||
]
|
||||
inputs[0].set_data_from_numpy(np.empty([1, 1], dtype=np.int32))
|
||||
inputs[1].set_data_from_numpy(np.zeros([1, 1], dtype=np.int32))
|
||||
inputs[2].set_data_from_numpy(np.array([[0]], dtype=np.uint32))
|
||||
inputs[3].set_data_from_numpy(np.array([[True]], dtype="bool"))
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def _process_result(result: Dict[str, str]) -> str:
|
||||
"""Post-process the result from the server."""
|
||||
|
||||
message = ModelInferResponse()
|
||||
google.protobuf.json_format.Parse(json.dumps(result), message)
|
||||
infer_result = grpcclient.InferResult(message)
|
||||
np_res = infer_result.as_numpy("text_output")
|
||||
|
||||
generated_text = ""
|
||||
if np_res is not None:
|
||||
generated_text = "".join([token.decode() for token in np_res])
|
||||
|
||||
return generated_text
|
||||
|
||||
def _stream_callback(
|
||||
self,
|
||||
result_queue: queue.Queue[Union[Optional[Dict[str, str]], str]],
|
||||
result: grpcclient.InferResult,
|
||||
error: str,
|
||||
stop_words: List[str],
|
||||
) -> None:
|
||||
"""Add streamed result to queue."""
|
||||
if error:
|
||||
result_queue.put(error)
|
||||
else:
|
||||
response_raw: dict = result.get_response(as_json=True)
|
||||
# TODO: Check the response is a map rather than a string
|
||||
if "outputs" in response_raw:
|
||||
# the very last response might have no output, just the final flag
|
||||
response = self._process_result(response_raw)
|
||||
|
||||
if response in stop_words:
|
||||
result_queue.put(None)
|
||||
else:
|
||||
result_queue.put(response)
|
||||
|
||||
if response_raw["parameters"]["triton_final_response"]["bool_param"]:
|
||||
# end of the generation
|
||||
result_queue.put(None)
|
||||
|
||||
def stop_stream(
|
||||
self, model_name: str, request_id: str, signal: bool = True
|
||||
) -> None:
|
||||
"""Close the streaming connection."""
|
||||
if signal:
|
||||
self._send_stop_signals(model_name, request_id)
|
||||
self.client.stop_stream()
|
||||
|
||||
|
||||
class StreamingResponseGenerator(queue.Queue):
|
||||
"""A Generator that provides the inference results from an LLM."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: grpcclient.InferenceServerClient,
|
||||
request_id: str,
|
||||
force_batch: bool,
|
||||
stop_words: Sequence[str],
|
||||
) -> None:
|
||||
"""Instantiate the generator class."""
|
||||
super().__init__()
|
||||
self.client = client
|
||||
self.request_id = request_id
|
||||
self._batch = force_batch
|
||||
self._stop_words = stop_words
|
||||
|
||||
def __iter__(self) -> StreamingResponseGenerator:
|
||||
"""Return self as a generator."""
|
||||
return self
|
||||
|
||||
def __next__(self) -> str:
|
||||
"""Return the next retrieved token."""
|
||||
val = self.get()
|
||||
if val is None or val in self._stop_words:
|
||||
self.client.stop_stream(
|
||||
"tensorrt_llm", self.request_id, signal=not self._batch
|
||||
)
|
||||
raise StopIteration()
|
||||
return val
|
4
libs/partners/nvidia-trt/mypy.ini
Normal file
4
libs/partners/nvidia-trt/mypy.ini
Normal file
@ -0,0 +1,4 @@
|
||||
[mypy]
|
||||
# Empty global config
|
||||
[mypy-tritonclient.*]
|
||||
ignore_missing_imports = True
|
2148
libs/partners/nvidia-trt/poetry.lock
generated
Normal file
2148
libs/partners/nvidia-trt/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
90
libs/partners/nvidia-trt/pyproject.toml
Normal file
90
libs/partners/nvidia-trt/pyproject.toml
Normal file
@ -0,0 +1,90 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-nvidia-trt"
|
||||
version = "0.0.1"
|
||||
description = "An integration package connecting TritonTensorRT and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = ">=0.0.12"
|
||||
tritonclient = { extras = ["all"], version = "^2.40.0" }
|
||||
lint = "^1.2.1"
|
||||
types-protobuf = "^4.24.0.4"
|
||||
protobuf = "^3.5.0"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^7.3.0"
|
||||
freezegun = "^1.2.2"
|
||||
pytest-mock = "^3.10.0"
|
||||
syrupy = "^4.0.2"
|
||||
pytest-watcher = "^0.3.4"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.codespell.dependencies]
|
||||
codespell = "^2.2.0"
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test_integration.dependencies]
|
||||
|
||||
[tool.poetry.group.lint]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.1.5"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^0.991"
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.ruff]
|
||||
select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
# --strict-markers will raise errors on unknown marks.
|
||||
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
||||
#
|
||||
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
||||
# --strict-config any warnings encountered while parsing the `pytest`
|
||||
# section of the configuration file raise errors.
|
||||
#
|
||||
# https://github.com/tophat/syrupy
|
||||
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
||||
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
|
||||
# Registering custom markers.
|
||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||
markers = [
|
||||
"requires: mark tests as requiring a specific library",
|
||||
"asyncio: mark tests as requiring asyncio",
|
||||
"compile: mark placeholder test used to compile integration tests without running them",
|
||||
]
|
||||
asyncio_mode = "auto"
|
17
libs/partners/nvidia-trt/scripts/check_imports.py
Normal file
17
libs/partners/nvidia-trt/scripts/check_imports.py
Normal file
@ -0,0 +1,17 @@
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.machinery import SourceFileLoader
|
||||
|
||||
if __name__ == "__main__":
|
||||
files = sys.argv[1:]
|
||||
has_failure = False
|
||||
for file in files:
|
||||
try:
|
||||
SourceFileLoader("x", file).load_module()
|
||||
except Exception:
|
||||
has_faillure = True
|
||||
print(file)
|
||||
traceback.print_exc()
|
||||
print()
|
||||
|
||||
sys.exit(1 if has_failure else 0)
|
27
libs/partners/nvidia-trt/scripts/check_pydantic.sh
Executable file
27
libs/partners/nvidia-trt/scripts/check_pydantic.sh
Executable file
@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
||||
# in tracked files within a Git repository.
|
||||
#
|
||||
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
||||
|
||||
# Check if a path argument is provided
|
||||
if [ $# -ne 1 ]; then
|
||||
echo "Usage: $0 /path/to/repository"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
repository_path="$1"
|
||||
|
||||
# Search for lines matching the pattern within the specified repository
|
||||
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
||||
|
||||
# Check if any matching lines were found
|
||||
if [ -n "$result" ]; then
|
||||
echo "ERROR: The following lines need to be updated:"
|
||||
echo "$result"
|
||||
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
||||
echo "For example, replace 'from pydantic import BaseModel'"
|
||||
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
||||
exit 1
|
||||
fi
|
17
libs/partners/nvidia-trt/scripts/lint_imports.sh
Executable file
17
libs/partners/nvidia-trt/scripts/lint_imports.sh
Executable file
@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -eu
|
||||
|
||||
# Initialize a variable to keep track of errors
|
||||
errors=0
|
||||
|
||||
# make sure not importing from langchain or langchain_experimental
|
||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
||||
|
||||
# Decide on an exit status based on the errors
|
||||
if [ "$errors" -gt 0 ]; then
|
||||
exit 1
|
||||
else
|
||||
exit 0
|
||||
fi
|
0
libs/partners/nvidia-trt/tests/__init__.py
Normal file
0
libs/partners/nvidia-trt/tests/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
||||
pass
|
@ -0,0 +1,74 @@
|
||||
"""Test TritonTensorRTLLM llm."""
|
||||
import pytest
|
||||
|
||||
from langchain_nvidia_trt.llms import TritonTensorRTLLM
|
||||
|
||||
_MODEL_NAME = "ensemble"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Need a working Triton server")
|
||||
def test_stream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = TritonTensorRTLLM(model_name=_MODEL_NAME)
|
||||
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Need a working Triton server")
|
||||
async def test_astream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = TritonTensorRTLLM(model_name=_MODEL_NAME)
|
||||
|
||||
async for token in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Need a working Triton server")
|
||||
async def test_abatch() -> None:
|
||||
"""Test streaming tokens from TritonTensorRTLLM."""
|
||||
llm = TritonTensorRTLLM(model_name=_MODEL_NAME)
|
||||
|
||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Need a working Triton server")
|
||||
async def test_abatch_tags() -> None:
|
||||
"""Test batch tokens from TritonTensorRTLLM."""
|
||||
llm = TritonTensorRTLLM(model_name=_MODEL_NAME)
|
||||
|
||||
result = await llm.abatch(
|
||||
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||
)
|
||||
for token in result:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Need a working Triton server")
|
||||
def test_batch() -> None:
|
||||
"""Test batch tokens from TritonTensorRTLLM."""
|
||||
llm = TritonTensorRTLLM(model_name=_MODEL_NAME)
|
||||
|
||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Need a working Triton server")
|
||||
async def test_ainvoke() -> None:
|
||||
"""Test invoke tokens from TritonTensorRTLLM."""
|
||||
llm = TritonTensorRTLLM(model_name=_MODEL_NAME)
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Need a working Triton server")
|
||||
def test_invoke() -> None:
|
||||
"""Test invoke tokens from TritonTensorRTLLM."""
|
||||
llm = TritonTensorRTLLM(model_name=_MODEL_NAME)
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result, str)
|
@ -0,0 +1,7 @@
|
||||
from langchain_nvidia_trt import __all__
|
||||
|
||||
EXPECTED_ALL = ["TritonTensorRTLLM"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
7
libs/partners/nvidia-trt/tests/unit_tests/test_llms.py
Normal file
7
libs/partners/nvidia-trt/tests/unit_tests/test_llms.py
Normal file
@ -0,0 +1,7 @@
|
||||
"""Test TritonTensorRT Chat API wrapper."""
|
||||
from langchain_nvidia_trt import TritonTensorRTLLM
|
||||
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test integration initialization."""
|
||||
TritonTensorRTLLM(model_name="ensemble", server_url="http://localhost:8001")
|
Loading…
Reference in New Issue
Block a user