diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index 879dafe8673..8ed2af8a1e2 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -82,7 +82,7 @@ class OpenAI(LLM, BaseModel): return values @property - def _default_params(self) -> Mapping[str, Any]: + def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling OpenAI API.""" normal_params = { "temperature": self.temperature, @@ -115,7 +115,10 @@ class OpenAI(LLM, BaseModel): response = openai("Tell me a joke.") """ - response = self.client.create( - model=self.model_name, prompt=prompt, stop=stop, **self._default_params - ) + params = self._default_params + if stop is not None: + if "stop" in params: + raise ValueError("`stop` found in both the input and default params.") + params["stop"] = stop + response = self.client.create(model=self.model_name, prompt=prompt, **params) return response["choices"][0]["text"] diff --git a/tests/integration_tests/llms/test_openai.py b/tests/integration_tests/llms/test_openai.py index 3519f928599..4050e42ed84 100644 --- a/tests/integration_tests/llms/test_openai.py +++ b/tests/integration_tests/llms/test_openai.py @@ -26,3 +26,21 @@ def test_openai_extra_kwargs() -> None: # Test that if provided twice it errors with pytest.raises(ValueError): OpenAI(foo=3, model_kwargs={"foo": 2}) + + +def test_openai_stop_valid() -> None: + """Test openai stop logic on valid configuration.""" + query = "write an ordered list of five items" + first_llm = OpenAI(stop="3", temperature=0) + first_output = first_llm(query) + second_llm = OpenAI(temperature=0) + second_output = second_llm(query, stop=["3"]) + # Because it stops on new lines, shouldn't return anything + assert first_output == second_output + + +def test_openai_stop_error() -> None: + """Test openai stop logic on bad configuration.""" + llm = OpenAI(stop="3", temperature=0) + with pytest.raises(ValueError): + llm("write an ordered list of five items", stop=["\n"])