Compare commits

...

18 Commits

Author SHA1 Message Date
Erick Friis
058239cdd2 community only 2024-02-04 17:41:14 -08:00
Erick Friis
c732db1173 x 2024-02-04 17:39:22 -08:00
Erick Friis
c935b9e025 openai module done 2024-02-04 17:31:32 -08:00
Erick Friis
a7a1690845 line length 2024-02-04 17:29:46 -08:00
Erick Friis
5964630e7b fmt 2024-02-04 17:29:24 -08:00
Erick Friis
0903f1d3e0 openai Completion proxy api_base 2024-02-04 17:29:18 -08:00
Erick Friis
bd9d6d314a gcs storage 2024-02-04 17:29:08 -08:00
Erick Friis
f62ac4e98c openai ChatCompletion 2024-02-04 17:27:46 -08:00
Erick Friis
ace0a27473 x 2024-02-04 17:27:10 -08:00
Erick Friis
f451e2dfc2 x 2024-02-04 17:26:55 -08:00
Erick Friis
126cb87b30 ruff 2024-02-04 17:26:36 -08:00
Erick Friis
9979a36392 openai error and Embedding 2024-02-04 17:26:23 -08:00
Erick Friis
1484c356e9 x 2024-02-04 17:24:40 -08:00
Erick Friis
3c9e133c70 sql/base 2024-02-04 17:16:04 -08:00
Erick Friis
35e3d3fa39 astradb 2024-02-04 17:08:59 -08:00
Harrison Chase
ff2f6f398d cr 2024-02-04 15:42:43 -08:00
Harrison Chase
0ede50c344 add -p flag 2024-02-04 15:39:09 -08:00
Harrison Chase
983c2a2f88 debug stuff 2024-02-04 15:19:23 -08:00
34 changed files with 267 additions and 245 deletions

View File

@@ -41,7 +41,7 @@ lint lint_diff lint_package lint_tests:
poetry run ruff .
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
poetry run ruff format $(PYTHON_FILES)

View File

