mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-17 07:26:16 +00:00
community[minor]: Prem AI langchain integration (#19113)
### Prem SDK integration in LangChain This PR adds the integration with [PremAI's](https://www.premai.io/) prem-sdk with langchain. User can now access to deployed models (llms/embeddings) and use it with langchain's ecosystem. This PR adds the following: ### This PR adds the following: - [x] Add chat support - [X] Adding embedding support - [X] writing integration tests - [X] writing tests for chat - [X] writing tests for embedding - [X] writing unit tests - [X] writing tests for chat - [X] writing tests for embedding - [X] Adding documentation - [X] writing documentation for chat - [X] writing documentation for embedding - [X] run `make test` - [X] run `make lint`, `make lint_diff` - [X] Final checks (spell check, lint, format and overall testing) --------- Co-authored-by: Anindyadeep Sannigrahi <anindyadeepsannigrahi@Anindyadeeps-MacBook-Pro.local> Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Erick Friis <erick@langchain.dev> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
@@ -64,6 +64,7 @@ _module_lookup = {
|
||||
"PromptLayerChatOpenAI": "langchain_community.chat_models.promptlayer_openai",
|
||||
"QianfanChatEndpoint": "langchain_community.chat_models.baidu_qianfan_endpoint",
|
||||
"VolcEngineMaasChat": "langchain_community.chat_models.volcengine_maas",
|
||||
"ChatPremAI": "langchain_community.chat_models.premai",
|
||||
}
|
||||
|
||||
|
||||
|
416
libs/community/langchain_community/chat_models/premai.py
Normal file
416
libs/community/langchain_community/chat_models/premai.py
Normal file
@@ -0,0 +1,416 @@
|
||||
"""Wrapper around Prem's Chat API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.language_models.llms import create_base_retry_decorator
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from premai.api.chat_completions.v1_chat_completions_create import (
|
||||
ChatCompletionResponseStream,
|
||||
)
|
||||
from premai.models.chat_completion_response import ChatCompletionResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatPremAPIError(Exception):
|
||||
"""Error with the `PremAI` API."""
|
||||
|
||||
|
||||
def _truncate_at_stop_tokens(
|
||||
text: str,
|
||||
stop: Optional[List[str]],
|
||||
) -> str:
|
||||
"""Truncates text at the earliest stop token found."""
|
||||
if stop is None:
|
||||
return text
|
||||
|
||||
for stop_token in stop:
|
||||
stop_token_idx = text.find(stop_token)
|
||||
if stop_token_idx != -1:
|
||||
text = text[:stop_token_idx]
|
||||
return text
|
||||
|
||||
|
||||
def _response_to_result(
|
||||
response: ChatCompletionResponse,
|
||||
stop: Optional[List[str]],
|
||||
) -> ChatResult:
|
||||
"""Converts a Prem API response into a LangChain result"""
|
||||
|
||||
if not response.choices:
|
||||
raise ChatPremAPIError("ChatResponse must have at least one candidate")
|
||||
generations: List[ChatGeneration] = []
|
||||
for choice in response.choices:
|
||||
role = choice.message.role
|
||||
if role is None:
|
||||
raise ChatPremAPIError(f"ChatResponse {choice} must have a role.")
|
||||
|
||||
# If content is None then it will be replaced by ""
|
||||
content = _truncate_at_stop_tokens(text=choice.message.content or "", stop=stop)
|
||||
if content is None:
|
||||
raise ChatPremAPIError(f"ChatResponse must have a content: {content}")
|
||||
|
||||
if role == "assistant":
|
||||
generations.append(
|
||||
ChatGeneration(text=content, message=AIMessage(content=content))
|
||||
)
|
||||
elif role == "user":
|
||||
generations.append(
|
||||
ChatGeneration(text=content, message=HumanMessage(content=content))
|
||||
)
|
||||
else:
|
||||
generations.append(
|
||||
ChatGeneration(
|
||||
text=content, message=ChatMessage(role=role, content=content)
|
||||
)
|
||||
)
|
||||
return ChatResult(generations=generations)
|
||||
|
||||
|
||||
def _convert_delta_response_to_message_chunk(
|
||||
response: ChatCompletionResponseStream, default_class: Type[BaseMessageChunk]
|
||||
) -> Tuple[
|
||||
Union[BaseMessageChunk, HumanMessageChunk, AIMessageChunk, SystemMessageChunk],
|
||||
Optional[str],
|
||||
]:
|
||||
"""Converts delta response to message chunk"""
|
||||
_delta = response.choices[0].delta # type: ignore
|
||||
role = _delta.get("role", "") # type: ignore
|
||||
content = _delta.get("content", "") # type: ignore
|
||||
additional_kwargs: Dict = {}
|
||||
|
||||
if role is None or role == "":
|
||||
raise ChatPremAPIError("Role can not be None. Please check the response")
|
||||
|
||||
finish_reasons: Optional[str] = response.choices[0].finish_reason
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content), finish_reasons
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return (
|
||||
AIMessageChunk(content=content, additional_kwargs=additional_kwargs),
|
||||
finish_reasons,
|
||||
)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content), finish_reasons
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role), finish_reasons
|
||||
else:
|
||||
return default_class(content=content), finish_reasons
|
||||
|
||||
|
||||
def _messages_to_prompt_dict(
|
||||
input_messages: List[BaseMessage],
|
||||
) -> Tuple[Optional[str], List[Dict[str, str]]]:
|
||||
"""Converts a list of LangChain Messages into a simple dict
|
||||
which is the message structure in Prem"""
|
||||
|
||||
system_prompt: Optional[str] = None
|
||||
examples_and_messages: List[Dict[str, str]] = []
|
||||
|
||||
for input_msg in input_messages:
|
||||
if isinstance(input_msg, SystemMessage):
|
||||
system_prompt = str(input_msg.content)
|
||||
elif isinstance(input_msg, HumanMessage):
|
||||
examples_and_messages.append(
|
||||
{"role": "user", "content": str(input_msg.content)}
|
||||
)
|
||||
elif isinstance(input_msg, AIMessage):
|
||||
examples_and_messages.append(
|
||||
{"role": "assistant", "content": str(input_msg.content)}
|
||||
)
|
||||
else:
|
||||
raise ChatPremAPIError("No such role explicitly exists")
|
||||
return system_prompt, examples_and_messages
|
||||
|
||||
|
||||
class ChatPremAI(BaseChatModel, BaseModel):
|
||||
"""Use any LLM provider with Prem and Langchain.
|
||||
|
||||
To use, you will need to have an API key. You can find your existing API Key
|
||||
or generate a new one here: https://app.premai.io/api_keys/
|
||||
"""
|
||||
|
||||
# TODO: Need to add the default parameters through prem-sdk here
|
||||
|
||||
project_id: int
|
||||
"""The project ID in which the experiments or deployments are carried out.
|
||||
You can find all your projects here: https://app.premai.io/projects/"""
|
||||
premai_api_key: Optional[SecretStr] = None
|
||||
"""Prem AI API Key. Get it here: https://app.premai.io/api_keys/"""
|
||||
|
||||
model: Optional[str] = None
|
||||
"""Name of the model. This is an optional parameter.
|
||||
The default model is the one deployed from Prem's LaunchPad: https://app.premai.io/projects/8/launchpad
|
||||
If model name is other than default model then it will override the calls
|
||||
from the model deployed from launchpad."""
|
||||
|
||||
session_id: Optional[str] = None
|
||||
"""The ID of the session to use. It helps to track the chat history."""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
"""Model temperature. Value should be >= 0 and <= 1.0"""
|
||||
|
||||
top_p: Optional[float] = None
|
||||
"""top_p adjusts the number of choices for each predicted tokens based on
|
||||
cumulative probabilities. Value should be ranging between 0.0 and 1.0.
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
"""The maximum number of tokens to generate"""
|
||||
|
||||
max_retries: int = 1
|
||||
"""Max number of retries to call the API"""
|
||||
|
||||
system_prompt: Optional[str] = ""
|
||||
"""Acts like a default instruction that helps the LLM act or generate
|
||||
in a specific way.This is an Optional Parameter. By default the
|
||||
system prompt would be using Prem's Launchpad models system prompt.
|
||||
Changing the system prompt would override the default system prompt.
|
||||
"""
|
||||
|
||||
streaming: Optional[bool] = False
|
||||
"""Whether to stream the responses or not."""
|
||||
|
||||
tools: Optional[Dict[str, Any]] = None
|
||||
"""A list of tools the model may call. Currently, only functions are
|
||||
supported as a tool"""
|
||||
|
||||
frequency_penalty: Optional[float] = None
|
||||
"""Number between -2.0 and 2.0. Positive values penalize new tokens based"""
|
||||
|
||||
presence_penalty: Optional[float] = None
|
||||
"""Number between -2.0 and 2.0. Positive values penalize new tokens based
|
||||
on whether they appear in the text so far."""
|
||||
|
||||
logit_bias: Optional[dict] = None
|
||||
"""JSON object that maps tokens to an associated bias value from -100 to 100."""
|
||||
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
"""Up to 4 sequences where the API will stop generating further tokens."""
|
||||
|
||||
seed: Optional[int] = None
|
||||
"""This feature is in Beta. If specified, our system will make a best effort
|
||||
to sample deterministically."""
|
||||
|
||||
client: Any
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environments(cls, values: Dict) -> Dict:
|
||||
"""Validate that the package is installed and that the API token is valid"""
|
||||
try:
|
||||
from premai import Prem
|
||||
except ImportError as error:
|
||||
raise ImportError(
|
||||
"Could not import Prem Python package."
|
||||
"Please install it with: `pip install premai`"
|
||||
) from error
|
||||
|
||||
try:
|
||||
premai_api_key = get_from_dict_or_env(
|
||||
values, "premai_api_key", "PREMAI_API_KEY"
|
||||
)
|
||||
values["client"] = Prem(api_key=premai_api_key)
|
||||
except Exception as error:
|
||||
raise ValueError("Your API Key is incorrect. Please try again.") from error
|
||||
return values
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "premai"
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
# FIXME: n and stop is not supported, so hardcoding to current default value
|
||||
return {
|
||||
"model": self.model,
|
||||
"system_prompt": self.system_prompt,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
"logit_bias": self.logit_bias,
|
||||
"max_tokens": self.max_tokens,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"seed": self.seed,
|
||||
"stop": None,
|
||||
}
|
||||
|
||||
def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
all_kwargs = {**self._default_params, **kwargs}
|
||||
for key in list(self._default_params.keys()):
|
||||
if all_kwargs.get(key) is None or all_kwargs.get(key) == "":
|
||||
all_kwargs.pop(key, None)
|
||||
return all_kwargs
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages) # type: ignore
|
||||
|
||||
kwargs["stop"] = stop
|
||||
if system_prompt is not None and system_prompt != "":
|
||||
kwargs["system_prompt"] = system_prompt
|
||||
|
||||
all_kwargs = self._get_all_kwargs(**kwargs)
|
||||
response = chat_with_retry(
|
||||
self,
|
||||
project_id=self.project_id,
|
||||
messages=messages_to_pass,
|
||||
stream=False,
|
||||
run_manager=run_manager,
|
||||
**all_kwargs,
|
||||
)
|
||||
|
||||
return _response_to_result(response=response, stop=stop)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages)
|
||||
kwargs["stop"] = stop
|
||||
|
||||
if "system_prompt" not in kwargs:
|
||||
if system_prompt is not None and system_prompt != "":
|
||||
kwargs["system_prompt"] = system_prompt
|
||||
|
||||
all_kwargs = self._get_all_kwargs(**kwargs)
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
|
||||
for streamed_response in chat_with_retry(
|
||||
self,
|
||||
project_id=self.project_id,
|
||||
messages=messages_to_pass,
|
||||
stream=True,
|
||||
run_manager=run_manager,
|
||||
**all_kwargs,
|
||||
):
|
||||
try:
|
||||
chunk, finish_reason = _convert_delta_response_to_message_chunk(
|
||||
response=streamed_response, default_class=default_chunk_class
|
||||
)
|
||||
generation_info = (
|
||||
dict(finish_reason=finish_reason)
|
||||
if finish_reason is not None
|
||||
else None
|
||||
)
|
||||
cg_chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(cg_chunk.text, chunk=cg_chunk)
|
||||
yield cg_chunk
|
||||
except Exception as _:
|
||||
continue
|
||||
|
||||
|
||||
def create_prem_retry_decorator(
|
||||
llm: ChatPremAI,
|
||||
*,
|
||||
max_retries: int = 1,
|
||||
run_manager: Optional[Union[CallbackManagerForLLMRun]] = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
import premai.models
|
||||
|
||||
errors = [
|
||||
premai.models.api_response_validation_error.APIResponseValidationError,
|
||||
premai.models.conflict_error.ConflictError,
|
||||
premai.models.model_not_found_error.ModelNotFoundError,
|
||||
premai.models.permission_denied_error.PermissionDeniedError,
|
||||
premai.models.provider_api_connection_error.ProviderAPIConnectionError,
|
||||
premai.models.provider_api_status_error.ProviderAPIStatusError,
|
||||
premai.models.provider_api_timeout_error.ProviderAPITimeoutError,
|
||||
premai.models.provider_internal_server_error.ProviderInternalServerError,
|
||||
premai.models.provider_not_found_error.ProviderNotFoundError,
|
||||
premai.models.rate_limit_error.RateLimitError,
|
||||
premai.models.unprocessable_entity_error.UnprocessableEntityError,
|
||||
premai.models.validation_error.ValidationError,
|
||||
]
|
||||
|
||||
decorator = create_base_retry_decorator(
|
||||
error_types=errors, max_retries=max_retries, run_manager=run_manager
|
||||
)
|
||||
return decorator
|
||||
|
||||
|
||||
def chat_with_retry(
|
||||
llm: ChatPremAI,
|
||||
project_id: int,
|
||||
messages: List[dict],
|
||||
stream: bool = False,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Using tenacity for retry in completion call"""
|
||||
retry_decorator = create_prem_retry_decorator(
|
||||
llm, max_retries=llm.max_retries, run_manager=run_manager
|
||||
)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(
|
||||
project_id: int,
|
||||
messages: List[dict],
|
||||
stream: Optional[bool] = False,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
response = llm.client.chat.completions.create(
|
||||
project_id=project_id,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
**kwargs,
|
||||
)
|
||||
return response
|
||||
|
||||
return _completion_with_retry(
|
||||
project_id=project_id,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
**kwargs,
|
||||
)
|
@@ -80,6 +80,7 @@ _module_lookup = {
|
||||
"VolcanoEmbeddings": "langchain_community.embeddings.volcengine",
|
||||
"VoyageEmbeddings": "langchain_community.embeddings.voyageai",
|
||||
"XinferenceEmbeddings": "langchain_community.embeddings.xinference",
|
||||
"PremAIEmbeddings": "langchain_community.embeddings.premai",
|
||||
"YandexGPTEmbeddings": "langchain_community.embeddings.yandex",
|
||||
}
|
||||
|
||||
|
121
libs/community/langchain_community/embeddings/premai.py
Normal file
121
libs/community/langchain_community/embeddings/premai.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models.llms import create_base_retry_decorator
|
||||
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PremAIEmbeddings(BaseModel, Embeddings):
|
||||
"""Prem's Embedding APIs"""
|
||||
|
||||
project_id: int
|
||||
"""The project ID in which the experiments or deployments are carried out.
|
||||
You can find all your projects here: https://app.premai.io/projects/"""
|
||||
|
||||
premai_api_key: Optional[SecretStr] = None
|
||||
"""Prem AI API Key. Get it here: https://app.premai.io/api_keys/"""
|
||||
|
||||
model: str
|
||||
"""The Embedding model to choose from"""
|
||||
|
||||
show_progress_bar: bool = False
|
||||
"""Whether to show a tqdm progress bar. Must have `tqdm` installed."""
|
||||
|
||||
max_retries: int = 1
|
||||
"""Max number of retries for tenacity"""
|
||||
|
||||
client: Any
|
||||
|
||||
@root_validator()
|
||||
def validate_environments(cls, values: Dict) -> Dict:
|
||||
"""Validate that the package is installed and that the API token is valid"""
|
||||
try:
|
||||
from premai import Prem
|
||||
except ImportError as error:
|
||||
raise ImportError(
|
||||
"Could not import Prem Python package."
|
||||
"Please install it with: `pip install premai`"
|
||||
) from error
|
||||
|
||||
try:
|
||||
premai_api_key = get_from_dict_or_env(
|
||||
values, "premai_api_key", "PREMAI_API_KEY"
|
||||
)
|
||||
values["client"] = Prem(api_key=premai_api_key)
|
||||
except Exception as error:
|
||||
raise ValueError("Your API Key is incorrect. Please try again.") from error
|
||||
return values
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text"""
|
||||
embeddings = embed_with_retry(
|
||||
self, model=self.model, project_id=self.project_id, input=text
|
||||
)
|
||||
return embeddings.data[0].embedding
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
embeddings = embed_with_retry(
|
||||
self, model=self.model, project_id=self.project_id, input=texts
|
||||
).data
|
||||
|
||||
return [embedding.embedding for embedding in embeddings]
|
||||
|
||||
|
||||
def create_prem_retry_decorator(
|
||||
embedder: PremAIEmbeddings,
|
||||
*,
|
||||
max_retries: int = 1,
|
||||
) -> Callable[[Any], Any]:
|
||||
import premai.models
|
||||
|
||||
errors = [
|
||||
premai.models.api_response_validation_error.APIResponseValidationError,
|
||||
premai.models.conflict_error.ConflictError,
|
||||
premai.models.model_not_found_error.ModelNotFoundError,
|
||||
premai.models.permission_denied_error.PermissionDeniedError,
|
||||
premai.models.provider_api_connection_error.ProviderAPIConnectionError,
|
||||
premai.models.provider_api_status_error.ProviderAPIStatusError,
|
||||
premai.models.provider_api_timeout_error.ProviderAPITimeoutError,
|
||||
premai.models.provider_internal_server_error.ProviderInternalServerError,
|
||||
premai.models.provider_not_found_error.ProviderNotFoundError,
|
||||
premai.models.rate_limit_error.RateLimitError,
|
||||
premai.models.unprocessable_entity_error.UnprocessableEntityError,
|
||||
premai.models.validation_error.ValidationError,
|
||||
]
|
||||
|
||||
decorator = create_base_retry_decorator(
|
||||
error_types=errors, max_retries=max_retries, run_manager=None
|
||||
)
|
||||
return decorator
|
||||
|
||||
|
||||
def embed_with_retry(
|
||||
embedder: PremAIEmbeddings,
|
||||
model: str,
|
||||
project_id: int,
|
||||
input: Union[str, List[str]],
|
||||
) -> Any:
|
||||
"""Using tenacity for retry in embedding calls"""
|
||||
retry_decorator = create_prem_retry_decorator(
|
||||
embedder, max_retries=embedder.max_retries
|
||||
)
|
||||
|
||||
@retry_decorator
|
||||
def _embed_with_retry(
|
||||
embedder: PremAIEmbeddings,
|
||||
project_id: int,
|
||||
model: str,
|
||||
input: Union[str, List[str]],
|
||||
) -> Any:
|
||||
embedding_response = embedder.client.embeddings.create(
|
||||
project_id=project_id, model=model, input=input
|
||||
)
|
||||
return embedding_response
|
||||
|
||||
return _embed_with_retry(embedder, project_id=project_id, model=model, input=input)
|
101
libs/community/poetry.lock
generated
101
libs/community/poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aenum"
|
||||
@@ -1167,10 +1167,7 @@ files = [
|
||||
[package.dependencies]
|
||||
jmespath = ">=0.7.1,<2.0.0"
|
||||
python-dateutil = ">=2.1,<3.0.0"
|
||||
urllib3 = [
|
||||
{version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""},
|
||||
{version = ">=1.25.4,<2.1", markers = "python_version >= \"3.10\""},
|
||||
]
|
||||
urllib3 = {version = ">=1.25.4,<2.1", markers = "python_version >= \"3.10\""}
|
||||
|
||||
[package.extras]
|
||||
crt = ["awscrt (==0.19.19)"]
|
||||
@@ -2689,6 +2686,37 @@ uvicorn = ">=0.23.2,<0.24.0"
|
||||
[package.extras]
|
||||
mllib = ["accelerate (==0.21.0)", "datasets (==2.16.0)", "einops (>=0.6.1,<0.7.0)", "h5py (>=3.9.0,<4.0.0)", "peft (==0.6.0)", "transformers (==4.36.2)"]
|
||||
|
||||
[[package]]
|
||||
name = "friendli-client"
|
||||
version = "1.3.1"
|
||||
description = "Client of Friendli Suite."
|
||||
optional = true
|
||||
python-versions = "<4.0.0,>=3.8.1"
|
||||
files = [
|
||||
{file = "friendli_client-1.3.1-py3-none-any.whl", hash = "sha256:1a77b046c57b0d70bac8d13ac6ecc861f8fc84d3c63e39b34543f862373a670b"},
|
||||
{file = "friendli_client-1.3.1.tar.gz", hash = "sha256:85f87976f7bb75eb424f384e3e73ac3256b7aad477361b51341e520c2aed3a0e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
fastapi = ">=0.104.0,<0.105.0"
|
||||
gql = ">=3.4.1,<4.0.0"
|
||||
httpx = ">=0.24.1,<0.25.0"
|
||||
injector = ">=0.21.0,<0.22.0"
|
||||
jsonschema = ">=4.17.3,<5.0.0"
|
||||
pathspec = ">=0.9.0,<0.10.0"
|
||||
protobuf = ">=4.24.2,<5.0.0"
|
||||
pydantic = {version = ">=1.9.0,<3", extras = ["email"]}
|
||||
PyYaml = ">=6.0.1,<7.0.0"
|
||||
requests = ">=2,<3"
|
||||
rich = ">=12.2.0,<13.0.0"
|
||||
tqdm = ">=4.48.0,<5.0.0"
|
||||
typer = ">=0.9.0,<0.10.0"
|
||||
types-protobuf = ">=4.24.0.1,<5.0.0.0"
|
||||
uvicorn = ">=0.23.2,<0.24.0"
|
||||
|
||||
[package.extras]
|
||||
mllib = ["accelerate (==0.21.0)", "datasets (==2.16.0)", "einops (>=0.6.1,<0.7.0)", "h5py (>=3.9.0,<4.0.0)", "peft (==0.6.0)", "transformers (==4.36.2)"]
|
||||
|
||||
[[package]]
|
||||
name = "frozenlist"
|
||||
version = "1.4.1"
|
||||
@@ -5965,6 +5993,23 @@ dev = ["packaging", "prawcore[lint]", "prawcore[test]"]
|
||||
lint = ["pre-commit", "ruff (>=0.0.291)"]
|
||||
test = ["betamax (>=0.8,<0.9)", "pytest (>=2.7.3)", "urllib3 (==1.26.*)"]
|
||||
|
||||
[[package]]
|
||||
name = "premai"
|
||||
version = "0.3.25"
|
||||
description = "A client library for accessing Prem APIs"
|
||||
optional = true
|
||||
python-versions = ">=3.8,<4.0"
|
||||
files = [
|
||||
{file = "premai-0.3.25-py3-none-any.whl", hash = "sha256:bddace7340e1827f048b410748d365e8663e4bbeb6bf7e8b8657f3cc267f7f28"},
|
||||
{file = "premai-0.3.25.tar.gz", hash = "sha256:c387980ecf3bdcb07886dd4f7a1c0f0701df67e772e62f444394cea97d5970a0"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
attrs = ">=21.3.0"
|
||||
httpx = ">=0.20.0,<0.27.0"
|
||||
python-dateutil = ">=2.8.0,<3.0.0"
|
||||
typing_extensions = ">=4.9.0"
|
||||
|
||||
[[package]]
|
||||
name = "prometheus-client"
|
||||
version = "0.20.0"
|
||||
@@ -9071,20 +9116,6 @@ files = [
|
||||
cryptography = ">=35.0.0"
|
||||
types-pyOpenSSL = "*"
|
||||
|
||||
[[package]]
|
||||
name = "types-requests"
|
||||
version = "2.31.0.6"
|
||||
description = "Typing stubs for requests"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "types-requests-2.31.0.6.tar.gz", hash = "sha256:cd74ce3b53c461f1228a9b783929ac73a666658f223e28ed29753771477b3bd0"},
|
||||
{file = "types_requests-2.31.0.6-py3-none-any.whl", hash = "sha256:a2db9cb228a81da8348b49ad6db3f5519452dd20a9c1e1a868c83c5fe88fd1a9"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
types-urllib3 = "*"
|
||||
|
||||
[[package]]
|
||||
name = "types-requests"
|
||||
version = "2.31.0.20240311"
|
||||
@@ -9121,17 +9152,6 @@ files = [
|
||||
{file = "types_toml-0.10.8.20240310-py3-none-any.whl", hash = "sha256:627b47775d25fa29977d9c70dc0cbab3f314f32c8d8d0c012f2ef5de7aaec05d"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "types-urllib3"
|
||||
version = "1.26.25.14"
|
||||
description = "Typing stubs for urllib3"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "types-urllib3-1.26.25.14.tar.gz", hash = "sha256:229b7f577c951b8c1b92c1bc2b2fdb0b49847bd2af6d1cc2a2e3dd340f3bda8f"},
|
||||
{file = "types_urllib3-1.26.25.14-py3-none-any.whl", hash = "sha256:9683bbb7fb72e32bfe9d2be6e04875fbe1b3eeec3cbb4ea231435aa7fd6b4f0e"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typing"
|
||||
version = "3.7.4.3"
|
||||
@@ -9228,22 +9248,6 @@ files = [
|
||||
[package.extras]
|
||||
dev = ["flake8", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake8-commas", "flake8-comprehensions", "flake8-continuation", "flake8-datetimez", "flake8-docstrings", "flake8-import-order", "flake8-literal", "flake8-modern-annotations", "flake8-noqa", "flake8-pyproject", "flake8-requirements", "flake8-typechecking-import", "flake8-use-fstring", "mypy", "pep8-naming", "types-PyYAML"]
|
||||
|
||||
[[package]]
|
||||
name = "urllib3"
|
||||
version = "1.26.18"
|
||||
description = "HTTP library with thread-safe connection pooling, file post, and more."
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
|
||||
files = [
|
||||
{file = "urllib3-1.26.18-py2.py3-none-any.whl", hash = "sha256:34b97092d7e0a3a8cf7cd10e386f401b3737364026c45e622aa02903dffe0f07"},
|
||||
{file = "urllib3-1.26.18.tar.gz", hash = "sha256:f8ecc1bba5667413457c529ab955bf8c67b45db799d159066261719e328580a0"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"]
|
||||
secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"]
|
||||
socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "urllib3"
|
||||
version = "2.0.7"
|
||||
@@ -9319,7 +9323,6 @@ files = [
|
||||
|
||||
[package.dependencies]
|
||||
PyYAML = "*"
|
||||
urllib3 = {version = "<2", markers = "platform_python_implementation == \"PyPy\" or python_version < \"3.10\""}
|
||||
wrapt = "*"
|
||||
yarl = "*"
|
||||
|
||||
@@ -9873,9 +9876,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
||||
|
||||
[extras]
|
||||
cli = ["typer"]
|
||||
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "friendli-client", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "tidb-vector", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "xata", "xmltodict", "zhipuai"]
|
||||
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "friendli-client", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "premai", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "tidb-vector", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "xata", "xmltodict", "zhipuai"]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "5b2a17ed079fa4cc1776f0474a9e73a428c10bf30a22b7185d2f7a77b2d146e5"
|
||||
content-hash = "dcaae2110a70843fa3cb375618bebbe16b3da9bfdbc1e471e57f144d0906f58b"
|
||||
|
@@ -97,6 +97,7 @@ rdflib = {version = "7.0.0", optional = true}
|
||||
nvidia-riva-client = {version = "^2.14.0", optional = true}
|
||||
tidb-vector = {version = ">=0.0.3,<1.0.0", optional = true}
|
||||
friendli-client = {version = "^1.2.4", optional = true}
|
||||
premai = {version = "^0.3.25", optional = true}
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
@@ -267,7 +268,8 @@ extended_testing = [
|
||||
"rdflib",
|
||||
"tidb-vector",
|
||||
"cloudpickle",
|
||||
"friendli-client"
|
||||
"friendli-client",
|
||||
"premai"
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
|
@@ -0,0 +1,70 @@
|
||||
"""Test ChatPremAI from PremAI API wrapper.
|
||||
|
||||
Note: This test must be run with the PREMAI_API_KEY environment variable set to a valid
|
||||
API key and a valid project_id.
|
||||
For this we need to have a project setup in PremAI's platform: https://app.premai.io
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
|
||||
from langchain_community.chat_models import ChatPremAI
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chat() -> ChatPremAI:
|
||||
return ChatPremAI(project_id=8)
|
||||
|
||||
|
||||
def test_chat_premai() -> None:
|
||||
"""Test ChatPremAI wrapper."""
|
||||
chat = ChatPremAI(project_id=8)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat([message])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_chat_prem_system_message() -> None:
|
||||
"""Test ChatPremAI wrapper for system message"""
|
||||
chat = ChatPremAI(project_id=8)
|
||||
system_message = SystemMessage(content="You are to chat with the user.")
|
||||
human_message = HumanMessage(content="Hello")
|
||||
response = chat([system_message, human_message])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_chat_prem_model() -> None:
|
||||
"""Test ChatPremAI wrapper handles model_name."""
|
||||
chat = ChatPremAI(model="foo", project_id=8)
|
||||
assert chat.model == "foo"
|
||||
|
||||
|
||||
def test_chat_prem_generate() -> None:
|
||||
"""Test ChatPremAI wrapper with generate."""
|
||||
chat = ChatPremAI(project_id=8)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat.generate([[message], [message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
for generations in response.generations:
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
|
||||
|
||||
async def test_prem_invoke(chat: ChatPremAI) -> None:
|
||||
"""Tests chat completion with invoke"""
|
||||
result = chat.invoke("How is the weather in New York today?")
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
|
||||
def test_prem_streaming() -> None:
|
||||
"""Test streaming tokens from Prem."""
|
||||
chat = ChatPremAI(project_id=8, streaming=True)
|
||||
|
||||
for token in chat.stream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
@@ -0,0 +1,40 @@
|
||||
"""Test PremAIEmbeddings from PremAI API wrapper.
|
||||
|
||||
Note: This test must be run with the PREMAI_API_KEY environment variable set to a valid
|
||||
API key and a valid project_id. This needs to setup a project in PremAI's platform.
|
||||
You can check it out here: https://app.premai.io
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_community.embeddings.premai import PremAIEmbeddings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embedder() -> PremAIEmbeddings:
|
||||
return PremAIEmbeddings(project_id=8, model="text-embedding-3-small")
|
||||
|
||||
|
||||
def test_prem_embedding_documents(embedder: PremAIEmbeddings) -> None:
|
||||
"""Test Prem embeddings."""
|
||||
documents = ["foo bar"]
|
||||
output = embedder.embed_documents(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 1536
|
||||
|
||||
|
||||
def test_prem_embedding_documents_multiple(embedder: PremAIEmbeddings) -> None:
|
||||
"""Test prem embeddings for multiple queries or documents."""
|
||||
documents = ["foo bar", "bar foo", "foo"]
|
||||
output = embedder.embed_documents(documents)
|
||||
assert len(output) == 3
|
||||
assert len(output[0]) == 1536
|
||||
assert len(output[1]) == 1536
|
||||
assert len(output[2]) == 1536
|
||||
|
||||
|
||||
def test_prem_embedding_query(embedder: PremAIEmbeddings) -> None:
|
||||
"""Test Prem embeddings for single query"""
|
||||
document = "foo bar"
|
||||
output = embedder.embed_query(document)
|
||||
assert len(output) == 1536
|
@@ -44,6 +44,7 @@ EXPECTED_ALL = [
|
||||
"ChatPerplexity",
|
||||
"ChatKinetica",
|
||||
"ChatFriendli",
|
||||
"ChatPremAI",
|
||||
]
|
||||
|
||||
|
||||
|
47
libs/community/tests/unit_tests/chat_models/test_premai.py
Normal file
47
libs/community/tests/unit_tests/chat_models/test_premai.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Test PremChat model"""
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pytest import CaptureFixture
|
||||
|
||||
from langchain_community.chat_models import ChatPremAI
|
||||
from langchain_community.chat_models.premai import _messages_to_prompt_dict
|
||||
|
||||
|
||||
@pytest.mark.requires("premai")
|
||||
def test_api_key_is_string() -> None:
|
||||
llm = ChatPremAI(premai_api_key="secret-api-key", project_id=8)
|
||||
assert isinstance(llm.premai_api_key, SecretStr)
|
||||
|
||||
|
||||
@pytest.mark.requires("premai")
|
||||
def test_api_key_masked_when_passed_via_constructor(
|
||||
capsys: CaptureFixture,
|
||||
) -> None:
|
||||
llm = ChatPremAI(premai_api_key="secret-api-key", project_id=8)
|
||||
print(llm.premai_api_key, end="") # noqa: T201
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
||||
|
||||
|
||||
def test_messages_to_prompt_dict_with_valid_messages() -> None:
|
||||
system_message, result = _messages_to_prompt_dict(
|
||||
[
|
||||
SystemMessage(content="System Prompt"),
|
||||
HumanMessage(content="User message #1"),
|
||||
AIMessage(content="AI message #1"),
|
||||
HumanMessage(content="User message #2"),
|
||||
AIMessage(content="AI message #2"),
|
||||
]
|
||||
)
|
||||
expected = [
|
||||
{"role": "user", "content": "User message #1"},
|
||||
{"role": "assistant", "content": "AI message #1"},
|
||||
{"role": "user", "content": "User message #2"},
|
||||
{"role": "assistant", "content": "AI message #2"},
|
||||
]
|
||||
|
||||
assert system_message == "System Prompt"
|
||||
assert result == expected
|
@@ -65,6 +65,7 @@ EXPECTED_ALL = [
|
||||
"QuantizedBiEncoderEmbeddings",
|
||||
"NeMoEmbeddings",
|
||||
"SparkLLMTextEmbeddings",
|
||||
"PremAIEmbeddings",
|
||||
"YandexGPTEmbeddings",
|
||||
]
|
||||
|
||||
|
28
libs/community/tests/unit_tests/embeddings/test_premai.py
Normal file
28
libs/community/tests/unit_tests/embeddings/test_premai.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Test EmbaasEmbeddings embeddings"""
|
||||
|
||||
import pytest
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pytest import CaptureFixture
|
||||
|
||||
from langchain_community.embeddings import PremAIEmbeddings
|
||||
|
||||
|
||||
@pytest.mark.requires("premai")
|
||||
def test_api_key_is_string() -> None:
|
||||
llm = PremAIEmbeddings(
|
||||
premai_api_key="secret-api-key", project_id=8, model="fake-model"
|
||||
)
|
||||
assert isinstance(llm.premai_api_key, SecretStr)
|
||||
|
||||
|
||||
@pytest.mark.requires("premai")
|
||||
def test_api_key_masked_when_passed_via_constructor(
|
||||
capsys: CaptureFixture,
|
||||
) -> None:
|
||||
llm = PremAIEmbeddings(
|
||||
premai_api_key="secret-api-key", project_id=8, model="fake-model"
|
||||
)
|
||||
print(llm.premai_api_key, end="") # noqa: T201
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
Reference in New Issue
Block a user