mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-23 05:09:12 +00:00
Compare commits
18 Commits
sr/agent-i
...
erick/comm
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
058239cdd2 | ||
|
|
c732db1173 | ||
|
|
c935b9e025 | ||
|
|
a7a1690845 | ||
|
|
5964630e7b | ||
|
|
0903f1d3e0 | ||
|
|
bd9d6d314a | ||
|
|
f62ac4e98c | ||
|
|
ace0a27473 | ||
|
|
f451e2dfc2 | ||
|
|
126cb87b30 | ||
|
|
9979a36392 | ||
|
|
1484c356e9 | ||
|
|
3c9e133c70 | ||
|
|
35e3d3fa39 | ||
|
|
ff2f6f398d | ||
|
|
0ede50c344 | ||
|
|
983c2a2f88 |
@@ -41,7 +41,7 @@ lint lint_diff lint_package lint_tests:
|
||||
poetry run ruff .
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES)
|
||||
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
|
||||
format format_diff:
|
||||
poetry run ruff format $(PYTHON_FILES)
|
||||
|
||||
@@ -4,9 +4,17 @@ from __future__ import annotations
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union
|
||||
|
||||
from langchain_core.messages import AIMessage, SystemMessage
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.prompts import (
|
||||
BasePromptTemplate,
|
||||
PromptTemplate,
|
||||
)
|
||||
from langchain_core.prompts.chat import (
|
||||
BaseMessagePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
@@ -132,6 +140,7 @@ def create_sql_agent(
|
||||
toolkit = toolkit or SQLDatabaseToolkit(llm=llm, db=db)
|
||||
agent_type = agent_type or AgentType.ZERO_SHOT_REACT_DESCRIPTION
|
||||
tools = toolkit.get_tools() + list(extra_tools)
|
||||
prefix = prefix or ""
|
||||
if prompt is None:
|
||||
prefix = prefix or SQL_PREFIX
|
||||
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
|
||||
@@ -152,6 +161,8 @@ def create_sql_agent(
|
||||
tool for tool in tools if not isinstance(tool, ListSQLDatabaseTool)
|
||||
]
|
||||
|
||||
messages: List[Union[BaseMessage, BaseMessagePromptTemplate]] = []
|
||||
|
||||
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
|
||||
if prompt is None:
|
||||
from langchain.agents.mrkl import prompt as react_prompt
|
||||
|
||||
@@ -931,7 +931,7 @@ class CassandraCache(BaseCache):
|
||||
self.table_name = table_name
|
||||
self.ttl_seconds = ttl_seconds
|
||||
|
||||
self.kv_cache = ElasticCassandraTable(
|
||||
self.kv_cache: Any = ElasticCassandraTable(
|
||||
session=self.session,
|
||||
keyspace=self.keyspace,
|
||||
table=self.table_name,
|
||||
@@ -1067,7 +1067,7 @@ class CassandraSemanticCache(BaseCache):
|
||||
self._get_embedding = _cache_embedding
|
||||
self.embedding_dimension = self._get_embedding_dimension()
|
||||
|
||||
self.table = MetadataVectorCassandraTable(
|
||||
self.table: Any = MetadataVectorCassandraTable(
|
||||
session=self.session,
|
||||
keyspace=self.keyspace,
|
||||
table=self.table_name,
|
||||
|
||||
@@ -117,13 +117,13 @@ class ChatAnyscale(ChatOpenAI):
|
||||
"ANYSCALE_API_KEY",
|
||||
)
|
||||
)
|
||||
values["openai_api_base"] = get_from_dict_or_env(
|
||||
values["openai_api_base"] = get_from_dict_or_env( # type: ignore
|
||||
values,
|
||||
"anyscale_api_base",
|
||||
"ANYSCALE_API_BASE",
|
||||
default=DEFAULT_API_BASE,
|
||||
)
|
||||
values["openai_proxy"] = get_from_dict_or_env(
|
||||
values["openai_proxy"] = get_from_dict_or_env( # type: ignore
|
||||
values,
|
||||
"anyscale_proxy",
|
||||
"ANYSCALE_PROXY",
|
||||
@@ -141,7 +141,7 @@ class ChatAnyscale(ChatOpenAI):
|
||||
if is_openai_v1():
|
||||
client_params = {
|
||||
"api_key": values["openai_api_key"],
|
||||
"base_url": values["openai_api_base"],
|
||||
"base_url": values["openai_api_base"], # type: ignore
|
||||
# To do: future support
|
||||
# "organization": values["openai_organization"],
|
||||
# "timeout": values["request_timeout"],
|
||||
@@ -152,7 +152,7 @@ class ChatAnyscale(ChatOpenAI):
|
||||
}
|
||||
values["client"] = openai.OpenAI(**client_params).chat.completions
|
||||
else:
|
||||
values["client"] = openai.ChatCompletion
|
||||
values["client"] = openai.ChatCompletion # type: ignore
|
||||
except AttributeError as exc:
|
||||
raise ValueError(
|
||||
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||
@@ -167,7 +167,7 @@ class ChatAnyscale(ChatOpenAI):
|
||||
|
||||
available_models = cls.get_available_models(
|
||||
values["openai_api_key"],
|
||||
values["openai_api_base"],
|
||||
values["openai_api_base"], # type: ignore
|
||||
)
|
||||
|
||||
if model_name not in available_models:
|
||||
|
||||
@@ -96,7 +96,7 @@ class AzureChatOpenAI(ChatOpenAI):
|
||||
openai_api_type: str = ""
|
||||
"""Legacy, for openai<1.0.0 support."""
|
||||
validate_base_url: bool = True
|
||||
"""For backwards compatibility. If legacy val openai_api_base is passed in, try to
|
||||
"""For backwards compatibility. If legacy val openai_api_base is passed in, try to
|
||||
infer if it is a base_url or azure_endpoint and update accordingly.
|
||||
"""
|
||||
|
||||
@@ -121,7 +121,7 @@ class AzureChatOpenAI(ChatOpenAI):
|
||||
or os.getenv("AZURE_OPENAI_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
|
||||
values["openai_api_base"] = values["openai_api_base"] or os.getenv( # type: ignore
|
||||
"OPENAI_API_BASE"
|
||||
)
|
||||
values["openai_api_version"] = values["openai_api_version"] or os.getenv(
|
||||
@@ -143,8 +143,11 @@ class AzureChatOpenAI(ChatOpenAI):
|
||||
values["openai_api_type"] = get_from_dict_or_env(
|
||||
values, "openai_api_type", "OPENAI_API_TYPE", default="azure"
|
||||
)
|
||||
values["openai_proxy"] = get_from_dict_or_env(
|
||||
values, "openai_proxy", "OPENAI_PROXY", default=""
|
||||
values["openai_proxy"] = get_from_dict_or_env( # type: ignore
|
||||
values,
|
||||
"openai_proxy",
|
||||
"OPENAI_PROXY",
|
||||
default="", # type: ignore
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -157,37 +160,37 @@ class AzureChatOpenAI(ChatOpenAI):
|
||||
)
|
||||
if is_openai_v1():
|
||||
# For backwards compatibility. Before openai v1, no distinction was made
|
||||
# between azure_endpoint and base_url (openai_api_base).
|
||||
openai_api_base = values["openai_api_base"]
|
||||
if openai_api_base and values["validate_base_url"]:
|
||||
if "/openai" not in openai_api_base:
|
||||
values["openai_api_base"] = (
|
||||
values["openai_api_base"].rstrip("/") + "/openai"
|
||||
# between azure_endpoint and base_url (openai_api_base). # type: ignore
|
||||
openai_api_base = values["openai_api_base"] # type: ignore
|
||||
if openai_api_base and values["validate_base_url"]: # type: ignore
|
||||
if "/openai" not in openai_api_base: # type: ignore
|
||||
values["openai_api_base"] = ( # type: ignore
|
||||
values["openai_api_base"].rstrip("/") + "/openai" # type: ignore
|
||||
)
|
||||
warnings.warn(
|
||||
"As of openai>=1.0.0, Azure endpoints should be specified via "
|
||||
f"the `azure_endpoint` param not `openai_api_base` "
|
||||
f"(or alias `base_url`). Updating `openai_api_base` from "
|
||||
f"{openai_api_base} to {values['openai_api_base']}."
|
||||
f"the `azure_endpoint` param not `openai_api_base` " # type: ignore
|
||||
f"(or alias `base_url`). Updating `openai_api_base` from " # type: ignore
|
||||
f"{openai_api_base} to {values['openai_api_base']}." # type: ignore
|
||||
)
|
||||
if values["deployment_name"]:
|
||||
warnings.warn(
|
||||
"As of openai>=1.0.0, if `deployment_name` (or alias "
|
||||
"`azure_deployment`) is specified then "
|
||||
"`openai_api_base` (or alias `base_url`) should not be. "
|
||||
"`openai_api_base` (or alias `base_url`) should not be. " # type: ignore
|
||||
"Instead use `deployment_name` (or alias `azure_deployment`) "
|
||||
"and `azure_endpoint`."
|
||||
)
|
||||
if values["deployment_name"] not in values["openai_api_base"]:
|
||||
if values["deployment_name"] not in values["openai_api_base"]: # type: ignore
|
||||
warnings.warn(
|
||||
"As of openai>=1.0.0, if `openai_api_base` "
|
||||
"As of openai>=1.0.0, if `openai_api_base` " # type: ignore
|
||||
"(or alias `base_url`) is specified it is expected to be "
|
||||
"of the form "
|
||||
"https://example-resource.azure.openai.com/openai/deployments/example-deployment. " # noqa: E501
|
||||
f"Updating {openai_api_base} to "
|
||||
f"{values['openai_api_base']}."
|
||||
f"Updating {openai_api_base} to " # type: ignore
|
||||
f"{values['openai_api_base']}." # type: ignore
|
||||
)
|
||||
values["openai_api_base"] += (
|
||||
values["openai_api_base"] += ( # type: ignore
|
||||
"/deployments/" + values["deployment_name"]
|
||||
)
|
||||
values["deployment_name"] = None
|
||||
@@ -199,7 +202,7 @@ class AzureChatOpenAI(ChatOpenAI):
|
||||
"azure_ad_token": values["azure_ad_token"],
|
||||
"azure_ad_token_provider": values["azure_ad_token_provider"],
|
||||
"organization": values["openai_organization"],
|
||||
"base_url": values["openai_api_base"],
|
||||
"base_url": values["openai_api_base"], # type: ignore
|
||||
"timeout": values["request_timeout"],
|
||||
"max_retries": values["max_retries"],
|
||||
"default_headers": values["default_headers"],
|
||||
@@ -211,7 +214,7 @@ class AzureChatOpenAI(ChatOpenAI):
|
||||
**client_params
|
||||
).chat.completions
|
||||
else:
|
||||
values["client"] = openai.ChatCompletion
|
||||
values["client"] = openai.ChatCompletion # type: ignore
|
||||
return values
|
||||
|
||||
@property
|
||||
|
||||
@@ -83,7 +83,7 @@ class ChatEverlyAI(ChatOpenAI):
|
||||
"everlyai_api_key",
|
||||
"EVERLYAI_API_KEY",
|
||||
)
|
||||
values["openai_api_base"] = DEFAULT_API_BASE
|
||||
values["openai_api_base"] = DEFAULT_API_BASE # type: ignore
|
||||
|
||||
try:
|
||||
import openai
|
||||
@@ -94,7 +94,7 @@ class ChatEverlyAI(ChatOpenAI):
|
||||
"Please install it with `pip install openai`.",
|
||||
) from e
|
||||
try:
|
||||
values["client"] = openai.ChatCompletion
|
||||
values["client"] = openai.ChatCompletion # type: ignore
|
||||
except AttributeError as exc:
|
||||
raise ValueError(
|
||||
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||
|
||||
@@ -39,7 +39,11 @@ from langchain_community.adapters.openai import (
|
||||
from langchain_community.chat_models.openai import _convert_delta_to_message_chunk
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gpt_router.models import ChunkedGenerationResponse, GenerationResponse
|
||||
from gpt_router.models import (
|
||||
ChunkedGenerationResponse,
|
||||
GenerationResponse,
|
||||
ModelGenerationRequest,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -57,8 +61,8 @@ class GPTRouterModel(BaseModel):
|
||||
|
||||
|
||||
def get_ordered_generation_requests(
|
||||
models_priority_list: List[GPTRouterModel], **kwargs
|
||||
):
|
||||
models_priority_list: List[GPTRouterModel], **kwargs: Any
|
||||
) -> List[ModelGenerationRequest]:
|
||||
"""
|
||||
Return the body for the model router input.
|
||||
"""
|
||||
@@ -100,7 +104,7 @@ def completion_with_retry(
|
||||
models_priority_list: List[GPTRouterModel],
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse]]:
|
||||
) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse, None, None]]:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@@ -122,7 +126,7 @@ async def acompletion_with_retry(
|
||||
models_priority_list: List[GPTRouterModel],
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse]]:
|
||||
) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse, None]]:
|
||||
"""Use tenacity to retry the async completion call."""
|
||||
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
@@ -283,8 +287,8 @@ class GPTRouter(BaseChatModel):
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _create_chat_generation_chunk(
|
||||
self, data: Mapping[str, Any], default_chunk_class
|
||||
):
|
||||
self, data: Mapping[str, Any], default_chunk_class: Any
|
||||
) -> Tuple[ChatGenerationChunk, Any]:
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
{"content": data.get("text", "")}, default_chunk_class
|
||||
)
|
||||
@@ -293,8 +297,8 @@ class GPTRouter(BaseChatModel):
|
||||
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||
)
|
||||
default_chunk_class = chunk.__class__
|
||||
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
||||
return chunk, default_chunk_class
|
||||
chunk_ = ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
||||
return chunk_, default_chunk_class
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
|
||||
@@ -68,11 +68,11 @@ def _create_retry_decorator(llm: JinaChat) -> Callable[[Any], Any]:
|
||||
stop=stop_after_attempt(llm.max_retries),
|
||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||
retry=(
|
||||
retry_if_exception_type(openai.error.Timeout)
|
||||
| retry_if_exception_type(openai.error.APIError)
|
||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||
| retry_if_exception_type(openai.error.RateLimitError)
|
||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||
retry_if_exception_type(openai.error.Timeout) # type: ignore
|
||||
| retry_if_exception_type(openai.error.APIError) # type: ignore
|
||||
| retry_if_exception_type(openai.error.APIConnectionError) # type: ignore
|
||||
| retry_if_exception_type(openai.error.RateLimitError) # type: ignore
|
||||
| retry_if_exception_type(openai.error.ServiceUnavailableError) # type: ignore
|
||||
),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
@@ -232,7 +232,7 @@ class JinaChat(BaseChatModel):
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
try:
|
||||
values["client"] = openai.ChatCompletion
|
||||
values["client"] = openai.ChatCompletion # type: ignore
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||
@@ -264,11 +264,11 @@ class JinaChat(BaseChatModel):
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||
retry=(
|
||||
retry_if_exception_type(openai.error.Timeout)
|
||||
| retry_if_exception_type(openai.error.APIError)
|
||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||
| retry_if_exception_type(openai.error.RateLimitError)
|
||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||
retry_if_exception_type(openai.error.Timeout) # type: ignore
|
||||
| retry_if_exception_type(openai.error.APIError) # type: ignore
|
||||
| retry_if_exception_type(openai.error.APIConnectionError) # type: ignore
|
||||
| retry_if_exception_type(openai.error.RateLimitError) # type: ignore
|
||||
| retry_if_exception_type(openai.error.ServiceUnavailableError) # type: ignore
|
||||
),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
|
||||
@@ -84,11 +84,11 @@ def _create_retry_decorator(
|
||||
import openai
|
||||
|
||||
errors = [
|
||||
openai.error.Timeout,
|
||||
openai.error.APIError,
|
||||
openai.error.APIConnectionError,
|
||||
openai.error.RateLimitError,
|
||||
openai.error.ServiceUnavailableError,
|
||||
openai.error.Timeout, # type: ignore
|
||||
openai.error.APIError, # type: ignore
|
||||
openai.error.APIConnectionError, # type: ignore
|
||||
openai.error.RateLimitError, # type: ignore
|
||||
openai.error.ServiceUnavailableError, # type: ignore
|
||||
]
|
||||
return create_base_retry_decorator(
|
||||
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||
@@ -179,11 +179,11 @@ class ChatOpenAI(BaseChatModel):
|
||||
if self.openai_organization:
|
||||
attributes["openai_organization"] = self.openai_organization
|
||||
|
||||
if self.openai_api_base:
|
||||
attributes["openai_api_base"] = self.openai_api_base
|
||||
if self.openai_api_base: # type: ignore
|
||||
attributes["openai_api_base"] = self.openai_api_base # type: ignore
|
||||
|
||||
if self.openai_proxy:
|
||||
attributes["openai_proxy"] = self.openai_proxy
|
||||
if self.openai_proxy: # type: ignore
|
||||
attributes["openai_proxy"] = self.openai_proxy # type: ignore
|
||||
|
||||
return attributes
|
||||
|
||||
@@ -205,13 +205,13 @@ class ChatOpenAI(BaseChatModel):
|
||||
# may assume openai_api_key is a str)
|
||||
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
|
||||
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
|
||||
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
|
||||
openai_api_base: Optional[str] = Field(default=None, alias="base_url") # type: ignore
|
||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||
emulator."""
|
||||
openai_organization: Optional[str] = Field(default=None, alias="organization")
|
||||
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
|
||||
# to support explicit proxy for OpenAI
|
||||
openai_proxy: Optional[str] = None
|
||||
openai_proxy: Optional[str] = None # type: ignore
|
||||
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
|
||||
default=None, alias="timeout"
|
||||
)
|
||||
@@ -290,12 +290,12 @@ class ChatOpenAI(BaseChatModel):
|
||||
or os.getenv("OPENAI_ORG_ID")
|
||||
or os.getenv("OPENAI_ORGANIZATION")
|
||||
)
|
||||
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
|
||||
values["openai_api_base"] = values["openai_api_base"] or os.getenv( # type: ignore
|
||||
"OPENAI_API_BASE"
|
||||
)
|
||||
values["openai_proxy"] = get_from_dict_or_env(
|
||||
values["openai_proxy"] = get_from_dict_or_env( # type: ignore
|
||||
values,
|
||||
"openai_proxy",
|
||||
"openai_proxy", # type: ignore
|
||||
"OPENAI_PROXY",
|
||||
default="",
|
||||
)
|
||||
@@ -312,7 +312,7 @@ class ChatOpenAI(BaseChatModel):
|
||||
client_params = {
|
||||
"api_key": values["openai_api_key"],
|
||||
"organization": values["openai_organization"],
|
||||
"base_url": values["openai_api_base"],
|
||||
"base_url": values["openai_api_base"], # type: ignore
|
||||
"timeout": values["request_timeout"],
|
||||
"max_retries": values["max_retries"],
|
||||
"default_headers": values["default_headers"],
|
||||
@@ -327,7 +327,7 @@ class ChatOpenAI(BaseChatModel):
|
||||
**client_params
|
||||
).chat.completions
|
||||
elif not values.get("client"):
|
||||
values["client"] = openai.ChatCompletion
|
||||
values["client"] = openai.ChatCompletion # type: ignore
|
||||
else:
|
||||
pass
|
||||
return values
|
||||
@@ -547,14 +547,14 @@ class ChatOpenAI(BaseChatModel):
|
||||
openai_creds.update(
|
||||
{
|
||||
"api_key": self.openai_api_key,
|
||||
"api_base": self.openai_api_base,
|
||||
"api_base": self.openai_api_base, # type: ignore
|
||||
"organization": self.openai_organization,
|
||||
}
|
||||
)
|
||||
if self.openai_proxy:
|
||||
if self.openai_proxy: # type: ignore
|
||||
import openai
|
||||
|
||||
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # noqa: E501
|
||||
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore # noqa: E501
|
||||
return {**self._default_params, **openai_creds}
|
||||
|
||||
def _get_invocation_params(
|
||||
|
||||
@@ -35,7 +35,7 @@ class GCSDirectoryLoader(BaseLoader):
|
||||
def load(self) -> List[Document]:
|
||||
"""Load documents."""
|
||||
try:
|
||||
from google.cloud import storage
|
||||
from google.cloud import storage # type: ignore
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import google-cloud-storage python package. "
|
||||
|
||||
@@ -51,7 +51,7 @@ class GCSFileLoader(BaseLoader):
|
||||
def load(self) -> List[Document]:
|
||||
"""Load documents."""
|
||||
try:
|
||||
from google.cloud import storage
|
||||
from google.cloud import storage # type: ignore
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import google-cloud-storage python package. "
|
||||
|
||||
@@ -73,7 +73,7 @@ class OpenAIWhisperParser(BaseBlobParser):
|
||||
model="whisper-1", file=file_obj
|
||||
)
|
||||
else:
|
||||
transcript = openai.Audio.transcribe("whisper-1", file_obj)
|
||||
transcript = openai.Audio.transcribe("whisper-1", file_obj) # type: ignore
|
||||
break
|
||||
except Exception as e:
|
||||
attempts += 1
|
||||
|
||||
@@ -27,7 +27,7 @@ class GoogleTranslateTransformer(BaseDocumentTransformer):
|
||||
"""
|
||||
try:
|
||||
from google.api_core.client_options import ClientOptions
|
||||
from google.cloud import translate
|
||||
from google.cloud import translate # type: ignore
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Install Google Cloud Translate to use this parser."
|
||||
@@ -70,7 +70,7 @@ class GoogleTranslateTransformer(BaseDocumentTransformer):
|
||||
Options: `text/plain`, `text/html`
|
||||
"""
|
||||
try:
|
||||
from google.cloud import translate
|
||||
from google.cloud import translate # type: ignore
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Install Google Cloud Translate to use this parser."
|
||||
|
||||
@@ -64,7 +64,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
|
||||
or os.getenv("AZURE_OPENAI_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
|
||||
values["openai_api_base"] = values["openai_api_base"] or os.getenv( # type: ignore
|
||||
"OPENAI_API_BASE"
|
||||
)
|
||||
values["openai_api_version"] = values["openai_api_version"] or os.getenv(
|
||||
@@ -78,9 +78,9 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
|
||||
or os.getenv("OPENAI_ORG_ID")
|
||||
or os.getenv("OPENAI_ORGANIZATION")
|
||||
)
|
||||
values["openai_proxy"] = get_from_dict_or_env(
|
||||
values["openai_proxy"] = get_from_dict_or_env( # type: ignore
|
||||
values,
|
||||
"openai_proxy",
|
||||
"openai_proxy", # type: ignore
|
||||
"OPENAI_PROXY",
|
||||
default="",
|
||||
)
|
||||
@@ -104,35 +104,35 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
|
||||
)
|
||||
if is_openai_v1():
|
||||
# For backwards compatibility. Before openai v1, no distinction was made
|
||||
# between azure_endpoint and base_url (openai_api_base).
|
||||
openai_api_base = values["openai_api_base"]
|
||||
if openai_api_base and values["validate_base_url"]:
|
||||
if "/openai" not in openai_api_base:
|
||||
values["openai_api_base"] += "/openai"
|
||||
# between azure_endpoint and base_url (openai_api_base). # type: ignore
|
||||
openai_api_base = values["openai_api_base"] # type: ignore
|
||||
if openai_api_base and values["validate_base_url"]: # type: ignore
|
||||
if "/openai" not in openai_api_base: # type: ignore
|
||||
values["openai_api_base"] += "/openai" # type: ignore
|
||||
warnings.warn(
|
||||
"As of openai>=1.0.0, Azure endpoints should be specified via "
|
||||
f"the `azure_endpoint` param not `openai_api_base` "
|
||||
f"(or alias `base_url`). Updating `openai_api_base` from "
|
||||
f"{openai_api_base} to {values['openai_api_base']}."
|
||||
f"the `azure_endpoint` param not `openai_api_base` " # type: ignore
|
||||
f"(or alias `base_url`). Updating `openai_api_base` from " # type: ignore
|
||||
f"{openai_api_base} to {values['openai_api_base']}." # type: ignore
|
||||
)
|
||||
if values["deployment"]:
|
||||
warnings.warn(
|
||||
"As of openai>=1.0.0, if `deployment` (or alias "
|
||||
"`azure_deployment`) is specified then "
|
||||
"`openai_api_base` (or alias `base_url`) should not be. "
|
||||
"`openai_api_base` (or alias `base_url`) should not be. " # type: ignore
|
||||
"Instead use `deployment` (or alias `azure_deployment`) "
|
||||
"and `azure_endpoint`."
|
||||
)
|
||||
if values["deployment"] not in values["openai_api_base"]:
|
||||
if values["deployment"] not in values["openai_api_base"]: # type: ignore
|
||||
warnings.warn(
|
||||
"As of openai>=1.0.0, if `openai_api_base` "
|
||||
"As of openai>=1.0.0, if `openai_api_base` " # type: ignore
|
||||
"(or alias `base_url`) is specified it is expected to be "
|
||||
"of the form "
|
||||
"https://example-resource.azure.openai.com/openai/deployments/example-deployment. " # noqa: E501
|
||||
f"Updating {openai_api_base} to "
|
||||
f"{values['openai_api_base']}."
|
||||
f"Updating {openai_api_base} to " # type: ignore
|
||||
f"{values['openai_api_base']}." # type: ignore
|
||||
)
|
||||
values["openai_api_base"] += (
|
||||
values["openai_api_base"] += ( # type: ignore
|
||||
"/deployments/" + values["deployment"]
|
||||
)
|
||||
values["deployment"] = None
|
||||
@@ -144,7 +144,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
|
||||
"azure_ad_token": values["azure_ad_token"],
|
||||
"azure_ad_token_provider": values["azure_ad_token_provider"],
|
||||
"organization": values["openai_organization"],
|
||||
"base_url": values["openai_api_base"],
|
||||
"base_url": values["openai_api_base"], # type: ignore
|
||||
"timeout": values["request_timeout"],
|
||||
"max_retries": values["max_retries"],
|
||||
"default_headers": values["default_headers"],
|
||||
@@ -154,7 +154,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
|
||||
values["client"] = openai.AzureOpenAI(**client_params).embeddings
|
||||
values["async_client"] = openai.AsyncAzureOpenAI(**client_params).embeddings
|
||||
else:
|
||||
values["client"] = openai.Embedding
|
||||
values["client"] = openai.Embedding # type: ignore
|
||||
return values
|
||||
|
||||
@property
|
||||
|
||||
@@ -42,11 +42,11 @@ def _create_retry_decorator(embeddings: LocalAIEmbeddings) -> Callable[[Any], An
|
||||
stop=stop_after_attempt(embeddings.max_retries),
|
||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||
retry=(
|
||||
retry_if_exception_type(openai.error.Timeout)
|
||||
| retry_if_exception_type(openai.error.APIError)
|
||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||
| retry_if_exception_type(openai.error.RateLimitError)
|
||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||
retry_if_exception_type(openai.error.Timeout) # type: ignore
|
||||
| retry_if_exception_type(openai.error.APIError) # type: ignore
|
||||
| retry_if_exception_type(openai.error.APIConnectionError) # type: ignore
|
||||
| retry_if_exception_type(openai.error.RateLimitError) # type: ignore
|
||||
| retry_if_exception_type(openai.error.ServiceUnavailableError) # type: ignore
|
||||
),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
@@ -64,11 +64,11 @@ def _async_retry_decorator(embeddings: LocalAIEmbeddings) -> Any:
|
||||
stop=stop_after_attempt(embeddings.max_retries),
|
||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||
retry=(
|
||||
retry_if_exception_type(openai.error.Timeout)
|
||||
| retry_if_exception_type(openai.error.APIError)
|
||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||
| retry_if_exception_type(openai.error.RateLimitError)
|
||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||
retry_if_exception_type(openai.error.Timeout) # type: ignore
|
||||
| retry_if_exception_type(openai.error.APIError) # type: ignore
|
||||
| retry_if_exception_type(openai.error.APIConnectionError) # type: ignore
|
||||
| retry_if_exception_type(openai.error.RateLimitError) # type: ignore
|
||||
| retry_if_exception_type(openai.error.ServiceUnavailableError) # type: ignore
|
||||
),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
@@ -89,7 +89,7 @@ def _check_response(response: dict) -> dict:
|
||||
if any(len(d["embedding"]) == 1 for d in response["data"]):
|
||||
import openai
|
||||
|
||||
raise openai.error.APIError("LocalAI API returned an empty embedding")
|
||||
raise openai.error.APIError("LocalAI API returned an empty embedding") # type: ignore
|
||||
return response
|
||||
|
||||
|
||||
@@ -132,7 +132,7 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
|
||||
from langchain_community.embeddings import LocalAIEmbeddings
|
||||
openai = LocalAIEmbeddings(
|
||||
openai_api_key="random-string",
|
||||
openai_api_base="http://localhost:8080"
|
||||
openai_api_base="http://localhost:8080" # type: ignore
|
||||
)
|
||||
|
||||
"""
|
||||
@@ -141,9 +141,9 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
|
||||
model: str = "text-embedding-ada-002"
|
||||
deployment: str = model
|
||||
openai_api_version: Optional[str] = None
|
||||
openai_api_base: Optional[str] = None
|
||||
openai_api_base: Optional[str] = None # type: ignore
|
||||
# to support explicit proxy for LocalAI
|
||||
openai_proxy: Optional[str] = None
|
||||
openai_proxy: Optional[str] = None # type: ignore
|
||||
embedding_ctx_length: int = 8191
|
||||
"""The maximum number of tokens to embed at once."""
|
||||
openai_api_key: Optional[str] = None
|
||||
@@ -199,15 +199,15 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
|
||||
values["openai_api_key"] = get_from_dict_or_env(
|
||||
values, "openai_api_key", "OPENAI_API_KEY"
|
||||
)
|
||||
values["openai_api_base"] = get_from_dict_or_env(
|
||||
values["openai_api_base"] = get_from_dict_or_env( # type: ignore
|
||||
values,
|
||||
"openai_api_base",
|
||||
"openai_api_base", # type: ignore
|
||||
"OPENAI_API_BASE",
|
||||
default="",
|
||||
)
|
||||
values["openai_proxy"] = get_from_dict_or_env(
|
||||
values["openai_proxy"] = get_from_dict_or_env( # type: ignore
|
||||
values,
|
||||
"openai_proxy",
|
||||
"openai_proxy", # type: ignore
|
||||
"OPENAI_PROXY",
|
||||
default="",
|
||||
)
|
||||
@@ -228,7 +228,7 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
|
||||
try:
|
||||
import openai
|
||||
|
||||
values["client"] = openai.Embedding
|
||||
values["client"] = openai.Embedding # type: ignore
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import openai python package. "
|
||||
@@ -244,16 +244,16 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
|
||||
"headers": self.headers,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization,
|
||||
"api_base": self.openai_api_base,
|
||||
"api_base": self.openai_api_base, # type: ignore
|
||||
"api_version": self.openai_api_version,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
if self.openai_proxy:
|
||||
if self.openai_proxy: # type: ignore
|
||||
import openai
|
||||
|
||||
openai.proxy = {
|
||||
"http": self.openai_proxy,
|
||||
"https": self.openai_proxy,
|
||||
openai.proxy = { # type: ignore
|
||||
"http": self.openai_proxy, # type: ignore
|
||||
"https": self.openai_proxy, # type: ignore
|
||||
} # type: ignore[assignment] # noqa: E501
|
||||
return openai_args
|
||||
|
||||
|
||||
@@ -54,11 +54,11 @@ def _create_retry_decorator(embeddings: OpenAIEmbeddings) -> Callable[[Any], Any
|
||||
max=embeddings.retry_max_seconds,
|
||||
),
|
||||
retry=(
|
||||
retry_if_exception_type(openai.error.Timeout)
|
||||
| retry_if_exception_type(openai.error.APIError)
|
||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||
| retry_if_exception_type(openai.error.RateLimitError)
|
||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||
retry_if_exception_type(openai.error.Timeout) # type: ignore
|
||||
| retry_if_exception_type(openai.error.APIError) # type: ignore
|
||||
| retry_if_exception_type(openai.error.APIConnectionError) # type: ignore
|
||||
| retry_if_exception_type(openai.error.RateLimitError) # type: ignore
|
||||
| retry_if_exception_type(openai.error.ServiceUnavailableError) # type: ignore
|
||||
),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
@@ -81,11 +81,11 @@ def _async_retry_decorator(embeddings: OpenAIEmbeddings) -> Any:
|
||||
max=embeddings.retry_max_seconds,
|
||||
),
|
||||
retry=(
|
||||
retry_if_exception_type(openai.error.Timeout)
|
||||
| retry_if_exception_type(openai.error.APIError)
|
||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||
| retry_if_exception_type(openai.error.RateLimitError)
|
||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||
retry_if_exception_type(openai.error.Timeout) # type: ignore
|
||||
| retry_if_exception_type(openai.error.APIError) # type: ignore
|
||||
| retry_if_exception_type(openai.error.APIConnectionError) # type: ignore
|
||||
| retry_if_exception_type(openai.error.RateLimitError) # type: ignore
|
||||
| retry_if_exception_type(openai.error.ServiceUnavailableError) # type: ignore
|
||||
),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
@@ -106,7 +106,7 @@ def _check_response(response: dict, skip_empty: bool = False) -> dict:
|
||||
if any(len(d["embedding"]) == 1 for d in response["data"]) and not skip_empty:
|
||||
import openai
|
||||
|
||||
raise openai.error.APIError("OpenAI API returned an empty embedding")
|
||||
raise openai.error.APIError("OpenAI API returned an empty embedding") # type: ignore
|
||||
return response
|
||||
|
||||
|
||||
@@ -194,13 +194,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
openai_api_version: Optional[str] = Field(default=None, alias="api_version")
|
||||
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
|
||||
# to support Azure OpenAI Service custom endpoints
|
||||
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
|
||||
openai_api_base: Optional[str] = Field(default=None, alias="base_url") # type: ignore
|
||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||
emulator."""
|
||||
# to support Azure OpenAI Service custom endpoints
|
||||
openai_api_type: Optional[str] = None
|
||||
# to support explicit proxy for OpenAI
|
||||
openai_proxy: Optional[str] = None
|
||||
openai_proxy: Optional[str] = None # type: ignore
|
||||
embedding_ctx_length: int = 8191
|
||||
"""The maximum number of tokens to embed at once."""
|
||||
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
|
||||
@@ -288,7 +288,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
values["openai_api_key"] = get_from_dict_or_env(
|
||||
values, "openai_api_key", "OPENAI_API_KEY"
|
||||
)
|
||||
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
|
||||
values["openai_api_base"] = values["openai_api_base"] or os.getenv( # type: ignore
|
||||
"OPENAI_API_BASE"
|
||||
)
|
||||
values["openai_api_type"] = get_from_dict_or_env(
|
||||
@@ -297,9 +297,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
"OPENAI_API_TYPE",
|
||||
default="",
|
||||
)
|
||||
values["openai_proxy"] = get_from_dict_or_env(
|
||||
values["openai_proxy"] = get_from_dict_or_env( # type: ignore
|
||||
values,
|
||||
"openai_proxy",
|
||||
"openai_proxy", # type: ignore
|
||||
"OPENAI_PROXY",
|
||||
default="",
|
||||
)
|
||||
@@ -340,7 +340,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
client_params = {
|
||||
"api_key": values["openai_api_key"],
|
||||
"organization": values["openai_organization"],
|
||||
"base_url": values["openai_api_base"],
|
||||
"base_url": values["openai_api_base"], # type: ignore
|
||||
"timeout": values["request_timeout"],
|
||||
"max_retries": values["max_retries"],
|
||||
"default_headers": values["default_headers"],
|
||||
@@ -354,7 +354,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
**client_params
|
||||
).embeddings
|
||||
elif not values.get("client"):
|
||||
values["client"] = openai.Embedding
|
||||
values["client"] = openai.Embedding # type: ignore
|
||||
else:
|
||||
pass
|
||||
return values
|
||||
@@ -370,7 +370,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
"headers": self.headers,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization,
|
||||
"api_base": self.openai_api_base,
|
||||
"api_base": self.openai_api_base, # type: ignore
|
||||
"api_type": self.openai_api_type,
|
||||
"api_version": self.openai_api_version,
|
||||
**self.model_kwargs,
|
||||
@@ -378,7 +378,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
if self.openai_api_type in ("azure", "azure_ad", "azuread"):
|
||||
openai_args["engine"] = self.deployment
|
||||
# TODO: Look into proxy with openai v1.
|
||||
if self.openai_proxy:
|
||||
if self.openai_proxy: # type: ignore
|
||||
try:
|
||||
import openai
|
||||
except ImportError:
|
||||
@@ -387,9 +387,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
|
||||
openai.proxy = {
|
||||
"http": self.openai_proxy,
|
||||
"https": self.openai_proxy,
|
||||
openai.proxy = { # type: ignore
|
||||
"http": self.openai_proxy, # type: ignore
|
||||
"https": self.openai_proxy, # type: ignore
|
||||
} # type: ignore[assignment] # noqa: E501
|
||||
return openai_args
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ class Anyscale(BaseOpenAI):
|
||||
import openai
|
||||
|
||||
## Always create ChatComplete client, replacing the legacy Complete client
|
||||
values["client"] = openai.ChatCompletion
|
||||
values["client"] = openai.ChatCompletion # type: ignore
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import openai python package. "
|
||||
|
||||
@@ -97,8 +97,8 @@ class GooseAI(LLM):
|
||||
import openai
|
||||
|
||||
openai.api_key = gooseai_api_key.get_secret_value()
|
||||
openai.api_base = "https://api.goose.ai/v1"
|
||||
values["client"] = openai.Completion
|
||||
openai.api_base = "https://api.goose.ai/v1" # type: ignore
|
||||
values["client"] = openai.Completion # type: ignore
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import openai python package. "
|
||||
|
||||
@@ -94,11 +94,11 @@ def _create_retry_decorator(
|
||||
import openai
|
||||
|
||||
errors = [
|
||||
openai.error.Timeout,
|
||||
openai.error.APIError,
|
||||
openai.error.APIConnectionError,
|
||||
openai.error.RateLimitError,
|
||||
openai.error.ServiceUnavailableError,
|
||||
openai.error.Timeout, # type: ignore
|
||||
openai.error.APIError, # type: ignore
|
||||
openai.error.APIConnectionError, # type: ignore
|
||||
openai.error.RateLimitError, # type: ignore
|
||||
openai.error.ServiceUnavailableError, # type: ignore
|
||||
]
|
||||
return create_base_retry_decorator(
|
||||
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||
@@ -157,14 +157,14 @@ class BaseOpenAI(BaseLLM):
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
attributes: Dict[str, Any] = {}
|
||||
if self.openai_api_base:
|
||||
attributes["openai_api_base"] = self.openai_api_base
|
||||
if self.openai_api_base: # type: ignore
|
||||
attributes["openai_api_base"] = self.openai_api_base # type: ignore
|
||||
|
||||
if self.openai_organization:
|
||||
attributes["openai_organization"] = self.openai_organization
|
||||
|
||||
if self.openai_proxy:
|
||||
attributes["openai_proxy"] = self.openai_proxy
|
||||
if self.openai_proxy: # type: ignore
|
||||
attributes["openai_proxy"] = self.openai_proxy # type: ignore
|
||||
|
||||
return attributes
|
||||
|
||||
@@ -199,13 +199,13 @@ class BaseOpenAI(BaseLLM):
|
||||
# may assume openai_api_key is a str)
|
||||
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
|
||||
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
|
||||
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
|
||||
openai_api_base: Optional[str] = Field(default=None, alias="base_url") # type: ignore
|
||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||
emulator."""
|
||||
openai_organization: Optional[str] = Field(default=None, alias="organization")
|
||||
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
|
||||
# to support explicit proxy for OpenAI
|
||||
openai_proxy: Optional[str] = None
|
||||
openai_proxy: Optional[str] = None # type: ignore
|
||||
batch_size: int = 20
|
||||
"""Batch size to use when passing multiple documents to generate."""
|
||||
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
|
||||
@@ -282,12 +282,12 @@ class BaseOpenAI(BaseLLM):
|
||||
values["openai_api_key"] = get_from_dict_or_env(
|
||||
values, "openai_api_key", "OPENAI_API_KEY"
|
||||
)
|
||||
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
|
||||
values["openai_api_base"] = values["openai_api_base"] or os.getenv( # type: ignore
|
||||
"OPENAI_API_BASE"
|
||||
)
|
||||
values["openai_proxy"] = get_from_dict_or_env(
|
||||
values["openai_proxy"] = get_from_dict_or_env( # type: ignore
|
||||
values,
|
||||
"openai_proxy",
|
||||
"openai_proxy", # type: ignore
|
||||
"OPENAI_PROXY",
|
||||
default="",
|
||||
)
|
||||
@@ -308,7 +308,7 @@ class BaseOpenAI(BaseLLM):
|
||||
client_params = {
|
||||
"api_key": values["openai_api_key"],
|
||||
"organization": values["openai_organization"],
|
||||
"base_url": values["openai_api_base"],
|
||||
"base_url": values["openai_api_base"], # type: ignore
|
||||
"timeout": values["request_timeout"],
|
||||
"max_retries": values["max_retries"],
|
||||
"default_headers": values["default_headers"],
|
||||
@@ -320,7 +320,7 @@ class BaseOpenAI(BaseLLM):
|
||||
if not values.get("async_client"):
|
||||
values["async_client"] = openai.AsyncOpenAI(**client_params).completions
|
||||
elif not values.get("client"):
|
||||
values["client"] = openai.Completion
|
||||
values["client"] = openai.Completion # type: ignore
|
||||
else:
|
||||
pass
|
||||
|
||||
@@ -597,14 +597,14 @@ class BaseOpenAI(BaseLLM):
|
||||
openai_creds.update(
|
||||
{
|
||||
"api_key": self.openai_api_key,
|
||||
"api_base": self.openai_api_base,
|
||||
"api_base": self.openai_api_base, # type: ignore
|
||||
"organization": self.openai_organization,
|
||||
}
|
||||
)
|
||||
if self.openai_proxy:
|
||||
if self.openai_proxy: # type: ignore
|
||||
import openai
|
||||
|
||||
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # noqa: E501
|
||||
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore # noqa: E501
|
||||
return {**openai_creds, **self._default_params}
|
||||
|
||||
@property
|
||||
@@ -807,7 +807,7 @@ class AzureOpenAI(BaseOpenAI):
|
||||
openai_api_type: str = ""
|
||||
"""Legacy, for openai<1.0.0 support."""
|
||||
validate_base_url: bool = True
|
||||
"""For backwards compatibility. If legacy val openai_api_base is passed in, try to
|
||||
"""For backwards compatibility. If legacy val openai_api_base is passed in, try to
|
||||
infer if it is a base_url or azure_endpoint and update accordingly.
|
||||
"""
|
||||
|
||||
@@ -841,12 +841,12 @@ class AzureOpenAI(BaseOpenAI):
|
||||
values["azure_ad_token"] = values["azure_ad_token"] or os.getenv(
|
||||
"AZURE_OPENAI_AD_TOKEN"
|
||||
)
|
||||
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
|
||||
values["openai_api_base"] = values["openai_api_base"] or os.getenv( # type: ignore
|
||||
"OPENAI_API_BASE"
|
||||
)
|
||||
values["openai_proxy"] = get_from_dict_or_env(
|
||||
values["openai_proxy"] = get_from_dict_or_env( # type: ignore
|
||||
values,
|
||||
"openai_proxy",
|
||||
"openai_proxy", # type: ignore
|
||||
"OPENAI_PROXY",
|
||||
default="",
|
||||
)
|
||||
@@ -870,37 +870,37 @@ class AzureOpenAI(BaseOpenAI):
|
||||
)
|
||||
if is_openai_v1():
|
||||
# For backwards compatibility. Before openai v1, no distinction was made
|
||||
# between azure_endpoint and base_url (openai_api_base).
|
||||
openai_api_base = values["openai_api_base"]
|
||||
if openai_api_base and values["validate_base_url"]:
|
||||
if "/openai" not in openai_api_base:
|
||||
values["openai_api_base"] = (
|
||||
values["openai_api_base"].rstrip("/") + "/openai"
|
||||
# between azure_endpoint and base_url (openai_api_base). # type: ignore
|
||||
openai_api_base = values["openai_api_base"] # type: ignore
|
||||
if openai_api_base and values["validate_base_url"]: # type: ignore
|
||||
if "/openai" not in openai_api_base: # type: ignore
|
||||
values["openai_api_base"] = ( # type: ignore
|
||||
values["openai_api_base"].rstrip("/") + "/openai" # type: ignore
|
||||
)
|
||||
warnings.warn(
|
||||
"As of openai>=1.0.0, Azure endpoints should be specified via "
|
||||
f"the `azure_endpoint` param not `openai_api_base` "
|
||||
f"(or alias `base_url`). Updating `openai_api_base` from "
|
||||
f"{openai_api_base} to {values['openai_api_base']}."
|
||||
f"the `azure_endpoint` param not `openai_api_base` " # type: ignore
|
||||
f"(or alias `base_url`). Updating `openai_api_base` from " # type: ignore
|
||||
f"{openai_api_base} to {values['openai_api_base']}." # type: ignore
|
||||
)
|
||||
if values["deployment_name"]:
|
||||
warnings.warn(
|
||||
"As of openai>=1.0.0, if `deployment_name` (or alias "
|
||||
"`azure_deployment`) is specified then "
|
||||
"`openai_api_base` (or alias `base_url`) should not be. "
|
||||
"`openai_api_base` (or alias `base_url`) should not be. " # type: ignore
|
||||
"Instead use `deployment_name` (or alias `azure_deployment`) "
|
||||
"and `azure_endpoint`."
|
||||
)
|
||||
if values["deployment_name"] not in values["openai_api_base"]:
|
||||
if values["deployment_name"] not in values["openai_api_base"]: # type: ignore
|
||||
warnings.warn(
|
||||
"As of openai>=1.0.0, if `openai_api_base` "
|
||||
"As of openai>=1.0.0, if `openai_api_base` " # type: ignore
|
||||
"(or alias `base_url`) is specified it is expected to be "
|
||||
"of the form "
|
||||
"https://example-resource.azure.openai.com/openai/deployments/example-deployment. " # noqa: E501
|
||||
f"Updating {openai_api_base} to "
|
||||
f"{values['openai_api_base']}."
|
||||
f"Updating {openai_api_base} to " # type: ignore
|
||||
f"{values['openai_api_base']}." # type: ignore
|
||||
)
|
||||
values["openai_api_base"] += (
|
||||
values["openai_api_base"] += ( # type: ignore
|
||||
"/deployments/" + values["deployment_name"]
|
||||
)
|
||||
values["deployment_name"] = None
|
||||
@@ -912,7 +912,7 @@ class AzureOpenAI(BaseOpenAI):
|
||||
"azure_ad_token": values["azure_ad_token"],
|
||||
"azure_ad_token_provider": values["azure_ad_token_provider"],
|
||||
"organization": values["openai_organization"],
|
||||
"base_url": values["openai_api_base"],
|
||||
"base_url": values["openai_api_base"], # type: ignore
|
||||
"timeout": values["request_timeout"],
|
||||
"max_retries": values["max_retries"],
|
||||
"default_headers": values["default_headers"],
|
||||
@@ -925,7 +925,7 @@ class AzureOpenAI(BaseOpenAI):
|
||||
).completions
|
||||
|
||||
else:
|
||||
values["client"] = openai.Completion
|
||||
values["client"] = openai.Completion # type: ignore
|
||||
|
||||
return values
|
||||
|
||||
@@ -993,11 +993,11 @@ class OpenAIChat(BaseLLM):
|
||||
# may assume openai_api_key is a str)
|
||||
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
|
||||
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
|
||||
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
|
||||
openai_api_base: Optional[str] = Field(default=None, alias="base_url") # type: ignore
|
||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||
emulator."""
|
||||
# to support explicit proxy for OpenAI
|
||||
openai_proxy: Optional[str] = None
|
||||
openai_proxy: Optional[str] = None # type: ignore
|
||||
max_retries: int = 6
|
||||
"""Maximum number of retries to make when generating."""
|
||||
prefix_messages: List = Field(default_factory=list)
|
||||
@@ -1029,15 +1029,15 @@ class OpenAIChat(BaseLLM):
|
||||
openai_api_key = get_from_dict_or_env(
|
||||
values, "openai_api_key", "OPENAI_API_KEY"
|
||||
)
|
||||
openai_api_base = get_from_dict_or_env(
|
||||
openai_api_base = get_from_dict_or_env( # type: ignore
|
||||
values,
|
||||
"openai_api_base",
|
||||
"openai_api_base", # type: ignore
|
||||
"OPENAI_API_BASE",
|
||||
default="",
|
||||
)
|
||||
openai_proxy = get_from_dict_or_env(
|
||||
openai_proxy = get_from_dict_or_env( # type: ignore
|
||||
values,
|
||||
"openai_proxy",
|
||||
"openai_proxy", # type: ignore
|
||||
"OPENAI_PROXY",
|
||||
default="",
|
||||
)
|
||||
@@ -1048,19 +1048,19 @@ class OpenAIChat(BaseLLM):
|
||||
import openai
|
||||
|
||||
openai.api_key = openai_api_key
|
||||
if openai_api_base:
|
||||
openai.api_base = openai_api_base
|
||||
if openai_api_base: # type: ignore
|
||||
openai.api_base = openai_api_base # type: ignore
|
||||
if openai_organization:
|
||||
openai.organization = openai_organization
|
||||
if openai_proxy:
|
||||
openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501
|
||||
if openai_proxy: # type: ignore
|
||||
openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore # noqa: E501
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
try:
|
||||
values["client"] = openai.ChatCompletion
|
||||
values["client"] = openai.ChatCompletion # type: ignore
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||
|
||||
@@ -164,7 +164,7 @@ class VLLMOpenAI(BaseOpenAI):
|
||||
params.update(
|
||||
{
|
||||
"api_key": self.openai_api_key,
|
||||
"api_base": self.openai_api_base,
|
||||
"api_base": self.openai_api_base, # type: ignore
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -51,6 +51,8 @@ class AmadeusClosestAirport(AmadeusBaseTool):
|
||||
location: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
if not self.llm:
|
||||
raise ValueError("No language model (llm) has been set for this tool.")
|
||||
content = (
|
||||
f" What is the nearest airport to {location}? Please respond with the "
|
||||
" airport's International Air Transport Association (IATA) Location "
|
||||
|
||||
@@ -9,12 +9,12 @@ from langchain_core.tools import BaseTool
|
||||
from langchain_community.utilities.vertexai import get_client_info
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google.cloud import texttospeech
|
||||
from google.cloud import texttospeech # type: ignore
|
||||
|
||||
|
||||
def _import_google_cloud_texttospeech() -> Any:
|
||||
try:
|
||||
from google.cloud import texttospeech
|
||||
from google.cloud import texttospeech # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Cannot import google.cloud.texttospeech, please install "
|
||||
|
||||
@@ -31,13 +31,13 @@ class DallEAPIWrapper(BaseModel):
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
|
||||
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
|
||||
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
|
||||
openai_api_base: Optional[str] = Field(default=None, alias="base_url") # type: ignore
|
||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||
emulator."""
|
||||
openai_organization: Optional[str] = Field(default=None, alias="organization")
|
||||
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
|
||||
# to support explicit proxy for OpenAI
|
||||
openai_proxy: Optional[str] = None
|
||||
openai_proxy: Optional[str] = None # type: ignore
|
||||
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
|
||||
default=None, alias="timeout"
|
||||
)
|
||||
@@ -102,12 +102,12 @@ class DallEAPIWrapper(BaseModel):
|
||||
or os.getenv("OPENAI_ORGANIZATION")
|
||||
or None
|
||||
)
|
||||
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
|
||||
values["openai_api_base"] = values["openai_api_base"] or os.getenv( # type: ignore
|
||||
"OPENAI_API_BASE"
|
||||
)
|
||||
values["openai_proxy"] = get_from_dict_or_env(
|
||||
values["openai_proxy"] = get_from_dict_or_env( # type: ignore
|
||||
values,
|
||||
"openai_proxy",
|
||||
"openai_proxy", # type: ignore
|
||||
"OPENAI_PROXY",
|
||||
default="",
|
||||
)
|
||||
@@ -125,7 +125,7 @@ class DallEAPIWrapper(BaseModel):
|
||||
client_params = {
|
||||
"api_key": values["openai_api_key"],
|
||||
"organization": values["openai_organization"],
|
||||
"base_url": values["openai_api_base"],
|
||||
"base_url": values["openai_api_base"], # type: ignore
|
||||
"timeout": values["request_timeout"],
|
||||
"max_retries": values["max_retries"],
|
||||
"default_headers": values["default_headers"],
|
||||
@@ -138,7 +138,7 @@ class DallEAPIWrapper(BaseModel):
|
||||
if not values.get("async_client"):
|
||||
values["async_client"] = openai.AsyncOpenAI(**client_params).images
|
||||
elif not values.get("client"):
|
||||
values["client"] = openai.Image
|
||||
values["client"] = openai.Image # type: ignore
|
||||
else:
|
||||
pass
|
||||
return values
|
||||
|
||||
@@ -111,7 +111,7 @@ def get_client_info(module: Optional[str] = None) -> "ClientInfo":
|
||||
def load_image_from_gcs(path: str, project: Optional[str] = None) -> "Image":
|
||||
"""Loads im Image from GCS."""
|
||||
try:
|
||||
from google.cloud import storage
|
||||
from google.cloud import storage # type: ignore
|
||||
except ImportError:
|
||||
raise ImportError("Could not import google-cloud-storage python package.")
|
||||
from vertexai.preview.generative_models import Image
|
||||
|
||||
@@ -15,7 +15,6 @@ from typing import (
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
@@ -33,7 +32,6 @@ if TYPE_CHECKING:
|
||||
from astrapy.db import AstraDB as LibAstraDB
|
||||
from astrapy.db import AsyncAstraDB
|
||||
|
||||
ADBVST = TypeVar("ADBVST", bound="AstraDB")
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
DocDict = Dict[str, Any] # dicts expressing entries to insert
|
||||
@@ -1142,10 +1140,10 @@ class AstraDB(VectorStore):
|
||||
|
||||
@classmethod
|
||||
def _from_kwargs(
|
||||
cls: Type[ADBVST],
|
||||
cls,
|
||||
embedding: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> ADBVST:
|
||||
) -> "AstraDB":
|
||||
known_kwargs = {
|
||||
"collection_name",
|
||||
"token",
|
||||
@@ -1197,13 +1195,13 @@ class AstraDB(VectorStore):
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls: Type[ADBVST],
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> ADBVST:
|
||||
) -> "AstraDB":
|
||||
"""Create an Astra DB vectorstore from raw texts.
|
||||
|
||||
Args:
|
||||
@@ -1232,13 +1230,13 @@ class AstraDB(VectorStore):
|
||||
|
||||
@classmethod
|
||||
async def afrom_texts(
|
||||
cls: Type[ADBVST],
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> ADBVST:
|
||||
) -> "AstraDB":
|
||||
"""Create an Astra DB vectorstore from raw texts.
|
||||
|
||||
Args:
|
||||
@@ -1267,11 +1265,11 @@ class AstraDB(VectorStore):
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls: Type[ADBVST],
|
||||
cls,
|
||||
documents: List[Document],
|
||||
embedding: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> ADBVST:
|
||||
) -> "AstraDB":
|
||||
"""Create an Astra DB vectorstore from a document list.
|
||||
|
||||
Utility method that defers to 'from_texts' (see that one).
|
||||
|
||||
@@ -13,7 +13,7 @@ from langchain_core.vectorstores import VectorStore
|
||||
from langchain_community.utilities.vertexai import get_client_info
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google.cloud import storage
|
||||
from google.cloud import storage # type: ignore
|
||||
from google.cloud.aiplatform import MatchingEngineIndex, MatchingEngineIndexEndpoint
|
||||
from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import (
|
||||
Namespace,
|
||||
@@ -103,7 +103,7 @@ class MatchingEngine(VectorStore):
|
||||
def _validate_google_libraries_installation(self) -> None:
|
||||
"""Validates that Google libraries that are needed are installed."""
|
||||
try:
|
||||
from google.cloud import aiplatform, storage # noqa: F401
|
||||
from google.cloud import aiplatform, storage # type: ignore # noqa: F401
|
||||
from google.oauth2 import service_account # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
@@ -545,7 +545,7 @@ class MatchingEngine(VectorStore):
|
||||
A configured GCS client.
|
||||
"""
|
||||
|
||||
from google.cloud import storage
|
||||
from google.cloud import storage # type: ignore
|
||||
|
||||
return storage.Client(
|
||||
credentials=credentials,
|
||||
|
||||
@@ -6,8 +6,8 @@ from langchain_community.adapters import openai as lcopenai
|
||||
def _test_no_stream(**kwargs: Any) -> None:
|
||||
import openai
|
||||
|
||||
result = openai.ChatCompletion.create(**kwargs)
|
||||
lc_result = lcopenai.ChatCompletion.create(**kwargs)
|
||||
result = openai.ChatCompletion.create(**kwargs) # type: ignore
|
||||
lc_result = lcopenai.ChatCompletion.create(**kwargs) # type: ignore
|
||||
if isinstance(lc_result, dict):
|
||||
if isinstance(result, dict):
|
||||
result_dict = result["choices"][0]["message"].to_dict_recursive()
|
||||
@@ -20,11 +20,11 @@ def _test_stream(**kwargs: Any) -> None:
|
||||
import openai
|
||||
|
||||
result = []
|
||||
for c in openai.ChatCompletion.create(**kwargs):
|
||||
for c in openai.ChatCompletion.create(**kwargs): # type: ignore
|
||||
result.append(c["choices"][0]["delta"].to_dict_recursive())
|
||||
|
||||
lc_result = []
|
||||
for c in lcopenai.ChatCompletion.create(**kwargs):
|
||||
for c in lcopenai.ChatCompletion.create(**kwargs): # type: ignore
|
||||
lc_result.append(c["choices"][0]["delta"])
|
||||
assert result == lc_result
|
||||
|
||||
@@ -32,8 +32,8 @@ def _test_stream(**kwargs: Any) -> None:
|
||||
async def _test_async(**kwargs: Any) -> None:
|
||||
import openai
|
||||
|
||||
result = await openai.ChatCompletion.acreate(**kwargs)
|
||||
lc_result = await lcopenai.ChatCompletion.acreate(**kwargs)
|
||||
result = await openai.ChatCompletion.acreate(**kwargs) # type: ignore
|
||||
lc_result = await lcopenai.ChatCompletion.acreate(**kwargs) # type: ignore
|
||||
if isinstance(lc_result, dict):
|
||||
if isinstance(result, dict):
|
||||
result_dict = result["choices"][0]["message"].to_dict_recursive()
|
||||
@@ -46,11 +46,11 @@ async def _test_astream(**kwargs: Any) -> None:
|
||||
import openai
|
||||
|
||||
result = []
|
||||
async for c in await openai.ChatCompletion.acreate(**kwargs):
|
||||
async for c in await openai.ChatCompletion.acreate(**kwargs): # type: ignore
|
||||
result.append(c["choices"][0]["delta"].to_dict_recursive())
|
||||
|
||||
lc_result = []
|
||||
async for c in await lcopenai.ChatCompletion.acreate(**kwargs):
|
||||
async for c in await lcopenai.ChatCompletion.acreate(**kwargs): # type: ignore
|
||||
lc_result.append(c["choices"][0]["delta"])
|
||||
assert result == lc_result
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ def test_chat_openai() -> None:
|
||||
temperature=0.7,
|
||||
base_url=None,
|
||||
organization=None,
|
||||
openai_proxy=None,
|
||||
openai_proxy=None, # type: ignore
|
||||
timeout=10.0,
|
||||
max_retries=3,
|
||||
http_client=None,
|
||||
|
||||
@@ -20,7 +20,7 @@ def _get_embeddings(**kwargs: Any) -> AzureOpenAIEmbeddings:
|
||||
return AzureOpenAIEmbeddings(
|
||||
azure_deployment=DEPLOYMENT_NAME,
|
||||
api_version=OPENAI_API_VERSION,
|
||||
openai_api_base=OPENAI_API_BASE,
|
||||
openai_api_base=OPENAI_API_BASE, # type: ignore
|
||||
openai_api_key=OPENAI_API_KEY,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -104,7 +104,7 @@ def test_azure_openai_embedding_with_empty_string() -> None:
|
||||
output = embedding.embed_documents(document)
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) == 1536
|
||||
expected_output = openai.Embedding.create(input="", model="text-embedding-ada-002")[
|
||||
expected_output = openai.Embedding.create(input="", model="text-embedding-ada-002")[ # type: ignore
|
||||
"data"
|
||||
][0]["embedding"]
|
||||
assert np.allclose(output[0], expected_output)
|
||||
|
||||
@@ -70,7 +70,7 @@ def test_openai_embedding_with_empty_string() -> None:
|
||||
output = embedding.embed_documents(document)
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) == 1536
|
||||
expected_output = openai.Embedding.create(input="", model="text-embedding-ada-002")[
|
||||
expected_output = openai.Embedding.create(input="", model="text-embedding-ada-002")[ # type: ignore
|
||||
"data"
|
||||
][0]["embedding"]
|
||||
assert np.allclose(output[0], expected_output)
|
||||
|
||||
@@ -22,7 +22,7 @@ def _get_llm(**kwargs: Any) -> AzureOpenAI:
|
||||
return AzureOpenAI(
|
||||
deployment_name=DEPLOYMENT_NAME,
|
||||
openai_api_version=OPENAI_API_VERSION,
|
||||
openai_api_base=OPENAI_API_BASE,
|
||||
openai_api_base=OPENAI_API_BASE, # type: ignore
|
||||
openai_api_key=OPENAI_API_KEY,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
"""Implement integration tests for AstraDB storage."""
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_community.storage.astradb import AstraDBByteStore, AstraDBStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrapy.db import AstraDB
|
||||
|
||||
|
||||
def _has_env_vars() -> bool:
|
||||
return all(
|
||||
@@ -16,7 +20,7 @@ def _has_env_vars() -> bool:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def astra_db():
|
||||
def astra_db() -> AstraDB:
|
||||
from astrapy.db import AstraDB
|
||||
|
||||
return AstraDB(
|
||||
@@ -26,14 +30,14 @@ def astra_db():
|
||||
)
|
||||
|
||||
|
||||
def init_store(astra_db, collection_name: str):
|
||||
def init_store(astra_db: AstraDB, collection_name: str) -> AstraDBStore:
|
||||
astra_db.create_collection(collection_name)
|
||||
store = AstraDBStore(collection_name=collection_name, astra_db_client=astra_db)
|
||||
store.mset([("key1", [0.1, 0.2]), ("key2", "value2")])
|
||||
return store
|
||||
|
||||
|
||||
def init_bytestore(astra_db, collection_name: str):
|
||||
def init_bytestore(astra_db: AstraDB, collection_name: str) -> AstraDBByteStore:
|
||||
astra_db.create_collection(collection_name)
|
||||
store = AstraDBByteStore(collection_name=collection_name, astra_db_client=astra_db)
|
||||
store.mset([("key1", b"value1"), ("key2", b"value2")])
|
||||
@@ -43,7 +47,7 @@ def init_bytestore(astra_db, collection_name: str):
|
||||
@pytest.mark.requires("astrapy")
|
||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||
class TestAstraDBStore:
|
||||
def test_mget(self, astra_db) -> None:
|
||||
def test_mget(self, astra_db: AstraDB) -> None:
|
||||
"""Test AstraDBStore mget method."""
|
||||
collection_name = "lc_test_store_mget"
|
||||
try:
|
||||
@@ -52,7 +56,7 @@ class TestAstraDBStore:
|
||||
finally:
|
||||
astra_db.delete_collection(collection_name)
|
||||
|
||||
def test_mset(self, astra_db) -> None:
|
||||
def test_mset(self, astra_db: AstraDB) -> None:
|
||||
"""Test that multiple keys can be set with AstraDBStore."""
|
||||
collection_name = "lc_test_store_mset"
|
||||
try:
|
||||
@@ -64,7 +68,7 @@ class TestAstraDBStore:
|
||||
finally:
|
||||
astra_db.delete_collection(collection_name)
|
||||
|
||||
def test_mdelete(self, astra_db) -> None:
|
||||
def test_mdelete(self, astra_db: AstraDB) -> None:
|
||||
"""Test that deletion works as expected."""
|
||||
collection_name = "lc_test_store_mdelete"
|
||||
try:
|
||||
@@ -75,7 +79,7 @@ class TestAstraDBStore:
|
||||
finally:
|
||||
astra_db.delete_collection(collection_name)
|
||||
|
||||
def test_yield_keys(self, astra_db) -> None:
|
||||
def test_yield_keys(self, astra_db: AstraDB) -> None:
|
||||
collection_name = "lc_test_store_yield_keys"
|
||||
try:
|
||||
store = init_store(astra_db, collection_name)
|
||||
@@ -85,7 +89,7 @@ class TestAstraDBStore:
|
||||
finally:
|
||||
astra_db.delete_collection(collection_name)
|
||||
|
||||
def test_bytestore_mget(self, astra_db) -> None:
|
||||
def test_bytestore_mget(self, astra_db: AstraDB) -> None:
|
||||
"""Test AstraDBByteStore mget method."""
|
||||
collection_name = "lc_test_bytestore_mget"
|
||||
try:
|
||||
@@ -94,7 +98,7 @@ class TestAstraDBStore:
|
||||
finally:
|
||||
astra_db.delete_collection(collection_name)
|
||||
|
||||
def test_bytestore_mset(self, astra_db) -> None:
|
||||
def test_bytestore_mset(self, astra_db: AstraDB) -> None:
|
||||
"""Test that multiple keys can be set with AstraDBByteStore."""
|
||||
collection_name = "lc_test_bytestore_mset"
|
||||
try:
|
||||
|
||||
@@ -42,7 +42,7 @@ lint lint_diff lint_package lint_tests:
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES)
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES)
|
||||
[ "$(PYTHON_FILES)" = "" ] || mkdir $(MYPY_CACHE) || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
|
||||
format format_diff:
|
||||
poetry run ruff format $(PYTHON_FILES)
|
||||
|
||||
@@ -62,7 +62,7 @@ lint lint_diff lint_package lint_tests:
|
||||
poetry run ruff .
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES)
|
||||
[ "$(PYTHON_FILES)" = "" ] || mkdir $(MYPY_CACHE) || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
|
||||
format format_diff:
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES)
|
||||
|
||||
Reference in New Issue
Block a user