From 42648061ad30c28afbe7cc0533e92f365ea9a07f Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Mon, 12 Feb 2024 12:36:12 -0800 Subject: [PATCH] openai[patch]: code cleaning (#17355) h/t @tdene for finding cleanup op in #17047 --- .../langchain_openai/chat_models/base.py | 33 ++++------- .../openai/langchain_openai/llms/base.py | 56 ++++++++++++------- 2 files changed, 45 insertions(+), 44 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index fc1430e2425..57e265a6ef4 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -1,10 +1,10 @@ """OpenAI chat wrapper.""" + from __future__ import annotations import logging import os import sys -import warnings from typing import ( Any, AsyncIterator, @@ -64,6 +64,7 @@ from langchain_core.utils.function_calling import ( convert_to_openai_function, convert_to_openai_tool, ) +from langchain_core.utils.utils import build_extra_kwargs logger = logging.getLogger(__name__) @@ -290,25 +291,9 @@ class ChatOpenAI(BaseChatModel): """Build extra kwargs from additional params that were passed in.""" all_required_field_names = get_pydantic_field_names(cls) extra = values.get("model_kwargs", {}) - for field_name in list(values): - if field_name in extra: - raise ValueError(f"Found {field_name} supplied twice.") - if field_name not in all_required_field_names: - warnings.warn( - f"""WARNING! {field_name} is not default parameter. - {field_name} was transferred to model_kwargs. - Please confirm that {field_name} is what you intended.""" - ) - extra[field_name] = values.pop(field_name) - - invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) - if invalid_model_kwargs: - raise ValueError( - f"Parameters {invalid_model_kwargs} should be specified explicitly. " - f"Instead they were passed in as part of `model_kwargs` parameter." - ) - - values["model_kwargs"] = extra + values["model_kwargs"] = build_extra_kwargs( + extra, values, all_required_field_names + ) return values @root_validator() @@ -339,9 +324,11 @@ class ChatOpenAI(BaseChatModel): ) client_params = { - "api_key": values["openai_api_key"].get_secret_value() - if values["openai_api_key"] - else None, + "api_key": ( + values["openai_api_key"].get_secret_value() + if values["openai_api_key"] + else None + ), "organization": values["openai_organization"], "base_url": values["openai_api_base"], "timeout": values["request_timeout"], diff --git a/libs/partners/openai/langchain_openai/llms/base.py b/libs/partners/openai/langchain_openai/llms/base.py index 9b0ba87d1f4..a298b048d0e 100644 --- a/libs/partners/openai/langchain_openai/llms/base.py +++ b/libs/partners/openai/langchain_openai/llms/base.py @@ -198,9 +198,11 @@ class BaseOpenAI(BaseLLM): ) client_params = { - "api_key": values["openai_api_key"].get_secret_value() - if values["openai_api_key"] - else None, + "api_key": ( + values["openai_api_key"].get_secret_value() + if values["openai_api_key"] + else None + ), "organization": values["openai_organization"], "base_url": values["openai_api_base"], "timeout": values["request_timeout"], @@ -257,9 +259,11 @@ class BaseOpenAI(BaseLLM): chunk.text, chunk=chunk, verbose=self.verbose, - logprobs=chunk.generation_info["logprobs"] - if chunk.generation_info - else None, + logprobs=( + chunk.generation_info["logprobs"] + if chunk.generation_info + else None + ), ) async def _astream( @@ -283,9 +287,11 @@ class BaseOpenAI(BaseLLM): chunk.text, chunk=chunk, verbose=self.verbose, - logprobs=chunk.generation_info["logprobs"] - if chunk.generation_info - else None, + logprobs=( + chunk.generation_info["logprobs"] + if chunk.generation_info + else None + ), ) def _generate( @@ -334,12 +340,16 @@ class BaseOpenAI(BaseLLM): choices.append( { "text": generation.text, - "finish_reason": generation.generation_info.get("finish_reason") - if generation.generation_info - else None, - "logprobs": generation.generation_info.get("logprobs") - if generation.generation_info - else None, + "finish_reason": ( + generation.generation_info.get("finish_reason") + if generation.generation_info + else None + ), + "logprobs": ( + generation.generation_info.get("logprobs") + if generation.generation_info + else None + ), } ) else: @@ -395,12 +405,16 @@ class BaseOpenAI(BaseLLM): choices.append( { "text": generation.text, - "finish_reason": generation.generation_info.get("finish_reason") - if generation.generation_info - else None, - "logprobs": generation.generation_info.get("logprobs") - if generation.generation_info - else None, + "finish_reason": ( + generation.generation_info.get("finish_reason") + if generation.generation_info + else None + ), + "logprobs": ( + generation.generation_info.get("logprobs") + if generation.generation_info + else None + ), } ) else: