diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index 2affb86de5e..c80bcbc0190 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -1,7 +1,7 @@ """Wrapper around OpenAI APIs.""" from typing import Any, Dict, List, Mapping, Optional -from pydantic import BaseModel, Extra, root_validator +from pydantic import BaseModel, Extra, Field, root_validator from langchain.llms.base import LLM from langchain.utils import get_from_dict_or_env @@ -13,6 +13,9 @@ class OpenAI(LLM, BaseModel): To use, you should have the ``openai`` python package installed, and the environment variable ``OPENAI_API_KEY`` set with your API key. + Any parameters that are valid to be passed to the openai.create call can be passed + in, even if not explicitly saved on this class. + Example: .. code-block:: python @@ -37,7 +40,8 @@ class OpenAI(LLM, BaseModel): """How many completions to generate for each prompt.""" best_of: int = 1 """Generates best_of completions server-side and returns the "best".""" - + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Holds any model parameters valid for `create` call not explicitly specified.""" openai_api_key: Optional[str] = None class Config: @@ -45,6 +49,20 @@ class OpenAI(LLM, BaseModel): extra = Extra.forbid + @root_validator(pre=True) + def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Build extra kwargs from additional params that were passed in.""" + all_required_field_names = {field.alias for field in cls.__fields__.values()} + + extra = values.get("model_kwargs", {}) + for field_name in list(values): + if field_name not in all_required_field_names: + if field_name in extra: + raise ValueError(f"Found {field_name} supplied twice.") + extra[field_name] = values.pop(field_name) + values["model_kwargs"] = extra + return values + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" @@ -66,7 +84,7 @@ class OpenAI(LLM, BaseModel): @property def _default_params(self) -> Mapping[str, Any]: """Get the default parameters for calling OpenAI API.""" - return { + normal_params = { "temperature": self.temperature, "max_tokens": self.max_tokens, "top_p": self.top_p, @@ -75,6 +93,7 @@ class OpenAI(LLM, BaseModel): "n": self.n, "best_of": self.best_of, } + return {**normal_params, **self.model_kwargs} @property def _identifying_params(self) -> Mapping[str, Any]: diff --git a/tests/integration_tests/llms/test_openai.py b/tests/integration_tests/llms/test_openai.py index 8667c6f0bb3..3519f928599 100644 --- a/tests/integration_tests/llms/test_openai.py +++ b/tests/integration_tests/llms/test_openai.py @@ -1,5 +1,7 @@ """Test OpenAI API wrapper.""" +import pytest + from langchain.llms.openai import OpenAI @@ -8,3 +10,19 @@ def test_openai_call() -> None: llm = OpenAI(max_tokens=10) output = llm("Say foo:") assert isinstance(output, str) + + +def test_openai_extra_kwargs() -> None: + """Test extra kwargs to openai.""" + # Check that foo is saved in extra_kwargs. + llm = OpenAI(foo=3, max_tokens=10) + assert llm.max_tokens == 10 + assert llm.model_kwargs == {"foo": 3} + + # Test that if extra_kwargs are provided, they are added to it. + llm = OpenAI(foo=3, model_kwargs={"bar": 2}) + assert llm.model_kwargs == {"foo": 3, "bar": 2} + + # Test that if provided twice it errors + with pytest.raises(ValueError): + OpenAI(foo=3, model_kwargs={"foo": 2})