mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 06:14:37 +00:00
Allowing additional params for OpenAIEmbeddings. (#7752)
(#7654) --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
862268175e
commit
0d23c0c82a
@ -36,7 +36,7 @@ from langchain.schema import (
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -155,7 +155,7 @@ class JinaChat(BaseChatModel):
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = cls._all_required_field_names()
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
|
@ -41,7 +41,7 @@ from langchain.schema.messages import (
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import tiktoken
|
||||
@ -205,7 +205,7 @@ class ChatOpenAI(BaseChatModel):
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = cls._all_required_field_names()
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
|
@ -2,6 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@ -16,7 +17,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
from tenacity import (
|
||||
AsyncRetrying,
|
||||
before_sleep_log,
|
||||
@ -27,7 +28,7 @@ from tenacity import (
|
||||
)
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -193,12 +194,40 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
when tiktoken is called, you can specify a model name to use here."""
|
||||
show_progress_bar: bool = False
|
||||
"""Whether to show a progress bar when embedding."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
if field_name not in all_required_field_names:
|
||||
warnings.warn(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
|
||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||
)
|
||||
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
@ -261,6 +290,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
"api_base": self.openai_api_base,
|
||||
"api_type": self.openai_api_type,
|
||||
"api_version": self.openai_api_version,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
if self.openai_api_type in ("azure", "azure_ad", "azuread"):
|
||||
openai_args["engine"] = self.deployment
|
||||
|
@ -28,7 +28,7 @@ from langchain.callbacks.manager import (
|
||||
)
|
||||
from langchain.llms.base import BaseLLM, create_base_retry_decorator
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -186,13 +186,13 @@ class BaseOpenAI(BaseLLM):
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = cls._all_required_field_names()
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
if field_name not in all_required_field_names:
|
||||
logger.warning(
|
||||
warnings.warn(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
|
@ -7,6 +7,7 @@ from langchain.load.serializable import Serializable
|
||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||
from langchain.schema.output import LLMResult
|
||||
from langchain.schema.prompt import PromptValue
|
||||
from langchain.utils import get_pydantic_field_names
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
@ -246,9 +247,8 @@ class BaseLanguageModel(Serializable, ABC):
|
||||
|
||||
@classmethod
|
||||
def _all_required_field_names(cls) -> Set:
|
||||
all_required_field_names = set()
|
||||
for field in cls.__fields__.values():
|
||||
all_required_field_names.add(field.name)
|
||||
if field.has_alias:
|
||||
all_required_field_names.add(field.alias)
|
||||
return all_required_field_names
|
||||
"""DEPRECATED: Kept for backwards compatibility.
|
||||
|
||||
Use get_pydantic_field_names.
|
||||
"""
|
||||
return get_pydantic_field_names(cls)
|
||||
|
@ -4,7 +4,7 @@ import datetime
|
||||
import importlib
|
||||
import os
|
||||
from importlib.metadata import version
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from packaging.version import parse
|
||||
from requests import HTTPError, Response
|
||||
@ -183,3 +183,16 @@ def check_package_version(
|
||||
f"Expected {package} version to be >= {gte_version}. Received "
|
||||
f"{imported_version}."
|
||||
)
|
||||
|
||||
|
||||
def get_pydantic_field_names(pydantic_cls: Any) -> Set:
|
||||
"""Get field names, including aliases, for a pydantic class.
|
||||
|
||||
Args:
|
||||
pydantic_cls: Pydantic class."""
|
||||
all_required_field_names = set()
|
||||
for field in pydantic_cls.__fields__.values():
|
||||
all_required_field_names.add(field.name)
|
||||
if field.has_alias:
|
||||
all_required_field_names.add(field.alias)
|
||||
return all_required_field_names
|
||||
|
4
poetry.lock
generated
4
poetry.lock
generated
@ -12847,7 +12847,7 @@ clarifai = ["clarifai"]
|
||||
cohere = ["cohere"]
|
||||
docarray = ["docarray"]
|
||||
embeddings = ["sentence-transformers"]
|
||||
extended-testing = ["atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "gql", "html2text", "jq", "lxml", "mwparserfromhell", "mwxml", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "zep-python"]
|
||||
extended-testing = ["atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "gql", "html2text", "jq", "lxml", "mwparserfromhell", "mwxml", "openai", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "zep-python"]
|
||||
javascript = ["esprima"]
|
||||
llms = ["anthropic", "clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openllm", "openlm", "torch", "transformers"]
|
||||
openai = ["openai", "tiktoken"]
|
||||
@ -12857,4 +12857,4 @@ text-helpers = ["chardet"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "cae082b5f45fe5564de8320fd1f39370f5e59389bf3aaa72291be531bce2e705"
|
||||
content-hash = "f322b36103013bd59c34dddadf84209292ea61ed73bd26fbfa355d372011238b"
|
||||
|
@ -362,6 +362,7 @@ extended_testing = [
|
||||
"openai",
|
||||
"sympy",
|
||||
"rapidfuzz",
|
||||
"openai",
|
||||
"rank_bm25",
|
||||
]
|
||||
|
||||
|
0
tests/unit_tests/embeddings/__init__.py
Normal file
0
tests/unit_tests/embeddings/__init__.py
Normal file
20
tests/unit_tests/embeddings/test_openai.py
Normal file
20
tests/unit_tests/embeddings/test_openai.py
Normal file
@ -0,0 +1,20 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "foo"
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_openai_invalid_model_kwargs() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
OpenAIEmbeddings(model_kwargs={"model": "foo"})
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_openai_incorrect_field() -> None:
|
||||
with pytest.warns(match="not default parameter"):
|
||||
llm = OpenAIEmbeddings(foo="bar")
|
||||
assert llm.model_kwargs == {"foo": "bar"}
|
28
tests/unit_tests/llms/test_openai.py
Normal file
28
tests/unit_tests/llms/test_openai.py
Normal file
@ -0,0 +1,28 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.llms.openai import OpenAI
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "foo"
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_openai_model_param() -> None:
|
||||
llm = OpenAI(model="foo")
|
||||
assert llm.model_name == "foo"
|
||||
llm = OpenAI(model_name="foo")
|
||||
assert llm.model_name == "foo"
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_openai_invalid_model_kwargs() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
OpenAI(model_kwargs={"model_name": "foo"})
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_openai_incorrect_field() -> None:
|
||||
with pytest.warns(match="not default parameter"):
|
||||
llm = OpenAI(foo="bar")
|
||||
assert llm.model_kwargs == {"foo": "bar"}
|
Loading…
Reference in New Issue
Block a user