@@ -4,9 +4,17 @@ from __future__ import annotations
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union
from langchain_core.messages import AIMessage, SystemMessage
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.messages import (
AIMessage,
BaseMessage,
SystemMessage,
)
from langchain_core.prompts import (
BasePromptTemplate,
PromptTemplate,
)
from langchain_core.prompts.chat import (
BaseMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
@@ -132,6 +140,7 @@ def create_sql_agent(
toolkit = toolkit or SQLDatabaseToolkit(llm=llm, db=db)
agent_type = agent_type or AgentType.ZERO_SHOT_REACT_DESCRIPTION
tools = toolkit.get_tools() + list(extra_tools)
prefix = prefix or ""
if prompt is None:
prefix = prefix or SQL_PREFIX
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
@@ -152,6 +161,8 @@ def create_sql_agent(
tool for tool in tools if not isinstance(tool, ListSQLDatabaseTool)
]
messages: List[Union[BaseMessage, BaseMessagePromptTemplate]] = []
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
if prompt is None:
from langchain.agents.mrkl import prompt as react_prompt

View File

@@ -931,7 +931,7 @@ class CassandraCache(BaseCache):
self.table_name = table_name
self.ttl_seconds = ttl_seconds
self.kv_cache = ElasticCassandraTable(
self.kv_cache: Any = ElasticCassandraTable(
session=self.session,
keyspace=self.keyspace,
table=self.table_name,
@@ -1067,7 +1067,7 @@ class CassandraSemanticCache(BaseCache):
self._get_embedding = _cache_embedding
self.embedding_dimension = self._get_embedding_dimension()
self.table = MetadataVectorCassandraTable(
self.table: Any = MetadataVectorCassandraTable(
session=self.session,
keyspace=self.keyspace,
table=self.table_name,

View File

@@ -117,13 +117,13 @@ class ChatAnyscale(ChatOpenAI):
"ANYSCALE_API_KEY",
)
)
values["openai_api_base"] = get_from_dict_or_env(
values["openai_api_base"] = get_from_dict_or_env( # type: ignore
values,
"anyscale_api_base",
"ANYSCALE_API_BASE",
default=DEFAULT_API_BASE,
)
values["openai_proxy"] = get_from_dict_or_env(
values["openai_proxy"] = get_from_dict_or_env( # type: ignore
values,
"anyscale_proxy",
"ANYSCALE_PROXY",
@@ -141,7 +141,7 @@ class ChatAnyscale(ChatOpenAI):
if is_openai_v1():
client_params = {
"api_key": values["openai_api_key"],
"base_url": values["openai_api_base"],
"base_url": values["openai_api_base"], # type: ignore
# To do: future support
# "organization": values["openai_organization"],
# "timeout": values["request_timeout"],
@@ -152,7 +152,7 @@ class ChatAnyscale(ChatOpenAI):
}
values["client"] = openai.OpenAI(**client_params).chat.completions
else:
values["client"] = openai.ChatCompletion
values["client"] = openai.ChatCompletion # type: ignore
except AttributeError as exc:
raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely "
@@ -167,7 +167,7 @@ class ChatAnyscale(ChatOpenAI):
available_models = cls.get_available_models(
values["openai_api_key"],
values["openai_api_base"],
values["openai_api_base"], # type: ignore
)
if model_name not in available_models:

View File

@@ -96,7 +96,7 @@ class AzureChatOpenAI(ChatOpenAI):
openai_api_type: str = ""
"""Legacy, for openai<1.0.0 support."""
validate_base_url: bool = True
"""For backwards compatibility. If legacy val openai_api_base is passed in, try to
"""For backwards compatibility. If legacy val openai_api_base is passed in, try to
infer if it is a base_url or azure_endpoint and update accordingly.
"""
@@ -121,7 +121,7 @@ class AzureChatOpenAI(ChatOpenAI):
or os.getenv("AZURE_OPENAI_API_KEY")
or os.getenv("OPENAI_API_KEY")
)
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
values["openai_api_base"] = values["openai_api_base"] or os.getenv( # type: ignore
"OPENAI_API_BASE"
)
values["openai_api_version"] = values["openai_api_version"] or os.getenv(
@@ -143,8 +143,11 @@ class AzureChatOpenAI(ChatOpenAI):
values["openai_api_type"] = get_from_dict_or_env(
values, "openai_api_type", "OPENAI_API_TYPE", default="azure"
)
values["openai_proxy"] = get_from_dict_or_env(
values, "openai_proxy", "OPENAI_PROXY", default=""
values["openai_proxy"] = get_from_dict_or_env( # type: ignore
values,
"openai_proxy",
"OPENAI_PROXY",
default="", # type: ignore
)
try:
@@ -157,37 +160,37 @@ class AzureChatOpenAI(ChatOpenAI):
)
if is_openai_v1():
# For backwards compatibility. Before openai v1, no distinction was made
# between azure_endpoint and base_url (openai_api_base).
openai_api_base = values["openai_api_base"]
if openai_api_base and values["validate_base_url"]:
if "/openai" not in openai_api_base:
values["openai_api_base"] = (
values["openai_api_base"].rstrip("/") + "/openai"
# between azure_endpoint and base_url (openai_api_base). # type: ignore
openai_api_base = values["openai_api_base"] # type: ignore
if openai_api_base and values["validate_base_url"]: # type: ignore
if "/openai" not in openai_api_base: # type: ignore
values["openai_api_base"] = ( # type: ignore
values["openai_api_base"].rstrip("/") + "/openai" # type: ignore
)
warnings.warn(
"As of openai>=1.0.0, Azure endpoints should be specified via "
f"the `azure_endpoint` param not `openai_api_base` "
f"(or alias `base_url`). Updating `openai_api_base` from "
f"{openai_api_base} to {values['openai_api_base']}."
f"the `azure_endpoint` param not `openai_api_base` " # type: ignore
f"(or alias `base_url`). Updating `openai_api_base` from " # type: ignore
f"{openai_api_base} to {values['openai_api_base']}." # type: ignore
)
if values["deployment_name"]:
warnings.warn(
"As of openai>=1.0.0, if `deployment_name` (or alias "
"`azure_deployment`) is specified then "
"`openai_api_base` (or alias `base_url`) should not be. "
"`openai_api_base` (or alias `base_url`) should not be. " # type: ignore
"Instead use `deployment_name` (or alias `azure_deployment`) "
"and `azure_endpoint`."
)
if values["deployment_name"] not in values["openai_api_base"]:
if values["deployment_name"] not in values["openai_api_base"]: # type: ignore
warnings.warn(
"As of openai>=1.0.0, if `openai_api_base` "
"As of openai>=1.0.0, if `openai_api_base` " # type: ignore
"(or alias `base_url`) is specified it is expected to be "
"of the form "
"https://example-resource.azure.openai.com/openai/deployments/example-deployment. " # noqa: E501
f"Updating {openai_api_base} to "
f"{values['openai_api_base']}."
f"Updating {openai_api_base} to " # type: ignore
f"{values['openai_api_base']}." # type: ignore
)
values["openai_api_base"] += (
values["openai_api_base"] += ( # type: ignore
"/deployments/" + values["deployment_name"]
)
values["deployment_name"] = None
@@ -199,7 +202,7 @@ class AzureChatOpenAI(ChatOpenAI):
"azure_ad_token": values["azure_ad_token"],
"azure_ad_token_provider": values["azure_ad_token_provider"],
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"base_url": values["openai_api_base"], # type: ignore
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
"default_headers": values["default_headers"],
@@ -211,7 +214,7 @@ class AzureChatOpenAI(ChatOpenAI):
**client_params
).chat.completions
else:
values["client"] = openai.ChatCompletion
values["client"] = openai.ChatCompletion # type: ignore
return values
@property

View File

@@ -83,7 +83,7 @@ class ChatEverlyAI(ChatOpenAI):
"everlyai_api_key",
"EVERLYAI_API_KEY",
)
values["openai_api_base"] = DEFAULT_API_BASE
values["openai_api_base"] = DEFAULT_API_BASE # type: ignore
try:
import openai
@@ -94,7 +94,7 @@ class ChatEverlyAI(ChatOpenAI):
"Please install it with `pip install openai`.",
) from e
try:
values["client"] = openai.ChatCompletion
values["client"] = openai.ChatCompletion # type: ignore
except AttributeError as exc:
raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely "

View File

@@ -39,7 +39,11 @@ from langchain_community.adapters.openai import (
from langchain_community.chat_models.openai import _convert_delta_to_message_chunk
if TYPE_CHECKING:
from gpt_router.models import ChunkedGenerationResponse, GenerationResponse
from gpt_router.models import (
ChunkedGenerationResponse,
GenerationResponse,
ModelGenerationRequest,
)
logger = logging.getLogger(__name__)
@@ -57,8 +61,8 @@ class GPTRouterModel(BaseModel):
def get_ordered_generation_requests(
models_priority_list: List[GPTRouterModel], **kwargs
):
models_priority_list: List[GPTRouterModel], **kwargs: Any
) -> List[ModelGenerationRequest]:
"""
Return the body for the model router input.
"""
@@ -100,7 +104,7 @@ def completion_with_retry(
models_priority_list: List[GPTRouterModel],
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse]]:
) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse, None, None]]:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@@ -122,7 +126,7 @@ async def acompletion_with_retry(
models_priority_list: List[GPTRouterModel],
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse]]:
) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse, None]]:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@@ -283,8 +287,8 @@ class GPTRouter(BaseChatModel):
return self._create_chat_result(response)
def _create_chat_generation_chunk(
self, data: Mapping[str, Any], default_chunk_class
):
self, data: Mapping[str, Any], default_chunk_class: Any
) -> Tuple[ChatGenerationChunk, Any]:
chunk = _convert_delta_to_message_chunk(
{"content": data.get("text", "")}, default_chunk_class
)
@@ -293,8 +297,8 @@ class GPTRouter(BaseChatModel):
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
return chunk, default_chunk_class
chunk_ = ChatGenerationChunk(message=chunk, generation_info=generation_info)
return chunk_, default_chunk_class
def _stream(
self,

View File

@@ -68,11 +68,11 @@ def _create_retry_decorator(llm: JinaChat) -> Callable[[Any], Any]:
stop=stop_after_attempt(llm.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
retry_if_exception_type(openai.error.Timeout) # type: ignore
| retry_if_exception_type(openai.error.APIError) # type: ignore
| retry_if_exception_type(openai.error.APIConnectionError) # type: ignore
| retry_if_exception_type(openai.error.RateLimitError) # type: ignore
| retry_if_exception_type(openai.error.ServiceUnavailableError) # type: ignore
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
@@ -232,7 +232,7 @@ class JinaChat(BaseChatModel):
"Please install it with `pip install openai`."
)
try:
values["client"] = openai.ChatCompletion
values["client"] = openai.ChatCompletion # type: ignore
except AttributeError:
raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely "
@@ -264,11 +264,11 @@ class JinaChat(BaseChatModel):
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
retry_if_exception_type(openai.error.Timeout) # type: ignore
| retry_if_exception_type(openai.error.APIError) # type: ignore
| retry_if_exception_type(openai.error.APIConnectionError) # type: ignore
| retry_if_exception_type(openai.error.RateLimitError) # type: ignore
| retry_if_exception_type(openai.error.ServiceUnavailableError) # type: ignore
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)

View File

@@ -84,11 +84,11 @@ def _create_retry_decorator(
import openai
errors = [
openai.error.Timeout,
openai.error.APIError,
openai.error.APIConnectionError,
openai.error.RateLimitError,
openai.error.ServiceUnavailableError,
openai.error.Timeout, # type: ignore
openai.error.APIError, # type: ignore
openai.error.APIConnectionError, # type: ignore
openai.error.RateLimitError, # type: ignore
openai.error.ServiceUnavailableError, # type: ignore
]
return create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
@@ -179,11 +179,11 @@ class ChatOpenAI(BaseChatModel):
if self.openai_organization:
attributes["openai_organization"] = self.openai_organization
if self.openai_api_base:
attributes["openai_api_base"] = self.openai_api_base
if self.openai_api_base: # type: ignore
attributes["openai_api_base"] = self.openai_api_base # type: ignore
if self.openai_proxy:
attributes["openai_proxy"] = self.openai_proxy
if self.openai_proxy: # type: ignore
attributes["openai_proxy"] = self.openai_proxy # type: ignore
return attributes
@@ -205,13 +205,13 @@ class ChatOpenAI(BaseChatModel):
# may assume openai_api_key is a str)
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
openai_api_base: Optional[str] = Field(default=None, alias="base_url") # type: ignore
"""Base URL path for API requests, leave blank if not using a proxy or service
emulator."""
openai_organization: Optional[str] = Field(default=None, alias="organization")
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
# to support explicit proxy for OpenAI
openai_proxy: Optional[str] = None
openai_proxy: Optional[str] = None # type: ignore
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
default=None, alias="timeout"
)
@@ -290,12 +290,12 @@ class ChatOpenAI(BaseChatModel):
or os.getenv("OPENAI_ORG_ID")
or os.getenv("OPENAI_ORGANIZATION")
)
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
values["openai_api_base"] = values["openai_api_base"] or os.getenv( # type: ignore
"OPENAI_API_BASE"
)
values["openai_proxy"] = get_from_dict_or_env(
values["openai_proxy"] = get_from_dict_or_env( # type: ignore
values,
"openai_proxy",
"openai_proxy", # type: ignore
"OPENAI_PROXY",
default="",
)
@@ -312,7 +312,7 @@ class ChatOpenAI(BaseChatModel):
client_params = {
"api_key": values["openai_api_key"],
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"base_url": values["openai_api_base"], # type: ignore
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
"default_headers": values["default_headers"],
@@ -327,7 +327,7 @@ class ChatOpenAI(BaseChatModel):
**client_params
).chat.completions
elif not values.get("client"):
values["client"] = openai.ChatCompletion
values["client"] = openai.ChatCompletion # type: ignore
else:
pass
return values
@@ -547,14 +547,14 @@ class ChatOpenAI(BaseChatModel):
openai_creds.update(
{
"api_key": self.openai_api_key,
"api_base": self.openai_api_base,
"api_base": self.openai_api_base, # type: ignore
"organization": self.openai_organization,
}
)
if self.openai_proxy:
if self.openai_proxy: # type: ignore
import openai
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # noqa: E501
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore # noqa: E501
return {**self._default_params, **openai_creds}
def _get_invocation_params(

View File

@@ -35,7 +35,7 @@ class GCSDirectoryLoader(BaseLoader):
def load(self) -> List[Document]:
"""Load documents."""
try:
from google.cloud import storage
from google.cloud import storage # type: ignore
except ImportError:
raise ImportError(
"Could not import google-cloud-storage python package. "

View File

@@ -51,7 +51,7 @@ class GCSFileLoader(BaseLoader):
def load(self) -> List[Document]:
"""Load documents."""
try:
from google.cloud import storage
from google.cloud import storage # type: ignore
except ImportError:
raise ImportError(
"Could not import google-cloud-storage python package. "

View File

@@ -73,7 +73,7 @@ class OpenAIWhisperParser(BaseBlobParser):
model="whisper-1", file=file_obj
)
else:
transcript = openai.Audio.transcribe("whisper-1", file_obj)
transcript = openai.Audio.transcribe("whisper-1", file_obj) # type: ignore
break
except Exception as e:
attempts += 1

View File

@@ -27,7 +27,7 @@ class GoogleTranslateTransformer(BaseDocumentTransformer):
"""
try:
from google.api_core.client_options import ClientOptions
from google.cloud import translate
from google.cloud import translate # type: ignore
except ImportError as exc:
raise ImportError(
"Install Google Cloud Translate to use this parser."
@@ -70,7 +70,7 @@ class GoogleTranslateTransformer(BaseDocumentTransformer):
Options: `text/plain`, `text/html`
"""
try:
from google.cloud import translate
from google.cloud import translate # type: ignore
except ImportError as exc:
raise ImportError(
"Install Google Cloud Translate to use this parser."

View File

@@ -64,7 +64,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
or os.getenv("AZURE_OPENAI_API_KEY")
or os.getenv("OPENAI_API_KEY")
)
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
values["openai_api_base"] = values["openai_api_base"] or os.getenv( # type: ignore
"OPENAI_API_BASE"
)
values["openai_api_version"] = values["openai_api_version"] or os.getenv(
@@ -78,9 +78,9 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
or os.getenv("OPENAI_ORG_ID")
or os.getenv("OPENAI_ORGANIZATION")
)
values["openai_proxy"] = get_from_dict_or_env(
values["openai_proxy"] = get_from_dict_or_env( # type: ignore
values,
"openai_proxy",
"openai_proxy", # type: ignore
"OPENAI_PROXY",
default="",
)
@@ -104,35 +104,35 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
)
if is_openai_v1():
# For backwards compatibility. Before openai v1, no distinction was made
# between azure_endpoint and base_url (openai_api_base).
openai_api_base = values["openai_api_base"]
if openai_api_base and values["validate_base_url"]:
if "/openai" not in openai_api_base:
values["openai_api_base"] += "/openai"
# between azure_endpoint and base_url (openai_api_base). # type: ignore
openai_api_base = values["openai_api_base"] # type: ignore
if openai_api_base and values["validate_base_url"]: # type: ignore
if "/openai" not in openai_api_base: # type: ignore
values["openai_api_base"] += "/openai" # type: ignore
warnings.warn(
"As of openai>=1.0.0, Azure endpoints should be specified via "
f"the `azure_endpoint` param not `openai_api_base` "
f"(or alias `base_url`). Updating `openai_api_base` from "
f"{openai_api_base} to {values['openai_api_base']}."
f"the `azure_endpoint` param not `openai_api_base` " # type: ignore
f"(or alias `base_url`). Updating `openai_api_base` from " # type: ignore
f"{openai_api_base} to {values['openai_api_base']}." # type: ignore
)
if values["deployment"]:
warnings.warn(
"As of openai>=1.0.0, if `deployment` (or alias "
"`azure_deployment`) is specified then "
"`openai_api_base` (or alias `base_url`) should not be. "
"`openai_api_base` (or alias `base_url`) should not be. " # type: ignore
"Instead use `deployment` (or alias `azure_deployment`) "
"and `azure_endpoint`."
)
if values["deployment"] not in values["openai_api_base"]:
if values["deployment"] not in values["openai_api_base"]: # type: ignore
warnings.warn(
"As of openai>=1.0.0, if `openai_api_base` "
"As of openai>=1.0.0, if `openai_api_base` " # type: ignore
"(or alias `base_url`) is specified it is expected to be "
"of the form "
"https://example-resource.azure.openai.com/openai/deployments/example-deployment. " # noqa: E501
f"Updating {openai_api_base} to "
f"{values['openai_api_base']}."
f"Updating {openai_api_base} to " # type: ignore
f"{values['openai_api_base']}." # type: ignore
)
values["openai_api_base"] += (
values["openai_api_base"] += ( # type: ignore
"/deployments/" + values["deployment"]
)
values["deployment"] = None
@@ -144,7 +144,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
"azure_ad_token": values["azure_ad_token"],
"azure_ad_token_provider": values["azure_ad_token_provider"],
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"base_url": values["openai_api_base"], # type: ignore
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
"default_headers": values["default_headers"],
@@ -154,7 +154,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
values["client"] = openai.AzureOpenAI(**client_params).embeddings
values["async_client"] = openai.AsyncAzureOpenAI(**client_params).embeddings
else:
values["client"] = openai.Embedding
values["client"] = openai.Embedding # type: ignore
return values
@property

View File

@@ -42,11 +42,11 @@ def _create_retry_decorator(embeddings: LocalAIEmbeddings) -> Callable[[Any], An
stop=stop_after_attempt(embeddings.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
retry_if_exception_type(openai.error.Timeout) # type: ignore
| retry_if_exception_type(openai.error.APIError) # type: ignore
| retry_if_exception_type(openai.error.APIConnectionError) # type: ignore
| retry_if_exception_type(openai.error.RateLimitError) # type: ignore
| retry_if_exception_type(openai.error.ServiceUnavailableError) # type: ignore
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
@@ -64,11 +64,11 @@ def _async_retry_decorator(embeddings: LocalAIEmbeddings) -> Any:
stop=stop_after_attempt(embeddings.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
retry_if_exception_type(openai.error.Timeout) # type: ignore
| retry_if_exception_type(openai.error.APIError) # type: ignore
| retry_if_exception_type(openai.error.APIConnectionError) # type: ignore
| retry_if_exception_type(openai.error.RateLimitError) # type: ignore
| retry_if_exception_type(openai.error.ServiceUnavailableError) # type: ignore
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
@@ -89,7 +89,7 @@ def _check_response(response: dict) -> dict:
if any(len(d["embedding"]) == 1 for d in response["data"]):
import openai
raise openai.error.APIError("LocalAI API returned an empty embedding")
raise openai.error.APIError("LocalAI API returned an empty embedding") # type: ignore
return response
@@ -132,7 +132,7 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
from langchain_community.embeddings import LocalAIEmbeddings
openai = LocalAIEmbeddings(
openai_api_key="random-string",
openai_api_base="http://localhost:8080"
openai_api_base="http://localhost:8080" # type: ignore
)
"""
@@ -141,9 +141,9 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
model: str = "text-embedding-ada-002"
deployment: str = model
openai_api_version: Optional[str] = None
openai_api_base: Optional[str] = None
openai_api_base: Optional[str] = None # type: ignore
# to support explicit proxy for LocalAI
openai_proxy: Optional[str] = None
openai_proxy: Optional[str] = None # type: ignore
embedding_ctx_length: int = 8191
"""The maximum number of tokens to embed at once."""
openai_api_key: Optional[str] = None
@@ -199,15 +199,15 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
values["openai_api_key"] = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
values["openai_api_base"] = get_from_dict_or_env(
values["openai_api_base"] = get_from_dict_or_env( # type: ignore
values,
"openai_api_base",
"openai_api_base", # type: ignore
"OPENAI_API_BASE",
default="",
)
values["openai_proxy"] = get_from_dict_or_env(
values["openai_proxy"] = get_from_dict_or_env( # type: ignore
values,
"openai_proxy",
"openai_proxy", # type: ignore
"OPENAI_PROXY",
default="",
)
@@ -228,7 +228,7 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
try:
import openai
values["client"] = openai.Embedding
values["client"] = openai.Embedding # type: ignore
except ImportError:
raise ImportError(
"Could not import openai python package. "
@@ -244,16 +244,16 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
"headers": self.headers,
"api_key": self.openai_api_key,
"organization": self.openai_organization,
"api_base": self.openai_api_base,
"api_base": self.openai_api_base, # type: ignore
"api_version": self.openai_api_version,
**self.model_kwargs,
}
if self.openai_proxy:
if self.openai_proxy: # type: ignore
import openai
openai.proxy = {
"http": self.openai_proxy,
"https": self.openai_proxy,
openai.proxy = { # type: ignore
"http": self.openai_proxy, # type: ignore
"https": self.openai_proxy, # type: ignore
} # type: ignore[assignment] # noqa: E501
return openai_args

View File

@@ -54,11 +54,11 @@ def _create_retry_decorator(embeddings: OpenAIEmbeddings) -> Callable[[Any], Any
max=embeddings.retry_max_seconds,
),
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
retry_if_exception_type(openai.error.Timeout) # type: ignore
| retry_if_exception_type(openai.error.APIError) # type: ignore
| retry_if_exception_type(openai.error.APIConnectionError) # type: ignore
| retry_if_exception_type(openai.error.RateLimitError) # type: ignore
| retry_if_exception_type(openai.error.ServiceUnavailableError) # type: ignore
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
@@ -81,11 +81,11 @@ def _async_retry_decorator(embeddings: OpenAIEmbeddings) -> Any:
max=embeddings.retry_max_seconds,
),
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
retry_if_exception_type(openai.error.Timeout) # type: ignore
| retry_if_exception_type(openai.error.APIError) # type: ignore
| retry_if_exception_type(openai.error.APIConnectionError) # type: ignore
| retry_if_exception_type(openai.error.RateLimitError) # type: ignore
| retry_if_exception_type(openai.error.ServiceUnavailableError) # type: ignore
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
@@ -106,7 +106,7 @@ def _check_response(response: dict, skip_empty: bool = False) -> dict:
if any(len(d["embedding"]) == 1 for d in response["data"]) and not skip_empty:
import openai
raise openai.error.APIError("OpenAI API returned an empty embedding")
raise openai.error.APIError("OpenAI API returned an empty embedding") # type: ignore
return response
@@ -194,13 +194,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
openai_api_version: Optional[str] = Field(default=None, alias="api_version")
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
# to support Azure OpenAI Service custom endpoints
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
openai_api_base: Optional[str] = Field(default=None, alias="base_url") # type: ignore
"""Base URL path for API requests, leave blank if not using a proxy or service
emulator."""
# to support Azure OpenAI Service custom endpoints
openai_api_type: Optional[str] = None
# to support explicit proxy for OpenAI
openai_proxy: Optional[str] = None
openai_proxy: Optional[str] = None # type: ignore
embedding_ctx_length: int = 8191
"""The maximum number of tokens to embed at once."""
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
@@ -288,7 +288,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
values["openai_api_key"] = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
values["openai_api_base"] = values["openai_api_base"] or os.getenv( # type: ignore
"OPENAI_API_BASE"
)
values["openai_api_type"] = get_from_dict_or_env(
@@ -297,9 +297,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"OPENAI_API_TYPE",
default="",
)
values["openai_proxy"] = get_from_dict_or_env(
values["openai_proxy"] = get_from_dict_or_env( # type: ignore
values,
"openai_proxy",
"openai_proxy", # type: ignore
"OPENAI_PROXY",
default="",
)
@@ -340,7 +340,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
client_params = {
"api_key": values["openai_api_key"],
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"base_url": values["openai_api_base"], # type: ignore
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
"default_headers": values["default_headers"],
@@ -354,7 +354,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
**client_params
).embeddings
elif not values.get("client"):
values["client"] = openai.Embedding
values["client"] = openai.Embedding # type: ignore
else:
pass
return values
@@ -370,7 +370,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"headers": self.headers,
"api_key": self.openai_api_key,
"organization": self.openai_organization,
"api_base": self.openai_api_base,
"api_base": self.openai_api_base, # type: ignore
"api_type": self.openai_api_type,
"api_version": self.openai_api_version,
**self.model_kwargs,
@@ -378,7 +378,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
if self.openai_api_type in ("azure", "azure_ad", "azuread"):
openai_args["engine"] = self.deployment
# TODO: Look into proxy with openai v1.
if self.openai_proxy:
if self.openai_proxy: # type: ignore
try:
import openai
except ImportError:
@@ -387,9 +387,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"Please install it with `pip install openai`."
)
openai.proxy = {
"http": self.openai_proxy,
"https": self.openai_proxy,
openai.proxy = { # type: ignore
"http": self.openai_proxy, # type: ignore
"https": self.openai_proxy, # type: ignore
} # type: ignore[assignment] # noqa: E501
return openai_args

View File

@@ -107,7 +107,7 @@ class Anyscale(BaseOpenAI):
import openai
## Always create ChatComplete client, replacing the legacy Complete client
values["client"] = openai.ChatCompletion
values["client"] = openai.ChatCompletion # type: ignore
except ImportError:
raise ImportError(
"Could not import openai python package. "

View File

@@ -97,8 +97,8 @@ class GooseAI(LLM):
import openai
openai.api_key = gooseai_api_key.get_secret_value()
openai.api_base = "https://api.goose.ai/v1"
values["client"] = openai.Completion
openai.api_base = "https://api.goose.ai/v1" # type: ignore
values["client"] = openai.Completion # type: ignore
except ImportError:
raise ImportError(
"Could not import openai python package. "

View File

@@ -94,11 +94,11 @@ def _create_retry_decorator(
import openai
errors = [
openai.error.Timeout,
openai.error.APIError,
openai.error.APIConnectionError,
openai.error.RateLimitError,
openai.error.ServiceUnavailableError,
openai.error.Timeout, # type: ignore
openai.error.APIError, # type: ignore
openai.error.APIConnectionError, # type: ignore
openai.error.RateLimitError, # type: ignore
openai.error.ServiceUnavailableError, # type: ignore
]
return create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
@@ -157,14 +157,14 @@ class BaseOpenAI(BaseLLM):
@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}
if self.openai_api_base:
attributes["openai_api_base"] = self.openai_api_base
if self.openai_api_base: # type: ignore
attributes["openai_api_base"] = self.openai_api_base # type: ignore
if self.openai_organization:
attributes["openai_organization"] = self.openai_organization
if self.openai_proxy:
attributes["openai_proxy"] = self.openai_proxy
if self.openai_proxy: # type: ignore
attributes["openai_proxy"] = self.openai_proxy # type: ignore
return attributes
@@ -199,13 +199,13 @@ class BaseOpenAI(BaseLLM):
# may assume openai_api_key is a str)
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
openai_api_base: Optional[str] = Field(default=None, alias="base_url") # type: ignore
"""Base URL path for API requests, leave blank if not using a proxy or service
emulator."""
openai_organization: Optional[str] = Field(default=None, alias="organization")
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
# to support explicit proxy for OpenAI
openai_proxy: Optional[str] = None
openai_proxy: Optional[str] = None # type: ignore
batch_size: int = 20
"""Batch size to use when passing multiple documents to generate."""
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
@@ -282,12 +282,12 @@ class BaseOpenAI(BaseLLM):
values["openai_api_key"] = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
values["openai_api_base"] = values["openai_api_base"] or os.getenv( # type: ignore
"OPENAI_API_BASE"
)
values["openai_proxy"] = get_from_dict_or_env(
values["openai_proxy"] = get_from_dict_or_env( # type: ignore
values,
"openai_proxy",
"openai_proxy", # type: ignore
"OPENAI_PROXY",
default="",
)
@@ -308,7 +308,7 @@ class BaseOpenAI(BaseLLM):
client_params = {
"api_key": values["openai_api_key"],
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"base_url": values["openai_api_base"], # type: ignore
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
"default_headers": values["default_headers"],
@@ -320,7 +320,7 @@ class BaseOpenAI(BaseLLM):
if not values.get("async_client"):
values["async_client"] = openai.AsyncOpenAI(**client_params).completions
elif not values.get("client"):
values["client"] = openai.Completion
values["client"] = openai.Completion # type: ignore
else:
pass
@@ -597,14 +597,14 @@ class BaseOpenAI(BaseLLM):
openai_creds.update(
{
"api_key": self.openai_api_key,
"api_base": self.openai_api_base,
"api_base": self.openai_api_base, # type: ignore
"organization": self.openai_organization,
}
)
if self.openai_proxy:
if self.openai_proxy: # type: ignore
import openai
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # noqa: E501
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore # noqa: E501
return {**openai_creds, **self._default_params}
@property
@@ -807,7 +807,7 @@ class AzureOpenAI(BaseOpenAI):
openai_api_type: str = ""
"""Legacy, for openai<1.0.0 support."""
validate_base_url: bool = True
"""For backwards compatibility. If legacy val openai_api_base is passed in, try to
"""For backwards compatibility. If legacy val openai_api_base is passed in, try to
infer if it is a base_url or azure_endpoint and update accordingly.
"""
@@ -841,12 +841,12 @@ class AzureOpenAI(BaseOpenAI):
values["azure_ad_token"] = values["azure_ad_token"] or os.getenv(
"AZURE_OPENAI_AD_TOKEN"
)
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
values["openai_api_base"] = values["openai_api_base"] or os.getenv( # type: ignore
"OPENAI_API_BASE"
)
values["openai_proxy"] = get_from_dict_or_env(
values["openai_proxy"] = get_from_dict_or_env( # type: ignore
values,
"openai_proxy",
"openai_proxy", # type: ignore
"OPENAI_PROXY",
default="",
)
@@ -870,37 +870,37 @@ class AzureOpenAI(BaseOpenAI):
)
if is_openai_v1():
# For backwards compatibility. Before openai v1, no distinction was made
# between azure_endpoint and base_url (openai_api_base).
openai_api_base = values["openai_api_base"]
if openai_api_base and values["validate_base_url"]:
if "/openai" not in openai_api_base:
values["openai_api_base"] = (
values["openai_api_base"].rstrip("/") + "/openai"
# between azure_endpoint and base_url (openai_api_base). # type: ignore
openai_api_base = values["openai_api_base"] # type: ignore
if openai_api_base and values["validate_base_url"]: # type: ignore
if "/openai" not in openai_api_base: # type: ignore
values["openai_api_base"] = ( # type: ignore
values["openai_api_base"].rstrip("/") + "/openai" # type: ignore
)
warnings.warn(
"As of openai>=1.0.0, Azure endpoints should be specified via "
f"the `azure_endpoint` param not `openai_api_base` "
f"(or alias `base_url`). Updating `openai_api_base` from "
f"{openai_api_base} to {values['openai_api_base']}."
f"the `azure_endpoint` param not `openai_api_base` " # type: ignore
f"(or alias `base_url`). Updating `openai_api_base` from " # type: ignore
f"{openai_api_base} to {values['openai_api_base']}." # type: ignore
)
if values["deployment_name"]:
warnings.warn(
"As of openai>=1.0.0, if `deployment_name` (or alias "
"`azure_deployment`) is specified then "
"`openai_api_base` (or alias `base_url`) should not be. "
"`openai_api_base` (or alias `base_url`) should not be. " # type: ignore
"Instead use `deployment_name` (or alias `azure_deployment`) "
"and `azure_endpoint`."
)
if values["deployment_name"] not in values["openai_api_base"]:
if values["deployment_name"] not in values["openai_api_base"]: # type: ignore
warnings.warn(
"As of openai>=1.0.0, if `openai_api_base` "
"As of openai>=1.0.0, if `openai_api_base` " # type: ignore
"(or alias `base_url`) is specified it is expected to be "
"of the form "
"https://example-resource.azure.openai.com/openai/deployments/example-deployment. " # noqa: E501
f"Updating {openai_api_base} to "
f"{values['openai_api_base']}."
f"Updating {openai_api_base} to " # type: ignore
f"{values['openai_api_base']}." # type: ignore
)
values["openai_api_base"] += (
values["openai_api_base"] += ( # type: ignore
"/deployments/" + values["deployment_name"]
)
values["deployment_name"] = None
@@ -912,7 +912,7 @@ class AzureOpenAI(BaseOpenAI):
"azure_ad_token": values["azure_ad_token"],
"azure_ad_token_provider": values["azure_ad_token_provider"],
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"base_url": values["openai_api_base"], # type: ignore
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
"default_headers": values["default_headers"],
@@ -925,7 +925,7 @@ class AzureOpenAI(BaseOpenAI):
).completions
else:
values["client"] = openai.Completion
values["client"] = openai.Completion # type: ignore
return values
@@ -993,11 +993,11 @@ class OpenAIChat(BaseLLM):
# may assume openai_api_key is a str)
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
openai_api_base: Optional[str] = Field(default=None, alias="base_url") # type: ignore
"""Base URL path for API requests, leave blank if not using a proxy or service
emulator."""
# to support explicit proxy for OpenAI
openai_proxy: Optional[str] = None
openai_proxy: Optional[str] = None # type: ignore
max_retries: int = 6
"""Maximum number of retries to make when generating."""
prefix_messages: List = Field(default_factory=list)
@@ -1029,15 +1029,15 @@ class OpenAIChat(BaseLLM):
openai_api_key = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
openai_api_base = get_from_dict_or_env(
openai_api_base = get_from_dict_or_env( # type: ignore
values,
"openai_api_base",
"openai_api_base", # type: ignore
"OPENAI_API_BASE",
default="",
)
openai_proxy = get_from_dict_or_env(
openai_proxy = get_from_dict_or_env( # type: ignore
values,
"openai_proxy",
"openai_proxy", # type: ignore
"OPENAI_PROXY",
default="",
)
@@ -1048,19 +1048,19 @@ class OpenAIChat(BaseLLM):
import openai
openai.api_key = openai_api_key
if openai_api_base:
openai.api_base = openai_api_base
if openai_api_base: # type: ignore
openai.api_base = openai_api_base # type: ignore
if openai_organization:
openai.organization = openai_organization
if openai_proxy:
openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501
if openai_proxy: # type: ignore
openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore # noqa: E501
except ImportError:
raise ImportError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
try:
values["client"] = openai.ChatCompletion
values["client"] = openai.ChatCompletion # type: ignore
except AttributeError:
raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely "

View File

@@ -164,7 +164,7 @@ class VLLMOpenAI(BaseOpenAI):
params.update(
{
"api_key": self.openai_api_key,
"api_base": self.openai_api_base,
"api_base": self.openai_api_base, # type: ignore
}
)

View File

@@ -51,6 +51,8 @@ class AmadeusClosestAirport(AmadeusBaseTool):
location: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
if not self.llm:
raise ValueError("No language model (llm) has been set for this tool.")
content = (
f" What is the nearest airport to {location}? Please respond with the "
" airport's International Air Transport Association (IATA) Location "

View File

@@ -9,12 +9,12 @@ from langchain_core.tools import BaseTool
from langchain_community.utilities.vertexai import get_client_info
if TYPE_CHECKING:
from google.cloud import texttospeech
from google.cloud import texttospeech # type: ignore
def _import_google_cloud_texttospeech() -> Any:
try:
from google.cloud import texttospeech
from google.cloud import texttospeech # type: ignore
except ImportError as e:
raise ImportError(
"Cannot import google.cloud.texttospeech, please install "

View File

@@ -31,13 +31,13 @@ class DallEAPIWrapper(BaseModel):
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
openai_api_base: Optional[str] = Field(default=None, alias="base_url") # type: ignore
"""Base URL path for API requests, leave blank if not using a proxy or service
emulator."""
openai_organization: Optional[str] = Field(default=None, alias="organization")
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
# to support explicit proxy for OpenAI
openai_proxy: Optional[str] = None
openai_proxy: Optional[str] = None # type: ignore
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
default=None, alias="timeout"
)
@@ -102,12 +102,12 @@ class DallEAPIWrapper(BaseModel):
or os.getenv("OPENAI_ORGANIZATION")
or None
)
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
values["openai_api_base"] = values["openai_api_base"] or os.getenv( # type: ignore
"OPENAI_API_BASE"
)
values["openai_proxy"] = get_from_dict_or_env(
values["openai_proxy"] = get_from_dict_or_env( # type: ignore
values,
"openai_proxy",
"openai_proxy", # type: ignore
"OPENAI_PROXY",
default="",
)
@@ -125,7 +125,7 @@ class DallEAPIWrapper(BaseModel):
client_params = {
"api_key": values["openai_api_key"],
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"base_url": values["openai_api_base"], # type: ignore
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
"default_headers": values["default_headers"],
@@ -138,7 +138,7 @@ class DallEAPIWrapper(BaseModel):
if not values.get("async_client"):
values["async_client"] = openai.AsyncOpenAI(**client_params).images
elif not values.get("client"):
values["client"] = openai.Image
values["client"] = openai.Image # type: ignore
else:
pass
return values

View File

@@ -111,7 +111,7 @@ def get_client_info(module: Optional[str] = None) -> "ClientInfo":
def load_image_from_gcs(path: str, project: Optional[str] = None) -> "Image":
"""Loads im Image from GCS."""
try:
from google.cloud import storage
from google.cloud import storage # type: ignore
except ImportError:
raise ImportError("Could not import google-cloud-storage python package.")
from vertexai.preview.generative_models import Image

View File

@@ -15,7 +15,6 @@ from typing import (
Optional,
Set,
Tuple,
Type,
TypeVar,
)
@@ -33,7 +32,6 @@ if TYPE_CHECKING:
from astrapy.db import AstraDB as LibAstraDB
from astrapy.db import AsyncAstraDB
ADBVST = TypeVar("ADBVST", bound="AstraDB")
T = TypeVar("T")
U = TypeVar("U")
DocDict = Dict[str, Any] # dicts expressing entries to insert
@@ -1142,10 +1140,10 @@ class AstraDB(VectorStore):
@classmethod
def _from_kwargs(
cls: Type[ADBVST],
cls,
embedding: Embeddings,
**kwargs: Any,
) -> ADBVST:
) -> "AstraDB":
known_kwargs = {
"collection_name",
"token",
@@ -1197,13 +1195,13 @@ class AstraDB(VectorStore):
@classmethod
def from_texts(
cls: Type[ADBVST],
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> ADBVST:
) -> "AstraDB":
"""Create an Astra DB vectorstore from raw texts.
Args:
@@ -1232,13 +1230,13 @@ class AstraDB(VectorStore):
@classmethod
async def afrom_texts(
cls: Type[ADBVST],
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> ADBVST:
) -> "AstraDB":
"""Create an Astra DB vectorstore from raw texts.
Args:
@@ -1267,11 +1265,11 @@ class AstraDB(VectorStore):
@classmethod
def from_documents(
cls: Type[ADBVST],
cls,
documents: List[Document],
embedding: Embeddings,
**kwargs: Any,
) -> ADBVST:
) -> "AstraDB":
"""Create an Astra DB vectorstore from a document list.
Utility method that defers to 'from_texts' (see that one).

View File

@@ -13,7 +13,7 @@ from langchain_core.vectorstores import VectorStore
from langchain_community.utilities.vertexai import get_client_info
if TYPE_CHECKING:
from google.cloud import storage
from google.cloud import storage # type: ignore
from google.cloud.aiplatform import MatchingEngineIndex, MatchingEngineIndexEndpoint
from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import (
Namespace,
@@ -103,7 +103,7 @@ class MatchingEngine(VectorStore):
def _validate_google_libraries_installation(self) -> None:
"""Validates that Google libraries that are needed are installed."""
try:
from google.cloud import aiplatform, storage # noqa: F401
from google.cloud import aiplatform, storage # type: ignore # noqa: F401
from google.oauth2 import service_account # noqa: F401
except ImportError:
raise ImportError(
@@ -545,7 +545,7 @@ class MatchingEngine(VectorStore):
A configured GCS client.
"""
from google.cloud import storage
from google.cloud import storage # type: ignore
return storage.Client(
credentials=credentials,

View File

@@ -6,8 +6,8 @@ from langchain_community.adapters import openai as lcopenai
def _test_no_stream(**kwargs: Any) -> None:
import openai
result = openai.ChatCompletion.create(**kwargs)
lc_result = lcopenai.ChatCompletion.create(**kwargs)
result = openai.ChatCompletion.create(**kwargs) # type: ignore
lc_result = lcopenai.ChatCompletion.create(**kwargs) # type: ignore
if isinstance(lc_result, dict):
if isinstance(result, dict):
result_dict = result["choices"][0]["message"].to_dict_recursive()
@@ -20,11 +20,11 @@ def _test_stream(**kwargs: Any) -> None:
import openai
result = []
for c in openai.ChatCompletion.create(**kwargs):
for c in openai.ChatCompletion.create(**kwargs): # type: ignore
result.append(c["choices"][0]["delta"].to_dict_recursive())
lc_result = []
for c in lcopenai.ChatCompletion.create(**kwargs):
for c in lcopenai.ChatCompletion.create(**kwargs): # type: ignore
lc_result.append(c["choices"][0]["delta"])
assert result == lc_result
@@ -32,8 +32,8 @@ def _test_stream(**kwargs: Any) -> None:
async def _test_async(**kwargs: Any) -> None:
import openai
result = await openai.ChatCompletion.acreate(**kwargs)
lc_result = await lcopenai.ChatCompletion.acreate(**kwargs)
result = await openai.ChatCompletion.acreate(**kwargs) # type: ignore
lc_result = await lcopenai.ChatCompletion.acreate(**kwargs) # type: ignore
if isinstance(lc_result, dict):
if isinstance(result, dict):
result_dict = result["choices"][0]["message"].to_dict_recursive()
@@ -46,11 +46,11 @@ async def _test_astream(**kwargs: Any) -> None:
import openai
result = []
async for c in await openai.ChatCompletion.acreate(**kwargs):
async for c in await openai.ChatCompletion.acreate(**kwargs): # type: ignore
result.append(c["choices"][0]["delta"].to_dict_recursive())
lc_result = []
async for c in await lcopenai.ChatCompletion.acreate(**kwargs):
async for c in await lcopenai.ChatCompletion.acreate(**kwargs): # type: ignore
lc_result.append(c["choices"][0]["delta"])
assert result == lc_result

View File

@@ -23,7 +23,7 @@ def test_chat_openai() -> None:
temperature=0.7,
base_url=None,
organization=None,
openai_proxy=None,
openai_proxy=None, # type: ignore
timeout=10.0,
max_retries=3,
http_client=None,

View File

@@ -20,7 +20,7 @@ def _get_embeddings(**kwargs: Any) -> AzureOpenAIEmbeddings:
return AzureOpenAIEmbeddings(
azure_deployment=DEPLOYMENT_NAME,
api_version=OPENAI_API_VERSION,
openai_api_base=OPENAI_API_BASE,
openai_api_base=OPENAI_API_BASE, # type: ignore
openai_api_key=OPENAI_API_KEY,
**kwargs,
)
@@ -104,7 +104,7 @@ def test_azure_openai_embedding_with_empty_string() -> None:
output = embedding.embed_documents(document)
assert len(output) == 2
assert len(output[0]) == 1536
expected_output = openai.Embedding.create(input="", model="text-embedding-ada-002")[
expected_output = openai.Embedding.create(input="", model="text-embedding-ada-002")[ # type: ignore
"data"
][0]["embedding"]
assert np.allclose(output[0], expected_output)

View File

@@ -70,7 +70,7 @@ def test_openai_embedding_with_empty_string() -> None:
output = embedding.embed_documents(document)
assert len(output) == 2
assert len(output[0]) == 1536
expected_output = openai.Embedding.create(input="", model="text-embedding-ada-002")[
expected_output = openai.Embedding.create(input="", model="text-embedding-ada-002")[ # type: ignore
"data"
][0]["embedding"]
assert np.allclose(output[0], expected_output)

View File

@@ -22,7 +22,7 @@ def _get_llm(**kwargs: Any) -> AzureOpenAI:
return AzureOpenAI(
deployment_name=DEPLOYMENT_NAME,
openai_api_version=OPENAI_API_VERSION,
openai_api_base=OPENAI_API_BASE,
openai_api_base=OPENAI_API_BASE, # type: ignore
openai_api_key=OPENAI_API_KEY,
**kwargs,
)

View File

@@ -1,10 +1,14 @@
"""Implement integration tests for AstraDB storage."""
import os
from typing import TYPE_CHECKING
import pytest
from langchain_community.storage.astradb import AstraDBByteStore, AstraDBStore
if TYPE_CHECKING:
from astrapy.db import AstraDB
def _has_env_vars() -> bool:
return all(
@@ -16,7 +20,7 @@ def _has_env_vars() -> bool:
@pytest.fixture
def astra_db():
def astra_db() -> AstraDB:
from astrapy.db import AstraDB
return AstraDB(
@@ -26,14 +30,14 @@ def astra_db():
)
def init_store(astra_db, collection_name: str):
def init_store(astra_db: AstraDB, collection_name: str) -> AstraDBStore:
astra_db.create_collection(collection_name)
store = AstraDBStore(collection_name=collection_name, astra_db_client=astra_db)
store.mset([("key1", [0.1, 0.2]), ("key2", "value2")])
return store
def init_bytestore(astra_db, collection_name: str):
def init_bytestore(astra_db: AstraDB, collection_name: str) -> AstraDBByteStore:
astra_db.create_collection(collection_name)
store = AstraDBByteStore(collection_name=collection_name, astra_db_client=astra_db)
store.mset([("key1", b"value1"), ("key2", b"value2")])
@@ -43,7 +47,7 @@ def init_bytestore(astra_db, collection_name: str):
@pytest.mark.requires("astrapy")
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
class TestAstraDBStore:
def test_mget(self, astra_db) -> None:
def test_mget(self, astra_db: AstraDB) -> None:
"""Test AstraDBStore mget method."""
collection_name = "lc_test_store_mget"
try:
@@ -52,7 +56,7 @@ class TestAstraDBStore:
finally:
astra_db.delete_collection(collection_name)
def test_mset(self, astra_db) -> None:
def test_mset(self, astra_db: AstraDB) -> None:
"""Test that multiple keys can be set with AstraDBStore."""
collection_name = "lc_test_store_mset"
try:
@@ -64,7 +68,7 @@ class TestAstraDBStore:
finally:
astra_db.delete_collection(collection_name)
def test_mdelete(self, astra_db) -> None:
def test_mdelete(self, astra_db: AstraDB) -> None:
"""Test that deletion works as expected."""
collection_name = "lc_test_store_mdelete"
try:
@@ -75,7 +79,7 @@ class TestAstraDBStore:
finally:
astra_db.delete_collection(collection_name)
def test_yield_keys(self, astra_db) -> None:
def test_yield_keys(self, astra_db: AstraDB) -> None:
collection_name = "lc_test_store_yield_keys"
try:
store = init_store(astra_db, collection_name)
@@ -85,7 +89,7 @@ class TestAstraDBStore:
finally:
astra_db.delete_collection(collection_name)
def test_bytestore_mget(self, astra_db) -> None:
def test_bytestore_mget(self, astra_db: AstraDB) -> None:
"""Test AstraDBByteStore mget method."""
collection_name = "lc_test_bytestore_mget"
try:
@@ -94,7 +98,7 @@ class TestAstraDBStore:
finally:
astra_db.delete_collection(collection_name)
def test_bytestore_mset(self, astra_db) -> None:
def test_bytestore_mset(self, astra_db: AstraDB) -> None:
"""Test that multiple keys can be set with AstraDBByteStore."""
collection_name = "lc_test_bytestore_mset"
try:

View File

@@ -42,7 +42,7 @@ lint lint_diff lint_package lint_tests:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || mkdir $(MYPY_CACHE) || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
poetry run ruff format $(PYTHON_FILES)

View File

@@ -62,7 +62,7 @@ lint lint_diff lint_package lint_tests:
poetry run ruff .
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || mkdir $(MYPY_CACHE) || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES)