openai[patch]: code cleaning (#17355)

h/t @tdene for finding cleanup op in #17047
This commit is contained in:
Erick Friis 2024-02-12 12:36:12 -08:00 committed by GitHub
parent a9d6da609a
commit 42648061ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 45 additions and 44 deletions

View File

@ -1,10 +1,10 @@
"""OpenAI chat wrapper."""
from __future__ import annotations
import logging
import os
import sys
import warnings
from typing import (
Any,
AsyncIterator,
@ -64,6 +64,7 @@ from langchain_core.utils.function_calling import (
convert_to_openai_function,
convert_to_openai_tool,
)
from langchain_core.utils.utils import build_extra_kwargs
logger = logging.getLogger(__name__)
@ -290,25 +291,9 @@ class ChatOpenAI(BaseChatModel):
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
warnings.warn(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
values["model_kwargs"] = build_extra_kwargs(
extra, values, all_required_field_names
)
return values
@root_validator()
@ -339,9 +324,11 @@ class ChatOpenAI(BaseChatModel):
)
client_params = {
"api_key": values["openai_api_key"].get_secret_value()
if values["openai_api_key"]
else None,
"api_key": (
values["openai_api_key"].get_secret_value()
if values["openai_api_key"]
else None
),
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"timeout": values["request_timeout"],

View File

@ -198,9 +198,11 @@ class BaseOpenAI(BaseLLM):
)
client_params = {
"api_key": values["openai_api_key"].get_secret_value()
if values["openai_api_key"]
else None,
"api_key": (
values["openai_api_key"].get_secret_value()
if values["openai_api_key"]
else None
),
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"timeout": values["request_timeout"],
@ -257,9 +259,11 @@ class BaseOpenAI(BaseLLM):
chunk.text,
chunk=chunk,
verbose=self.verbose,
logprobs=chunk.generation_info["logprobs"]
if chunk.generation_info
else None,
logprobs=(
chunk.generation_info["logprobs"]
if chunk.generation_info
else None
),
)
async def _astream(
@ -283,9 +287,11 @@ class BaseOpenAI(BaseLLM):
chunk.text,
chunk=chunk,
verbose=self.verbose,
logprobs=chunk.generation_info["logprobs"]
if chunk.generation_info
else None,
logprobs=(
chunk.generation_info["logprobs"]
if chunk.generation_info
else None
),
)
def _generate(
@ -334,12 +340,16 @@ class BaseOpenAI(BaseLLM):
choices.append(
{
"text": generation.text,
"finish_reason": generation.generation_info.get("finish_reason")
if generation.generation_info
else None,
"logprobs": generation.generation_info.get("logprobs")
if generation.generation_info
else None,
"finish_reason": (
generation.generation_info.get("finish_reason")
if generation.generation_info
else None
),
"logprobs": (
generation.generation_info.get("logprobs")
if generation.generation_info
else None
),
}
)
else:
@ -395,12 +405,16 @@ class BaseOpenAI(BaseLLM):
choices.append(
{
"text": generation.text,
"finish_reason": generation.generation_info.get("finish_reason")
if generation.generation_info
else None,
"logprobs": generation.generation_info.get("logprobs")
if generation.generation_info
else None,
"finish_reason": (
generation.generation_info.get("finish_reason")
if generation.generation_info
else None
),
"logprobs": (
generation.generation_info.get("logprobs")
if generation.generation_info
else None
),
}
)
else: