This commit is contained in:
Bagatur
2024-09-03 16:48:53 -07:00
parent 615f8b0d47
commit 5f5287c3b0
8 changed files with 60 additions and 92 deletions

View File

@@ -31,16 +31,15 @@ from langchain_core.output_parsers.openai_tools import (
PydanticToolsParser, PydanticToolsParser,
) )
from langchain_core.outputs import ChatResult from langchain_core.outputs import ChatResult
from pydantic import BaseModel, Field, SecretStr, model_validator
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils import from_env, secret_from_env from langchain_core.utils import from_env, secret_from_env
from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_core.utils.pydantic import is_basemodel_subclass
from pydantic import BaseModel, Field, SecretStr, model_validator
from langchain_openai.chat_models.base import BaseChatOpenAI
from typing_extensions import Self from typing_extensions import Self
from langchain_openai.chat_models.base import BaseChatOpenAI
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -604,19 +603,15 @@ class AzureChatOpenAI(BaseChatOpenAI):
"Or you can equivalently specify:\n\n" "Or you can equivalently specify:\n\n"
'base_url="https://xxx.openai.azure.com/openai/deployments/my-deployment"' 'base_url="https://xxx.openai.azure.com/openai/deployments/my-deployment"'
) )
client_params = { client_params: dict = {
"api_version": self.openai_api_version, "api_version": self.openai_api_version,
"azure_endpoint": self.azure_endpoint, "azure_endpoint": self.azure_endpoint,
"azure_deployment": self.deployment_name, "azure_deployment": self.deployment_name,
"api_key": ( "api_key": (
self.openai_api_key.get_secret_value() self.openai_api_key.get_secret_value() if self.openai_api_key else None
if self.openai_api_key
else None
), ),
"azure_ad_token": ( "azure_ad_token": (
self.azure_ad_token.get_secret_value() self.azure_ad_token.get_secret_value() if self.azure_ad_token else None
if self.azure_ad_token
else None
), ),
"azure_ad_token_provider": self.azure_ad_token_provider, "azure_ad_token_provider": self.azure_ad_token_provider,
"organization": self.openai_organization, "organization": self.openai_organization,
@@ -628,12 +623,13 @@ class AzureChatOpenAI(BaseChatOpenAI):
} }
if not (self.client or None): if not (self.client or None):
sync_specific = {"http_client": self.http_client} sync_specific = {"http_client": self.http_client}
self.root_client = openai.AzureOpenAI(**client_params, **sync_specific) self.root_client = openai.AzureOpenAI(**client_params, **sync_specific) # type: ignore[arg-type]
self.client = self.root_client.chat.completions self.client = self.root_client.chat.completions
if not (self.async_client or None): if not (self.async_client or None):
async_specific = {"http_client": self.http_async_client} async_specific = {"http_client": self.http_async_client}
self.root_async_client = openai.AsyncAzureOpenAI( self.root_async_client = openai.AsyncAzureOpenAI(
**client_params, **async_specific **client_params,
**async_specific, # type: ignore[arg-type]
) )
self.async_client = self.root_async_client.chat.completions self.async_client = self.root_async_client.chat.completions
return self return self

View File

