mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-11 07:50:47 +00:00
Add Javelin integration (#10275)
We are introducing the py integration to Javelin AI Gateway www.getjavelin.io. Javelin is an enterprise-scale fast llm router & gateway. Could you please review and let us know if there is anything missing. Javelin AI Gateway wraps Embedding, Chat and Completion LLMs. Uses javelin_sdk under the covers (pip install javelin_sdk). Author: Sharath Rajasekar, Twitter: @sharathr, @javelinai Thanks!!
This commit is contained in:
committed by
GitHub
parent
957956ba6d
commit
96023f94d9
@@ -26,6 +26,7 @@ from langchain.chat_models.ernie import ErnieBotChat
|
||||
from langchain.chat_models.fake import FakeListChatModel
|
||||
from langchain.chat_models.google_palm import ChatGooglePalm
|
||||
from langchain.chat_models.human import HumanInputChatModel
|
||||
from langchain.chat_models.javelin_ai_gateway import ChatJavelinAIGateway
|
||||
from langchain.chat_models.jinachat import JinaChat
|
||||
from langchain.chat_models.konko import ChatKonko
|
||||
from langchain.chat_models.litellm import ChatLiteLLM
|
||||
@@ -53,6 +54,7 @@ __all__ = [
|
||||
"ChatAnyscale",
|
||||
"ChatLiteLLM",
|
||||
"ErnieBotChat",
|
||||
"ChatJavelinAIGateway",
|
||||
"ChatKonko",
|
||||
"QianfanChatEndpoint",
|
||||
]
|
||||
|
223
libs/langchain/langchain/chat_models/javelin_ai_gateway.py
Normal file
223
libs/langchain/langchain/chat_models/javelin_ai_gateway.py
Normal file
@@ -0,0 +1,223 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.pydantic_v1 import BaseModel, Extra
|
||||
from langchain.schema import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
)
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Ignoring type because below is valid pydantic code
|
||||
# Unexpected keyword argument "extra" for "__init_subclass__" of "object" [call-arg]
|
||||
class ChatParams(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
|
||||
"""Parameters for the `Javelin AI Gateway` LLM."""
|
||||
|
||||
temperature: float = 0.0
|
||||
stop: Optional[List[str]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
|
||||
|
||||
class ChatJavelinAIGateway(BaseChatModel):
|
||||
"""`Javelin AI Gateway` chat models API.
|
||||
|
||||
To use, you should have the ``javelin_sdk`` python package installed.
|
||||
For more information, see https://docs.getjavelin.io
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chat_models import ChatJavelinAIGateway
|
||||
|
||||
chat = ChatJavelinAIGateway(
|
||||
gateway_uri="<javelin-ai-gateway-uri>",
|
||||
route="<javelin-ai-gateway-chat-route>",
|
||||
params={
|
||||
"temperature": 0.1
|
||||
}
|
||||
)
|
||||
"""
|
||||
|
||||
route: str
|
||||
"""The route to use for the Javelin AI Gateway API."""
|
||||
|
||||
gateway_uri: Optional[str] = None
|
||||
"""The URI for the Javelin AI Gateway API."""
|
||||
|
||||
params: Optional[ChatParams] = None
|
||||
"""Parameters for the Javelin AI Gateway LLM."""
|
||||
|
||||
client: Any
|
||||
"""javelin client."""
|
||||
|
||||
javelin_api_key: Optional[str] = None
|
||||
"""The API key for the Javelin AI Gateway."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
try:
|
||||
from javelin_sdk import (
|
||||
JavelinClient,
|
||||
UnauthorizedError,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import javelin_sdk python package. "
|
||||
"Please install it with `pip install javelin_sdk`."
|
||||
)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
if self.gateway_uri:
|
||||
try:
|
||||
self.client = JavelinClient(
|
||||
base_url=self.gateway_uri, api_key=self.javelin_api_key
|
||||
)
|
||||
except UnauthorizedError as e:
|
||||
raise ValueError("Javelin: Incorrect API Key.") from e
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
params: Dict[str, Any] = {
|
||||
"gateway_uri": self.gateway_uri,
|
||||
"javelin_api_key": self.javelin_api_key,
|
||||
"route": self.route,
|
||||
**(self.params.dict() if self.params else {}),
|
||||
}
|
||||
return params
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message_dicts = [
|
||||
ChatJavelinAIGateway._convert_message_to_dict(message)
|
||||
for message in messages
|
||||
]
|
||||
data: Dict[str, Any] = {
|
||||
"messages": message_dicts,
|
||||
**(self.params.dict() if self.params else {}),
|
||||
}
|
||||
|
||||
resp = self.client.query_route(self.route, query_body=data)
|
||||
|
||||
return ChatJavelinAIGateway._create_chat_result(resp.dict())
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message_dicts = [
|
||||
ChatJavelinAIGateway._convert_message_to_dict(message)
|
||||
for message in messages
|
||||
]
|
||||
data: Dict[str, Any] = {
|
||||
"messages": message_dicts,
|
||||
**(self.params.dict() if self.params else {}),
|
||||
}
|
||||
|
||||
resp = await self.client.aquery_route(self.route, query_body=data)
|
||||
|
||||
return ChatJavelinAIGateway._create_chat_result(resp.dict())
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
return self._default_params
|
||||
|
||||
def _get_invocation_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""Get the parameters used to invoke the model FOR THE CALLBACKS."""
|
||||
return {
|
||||
**self._default_params,
|
||||
**super()._get_invocation_params(stop=stop, **kwargs),
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "javelin-ai-gateway-chat"
|
||||
|
||||
@staticmethod
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
content = _dict["content"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=content)
|
||||
elif role == "assistant":
|
||||
return AIMessage(content=content)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=content)
|
||||
else:
|
||||
return ChatMessage(content=content, role=role)
|
||||
|
||||
@staticmethod
|
||||
def _raise_functions_not_supported() -> None:
|
||||
raise ValueError(
|
||||
"Function messages are not supported by the Javelin AI Gateway. Please"
|
||||
" create a feature request at https://docs.getjavelin.io"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
raise ValueError(
|
||||
"Function messages are not supported by the Javelin AI Gateway. Please"
|
||||
" create a feature request at https://docs.getjavelin.io"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Got unknown message type: {message}")
|
||||
|
||||
if "function_call" in message.additional_kwargs:
|
||||
ChatJavelinAIGateway._raise_functions_not_supported()
|
||||
if message.additional_kwargs:
|
||||
logger.warning(
|
||||
"Additional message arguments are unsupported by Javelin AI Gateway "
|
||||
" and will be ignored: %s",
|
||||
message.additional_kwargs,
|
||||
)
|
||||
return message_dict
|
||||
|
||||
@staticmethod
|
||||
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
|
||||
generations = []
|
||||
for candidate in response["llm_response"]["choices"]:
|
||||
message = ChatJavelinAIGateway._convert_dict_to_message(
|
||||
candidate["message"]
|
||||
)
|
||||
message_metadata = candidate.get("metadata", {})
|
||||
gen = ChatGeneration(
|
||||
message=message,
|
||||
generation_info=dict(message_metadata),
|
||||
)
|
||||
generations.append(gen)
|
||||
|
||||
response_metadata = response.get("metadata", {})
|
||||
return ChatResult(generations=generations, llm_output=response_metadata)
|
@@ -40,6 +40,7 @@ from langchain.embeddings.huggingface import (
|
||||
HuggingFaceInstructEmbeddings,
|
||||
)
|
||||
from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings
|
||||
from langchain.embeddings.javelin_ai_gateway import JavelinAIGatewayEmbeddings
|
||||
from langchain.embeddings.jina import JinaEmbeddings
|
||||
from langchain.embeddings.llamacpp import LlamaCppEmbeddings
|
||||
from langchain.embeddings.localai import LocalAIEmbeddings
|
||||
@@ -107,6 +108,7 @@ __all__ = [
|
||||
"AwaEmbeddings",
|
||||
"HuggingFaceBgeEmbeddings",
|
||||
"ErnieEmbeddings",
|
||||
"JavelinAIGatewayEmbeddings",
|
||||
"OllamaEmbeddings",
|
||||
"QianfanEmbeddingsEndpoint",
|
||||
]
|
||||
|
110
libs/langchain/langchain/embeddings/javelin_ai_gateway.py
Normal file
110
libs/langchain/langchain/embeddings/javelin_ai_gateway.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Iterator, List, Optional
|
||||
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
|
||||
|
||||
def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
|
||||
for i in range(0, len(texts), size):
|
||||
yield texts[i : i + size]
|
||||
|
||||
|
||||
class JavelinAIGatewayEmbeddings(Embeddings, BaseModel):
|
||||
"""
|
||||
Wrapper around embeddings LLMs in the Javelin AI Gateway.
|
||||
|
||||
To use, you should have the ``javelin_sdk`` python package installed.
|
||||
For more information, see https://docs.getjavelin.io
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.embeddings import JavelinAIGatewayEmbeddings
|
||||
|
||||
embeddings = JavelinAIGatewayEmbeddings(
|
||||
gateway_uri="<javelin-ai-gateway-uri>",
|
||||
route="<your-javelin-gateway-embeddings-route>"
|
||||
)
|
||||
"""
|
||||
|
||||
client: Any
|
||||
"""javelin client."""
|
||||
|
||||
route: str
|
||||
"""The route to use for the Javelin AI Gateway API."""
|
||||
|
||||
gateway_uri: Optional[str] = None
|
||||
"""The URI for the Javelin AI Gateway API."""
|
||||
|
||||
javelin_api_key: Optional[str] = None
|
||||
"""The API key for the Javelin AI Gateway API."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
try:
|
||||
from javelin_sdk import (
|
||||
JavelinClient,
|
||||
UnauthorizedError,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import javelin_sdk python package. "
|
||||
"Please install it with `pip install javelin_sdk`."
|
||||
)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
if self.gateway_uri:
|
||||
try:
|
||||
self.client = JavelinClient(
|
||||
base_url=self.gateway_uri, api_key=self.javelin_api_key
|
||||
)
|
||||
except UnauthorizedError as e:
|
||||
raise ValueError("Javelin: Incorrect API Key.") from e
|
||||
|
||||
def _query(self, texts: List[str]) -> List[List[float]]:
|
||||
embeddings = []
|
||||
for txt in _chunk(texts, 20):
|
||||
try:
|
||||
resp = self.client.query_route(self.route, query_body={"input": txt})
|
||||
resp_dict = resp.dict()
|
||||
|
||||
embeddings_chunk = resp_dict.get("llm_response", {}).get("data", [])
|
||||
for item in embeddings_chunk:
|
||||
if "embedding" in item:
|
||||
embeddings.append(item["embedding"])
|
||||
except ValueError as e:
|
||||
print("Failed to query route: " + str(e))
|
||||
|
||||
return embeddings
|
||||
|
||||
async def _aquery(self, texts: List[str]) -> List[List[float]]:
|
||||
embeddings = []
|
||||
for txt in _chunk(texts, 20):
|
||||
try:
|
||||
resp = await self.client.aquery_route(
|
||||
self.route, query_body={"input": txt}
|
||||
)
|
||||
resp_dict = resp.dict()
|
||||
|
||||
embeddings_chunk = resp_dict.get("llm_response", {}).get("data", [])
|
||||
for item in embeddings_chunk:
|
||||
if "embedding" in item:
|
||||
embeddings.append(item["embedding"])
|
||||
except ValueError as e:
|
||||
print("Failed to query route: " + str(e))
|
||||
|
||||
return embeddings
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return self._query(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self._query([text])[0]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return await self._aquery(texts)
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
result = await self._aquery([text])
|
||||
return result[0]
|
@@ -54,6 +54,7 @@ from langchain.llms.huggingface_hub import HuggingFaceHub
|
||||
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||
from langchain.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference
|
||||
from langchain.llms.human import HumanInputLLM
|
||||
from langchain.llms.javelin_ai_gateway import JavelinAIGateway
|
||||
from langchain.llms.koboldai import KoboldApiLLM
|
||||
from langchain.llms.llamacpp import LlamaCpp
|
||||
from langchain.llms.manifest import ManifestWrapper
|
||||
@@ -161,6 +162,7 @@ __all__ = [
|
||||
"Writer",
|
||||
"OctoAIEndpoint",
|
||||
"Xinference",
|
||||
"JavelinAIGateway",
|
||||
"QianfanLLMEndpoint",
|
||||
]
|
||||
|
||||
@@ -230,5 +232,6 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||
"vllm_openai": VLLMOpenAI,
|
||||
"writer": Writer,
|
||||
"xinference": Xinference,
|
||||
"javelin-ai-gateway": JavelinAIGateway,
|
||||
"qianfan_endpoint": QianfanLLMEndpoint,
|
||||
}
|
||||
|
152
libs/langchain/langchain/llms/javelin_ai_gateway.py
Normal file
152
libs/langchain/langchain/llms/javelin_ai_gateway.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.pydantic_v1 import BaseModel, Extra
|
||||
|
||||
|
||||
# Ignoring type because below is valid pydantic code
|
||||
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
||||
class Params(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
|
||||
"""Parameters for the Javelin AI Gateway LLM."""
|
||||
|
||||
temperature: float = 0.0
|
||||
stop: Optional[List[str]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
|
||||
|
||||
class JavelinAIGateway(LLM):
|
||||
"""
|
||||
Wrapper around completions LLMs in the Javelin AI Gateway.
|
||||
|
||||
To use, you should have the ``javelin_sdk`` python package installed.
|
||||
For more information, see https://docs.getjavelin.io
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import JavelinAIGateway
|
||||
|
||||
completions = JavelinAIGateway(
|
||||
gateway_uri="<your-javelin-ai-gateway-uri>",
|
||||
route="<your-javelin-ai-gateway-completions-route>",
|
||||
params={
|
||||
"temperature": 0.1
|
||||
}
|
||||
)
|
||||
"""
|
||||
|
||||
route: str
|
||||
"""The route to use for the Javelin AI Gateway API."""
|
||||
|
||||
client: Optional[Any] = None
|
||||
"""The Javelin AI Gateway client."""
|
||||
|
||||
gateway_uri: Optional[str] = None
|
||||
"""The URI of the Javelin AI Gateway API."""
|
||||
|
||||
params: Optional[Params] = None
|
||||
"""Parameters for the Javelin AI Gateway API."""
|
||||
|
||||
javelin_api_key: Optional[str] = None
|
||||
"""The API key for the Javelin AI Gateway API."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
try:
|
||||
from javelin_sdk import (
|
||||
JavelinClient,
|
||||
UnauthorizedError,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import javelin_sdk python package. "
|
||||
"Please install it with `pip install javelin_sdk`."
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
if self.gateway_uri:
|
||||
try:
|
||||
self.client = JavelinClient(
|
||||
base_url=self.gateway_uri, api_key=self.javelin_api_key
|
||||
)
|
||||
except UnauthorizedError as e:
|
||||
raise ValueError("Javelin: Incorrect API Key.") from e
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Javelin AI Gateway API."""
|
||||
params: Dict[str, Any] = {
|
||||
"gateway_uri": self.gateway_uri,
|
||||
"route": self.route,
|
||||
"javelin_api_key": self.javelin_api_key,
|
||||
**(self.params.dict() if self.params else {}),
|
||||
}
|
||||
return params
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return self._default_params
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call the Javelin AI Gateway API."""
|
||||
data: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
**(self.params.dict() if self.params else {}),
|
||||
}
|
||||
if s := (stop or (self.params.stop if self.params else None)):
|
||||
data["stop"] = s
|
||||
|
||||
if self.client is not None:
|
||||
resp = self.client.query_route(self.route, query_body=data)
|
||||
else:
|
||||
raise ValueError("Javelin client is not initialized.")
|
||||
|
||||
resp_dict = resp.dict()
|
||||
|
||||
try:
|
||||
return resp_dict["llm_response"]["choices"][0]["text"]
|
||||
except KeyError:
|
||||
return ""
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call async the Javelin AI Gateway API."""
|
||||
data: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
**(self.params.dict() if self.params else {}),
|
||||
}
|
||||
if s := (stop or (self.params.stop if self.params else None)):
|
||||
data["stop"] = s
|
||||
|
||||
if self.client is not None:
|
||||
resp = await self.client.aquery_route(self.route, query_body=data)
|
||||
else:
|
||||
raise ValueError("Javelin client is not initialized.")
|
||||
|
||||
resp_dict = resp.dict()
|
||||
|
||||
try:
|
||||
return resp_dict["llm_response"]["choices"][0]["text"]
|
||||
except KeyError:
|
||||
return ""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "javelin-ai-gateway"
|
Reference in New Issue
Block a user