From 98779797fefa670db94e8e699bd9d52cf0920a18 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 8 Aug 2024 14:43:09 -0400 Subject: [PATCH] community[patch]: Use get_fields adapter for pydantic (#25191) Change all usages of __fields__ with get_fields adapter merged into langchain_core. Code mod generated using the following grit pattern: ``` engine marzano(0.1) language python `$X.__fields__` => `get_fields($X)` where { add_import(source="langchain_core.utils.pydantic", name="get_fields") } ``` --- .../agent_toolkits/file_management/toolkit.py | 3 ++- .../chat_models/baidu_qianfan_endpoint.py | 4 ++-- libs/community/langchain_community/chat_models/minimax.py | 3 ++- libs/community/langchain_community/chat_models/sparkllm.py | 3 ++- libs/community/langchain_community/embeddings/gigachat.py | 3 ++- libs/community/langchain_community/llms/bananadev.py | 3 ++- libs/community/langchain_community/llms/beam.py | 3 ++- libs/community/langchain_community/llms/cerebriumai.py | 3 ++- libs/community/langchain_community/llms/edenai.py | 3 ++- libs/community/langchain_community/llms/gigachat.py | 3 ++- libs/community/langchain_community/llms/gooseai.py | 3 ++- libs/community/langchain_community/llms/loading.py | 3 ++- libs/community/langchain_community/llms/modal.py | 3 ++- libs/community/langchain_community/llms/openai.py | 3 ++- libs/community/langchain_community/llms/petals.py | 3 ++- libs/community/langchain_community/llms/pipelineai.py | 3 ++- libs/community/langchain_community/llms/replicate.py | 3 ++- libs/community/langchain_community/llms/stochasticai.py | 3 ++- libs/community/langchain_community/retrievers/docarray.py | 3 ++- .../langchain_community/vectorstores/redis/schema.py | 5 +++-- libs/community/tests/integration_tests/llms/utils.py | 3 ++- .../tests/unit_tests/callbacks/test_callback_manager.py | 5 +++-- .../community/tests/unit_tests/callbacks/test_openai_info.py | 3 ++- libs/community/tests/unit_tests/tools/test_exported.py | 5 +++-- 24 files changed, 51 insertions(+), 28 deletions(-) diff --git a/libs/community/langchain_community/agent_toolkits/file_management/toolkit.py b/libs/community/langchain_community/agent_toolkits/file_management/toolkit.py index bfcb77ba2c0..298f6646129 100644 --- a/libs/community/langchain_community/agent_toolkits/file_management/toolkit.py +++ b/libs/community/langchain_community/agent_toolkits/file_management/toolkit.py @@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Type from langchain_core.pydantic_v1 import root_validator from langchain_core.tools import BaseToolkit +from langchain_core.utils.pydantic import get_fields from langchain_community.tools import BaseTool from langchain_community.tools.file_management.copy import CopyFileTool @@ -24,7 +25,7 @@ _FILE_TOOLS: List[Type[BaseTool]] = [ ListDirectoryTool, ] _FILE_TOOLS_MAP: Dict[str, Type[BaseTool]] = { - tool_cls.__fields__["name"].default: tool_cls for tool_cls in _FILE_TOOLS + get_fields(tool_cls)["name"].default: tool_cls for tool_cls in _FILE_TOOLS } diff --git a/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py b/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py index 4d450142621..2a9cc8cec29 100644 --- a/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py +++ b/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py @@ -51,7 +51,7 @@ from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_core.utils.function_calling import convert_to_openai_tool -from langchain_core.utils.pydantic import is_basemodel_subclass +from langchain_core.utils.pydantic import get_fields, is_basemodel_subclass logger = logging.getLogger(__name__) @@ -397,7 +397,7 @@ class QianfanChatEndpoint(BaseChatModel): default_values = { name: field.default - for name, field in cls.__fields__.items() + for name, field in get_fields(cls).items() if field.default is not None } default_values.update(values) diff --git a/libs/community/langchain_community/chat_models/minimax.py b/libs/community/langchain_community/chat_models/minimax.py index 8bbe0d9e603..0b88c2f6c4d 100644 --- a/libs/community/langchain_community/chat_models/minimax.py +++ b/libs/community/langchain_community/chat_models/minimax.py @@ -49,6 +49,7 @@ from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_core.utils.function_calling import convert_to_openai_tool +from langchain_core.utils.pydantic import get_fields logger = logging.getLogger(__name__) @@ -399,7 +400,7 @@ class MiniMaxChat(BaseChatModel): default_values = { name: field.default - for name, field in cls.__fields__.items() + for name, field in get_fields(cls).items() if field.default is not None } default_values.update(values) diff --git a/libs/community/langchain_community/chat_models/sparkllm.py b/libs/community/langchain_community/chat_models/sparkllm.py index 899463752db..75927d61535 100644 --- a/libs/community/langchain_community/chat_models/sparkllm.py +++ b/libs/community/langchain_community/chat_models/sparkllm.py @@ -40,6 +40,7 @@ from langchain_core.utils import ( get_from_dict_or_env, get_pydantic_field_names, ) +from langchain_core.utils.pydantic import get_fields logger = logging.getLogger(__name__) @@ -308,7 +309,7 @@ class ChatSparkLLM(BaseChatModel): # put extra params into model_kwargs default_values = { name: field.default - for name, field in cls.__fields__.items() + for name, field in get_fields(cls).items() if field.default is not None } values["model_kwargs"]["temperature"] = default_values.get("temperature") diff --git a/libs/community/langchain_community/embeddings/gigachat.py b/libs/community/langchain_community/embeddings/gigachat.py index 473878103cf..09ac5b37578 100644 --- a/libs/community/langchain_community/embeddings/gigachat.py +++ b/libs/community/langchain_community/embeddings/gigachat.py @@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel from langchain_core.utils import pre_init +from langchain_core.utils.pydantic import get_fields logger = logging.getLogger(__name__) @@ -88,7 +89,7 @@ class GigaChatEmbeddings(BaseModel, Embeddings): "Could not import gigachat python package. " "Please install it with `pip install gigachat`." ) - fields = set(cls.__fields__.keys()) + fields = set(get_fields(cls).keys()) diff = set(values.keys()) - fields if diff: logger.warning(f"Extra fields {diff} in GigaChat class") diff --git a/libs/community/langchain_community/llms/bananadev.py b/libs/community/langchain_community/llms/bananadev.py index 2e85ef7aee0..c1be0bb1fb6 100644 --- a/libs/community/langchain_community/llms/bananadev.py +++ b/libs/community/langchain_community/llms/bananadev.py @@ -5,6 +5,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.pydantic_v1 import Field, SecretStr, root_validator from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init +from langchain_core.utils.pydantic import get_fields from langchain_community.llms.utils import enforce_stop_tokens @@ -46,7 +47,7 @@ class Banana(LLM): @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = {field.alias for field in cls.__fields__.values()} + all_required_field_names = {field.alias for field in get_fields(cls).values()} extra = values.get("model_kwargs", {}) for field_name in list(values): diff --git a/libs/community/langchain_community/llms/beam.py b/libs/community/langchain_community/llms/beam.py index 12f150493ca..63d51499e9a 100644 --- a/libs/community/langchain_community/llms/beam.py +++ b/libs/community/langchain_community/llms/beam.py @@ -11,6 +11,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.utils import get_from_dict_or_env, pre_init +from langchain_core.utils.pydantic import get_fields logger = logging.getLogger(__name__) @@ -78,7 +79,7 @@ class Beam(LLM): @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = {field.alias for field in cls.__fields__.values()} + all_required_field_names = {field.alias for field in get_fields(cls).values()} extra = values.get("model_kwargs", {}) for field_name in list(values): diff --git a/libs/community/langchain_community/llms/cerebriumai.py b/libs/community/langchain_community/llms/cerebriumai.py index d5b3d41a7ae..9e69b026c9c 100644 --- a/libs/community/langchain_community/llms/cerebriumai.py +++ b/libs/community/langchain_community/llms/cerebriumai.py @@ -6,6 +6,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.pydantic_v1 import Field, SecretStr, root_validator from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init +from langchain_core.utils.pydantic import get_fields from langchain_community.llms.utils import enforce_stop_tokens @@ -45,7 +46,7 @@ class CerebriumAI(LLM): @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = {field.alias for field in cls.__fields__.values()} + all_required_field_names = {field.alias for field in get_fields(cls).values()} extra = values.get("model_kwargs", {}) for field_name in list(values): diff --git a/libs/community/langchain_community/llms/edenai.py b/libs/community/langchain_community/llms/edenai.py index 6b8bbe23326..31cdaee92ae 100644 --- a/libs/community/langchain_community/llms/edenai.py +++ b/libs/community/langchain_community/llms/edenai.py @@ -11,6 +11,7 @@ from langchain_core.callbacks import ( from langchain_core.language_models.llms import LLM from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.utils import get_from_dict_or_env, pre_init +from langchain_core.utils.pydantic import get_fields from langchain_community.llms.utils import enforce_stop_tokens from langchain_community.utilities.requests import Requests @@ -82,7 +83,7 @@ class EdenAI(LLM): @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = {field.alias for field in cls.__fields__.values()} + all_required_field_names = {field.alias for field in get_fields(cls).values()} extra = values.get("model_kwargs", {}) for field_name in list(values): diff --git a/libs/community/langchain_community/llms/gigachat.py b/libs/community/langchain_community/llms/gigachat.py index a8fbb1654c4..4172081dac4 100644 --- a/libs/community/langchain_community/llms/gigachat.py +++ b/libs/community/langchain_community/llms/gigachat.py @@ -12,6 +12,7 @@ from langchain_core.language_models.llms import BaseLLM from langchain_core.load.serializable import Serializable from langchain_core.outputs import Generation, GenerationChunk, LLMResult from langchain_core.utils import pre_init +from langchain_core.utils.pydantic import get_fields if TYPE_CHECKING: import gigachat @@ -123,7 +124,7 @@ class _BaseGigaChat(Serializable): "Could not import gigachat python package. " "Please install it with `pip install gigachat`." ) - fields = set(cls.__fields__.keys()) + fields = set(get_fields(cls).keys()) diff = set(values.keys()) - fields if diff: logger.warning(f"Extra fields {diff} in GigaChat class") diff --git a/libs/community/langchain_community/llms/gooseai.py b/libs/community/langchain_community/llms/gooseai.py index b9eae22ae73..255908a20c8 100644 --- a/libs/community/langchain_community/llms/gooseai.py +++ b/libs/community/langchain_community/llms/gooseai.py @@ -5,6 +5,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.pydantic_v1 import Field, SecretStr, root_validator from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init +from langchain_core.utils.pydantic import get_fields logger = logging.getLogger(__name__) @@ -68,7 +69,7 @@ class GooseAI(LLM): @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = {field.alias for field in cls.__fields__.values()} + all_required_field_names = {field.alias for field in get_fields(cls).values()} extra = values.get("model_kwargs", {}) for field_name in list(values): diff --git a/libs/community/langchain_community/llms/loading.py b/libs/community/langchain_community/llms/loading.py index 67fc020a162..4e97587b1b8 100644 --- a/libs/community/langchain_community/llms/loading.py +++ b/libs/community/langchain_community/llms/loading.py @@ -6,6 +6,7 @@ from typing import Any, Union import yaml from langchain_core.language_models.llms import BaseLLM +from langchain_core.utils.pydantic import get_fields from langchain_community.llms import get_type_to_cls_dict @@ -26,7 +27,7 @@ def load_llm_from_config(config: dict, **kwargs: Any) -> BaseLLM: llm_cls = type_to_cls_dict[config_type]() load_kwargs = {} - if _ALLOW_DANGEROUS_DESERIALIZATION_ARG in llm_cls.__fields__: + if _ALLOW_DANGEROUS_DESERIALIZATION_ARG in get_fields(llm_cls): load_kwargs[_ALLOW_DANGEROUS_DESERIALIZATION_ARG] = kwargs.get( _ALLOW_DANGEROUS_DESERIALIZATION_ARG, False ) diff --git a/libs/community/langchain_community/llms/modal.py b/libs/community/langchain_community/llms/modal.py index b8751394b38..010127d4d54 100644 --- a/libs/community/langchain_community/llms/modal.py +++ b/libs/community/langchain_community/llms/modal.py @@ -5,6 +5,7 @@ import requests from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.utils.pydantic import get_fields from langchain_community.llms.utils import enforce_stop_tokens @@ -40,7 +41,7 @@ class Modal(LLM): @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = {field.alias for field in cls.__fields__.values()} + all_required_field_names = {field.alias for field in get_fields(cls).values()} extra = values.get("model_kwargs", {}) for field_name in list(values): diff --git a/libs/community/langchain_community/llms/openai.py b/libs/community/langchain_community/llms/openai.py index d0d377f5d59..6a9cd5d9e45 100644 --- a/libs/community/langchain_community/llms/openai.py +++ b/libs/community/langchain_community/llms/openai.py @@ -34,6 +34,7 @@ from langchain_core.utils import ( get_pydantic_field_names, pre_init, ) +from langchain_core.utils.pydantic import get_fields from langchain_core.utils.utils import build_extra_kwargs from langchain_community.utils.openai import is_openai_v1 @@ -1016,7 +1017,7 @@ class OpenAIChat(BaseLLM): @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = {field.alias for field in cls.__fields__.values()} + all_required_field_names = {field.alias for field in get_fields(cls).values()} extra = values.get("model_kwargs", {}) for field_name in list(values): diff --git a/libs/community/langchain_community/llms/petals.py b/libs/community/langchain_community/llms/petals.py index 2149034428e..7506e3bf5c8 100644 --- a/libs/community/langchain_community/llms/petals.py +++ b/libs/community/langchain_community/llms/petals.py @@ -5,6 +5,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.pydantic_v1 import Field, SecretStr, root_validator from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init +from langchain_core.utils.pydantic import get_fields from langchain_community.llms.utils import enforce_stop_tokens @@ -68,7 +69,7 @@ class Petals(LLM): @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = {field.alias for field in cls.__fields__.values()} + all_required_field_names = {field.alias for field in get_fields(cls).values()} extra = values.get("model_kwargs", {}) for field_name in list(values): diff --git a/libs/community/langchain_community/llms/pipelineai.py b/libs/community/langchain_community/llms/pipelineai.py index 110b7f8232c..d3616a4f60d 100644 --- a/libs/community/langchain_community/llms/pipelineai.py +++ b/libs/community/langchain_community/llms/pipelineai.py @@ -10,6 +10,7 @@ from langchain_core.pydantic_v1 import ( root_validator, ) from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init +from langchain_core.utils.pydantic import get_fields from langchain_community.llms.utils import enforce_stop_tokens @@ -47,7 +48,7 @@ class PipelineAI(LLM, BaseModel): @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = {field.alias for field in cls.__fields__.values()} + all_required_field_names = {field.alias for field in get_fields(cls).values()} extra = values.get("pipeline_kwargs", {}) for field_name in list(values): diff --git a/libs/community/langchain_community/llms/replicate.py b/libs/community/langchain_community/llms/replicate.py index 1f9a4b3f852..7235e57d7f4 100644 --- a/libs/community/langchain_community/llms/replicate.py +++ b/libs/community/langchain_community/llms/replicate.py @@ -8,6 +8,7 @@ from langchain_core.language_models.llms import LLM from langchain_core.outputs import GenerationChunk from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.utils import get_from_dict_or_env, pre_init +from langchain_core.utils.pydantic import get_fields if TYPE_CHECKING: from replicate.prediction import Prediction @@ -75,7 +76,7 @@ class Replicate(LLM): @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = {field.alias for field in cls.__fields__.values()} + all_required_field_names = {field.alias for field in get_fields(cls).values()} input = values.pop("input", {}) if input: diff --git a/libs/community/langchain_community/llms/stochasticai.py b/libs/community/langchain_community/llms/stochasticai.py index fdd8b122d68..d97fd1e5ee2 100644 --- a/libs/community/langchain_community/llms/stochasticai.py +++ b/libs/community/langchain_community/llms/stochasticai.py @@ -7,6 +7,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.pydantic_v1 import Field, SecretStr, root_validator from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init +from langchain_core.utils.pydantic import get_fields from langchain_community.llms.utils import enforce_stop_tokens @@ -41,7 +42,7 @@ class StochasticAI(LLM): @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = {field.alias for field in cls.__fields__.values()} + all_required_field_names = {field.alias for field in get_fields(cls).values()} extra = values.get("model_kwargs", {}) for field_name in list(values): diff --git a/libs/community/langchain_community/retrievers/docarray.py b/libs/community/langchain_community/retrievers/docarray.py index c3a791a2599..e258735be2c 100644 --- a/libs/community/langchain_community/retrievers/docarray.py +++ b/libs/community/langchain_community/retrievers/docarray.py @@ -6,6 +6,7 @@ from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.retrievers import BaseRetriever +from langchain_core.utils.pydantic import get_fields from langchain_community.vectorstores.utils import maximal_marginal_relevance @@ -182,7 +183,7 @@ class DocArrayRetriever(BaseRetriever): ValueError: If the document doesn't contain the content field """ - fields = doc.keys() if isinstance(doc, dict) else doc.__fields__ + fields = doc.keys() if isinstance(doc, dict) else get_fields(doc) if self.content_field not in fields: raise ValueError( diff --git a/libs/community/langchain_community/vectorstores/redis/schema.py b/libs/community/langchain_community/vectorstores/redis/schema.py index 5b8618797eb..50b20245fe2 100644 --- a/libs/community/langchain_community/vectorstores/redis/schema.py +++ b/libs/community/langchain_community/vectorstores/redis/schema.py @@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Union import numpy as np import yaml from langchain_core.pydantic_v1 import BaseModel, Field, validator +from langchain_core.utils.pydantic import get_fields from typing_extensions import TYPE_CHECKING, Literal from langchain_community.vectorstores.redis.constants import REDIS_VECTOR_DTYPE_MAP @@ -255,7 +256,7 @@ class RedisModel(BaseModel): if self.is_empty: return redis_fields - for field_name in self.__fields__.keys(): + for field_name in get_fields(self).keys(): if field_name not in ["content_key", "content_vector_key", "extra"]: field_group = getattr(self, field_name) if field_group is not None: @@ -269,7 +270,7 @@ class RedisModel(BaseModel): if self.is_empty: return keys - for field_name in self.__fields__.keys(): + for field_name in get_fields(self).keys(): field_group = getattr(self, field_name) if field_group is not None: for field in field_group: diff --git a/libs/community/tests/integration_tests/llms/utils.py b/libs/community/tests/integration_tests/llms/utils.py index 064cd9025d8..2793f5a71e1 100644 --- a/libs/community/tests/integration_tests/llms/utils.py +++ b/libs/community/tests/integration_tests/llms/utils.py @@ -1,6 +1,7 @@ """Utils for LLM Tests.""" from langchain_core.language_models.llms import BaseLLM +from langchain_core.utils.pydantic import get_fields def assert_llm_equality(llm: BaseLLM, loaded_llm: BaseLLM) -> None: @@ -9,7 +10,7 @@ def assert_llm_equality(llm: BaseLLM, loaded_llm: BaseLLM) -> None: assert type(llm) is type(loaded_llm) # Client field can be session based, so hash is different despite # all other values being the same, so just assess all other fields - for field in llm.__fields__.keys(): + for field in get_fields(llm).keys(): if field != "client" and field != "pipeline": val = getattr(llm, field) new_val = getattr(loaded_llm, field) diff --git a/libs/community/tests/unit_tests/callbacks/test_callback_manager.py b/libs/community/tests/unit_tests/callbacks/test_callback_manager.py index cc6e29410f3..69821d2de6d 100644 --- a/libs/community/tests/unit_tests/callbacks/test_callback_manager.py +++ b/libs/community/tests/unit_tests/callbacks/test_callback_manager.py @@ -6,6 +6,7 @@ import pytest from langchain_core.callbacks.manager import CallbackManager, trace_as_chain_group from langchain_core.outputs import LLMResult from langchain_core.tracers.langchain import LangChainTracer, wait_for_all_tracers +from langchain_core.utils.pydantic import get_fields from langchain_community.callbacks import get_openai_callback from langchain_community.callbacks.manager import get_bedrock_anthropic_callback @@ -44,7 +45,7 @@ def test_callback_manager_configure_context_vars( "completion_tokens": 1, "total_tokens": 3, }, - "model_name": BaseOpenAI.__fields__["model_name"].default, + "model_name": get_fields(BaseOpenAI)["model_name"].default, }, ) mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response) @@ -74,7 +75,7 @@ def test_callback_manager_configure_context_vars( "completion_tokens": 1, "total_tokens": 3, }, - "model_name": BaseOpenAI.__fields__["model_name"].default, + "model_name": get_fields(BaseOpenAI)["model_name"].default, }, ) mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response) diff --git a/libs/community/tests/unit_tests/callbacks/test_openai_info.py b/libs/community/tests/unit_tests/callbacks/test_openai_info.py index c2caa67fb3b..48ab5fd1a9a 100644 --- a/libs/community/tests/unit_tests/callbacks/test_openai_info.py +++ b/libs/community/tests/unit_tests/callbacks/test_openai_info.py @@ -4,6 +4,7 @@ from uuid import uuid4 import numpy as np import pytest from langchain_core.outputs import LLMResult +from langchain_core.utils.pydantic import get_fields from langchain_community.callbacks import OpenAICallbackHandler from langchain_community.llms.openai import BaseOpenAI @@ -23,7 +24,7 @@ def test_on_llm_end(handler: OpenAICallbackHandler) -> None: "completion_tokens": 1, "total_tokens": 3, }, - "model_name": BaseOpenAI.__fields__["model_name"].default, + "model_name": get_fields(BaseOpenAI)["model_name"].default, }, ) handler.on_llm_end(response) diff --git a/libs/community/tests/unit_tests/tools/test_exported.py b/libs/community/tests/unit_tests/tools/test_exported.py index 5ccf8eca892..6dd98bd0d77 100644 --- a/libs/community/tests/unit_tests/tools/test_exported.py +++ b/libs/community/tests/unit_tests/tools/test_exported.py @@ -1,6 +1,7 @@ from typing import List, Type from langchain_core.tools import BaseTool, StructuredTool +from langchain_core.utils.pydantic import get_fields import langchain_community.tools from langchain_community.tools import _DEPRECATED_TOOLS @@ -22,7 +23,7 @@ def _get_tool_classes(skip_tools_without_default_names: bool) -> List[Type[BaseT if isinstance(tool_class, type) and issubclass(tool_class, BaseTool): if tool_class in _EXCLUDE: continue - if skip_tools_without_default_names and tool_class.__fields__[ + if skip_tools_without_default_names and get_fields(tool_class)[ "name" ].default in [ # type: ignore None, @@ -36,6 +37,6 @@ def _get_tool_classes(skip_tools_without_default_names: bool) -> List[Type[BaseT def test_tool_names_unique() -> None: """Test that the default names for our core tools are unique.""" tool_classes = _get_tool_classes(skip_tools_without_default_names=True) - names = sorted([tool_cls.__fields__["name"].default for tool_cls in tool_classes]) + names = sorted([get_fields(tool_cls)["name"].default for tool_cls in tool_classes]) duplicated_names = [name for name in names if names.count(name) > 1] assert not duplicated_names