@@ -73,11 +73,10 @@ from langchain_core.output_parsers.openai_tools import (
parse_tool_call, parse_tool_call,
) )
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from pydantic import BaseModel, Field, model_validator, SecretStr
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough, chain from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough, chain
from langchain_core.runnables.config import run_in_executor from langchain_core.runnables.config import run_in_executor
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names from langchain_core.utils import get_pydantic_field_names
from langchain_core.utils.function_calling import ( from langchain_core.utils.function_calling import (
convert_to_openai_function, convert_to_openai_function,
convert_to_openai_tool, convert_to_openai_tool,
@@ -88,10 +87,9 @@ from langchain_core.utils.pydantic import (
is_basemodel_subclass, is_basemodel_subclass,
) )
from langchain_core.utils.utils import build_extra_kwargs, from_env, secret_from_env from langchain_core.utils.utils import build_extra_kwargs, from_env, secret_from_env
from pydantic import ConfigDict from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self from typing_extensions import Self
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -361,8 +359,7 @@ class BaseChatOpenAI(BaseChatModel):
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified.""" """Holds any model parameters valid for `create` call not explicitly specified."""
openai_api_key: Optional[SecretStr] = Field( openai_api_key: Optional[SecretStr] = Field(
alias="api_key", alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
default_factory=secret_from_env("OPENAI_API_KEY", default=None),
) )
openai_api_base: Optional[str] = Field(default=None, alias="base_url") openai_api_base: Optional[str] = Field(default=None, alias="base_url")
"""Base URL path for API requests, leave blank if not using a proxy or service """Base URL path for API requests, leave blank if not using a proxy or service
@@ -431,7 +428,7 @@ class BaseChatOpenAI(BaseChatModel):
include_response_headers: bool = False include_response_headers: bool = False
"""Whether to include response headers in the output message response_metadata.""" """Whether to include response headers in the output message response_metadata."""
model_config = ConfigDict(populate_by_name=True,) model_config = ConfigDict(populate_by_name=True)
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@@ -458,14 +455,10 @@ class BaseChatOpenAI(BaseChatModel):
or os.getenv("OPENAI_ORG_ID") or os.getenv("OPENAI_ORG_ID")
or os.getenv("OPENAI_ORGANIZATION") or os.getenv("OPENAI_ORGANIZATION")
) )
self.openai_api_base = self.openai_api_base or os.getenv( self.openai_api_base = self.openai_api_base or os.getenv("OPENAI_API_BASE")
"OPENAI_API_BASE" client_params: dict = {
)
client_params = {
"api_key": ( "api_key": (
self.openai_api_key.get_secret_value() self.openai_api_key.get_secret_value() if self.openai_api_key else None
if self.openai_api_key
else None
), ),
"organization": self.openai_organization, "organization": self.openai_organization,
"base_url": self.openai_api_base, "base_url": self.openai_api_base,
@@ -474,9 +467,7 @@ class BaseChatOpenAI(BaseChatModel):
"default_headers": self.default_headers, "default_headers": self.default_headers,
"default_query": self.default_query, "default_query": self.default_query,
} }
if self.openai_proxy and ( if self.openai_proxy and (self.http_client or self.http_async_client):
self.http_client or self.http_async_client
):
openai_proxy = self.openai_proxy openai_proxy = self.openai_proxy
http_client = self.http_client http_client = self.http_client
http_async_client = self.http_async_client http_async_client = self.http_async_client
@@ -496,7 +487,7 @@ class BaseChatOpenAI(BaseChatModel):
) from e ) from e
self.http_client = httpx.Client(proxy=self.openai_proxy) self.http_client = httpx.Client(proxy=self.openai_proxy)
sync_specific = {"http_client": self.http_client} sync_specific = {"http_client": self.http_client}
self.root_client = openai.OpenAI(**client_params, **sync_specific) self.root_client = openai.OpenAI(**client_params, **sync_specific) # type: ignore[arg-type]
self.client = self.root_client.chat.completions self.client = self.root_client.chat.completions
if not (self.async_client or None): if not (self.async_client or None):
if self.openai_proxy and not self.http_async_client: if self.openai_proxy and not self.http_async_client:
@@ -507,12 +498,11 @@ class BaseChatOpenAI(BaseChatModel):
"Could not import httpx python package. " "Could not import httpx python package. "
"Please install it with `pip install httpx`." "Please install it with `pip install httpx`."
) from e ) from e
self.http_async_client = httpx.AsyncClient( self.http_async_client = httpx.AsyncClient(proxy=self.openai_proxy)
proxy=self.openai_proxy
)
async_specific = {"http_client": self.http_async_client} async_specific = {"http_client": self.http_async_client}
self.root_async_client = openai.AsyncOpenAI( self.root_async_client = openai.AsyncOpenAI(
**client_params, **async_specific **client_params,
**async_specific, # type: ignore[arg-type]
) )
self.async_client = self.root_async_client.chat.completions self.async_client = self.root_async_client.chat.completions
return self return self

View File

@@ -2,15 +2,14 @@
from __future__ import annotations from __future__ import annotations
from typing import Callable, Dict, Optional, Union from typing import Callable, Optional, Union
import openai import openai
from pydantic import Field, SecretStr, root_validator, model_validator
from langchain_core.utils import from_env, secret_from_env from langchain_core.utils import from_env, secret_from_env
from pydantic import Field, SecretStr, model_validator
from typing_extensions import Self, cast
from langchain_openai.embeddings.base import OpenAIEmbeddings from langchain_openai.embeddings.base import OpenAIEmbeddings
from typing_extensions import Self
class AzureOpenAIEmbeddings(OpenAIEmbeddings): class AzureOpenAIEmbeddings(OpenAIEmbeddings):
@@ -163,7 +162,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
openai_api_base = self.openai_api_base openai_api_base = self.openai_api_base
if openai_api_base and self.validate_base_url: if openai_api_base and self.validate_base_url:
if "/openai" not in openai_api_base: if "/openai" not in openai_api_base:
self.openai_api_base += "/openai" self.openai_api_base = cast(str, self.openai_api_base) + "/openai"
raise ValueError( raise ValueError(
"As of openai>=1.0.0, Azure endpoints should be specified via " "As of openai>=1.0.0, Azure endpoints should be specified via "
"the `azure_endpoint` param not `openai_api_base` " "the `azure_endpoint` param not `openai_api_base` "
@@ -177,19 +176,15 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
"Instead use `deployment` (or alias `azure_deployment`) " "Instead use `deployment` (or alias `azure_deployment`) "
"and `azure_endpoint`." "and `azure_endpoint`."
) )
client_params = { client_params: dict = {
"api_version": self.openai_api_version, "api_version": self.openai_api_version,
"azure_endpoint": self.azure_endpoint, "azure_endpoint": self.azure_endpoint,
"azure_deployment": self.deployment, "azure_deployment": self.deployment,
"api_key": ( "api_key": (
self.openai_api_key.get_secret_value() self.openai_api_key.get_secret_value() if self.openai_api_key else None
if self.openai_api_key
else None
), ),
"azure_ad_token": ( "azure_ad_token": (
self.azure_ad_token.get_secret_value() self.azure_ad_token.get_secret_value() if self.azure_ad_token else None
if self.azure_ad_token
else None
), ),
"azure_ad_token_provider": self.azure_ad_token_provider, "azure_ad_token_provider": self.azure_ad_token_provider,
"organization": self.openai_organization, "organization": self.openai_organization,
@@ -200,14 +195,16 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
"default_query": self.default_query, "default_query": self.default_query,
} }
if not (self.client or None): if not (self.client or None):
sync_specific = {"http_client": self.http_client} sync_specific: dict = {"http_client": self.http_client}
self.client = openai.AzureOpenAI( self.client = openai.AzureOpenAI(
**client_params, **sync_specific **client_params, # type: ignore[arg-type]
**sync_specific,
).embeddings ).embeddings
if not (self.async_client or None): if not (self.async_client or None):
async_specific = {"http_client": self.http_async_client} async_specific: dict = {"http_client": self.http_async_client}
self.async_client = openai.AsyncAzureOpenAI( self.async_client = openai.AsyncAzureOpenAI(
**client_params, **async_specific **client_params, # type: ignore[arg-type]
**async_specific,
).embeddings ).embeddings
return self return self

View File

@@ -20,13 +20,10 @@ from typing import (
import openai import openai
import tiktoken import tiktoken
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from pydantic import BaseModel, Field, SecretStr, root_validator, model_validator
from langchain_core.utils import from_env, get_pydantic_field_names, secret_from_env from langchain_core.utils import from_env, get_pydantic_field_names, secret_from_env
from pydantic import ConfigDict from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self from typing_extensions import Self
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -267,7 +264,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""Whether to check the token length of inputs and automatically split inputs """Whether to check the token length of inputs and automatically split inputs
longer than embedding_ctx_length.""" longer than embedding_ctx_length."""
model_config = ConfigDict(extra="forbid",populate_by_name=True,) model_config = ConfigDict(extra="forbid", populate_by_name=True)
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@@ -304,11 +301,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"If you are using Azure, " "If you are using Azure, "
"please use the `AzureOpenAIEmbeddings` class." "please use the `AzureOpenAIEmbeddings` class."
) )
client_params = { client_params: dict = {
"api_key": ( "api_key": (
self.openai_api_key.get_secret_value() self.openai_api_key.get_secret_value() if self.openai_api_key else None
if self.openai_api_key
else None
), ),
"organization": self.openai_organization, "organization": self.openai_organization,
"base_url": self.openai_api_base, "base_url": self.openai_api_base,
@@ -318,9 +313,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"default_query": self.default_query, "default_query": self.default_query,
} }
if self.openai_proxy and ( if self.openai_proxy and (self.http_client or self.http_async_client):
self.http_client or self.http_async_client
):
openai_proxy = self.openai_proxy openai_proxy = self.openai_proxy
http_client = self.http_client http_client = self.http_client
http_async_client = self.http_async_client http_async_client = self.http_async_client
@@ -340,9 +333,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
) from e ) from e
self.http_client = httpx.Client(proxy=self.openai_proxy) self.http_client = httpx.Client(proxy=self.openai_proxy)
sync_specific = {"http_client": self.http_client} sync_specific = {"http_client": self.http_client}
self.client = openai.OpenAI( self.client = openai.OpenAI(**client_params, **sync_specific).embeddings # type: ignore[arg-type]
**client_params, **sync_specific
).embeddings
if not (self.async_client or None): if not (self.async_client or None):
if self.openai_proxy and not self.http_async_client: if self.openai_proxy and not self.http_async_client:
try: try:
@@ -352,12 +343,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"Could not import httpx python package. " "Could not import httpx python package. "
"Please install it with `pip install httpx`." "Please install it with `pip install httpx`."
) from e ) from e
self.http_async_client = httpx.AsyncClient( self.http_async_client = httpx.AsyncClient(proxy=self.openai_proxy)
proxy=self.openai_proxy
)
async_specific = {"http_client": self.http_async_client} async_specific = {"http_client": self.http_async_client}
self.async_client = openai.AsyncOpenAI( self.async_client = openai.AsyncOpenAI(
**client_params, **async_specific **client_params,
**async_specific, # type: ignore[arg-type]
).embeddings ).embeddings
return self return self

