This commit is contained in:
Eugene Yurtsev
2024-08-16 13:22:11 -04:00
parent 1645340680
commit f3a075df9f
12 changed files with 38 additions and 35 deletions

View File

@@ -31,12 +31,12 @@ from langchain_core.output_parsers.openai_tools import (
PydanticToolsParser,
)
from langchain_core.outputs import ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import from_env, secret_from_env
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from pydantic import BaseModel, Field, SecretStr, root_validator
from langchain_openai.chat_models.base import BaseChatOpenAI

View File

@@ -73,7 +73,6 @@ from langchain_core.output_parsers.openai_tools import (
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough, chain
from langchain_core.runnables.config import run_in_executor
from langchain_core.tools import BaseTool
@@ -92,6 +91,14 @@ from langchain_core.utils.pydantic import (
is_basemodel_subclass,
)
from langchain_core.utils.utils import build_extra_kwargs
from pydantic import (
BaseModel,
ConfigDict,
Field,
SecretStr,
model_validator,
root_validator,
)
logger = logging.getLogger(__name__)
@@ -377,13 +384,11 @@ class BaseChatOpenAI(BaseChatModel):
include_response_headers: bool = False
"""Whether to include response headers in the output message response_metadata."""
class Config:
"""Configuration for this pydantic object."""
model_config = ConfigDict(populate_by_name=True)
allow_population_by_field_name = True
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})

View File

@@ -5,8 +5,8 @@ from __future__ import annotations
from typing import Callable, Dict, Optional, Union
import openai
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import from_env, secret_from_env
from pydantic import Field, SecretStr, root_validator
from langchain_openai.embeddings.base import OpenAIEmbeddings

View File

@@ -21,8 +21,15 @@ from typing import (
import openai
import tiktoken
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import from_env, get_pydantic_field_names, secret_from_env
from pydantic import (
BaseModel,
ConfigDict,
Field,
SecretStr,
model_validator,
root_validator,
)
logger = logging.getLogger(__name__)
@@ -261,14 +268,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""Whether to check the token length of inputs and automatically split inputs
longer than embedding_ctx_length."""
class Config:
"""Configuration for this pydantic object."""
model_config = ConfigDict(extra="forbid", populate_by_name=True)
extra = "forbid"
allow_population_by_field_name = True
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})

View File

@@ -4,8 +4,8 @@ import logging
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
import openai
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import from_env, secret_from_env
from pydantic import Field, SecretStr, root_validator
from langchain_openai.llms.base import BaseOpenAI

View File

@@ -26,9 +26,9 @@ from langchain_core.callbacks import (
)
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import get_pydantic_field_names
from langchain_core.utils.utils import build_extra_kwargs, from_env, secret_from_env
from pydantic import ConfigDict, Field, SecretStr, model_validator, root_validator
logger = logging.getLogger(__name__)
@@ -152,13 +152,11 @@ class BaseOpenAI(BaseLLM):
"""Optional additional JSON properties to include in the request parameters when
making requests to OpenAI compatible APIs, such as vLLM."""
class Config:
"""Configuration for this pydantic object."""
model_config = ConfigDict(populate_by_name=True)
allow_population_by_field_name = True
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})

View File

@@ -19,13 +19,13 @@ from langchain_core.messages import (
)
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_standard_tests.integration_tests.chat_models import (
_validate_tool_call_message,
)
from langchain_standard_tests.integration_tests.chat_models import (
magic_function as invalid_magic_function,
)
from pydantic import BaseModel, Field
from langchain_openai import ChatOpenAI
from tests.unit_tests.fake.callbacks import FakeCallbackHandler

View File

@@ -40,9 +40,7 @@ def test_initialize_more() -> None:
def test_initialize_azure_openai_with_openai_api_base_set() -> None:
with mock.patch.dict(
os.environ, {"OPENAI_API_BASE": "https://api.openai.com"}
):
with mock.patch.dict(os.environ, {"OPENAI_API_BASE": "https://api.openai.com"}):
llm = AzureChatOpenAI( # type: ignore[call-arg, call-arg]
api_key="xyz", # type: ignore[arg-type]
azure_endpoint="my-base-url",

View File

@@ -14,7 +14,7 @@ from langchain_core.messages import (
ToolCall,
ToolMessage,
)
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import (

View File

@@ -16,9 +16,7 @@ def test_initialize_azure_openai() -> None:
def test_intialize_azure_openai_with_base_set() -> None:
with mock.patch.dict(
os.environ, {"OPENAI_API_BASE": "https://api.openai.com"}
):
with mock.patch.dict(os.environ, {"OPENAI_API_BASE": "https://api.openai.com"}):
embeddings = AzureOpenAIEmbeddings( # type: ignore[call-arg, call-arg]
model="text-embedding-large",
api_key="xyz", # type: ignore[arg-type]

View File

@@ -6,7 +6,7 @@ from uuid import UUID
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
class BaseFakeCallbackHandler(BaseModel):

View File

@@ -2,7 +2,7 @@ from typing import Type, cast
import pytest
from langchain_core.load import dumpd
from langchain_core.pydantic_v1 import SecretStr
from pydantic import SecretStr
from pytest import CaptureFixture, MonkeyPatch
from langchain_openai import (