From 7ce338201cc4bc0cba5f387d613582114217d5c8 Mon Sep 17 00:00:00 2001 From: chyroc Date: Sat, 30 Dec 2023 05:44:19 +0800 Subject: [PATCH] Patch: improve check openai version (#15301) --- .../langchain_community/embeddings/openai.py | 17 ++++++----------- libs/core/langchain_core/runnables/config.py | 2 +- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/libs/community/langchain_community/embeddings/openai.py b/libs/community/langchain_community/embeddings/openai.py index a44b20527e1..99f41e62856 100644 --- a/libs/community/langchain_community/embeddings/openai.py +++ b/libs/community/langchain_community/embeddings/openai.py @@ -3,7 +3,6 @@ from __future__ import annotations import logging import os import warnings -from importlib.metadata import version from typing import ( Any, Callable, @@ -23,7 +22,6 @@ import numpy as np from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names -from packaging.version import Version, parse from tenacity import ( AsyncRetrying, before_sleep_log, @@ -33,6 +31,8 @@ from tenacity import ( wait_exponential, ) +from langchain_community.utils.openai import is_openai_v1 + logger = logging.getLogger(__name__) @@ -111,7 +111,7 @@ def _check_response(response: dict, skip_empty: bool = False) -> dict: def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any: """Use tenacity to retry the embedding call.""" - if _is_openai_v1(): + if is_openai_v1(): return embeddings.client.create(**kwargs) retry_decorator = _create_retry_decorator(embeddings) @@ -126,7 +126,7 @@ def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any: async def async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any: """Use tenacity to retry the embedding call.""" - if _is_openai_v1(): + if is_openai_v1(): return await embeddings.async_client.create(**kwargs) @_async_retry_decorator(embeddings) @@ -137,11 +137,6 @@ async def async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> return await _async_embed_with_retry(**kwargs) -def _is_openai_v1() -> bool: - _version = parse(version("openai")) - return _version >= Version("1.0.0") - - class OpenAIEmbeddings(BaseModel, Embeddings): """OpenAI embedding models. @@ -330,7 +325,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): "Please install it with `pip install openai`." ) else: - if _is_openai_v1(): + if is_openai_v1(): if values["openai_api_type"] in ("azure", "azure_ad", "azuread"): warnings.warn( "If you have openai>=1.0.0 installed and are using Azure, " @@ -360,7 +355,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): @property def _invocation_params(self) -> Dict[str, Any]: - if _is_openai_v1(): + if is_openai_v1(): openai_args: Dict = {"model": self.model, **self.model_kwargs} else: openai_args = { diff --git a/libs/core/langchain_core/runnables/config.py b/libs/core/langchain_core/runnables/config.py index 5672a60fa2d..c803f52b295 100644 --- a/libs/core/langchain_core/runnables/config.py +++ b/libs/core/langchain_core/runnables/config.py @@ -395,7 +395,7 @@ def _set_context(context: Context) -> None: @contextmanager def get_executor_for_config( - config: Optional[RunnableConfig] + config: Optional[RunnableConfig], ) -> Generator[Executor, None, None]: """Get an executor for a config.