View File

@@ -5,12 +5,11 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Union
import openai import openai
from langchain_core.language_models import LangSmithParams from langchain_core.language_models import LangSmithParams
from pydantic import Field, SecretStr, root_validator, model_validator
from langchain_core.utils import from_env, secret_from_env from langchain_core.utils import from_env, secret_from_env
from pydantic import Field, SecretStr, model_validator
from typing_extensions import Self, cast
from langchain_openai.llms.base import BaseOpenAI from langchain_openai.llms.base import BaseOpenAI
from typing_extensions import Self
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -117,7 +116,7 @@ class AzureOpenAI(BaseOpenAI):
if openai_api_base and self.validate_base_url: if openai_api_base and self.validate_base_url:
if "/openai" not in openai_api_base: if "/openai" not in openai_api_base:
self.openai_api_base = ( self.openai_api_base = (
self.openai_api_base.rstrip("/") + "/openai" cast(str, self.openai_api_base).rstrip("/") + "/openai"
) )
raise ValueError( raise ValueError(
"As of openai>=1.0.0, Azure endpoints should be specified via " "As of openai>=1.0.0, Azure endpoints should be specified via "
@@ -133,7 +132,7 @@ class AzureOpenAI(BaseOpenAI):
"and `azure_endpoint`." "and `azure_endpoint`."
) )
self.deployment_name = None self.deployment_name = None
client_params = { client_params: dict = {
"api_version": self.openai_api_version, "api_version": self.openai_api_version,
"azure_endpoint": self.azure_endpoint, "azure_endpoint": self.azure_endpoint,
"azure_deployment": self.deployment_name, "azure_deployment": self.deployment_name,
@@ -154,12 +153,14 @@ class AzureOpenAI(BaseOpenAI):
if not (self.client or None): if not (self.client or None):
sync_specific = {"http_client": self.http_client} sync_specific = {"http_client": self.http_client}
self.client = openai.AzureOpenAI( self.client = openai.AzureOpenAI(
**client_params, **sync_specific **client_params,
**sync_specific, # type: ignore[arg-type]
).completions ).completions
if not (self.async_client or None): if not (self.async_client or None):
async_specific = {"http_client": self.http_async_client} async_specific = {"http_client": self.http_async_client}
self.async_client = openai.AsyncAzureOpenAI( self.async_client = openai.AsyncAzureOpenAI(
**client_params, **async_specific **client_params,
**async_specific, # type: ignore[arg-type]
).completions ).completions
return self return self

View File

@@ -26,14 +26,11 @@ from langchain_core.callbacks import (
) )
from langchain_core.language_models.llms import BaseLLM from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from pydantic import Field, SecretStr, root_validator, model_validator
from langchain_core.utils import get_pydantic_field_names from langchain_core.utils import get_pydantic_field_names
from langchain_core.utils.utils import build_extra_kwargs, from_env, secret_from_env from langchain_core.utils.utils import build_extra_kwargs, from_env, secret_from_env
from pydantic import ConfigDict from pydantic import ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self from typing_extensions import Self
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -156,7 +153,7 @@ class BaseOpenAI(BaseLLM):
"""Optional additional JSON properties to include in the request parameters when """Optional additional JSON properties to include in the request parameters when
making requests to OpenAI compatible APIs, such as vLLM.""" making requests to OpenAI compatible APIs, such as vLLM."""
model_config = ConfigDict(populate_by_name=True,) model_config = ConfigDict(populate_by_name=True)
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@@ -179,11 +176,9 @@ class BaseOpenAI(BaseLLM):
if self.streaming and self.best_of > 1: if self.streaming and self.best_of > 1:
raise ValueError("Cannot stream results when best_of > 1.") raise ValueError("Cannot stream results when best_of > 1.")
client_params = { client_params: dict = {
"api_key": ( "api_key": (
self.openai_api_key.get_secret_value() self.openai_api_key.get_secret_value() if self.openai_api_key else None
if self.openai_api_key
else None
), ),
"organization": self.openai_organization, "organization": self.openai_organization,
"base_url": self.openai_api_base, "base_url": self.openai_api_base,
@@ -194,13 +189,12 @@ class BaseOpenAI(BaseLLM):
} }
if not (self.client or None): if not (self.client or None):
sync_specific = {"http_client": self.http_client} sync_specific = {"http_client": self.http_client}
self.client = openai.OpenAI( self.client = openai.OpenAI(**client_params, **sync_specific).completions # type: ignore[arg-type]
**client_params, **sync_specific
).completions
if not (self.async_client or None): if not (self.async_client or None):
async_specific = {"http_client": self.http_async_client} async_specific = {"http_client": self.http_async_client}
self.async_client = openai.AsyncOpenAI( self.async_client = openai.AsyncOpenAI(
**client_params, **async_specific **client_params,
**async_specific, # type: ignore[arg-type]
).completions ).completions
return self return self

View File

@@ -20,13 +20,13 @@ from langchain_core.messages import (
) )
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
from langchain_standard_tests.integration_tests.chat_models import ( from langchain_standard_tests.integration_tests.chat_models import (
_validate_tool_call_message, _validate_tool_call_message,
) )
from langchain_standard_tests.integration_tests.chat_models import ( from langchain_standard_tests.integration_tests.chat_models import (
magic_function as invalid_magic_function, magic_function as invalid_magic_function,
) )
from pydantic import BaseModel, Field
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from tests.unit_tests.fake.callbacks import FakeCallbackHandler from tests.unit_tests.fake.callbacks import FakeCallbackHandler

View File

@@ -188,7 +188,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
def on_retriever_error(self, *args: Any, **kwargs: Any) -> Any: def on_retriever_error(self, *args: Any, **kwargs: Any) -> Any:
self.on_retriever_error_common() self.on_retriever_error_common()
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore[override]
return self return self
@@ -266,5 +266,5 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
async def on_text(self, *args: Any, **kwargs: Any) -> None: async def on_text(self, *args: Any, **kwargs: Any) -> None:
self.on_text_common() self.on_text_common()
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": # type: ignore[override]
return self return self