mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
Allowing additional params for OpenAIEmbeddings. (#7752)
(#7654) --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
862268175e
commit
0d23c0c82a
@ -36,7 +36,7 @@ from langchain.schema import (
|
|||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -155,7 +155,7 @@ class JinaChat(BaseChatModel):
|
|||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Build extra kwargs from additional params that were passed in."""
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
all_required_field_names = cls._all_required_field_names()
|
all_required_field_names = get_pydantic_field_names(cls)
|
||||||
extra = values.get("model_kwargs", {})
|
extra = values.get("model_kwargs", {})
|
||||||
for field_name in list(values):
|
for field_name in list(values):
|
||||||
if field_name in extra:
|
if field_name in extra:
|
||||||
|
@ -41,7 +41,7 @@ from langchain.schema.messages import (
|
|||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import tiktoken
|
import tiktoken
|
||||||
@ -205,7 +205,7 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Build extra kwargs from additional params that were passed in."""
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
all_required_field_names = cls._all_required_field_names()
|
all_required_field_names = get_pydantic_field_names(cls)
|
||||||
extra = values.get("model_kwargs", {})
|
extra = values.get("model_kwargs", {})
|
||||||
for field_name in list(values):
|
for field_name in list(values):
|
||||||
if field_name in extra:
|
if field_name in extra:
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import warnings
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
@ -16,7 +17,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, Extra, Field, root_validator
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
AsyncRetrying,
|
AsyncRetrying,
|
||||||
before_sleep_log,
|
before_sleep_log,
|
||||||
@ -27,7 +28,7 @@ from tenacity import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -193,12 +194,40 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
when tiktoken is called, you can specify a model name to use here."""
|
when tiktoken is called, you can specify a model name to use here."""
|
||||||
show_progress_bar: bool = False
|
show_progress_bar: bool = False
|
||||||
"""Whether to show a progress bar when embedding."""
|
"""Whether to show a progress bar when embedding."""
|
||||||
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
|
all_required_field_names = get_pydantic_field_names(cls)
|
||||||
|
extra = values.get("model_kwargs", {})
|
||||||
|
for field_name in list(values):
|
||||||
|
if field_name in extra:
|
||||||
|
raise ValueError(f"Found {field_name} supplied twice.")
|
||||||
|
if field_name not in all_required_field_names:
|
||||||
|
warnings.warn(
|
||||||
|
f"""WARNING! {field_name} is not default parameter.
|
||||||
|
{field_name} was transferred to model_kwargs.
|
||||||
|
Please confirm that {field_name} is what you intended."""
|
||||||
|
)
|
||||||
|
extra[field_name] = values.pop(field_name)
|
||||||
|
|
||||||
|
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||||
|
if invalid_model_kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||||
|
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||||
|
)
|
||||||
|
|
||||||
|
values["model_kwargs"] = extra
|
||||||
|
return values
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
@ -261,6 +290,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
"api_base": self.openai_api_base,
|
"api_base": self.openai_api_base,
|
||||||
"api_type": self.openai_api_type,
|
"api_type": self.openai_api_type,
|
||||||
"api_version": self.openai_api_version,
|
"api_version": self.openai_api_version,
|
||||||
|
**self.model_kwargs,
|
||||||
}
|
}
|
||||||
if self.openai_api_type in ("azure", "azure_ad", "azuread"):
|
if self.openai_api_type in ("azure", "azure_ad", "azuread"):
|
||||||
openai_args["engine"] = self.deployment
|
openai_args["engine"] = self.deployment
|
||||||
|
@ -28,7 +28,7 @@ from langchain.callbacks.manager import (
|
|||||||
)
|
)
|
||||||
from langchain.llms.base import BaseLLM, create_base_retry_decorator
|
from langchain.llms.base import BaseLLM, create_base_retry_decorator
|
||||||
from langchain.schema import Generation, LLMResult
|
from langchain.schema import Generation, LLMResult
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -186,13 +186,13 @@ class BaseOpenAI(BaseLLM):
|
|||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Build extra kwargs from additional params that were passed in."""
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
all_required_field_names = cls._all_required_field_names()
|
all_required_field_names = get_pydantic_field_names(cls)
|
||||||
extra = values.get("model_kwargs", {})
|
extra = values.get("model_kwargs", {})
|
||||||
for field_name in list(values):
|
for field_name in list(values):
|
||||||
if field_name in extra:
|
if field_name in extra:
|
||||||
raise ValueError(f"Found {field_name} supplied twice.")
|
raise ValueError(f"Found {field_name} supplied twice.")
|
||||||
if field_name not in all_required_field_names:
|
if field_name not in all_required_field_names:
|
||||||
logger.warning(
|
warnings.warn(
|
||||||
f"""WARNING! {field_name} is not default parameter.
|
f"""WARNING! {field_name} is not default parameter.
|
||||||
{field_name} was transferred to model_kwargs.
|
{field_name} was transferred to model_kwargs.
|
||||||
Please confirm that {field_name} is what you intended."""
|
Please confirm that {field_name} is what you intended."""
|
||||||
|
@ -7,6 +7,7 @@ from langchain.load.serializable import Serializable
|
|||||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||||
from langchain.schema.output import LLMResult
|
from langchain.schema.output import LLMResult
|
||||||
from langchain.schema.prompt import PromptValue
|
from langchain.schema.prompt import PromptValue
|
||||||
|
from langchain.utils import get_pydantic_field_names
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
@ -246,9 +247,8 @@ class BaseLanguageModel(Serializable, ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _all_required_field_names(cls) -> Set:
|
def _all_required_field_names(cls) -> Set:
|
||||||
all_required_field_names = set()
|
"""DEPRECATED: Kept for backwards compatibility.
|
||||||
for field in cls.__fields__.values():
|
|
||||||
all_required_field_names.add(field.name)
|
Use get_pydantic_field_names.
|
||||||
if field.has_alias:
|
"""
|
||||||
all_required_field_names.add(field.alias)
|
return get_pydantic_field_names(cls)
|
||||||
return all_required_field_names
|
|
||||||
|
@ -4,7 +4,7 @@ import datetime
|
|||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from packaging.version import parse
|
from packaging.version import parse
|
||||||
from requests import HTTPError, Response
|
from requests import HTTPError, Response
|
||||||
@ -183,3 +183,16 @@ def check_package_version(
|
|||||||
f"Expected {package} version to be >= {gte_version}. Received "
|
f"Expected {package} version to be >= {gte_version}. Received "
|
||||||
f"{imported_version}."
|
f"{imported_version}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_pydantic_field_names(pydantic_cls: Any) -> Set:
|
||||||
|
"""Get field names, including aliases, for a pydantic class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pydantic_cls: Pydantic class."""
|
||||||
|
all_required_field_names = set()
|
||||||
|
for field in pydantic_cls.__fields__.values():
|
||||||
|
all_required_field_names.add(field.name)
|
||||||
|
if field.has_alias:
|
||||||
|
all_required_field_names.add(field.alias)
|
||||||
|
return all_required_field_names
|
||||||
|
4
poetry.lock
generated
4
poetry.lock
generated
@ -12847,7 +12847,7 @@ clarifai = ["clarifai"]
|
|||||||
cohere = ["cohere"]
|
cohere = ["cohere"]
|
||||||
docarray = ["docarray"]
|
docarray = ["docarray"]
|
||||||
embeddings = ["sentence-transformers"]
|
embeddings = ["sentence-transformers"]
|
||||||
extended-testing = ["atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "gql", "html2text", "jq", "lxml", "mwparserfromhell", "mwxml", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "zep-python"]
|
extended-testing = ["atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "gql", "html2text", "jq", "lxml", "mwparserfromhell", "mwxml", "openai", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "zep-python"]
|
||||||
javascript = ["esprima"]
|
javascript = ["esprima"]
|
||||||
llms = ["anthropic", "clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openllm", "openlm", "torch", "transformers"]
|
llms = ["anthropic", "clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openllm", "openlm", "torch", "transformers"]
|
||||||
openai = ["openai", "tiktoken"]
|
openai = ["openai", "tiktoken"]
|
||||||
@ -12857,4 +12857,4 @@ text-helpers = ["chardet"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
content-hash = "cae082b5f45fe5564de8320fd1f39370f5e59389bf3aaa72291be531bce2e705"
|
content-hash = "f322b36103013bd59c34dddadf84209292ea61ed73bd26fbfa355d372011238b"
|
||||||
|
@ -362,6 +362,7 @@ extended_testing = [
|
|||||||
"openai",
|
"openai",
|
||||||
"sympy",
|
"sympy",
|
||||||
"rapidfuzz",
|
"rapidfuzz",
|
||||||
|
"openai",
|
||||||
"rank_bm25",
|
"rank_bm25",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
0
tests/unit_tests/embeddings/__init__.py
Normal file
0
tests/unit_tests/embeddings/__init__.py
Normal file
20
tests/unit_tests/embeddings/test_openai.py
Normal file
20
tests/unit_tests/embeddings/test_openai.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||||
|
|
||||||
|
os.environ["OPENAI_API_KEY"] = "foo"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_openai_invalid_model_kwargs() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
OpenAIEmbeddings(model_kwargs={"model": "foo"})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_openai_incorrect_field() -> None:
|
||||||
|
with pytest.warns(match="not default parameter"):
|
||||||
|
llm = OpenAIEmbeddings(foo="bar")
|
||||||
|
assert llm.model_kwargs == {"foo": "bar"}
|
28
tests/unit_tests/llms/test_openai.py
Normal file
28
tests/unit_tests/llms/test_openai.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.llms.openai import OpenAI
|
||||||
|
|
||||||
|
os.environ["OPENAI_API_KEY"] = "foo"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_openai_model_param() -> None:
|
||||||
|
llm = OpenAI(model="foo")
|
||||||
|
assert llm.model_name == "foo"
|
||||||
|
llm = OpenAI(model_name="foo")
|
||||||
|
assert llm.model_name == "foo"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_openai_invalid_model_kwargs() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
OpenAI(model_kwargs={"model_name": "foo"})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_openai_incorrect_field() -> None:
|
||||||
|
with pytest.warns(match="not default parameter"):
|
||||||
|
llm = OpenAI(foo="bar")
|
||||||
|
assert llm.model_kwargs == {"foo": "bar"}
|
Loading…
Reference in New Issue
Block a user