diff --git a/langchain/chat_models/jinachat.py b/langchain/chat_models/jinachat.py index bbf7c511888..5c7d89b4023 100644 --- a/langchain/chat_models/jinachat.py +++ b/langchain/chat_models/jinachat.py @@ -36,7 +36,7 @@ from langchain.schema import ( HumanMessage, SystemMessage, ) -from langchain.utils import get_from_dict_or_env +from langchain.utils import get_from_dict_or_env, get_pydantic_field_names logger = logging.getLogger(__name__) @@ -155,7 +155,7 @@ class JinaChat(BaseChatModel): @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = cls._all_required_field_names() + all_required_field_names = get_pydantic_field_names(cls) extra = values.get("model_kwargs", {}) for field_name in list(values): if field_name in extra: diff --git a/langchain/chat_models/openai.py b/langchain/chat_models/openai.py index 66bf2761cf1..ebab719e87a 100644 --- a/langchain/chat_models/openai.py +++ b/langchain/chat_models/openai.py @@ -41,7 +41,7 @@ from langchain.schema.messages import ( HumanMessage, SystemMessage, ) -from langchain.utils import get_from_dict_or_env +from langchain.utils import get_from_dict_or_env, get_pydantic_field_names if TYPE_CHECKING: import tiktoken @@ -205,7 +205,7 @@ class ChatOpenAI(BaseChatModel): @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = cls._all_required_field_names() + all_required_field_names = get_pydantic_field_names(cls) extra = values.get("model_kwargs", {}) for field_name in list(values): if field_name in extra: diff --git a/langchain/embeddings/openai.py b/langchain/embeddings/openai.py index 8db55f935d0..e9b3f7f6406 100644 --- a/langchain/embeddings/openai.py +++ b/langchain/embeddings/openai.py @@ -2,6 +2,7 @@ from __future__ import annotations import logging +import warnings from typing import ( Any, Callable, @@ -16,7 +17,7 @@ from typing import ( ) import numpy as np -from pydantic import BaseModel, Extra, root_validator +from pydantic import BaseModel, Extra, Field, root_validator from tenacity import ( AsyncRetrying, before_sleep_log, @@ -27,7 +28,7 @@ from tenacity import ( ) from langchain.embeddings.base import Embeddings -from langchain.utils import get_from_dict_or_env +from langchain.utils import get_from_dict_or_env, get_pydantic_field_names logger = logging.getLogger(__name__) @@ -193,12 +194,40 @@ class OpenAIEmbeddings(BaseModel, Embeddings): when tiktoken is called, you can specify a model name to use here.""" show_progress_bar: bool = False """Whether to show a progress bar when embedding.""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Holds any model parameters valid for `create` call not explicitly specified.""" class Config: """Configuration for this pydantic object.""" extra = Extra.forbid + @root_validator(pre=True) + def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Build extra kwargs from additional params that were passed in.""" + all_required_field_names = get_pydantic_field_names(cls) + extra = values.get("model_kwargs", {}) + for field_name in list(values): + if field_name in extra: + raise ValueError(f"Found {field_name} supplied twice.") + if field_name not in all_required_field_names: + warnings.warn( + f"""WARNING! {field_name} is not default parameter. + {field_name} was transferred to model_kwargs. + Please confirm that {field_name} is what you intended.""" + ) + extra[field_name] = values.pop(field_name) + + invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) + if invalid_model_kwargs: + raise ValueError( + f"Parameters {invalid_model_kwargs} should be specified explicitly. " + f"Instead they were passed in as part of `model_kwargs` parameter." + ) + + values["model_kwargs"] = extra + return values + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" @@ -261,6 +290,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): "api_base": self.openai_api_base, "api_type": self.openai_api_type, "api_version": self.openai_api_version, + **self.model_kwargs, } if self.openai_api_type in ("azure", "azure_ad", "azuread"): openai_args["engine"] = self.deployment diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index e241ed7fb84..dcdb2738a7b 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -28,7 +28,7 @@ from langchain.callbacks.manager import ( ) from langchain.llms.base import BaseLLM, create_base_retry_decorator from langchain.schema import Generation, LLMResult -from langchain.utils import get_from_dict_or_env +from langchain.utils import get_from_dict_or_env, get_pydantic_field_names logger = logging.getLogger(__name__) @@ -186,13 +186,13 @@ class BaseOpenAI(BaseLLM): @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = cls._all_required_field_names() + all_required_field_names = get_pydantic_field_names(cls) extra = values.get("model_kwargs", {}) for field_name in list(values): if field_name in extra: raise ValueError(f"Found {field_name} supplied twice.") if field_name not in all_required_field_names: - logger.warning( + warnings.warn( f"""WARNING! {field_name} is not default parameter. {field_name} was transferred to model_kwargs. Please confirm that {field_name} is what you intended.""" diff --git a/langchain/schema/language_model.py b/langchain/schema/language_model.py index 2c3727bc7e7..19b4de1ef73 100644 --- a/langchain/schema/language_model.py +++ b/langchain/schema/language_model.py @@ -7,6 +7,7 @@ from langchain.load.serializable import Serializable from langchain.schema.messages import BaseMessage, get_buffer_string from langchain.schema.output import LLMResult from langchain.schema.prompt import PromptValue +from langchain.utils import get_pydantic_field_names if TYPE_CHECKING: from langchain.callbacks.manager import Callbacks @@ -246,9 +247,8 @@ class BaseLanguageModel(Serializable, ABC): @classmethod def _all_required_field_names(cls) -> Set: - all_required_field_names = set() - for field in cls.__fields__.values(): - all_required_field_names.add(field.name) - if field.has_alias: - all_required_field_names.add(field.alias) - return all_required_field_names + """DEPRECATED: Kept for backwards compatibility. + + Use get_pydantic_field_names. + """ + return get_pydantic_field_names(cls) diff --git a/langchain/utils.py b/langchain/utils.py index fd03fa6248e..f35a6f3aac2 100644 --- a/langchain/utils.py +++ b/langchain/utils.py @@ -4,7 +4,7 @@ import datetime import importlib import os from importlib.metadata import version -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Set, Tuple from packaging.version import parse from requests import HTTPError, Response @@ -183,3 +183,16 @@ def check_package_version( f"Expected {package} version to be >= {gte_version}. Received " f"{imported_version}." ) + + +def get_pydantic_field_names(pydantic_cls: Any) -> Set: + """Get field names, including aliases, for a pydantic class. + + Args: + pydantic_cls: Pydantic class.""" + all_required_field_names = set() + for field in pydantic_cls.__fields__.values(): + all_required_field_names.add(field.name) + if field.has_alias: + all_required_field_names.add(field.alias) + return all_required_field_names diff --git a/poetry.lock b/poetry.lock index 474dfa23770..273f7fc77a6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -12847,7 +12847,7 @@ clarifai = ["clarifai"] cohere = ["cohere"] docarray = ["docarray"] embeddings = ["sentence-transformers"] -extended-testing = ["atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "gql", "html2text", "jq", "lxml", "mwparserfromhell", "mwxml", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "zep-python"] +extended-testing = ["atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "gql", "html2text", "jq", "lxml", "mwparserfromhell", "mwxml", "openai", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "zep-python"] javascript = ["esprima"] llms = ["anthropic", "clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openllm", "openlm", "torch", "transformers"] openai = ["openai", "tiktoken"] @@ -12857,4 +12857,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "cae082b5f45fe5564de8320fd1f39370f5e59389bf3aaa72291be531bce2e705" +content-hash = "f322b36103013bd59c34dddadf84209292ea61ed73bd26fbfa355d372011238b" diff --git a/pyproject.toml b/pyproject.toml index 1d41c59a7b5..ebfce6c09e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -362,6 +362,7 @@ extended_testing = [ "openai", "sympy", "rapidfuzz", + "openai", "rank_bm25", ] diff --git a/tests/unit_tests/embeddings/__init__.py b/tests/unit_tests/embeddings/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unit_tests/embeddings/test_openai.py b/tests/unit_tests/embeddings/test_openai.py new file mode 100644 index 00000000000..f3c48af22fd --- /dev/null +++ b/tests/unit_tests/embeddings/test_openai.py @@ -0,0 +1,20 @@ +import os + +import pytest + +from langchain.embeddings.openai import OpenAIEmbeddings + +os.environ["OPENAI_API_KEY"] = "foo" + + +@pytest.mark.requires("openai") +def test_openai_invalid_model_kwargs() -> None: + with pytest.raises(ValueError): + OpenAIEmbeddings(model_kwargs={"model": "foo"}) + + +@pytest.mark.requires("openai") +def test_openai_incorrect_field() -> None: + with pytest.warns(match="not default parameter"): + llm = OpenAIEmbeddings(foo="bar") + assert llm.model_kwargs == {"foo": "bar"} diff --git a/tests/unit_tests/llms/test_openai.py b/tests/unit_tests/llms/test_openai.py new file mode 100644 index 00000000000..ef311ea8788 --- /dev/null +++ b/tests/unit_tests/llms/test_openai.py @@ -0,0 +1,28 @@ +import os + +import pytest + +from langchain.llms.openai import OpenAI + +os.environ["OPENAI_API_KEY"] = "foo" + + +@pytest.mark.requires("openai") +def test_openai_model_param() -> None: + llm = OpenAI(model="foo") + assert llm.model_name == "foo" + llm = OpenAI(model_name="foo") + assert llm.model_name == "foo" + + +@pytest.mark.requires("openai") +def test_openai_invalid_model_kwargs() -> None: + with pytest.raises(ValueError): + OpenAI(model_kwargs={"model_name": "foo"}) + + +@pytest.mark.requires("openai") +def test_openai_incorrect_field() -> None: + with pytest.warns(match="not default parameter"): + llm = OpenAI(foo="bar") + assert llm.model_kwargs == {"foo": "bar"}