Add Javelin integration (#10275)

We are introducing the py integration to Javelin AI Gateway
www.getjavelin.io. Javelin is an enterprise-scale fast llm router &
gateway. Could you please review and let us know if there is anything
missing.

Javelin AI Gateway wraps Embedding, Chat and Completion LLMs. Uses
javelin_sdk under the covers (pip install javelin_sdk).

Author: Sharath Rajasekar, Twitter: @sharathr, @javelinai

Thanks!!
This commit is contained in:
Sharath Rajasekar
2023-09-20 16:36:39 -07:00
committed by GitHub
parent 957956ba6d
commit 96023f94d9
9 changed files with 837 additions and 0 deletions

View File

@@ -26,6 +26,7 @@ from langchain.chat_models.ernie import ErnieBotChat
from langchain.chat_models.fake import FakeListChatModel
from langchain.chat_models.google_palm import ChatGooglePalm
from langchain.chat_models.human import HumanInputChatModel
from langchain.chat_models.javelin_ai_gateway import ChatJavelinAIGateway
from langchain.chat_models.jinachat import JinaChat
from langchain.chat_models.konko import ChatKonko
from langchain.chat_models.litellm import ChatLiteLLM
@@ -53,6 +54,7 @@ __all__ = [
"ChatAnyscale",
"ChatLiteLLM",
"ErnieBotChat",
"ChatJavelinAIGateway",
"ChatKonko",
"QianfanChatEndpoint",
]

View File

@@ -0,0 +1,223 @@
import logging
from typing import Any, Dict, List, Mapping, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.pydantic_v1 import BaseModel, Extra
from langchain.schema import (
ChatGeneration,
ChatResult,
)
from langchain.schema.messages import (
AIMessage,
BaseMessage,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
logger = logging.getLogger(__name__)
# Ignoring type because below is valid pydantic code
# Unexpected keyword argument "extra" for "__init_subclass__" of "object" [call-arg]
class ChatParams(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
"""Parameters for the `Javelin AI Gateway` LLM."""
temperature: float = 0.0
stop: Optional[List[str]] = None
max_tokens: Optional[int] = None
class ChatJavelinAIGateway(BaseChatModel):
"""`Javelin AI Gateway` chat models API.
To use, you should have the ``javelin_sdk`` python package installed.
For more information, see https://docs.getjavelin.io
Example:
.. code-block:: python
from langchain.chat_models import ChatJavelinAIGateway
chat = ChatJavelinAIGateway(
gateway_uri="<javelin-ai-gateway-uri>",
route="<javelin-ai-gateway-chat-route>",
params={
"temperature": 0.1
}
)
"""
route: str
"""The route to use for the Javelin AI Gateway API."""
gateway_uri: Optional[str] = None
"""The URI for the Javelin AI Gateway API."""
params: Optional[ChatParams] = None
"""Parameters for the Javelin AI Gateway LLM."""
client: Any
"""javelin client."""
javelin_api_key: Optional[str] = None
"""The API key for the Javelin AI Gateway."""
def __init__(self, **kwargs: Any):
try:
from javelin_sdk import (
JavelinClient,
UnauthorizedError,
)
except ImportError:
raise ImportError(
"Could not import javelin_sdk python package. "
"Please install it with `pip install javelin_sdk`."
)
super().__init__(**kwargs)
if self.gateway_uri:
try:
self.client = JavelinClient(
base_url=self.gateway_uri, api_key=self.javelin_api_key
)
except UnauthorizedError as e:
raise ValueError("Javelin: Incorrect API Key.") from e
@property
def _default_params(self) -> Dict[str, Any]:
params: Dict[str, Any] = {
"gateway_uri": self.gateway_uri,
"javelin_api_key": self.javelin_api_key,
"route": self.route,
**(self.params.dict() if self.params else {}),
}
return params
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts = [
ChatJavelinAIGateway._convert_message_to_dict(message)
for message in messages
]
data: Dict[str, Any] = {
"messages": message_dicts,
**(self.params.dict() if self.params else {}),
}
resp = self.client.query_route(self.route, query_body=data)
return ChatJavelinAIGateway._create_chat_result(resp.dict())
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts = [
ChatJavelinAIGateway._convert_message_to_dict(message)
for message in messages
]
data: Dict[str, Any] = {
"messages": message_dicts,
**(self.params.dict() if self.params else {}),
}
resp = await self.client.aquery_route(self.route, query_body=data)
return ChatJavelinAIGateway._create_chat_result(resp.dict())
@property
def _identifying_params(self) -> Dict[str, Any]:
return self._default_params
def _get_invocation_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> Dict[str, Any]:
"""Get the parameters used to invoke the model FOR THE CALLBACKS."""
return {
**self._default_params,
**super()._get_invocation_params(stop=stop, **kwargs),
}
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "javelin-ai-gateway-chat"
@staticmethod
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
content = _dict["content"]
if role == "user":
return HumanMessage(content=content)
elif role == "assistant":
return AIMessage(content=content)
elif role == "system":
return SystemMessage(content=content)
else:
return ChatMessage(content=content, role=role)
@staticmethod
def _raise_functions_not_supported() -> None:
raise ValueError(
"Function messages are not supported by the Javelin AI Gateway. Please"
" create a feature request at https://docs.getjavelin.io"
)
@staticmethod
def _convert_message_to_dict(message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
raise ValueError(
"Function messages are not supported by the Javelin AI Gateway. Please"
" create a feature request at https://docs.getjavelin.io"
)
else:
raise ValueError(f"Got unknown message type: {message}")
if "function_call" in message.additional_kwargs:
ChatJavelinAIGateway._raise_functions_not_supported()
if message.additional_kwargs:
logger.warning(
"Additional message arguments are unsupported by Javelin AI Gateway "
" and will be ignored: %s",
message.additional_kwargs,
)
return message_dict
@staticmethod
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
generations = []
for candidate in response["llm_response"]["choices"]:
message = ChatJavelinAIGateway._convert_dict_to_message(
candidate["message"]
)
message_metadata = candidate.get("metadata", {})
gen = ChatGeneration(
message=message,
generation_info=dict(message_metadata),
)
generations.append(gen)
response_metadata = response.get("metadata", {})
return ChatResult(generations=generations, llm_output=response_metadata)

View File

@@ -40,6 +40,7 @@ from langchain.embeddings.huggingface import (
HuggingFaceInstructEmbeddings,
)
from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings
from langchain.embeddings.javelin_ai_gateway import JavelinAIGatewayEmbeddings
from langchain.embeddings.jina import JinaEmbeddings
from langchain.embeddings.llamacpp import LlamaCppEmbeddings
from langchain.embeddings.localai import LocalAIEmbeddings
@@ -107,6 +108,7 @@ __all__ = [
"AwaEmbeddings",
"HuggingFaceBgeEmbeddings",
"ErnieEmbeddings",
"JavelinAIGatewayEmbeddings",
"OllamaEmbeddings",
"QianfanEmbeddingsEndpoint",
]

View File

@@ -0,0 +1,110 @@
from __future__ import annotations
from typing import Any, Iterator, List, Optional
from langchain.pydantic_v1 import BaseModel
from langchain.schema.embeddings import Embeddings
def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
for i in range(0, len(texts), size):
yield texts[i : i + size]
class JavelinAIGatewayEmbeddings(Embeddings, BaseModel):
"""
Wrapper around embeddings LLMs in the Javelin AI Gateway.
To use, you should have the ``javelin_sdk`` python package installed.
For more information, see https://docs.getjavelin.io
Example:
.. code-block:: python
from langchain.embeddings import JavelinAIGatewayEmbeddings
embeddings = JavelinAIGatewayEmbeddings(
gateway_uri="<javelin-ai-gateway-uri>",
route="<your-javelin-gateway-embeddings-route>"
)
"""
client: Any
"""javelin client."""
route: str
"""The route to use for the Javelin AI Gateway API."""
gateway_uri: Optional[str] = None
"""The URI for the Javelin AI Gateway API."""
javelin_api_key: Optional[str] = None
"""The API key for the Javelin AI Gateway API."""
def __init__(self, **kwargs: Any):
try:
from javelin_sdk import (
JavelinClient,
UnauthorizedError,
)
except ImportError:
raise ImportError(
"Could not import javelin_sdk python package. "
"Please install it with `pip install javelin_sdk`."
)
super().__init__(**kwargs)
if self.gateway_uri:
try:
self.client = JavelinClient(
base_url=self.gateway_uri, api_key=self.javelin_api_key
)
except UnauthorizedError as e:
raise ValueError("Javelin: Incorrect API Key.") from e
def _query(self, texts: List[str]) -> List[List[float]]:
embeddings = []
for txt in _chunk(texts, 20):
try:
resp = self.client.query_route(self.route, query_body={"input": txt})
resp_dict = resp.dict()
embeddings_chunk = resp_dict.get("llm_response", {}).get("data", [])
for item in embeddings_chunk:
if "embedding" in item:
embeddings.append(item["embedding"])
except ValueError as e:
print("Failed to query route: " + str(e))
return embeddings
async def _aquery(self, texts: List[str]) -> List[List[float]]:
embeddings = []
for txt in _chunk(texts, 20):
try:
resp = await self.client.aquery_route(
self.route, query_body={"input": txt}
)
resp_dict = resp.dict()
embeddings_chunk = resp_dict.get("llm_response", {}).get("data", [])
for item in embeddings_chunk:
if "embedding" in item:
embeddings.append(item["embedding"])
except ValueError as e:
print("Failed to query route: " + str(e))
return embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return self._query(texts)
def embed_query(self, text: str) -> List[float]:
return self._query([text])[0]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
return await self._aquery(texts)
async def aembed_query(self, text: str) -> List[float]:
result = await self._aquery([text])
return result[0]

View File

@@ -54,6 +54,7 @@ from langchain.llms.huggingface_hub import HuggingFaceHub
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference
from langchain.llms.human import HumanInputLLM
from langchain.llms.javelin_ai_gateway import JavelinAIGateway
from langchain.llms.koboldai import KoboldApiLLM
from langchain.llms.llamacpp import LlamaCpp
from langchain.llms.manifest import ManifestWrapper
@@ -161,6 +162,7 @@ __all__ = [
"Writer",
"OctoAIEndpoint",
"Xinference",
"JavelinAIGateway",
"QianfanLLMEndpoint",
]
@@ -230,5 +232,6 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"vllm_openai": VLLMOpenAI,
"writer": Writer,
"xinference": Xinference,
"javelin-ai-gateway": JavelinAIGateway,
"qianfan_endpoint": QianfanLLMEndpoint,
}

View File

@@ -0,0 +1,152 @@
from __future__ import annotations
from typing import Any, Dict, List, Mapping, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.pydantic_v1 import BaseModel, Extra
# Ignoring type because below is valid pydantic code
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
class Params(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
"""Parameters for the Javelin AI Gateway LLM."""
temperature: float = 0.0
stop: Optional[List[str]] = None
max_tokens: Optional[int] = None
class JavelinAIGateway(LLM):
"""
Wrapper around completions LLMs in the Javelin AI Gateway.
To use, you should have the ``javelin_sdk`` python package installed.
For more information, see https://docs.getjavelin.io
Example:
.. code-block:: python
from langchain.llms import JavelinAIGateway
completions = JavelinAIGateway(
gateway_uri="<your-javelin-ai-gateway-uri>",
route="<your-javelin-ai-gateway-completions-route>",
params={
"temperature": 0.1
}
)
"""
route: str
"""The route to use for the Javelin AI Gateway API."""
client: Optional[Any] = None
"""The Javelin AI Gateway client."""
gateway_uri: Optional[str] = None
"""The URI of the Javelin AI Gateway API."""
params: Optional[Params] = None
"""Parameters for the Javelin AI Gateway API."""
javelin_api_key: Optional[str] = None
"""The API key for the Javelin AI Gateway API."""
def __init__(self, **kwargs: Any):
try:
from javelin_sdk import (
JavelinClient,
UnauthorizedError,
)
except ImportError:
raise ImportError(
"Could not import javelin_sdk python package. "
"Please install it with `pip install javelin_sdk`."
)
super().__init__(**kwargs)
if self.gateway_uri:
try:
self.client = JavelinClient(
base_url=self.gateway_uri, api_key=self.javelin_api_key
)
except UnauthorizedError as e:
raise ValueError("Javelin: Incorrect API Key.") from e
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Javelin AI Gateway API."""
params: Dict[str, Any] = {
"gateway_uri": self.gateway_uri,
"route": self.route,
"javelin_api_key": self.javelin_api_key,
**(self.params.dict() if self.params else {}),
}
return params
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return self._default_params
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call the Javelin AI Gateway API."""
data: Dict[str, Any] = {
"prompt": prompt,
**(self.params.dict() if self.params else {}),
}
if s := (stop or (self.params.stop if self.params else None)):
data["stop"] = s
if self.client is not None:
resp = self.client.query_route(self.route, query_body=data)
else:
raise ValueError("Javelin client is not initialized.")
resp_dict = resp.dict()
try:
return resp_dict["llm_response"]["choices"][0]["text"]
except KeyError:
return ""
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call async the Javelin AI Gateway API."""
data: Dict[str, Any] = {
"prompt": prompt,
**(self.params.dict() if self.params else {}),
}
if s := (stop or (self.params.stop if self.params else None)):
data["stop"] = s
if self.client is not None:
resp = await self.client.aquery_route(self.route, query_body=data)
else:
raise ValueError("Javelin client is not initialized.")
resp_dict = resp.dict()
try:
return resp_dict["llm_response"]["choices"][0]["text"]
except KeyError:
return ""
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "javelin-ai-gateway"