mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-02 04:58:46 +00:00
Check OpenAI model kwargs (#4366)
Handle duplicate and incorrectly specified OpenAI params Thanks @PawelFaron for the fix! Made small update Closes #4331 --------- Co-authored-by: PawelFaron <42373772+PawelFaron@users.noreply.github.com> Co-authored-by: Pawel Faron <ext-pawel.faron@vaisala.com>
This commit is contained in:
parent
02ebb15c4a
commit
ba0057c077
@ -143,10 +143,24 @@ class ChatOpenAI(BaseChatModel):
|
||||
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
if field_name not in all_required_field_names:
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
logger.warning(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
|
||||
disallowed_model_kwargs = all_required_field_names | {"model"}
|
||||
invalid_model_kwargs = disallowed_model_kwargs.intersection(extra.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||
)
|
||||
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
|
@ -186,15 +186,24 @@ class BaseOpenAI(BaseLLM):
|
||||
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
if field_name not in all_required_field_names:
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
logger.warning(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transfered to model_kwargs.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
|
||||
disallowed_model_kwargs = all_required_field_names | {"model"}
|
||||
invalid_model_kwargs = disallowed_model_kwargs.intersection(extra.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||
)
|
||||
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@ -422,7 +431,7 @@ class BaseOpenAI(BaseLLM):
|
||||
def prep_streaming_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
|
||||
"""Prepare the params for streaming."""
|
||||
params = self._invocation_params
|
||||
if params["best_of"] != 1:
|
||||
if "best_of" in params and params["best_of"] != 1:
|
||||
raise ValueError("OpenAI only supports best_of == 1 for streaming")
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
|
@ -147,3 +147,27 @@ async def test_async_chat_openai_streaming() -> None:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
|
||||
|
||||
def test_chat_openai_extra_kwargs() -> None:
|
||||
"""Test extra kwargs to chat openai."""
|
||||
# Check that foo is saved in extra_kwargs.
|
||||
llm = ChatOpenAI(foo=3, max_tokens=10)
|
||||
assert llm.max_tokens == 10
|
||||
assert llm.model_kwargs == {"foo": 3}
|
||||
|
||||
# Test that if extra_kwargs are provided, they are added to it.
|
||||
llm = ChatOpenAI(foo=3, model_kwargs={"bar": 2})
|
||||
assert llm.model_kwargs == {"foo": 3, "bar": 2}
|
||||
|
||||
# Test that if provided twice it errors
|
||||
with pytest.raises(ValueError):
|
||||
ChatOpenAI(foo=3, model_kwargs={"foo": 2})
|
||||
|
||||
# Test that if explicit param is specified in kwargs it errors
|
||||
with pytest.raises(ValueError):
|
||||
ChatOpenAI(model_kwargs={"temperature": 0.2})
|
||||
|
||||
# Test that "model" cannot be specified in kwargs
|
||||
with pytest.raises(ValueError):
|
||||
ChatOpenAI(model_kwargs={"model": "text-davinci-003"})
|
||||
|
@ -34,6 +34,14 @@ def test_openai_extra_kwargs() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
OpenAI(foo=3, model_kwargs={"foo": 2})
|
||||
|
||||
# Test that if explicit param is specified in kwargs it errors
|
||||
with pytest.raises(ValueError):
|
||||
OpenAI(model_kwargs={"temperature": 0.2})
|
||||
|
||||
# Test that "model" cannot be specified in kwargs
|
||||
with pytest.raises(ValueError):
|
||||
OpenAI(model_kwargs={"model": "text-davinci-003"})
|
||||
|
||||
|
||||
def test_openai_llm_output_contains_model_name() -> None:
|
||||
"""Test llm_output contains model_name."""
|
||||
|
Loading…
Reference in New Issue
Block a user