feat(llms): support vLLM's OpenAI-compatible server (#9179)

This PR aims at supporting [vLLM's OpenAI-compatible server
feature](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html#openai-compatible-server),
i.e. allowing to call vLLM's LLMs like if they were OpenAI's.

I've also udpated the related notebook providing an example usage. At
the moment, vLLM only supports the `Completion` API.
This commit is contained in:
Massimiliano Pronesti
2023-08-14 08:03:05 +02:00
committed by GitHub
parent 621da3c164
commit d95eeaedbe
3 changed files with 73 additions and 1 deletions

View File

@@ -80,7 +80,7 @@ from langchain.llms.textgen import TextGen
from langchain.llms.titan_takeoff import TitanTakeoff
from langchain.llms.tongyi import Tongyi
from langchain.llms.vertexai import VertexAI
from langchain.llms.vllm import VLLM
from langchain.llms.vllm import VLLM, VLLMOpenAI
from langchain.llms.writer import Writer
from langchain.llms.xinference import Xinference
@@ -149,6 +149,7 @@ __all__ = [
"Tongyi",
"VertexAI",
"VLLM",
"VLLMOpenAI",
"Writer",
"OctoAIEndpoint",
"Xinference",
@@ -213,6 +214,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"openllm": OpenLLM,
"openllm_client": OpenLLM,
"vllm": VLLM,
"vllm_openai": VLLMOpenAI,
"writer": Writer,
"xinference": Xinference,
}

View File

@@ -4,6 +4,7 @@ from pydantic import root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import BaseLLM
from langchain.llms.openai import BaseOpenAI
from langchain.schema.output import Generation, LLMResult
@@ -127,3 +128,27 @@ class VLLM(BaseLLM):
def _llm_type(self) -> str:
"""Return type of llm."""
return "vllm"
class VLLMOpenAI(BaseOpenAI):
"""vLLM OpenAI-compatible API client"""
@property
def _invocation_params(self) -> Dict[str, Any]:
"""Get the parameters used to invoke the model."""
openai_creds: Dict[str, Any] = {
"api_key": self.openai_api_key,
"api_base": self.openai_api_base,
}
return {
"model": self.model_name,
**openai_creds,
**self._default_params,
"logit_bias": None,
}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "vllm-openai"