mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-05 06:33:20 +00:00
openai[patch]: code cleaning (#17355)
h/t @tdene for finding cleanup op in #17047
This commit is contained in:
parent
a9d6da609a
commit
42648061ad
@ -1,10 +1,10 @@
|
|||||||
"""OpenAI chat wrapper."""
|
"""OpenAI chat wrapper."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
@ -64,6 +64,7 @@ from langchain_core.utils.function_calling import (
|
|||||||
convert_to_openai_function,
|
convert_to_openai_function,
|
||||||
convert_to_openai_tool,
|
convert_to_openai_tool,
|
||||||
)
|
)
|
||||||
|
from langchain_core.utils.utils import build_extra_kwargs
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -290,25 +291,9 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
"""Build extra kwargs from additional params that were passed in."""
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
all_required_field_names = get_pydantic_field_names(cls)
|
all_required_field_names = get_pydantic_field_names(cls)
|
||||||
extra = values.get("model_kwargs", {})
|
extra = values.get("model_kwargs", {})
|
||||||
for field_name in list(values):
|
values["model_kwargs"] = build_extra_kwargs(
|
||||||
if field_name in extra:
|
extra, values, all_required_field_names
|
||||||
raise ValueError(f"Found {field_name} supplied twice.")
|
)
|
||||||
if field_name not in all_required_field_names:
|
|
||||||
warnings.warn(
|
|
||||||
f"""WARNING! {field_name} is not default parameter.
|
|
||||||
{field_name} was transferred to model_kwargs.
|
|
||||||
Please confirm that {field_name} is what you intended."""
|
|
||||||
)
|
|
||||||
extra[field_name] = values.pop(field_name)
|
|
||||||
|
|
||||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
|
||||||
if invalid_model_kwargs:
|
|
||||||
raise ValueError(
|
|
||||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
|
||||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
|
||||||
)
|
|
||||||
|
|
||||||
values["model_kwargs"] = extra
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
@ -339,9 +324,11 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
client_params = {
|
client_params = {
|
||||||
"api_key": values["openai_api_key"].get_secret_value()
|
"api_key": (
|
||||||
if values["openai_api_key"]
|
values["openai_api_key"].get_secret_value()
|
||||||
else None,
|
if values["openai_api_key"]
|
||||||
|
else None
|
||||||
|
),
|
||||||
"organization": values["openai_organization"],
|
"organization": values["openai_organization"],
|
||||||
"base_url": values["openai_api_base"],
|
"base_url": values["openai_api_base"],
|
||||||
"timeout": values["request_timeout"],
|
"timeout": values["request_timeout"],
|
||||||
|
@ -198,9 +198,11 @@ class BaseOpenAI(BaseLLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
client_params = {
|
client_params = {
|
||||||
"api_key": values["openai_api_key"].get_secret_value()
|
"api_key": (
|
||||||
if values["openai_api_key"]
|
values["openai_api_key"].get_secret_value()
|
||||||
else None,
|
if values["openai_api_key"]
|
||||||
|
else None
|
||||||
|
),
|
||||||
"organization": values["openai_organization"],
|
"organization": values["openai_organization"],
|
||||||
"base_url": values["openai_api_base"],
|
"base_url": values["openai_api_base"],
|
||||||
"timeout": values["request_timeout"],
|
"timeout": values["request_timeout"],
|
||||||
@ -257,9 +259,11 @@ class BaseOpenAI(BaseLLM):
|
|||||||
chunk.text,
|
chunk.text,
|
||||||
chunk=chunk,
|
chunk=chunk,
|
||||||
verbose=self.verbose,
|
verbose=self.verbose,
|
||||||
logprobs=chunk.generation_info["logprobs"]
|
logprobs=(
|
||||||
if chunk.generation_info
|
chunk.generation_info["logprobs"]
|
||||||
else None,
|
if chunk.generation_info
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _astream(
|
async def _astream(
|
||||||
@ -283,9 +287,11 @@ class BaseOpenAI(BaseLLM):
|
|||||||
chunk.text,
|
chunk.text,
|
||||||
chunk=chunk,
|
chunk=chunk,
|
||||||
verbose=self.verbose,
|
verbose=self.verbose,
|
||||||
logprobs=chunk.generation_info["logprobs"]
|
logprobs=(
|
||||||
if chunk.generation_info
|
chunk.generation_info["logprobs"]
|
||||||
else None,
|
if chunk.generation_info
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
@ -334,12 +340,16 @@ class BaseOpenAI(BaseLLM):
|
|||||||
choices.append(
|
choices.append(
|
||||||
{
|
{
|
||||||
"text": generation.text,
|
"text": generation.text,
|
||||||
"finish_reason": generation.generation_info.get("finish_reason")
|
"finish_reason": (
|
||||||
if generation.generation_info
|
generation.generation_info.get("finish_reason")
|
||||||
else None,
|
if generation.generation_info
|
||||||
"logprobs": generation.generation_info.get("logprobs")
|
else None
|
||||||
if generation.generation_info
|
),
|
||||||
else None,
|
"logprobs": (
|
||||||
|
generation.generation_info.get("logprobs")
|
||||||
|
if generation.generation_info
|
||||||
|
else None
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -395,12 +405,16 @@ class BaseOpenAI(BaseLLM):
|
|||||||
choices.append(
|
choices.append(
|
||||||
{
|
{
|
||||||
"text": generation.text,
|
"text": generation.text,
|
||||||
"finish_reason": generation.generation_info.get("finish_reason")
|
"finish_reason": (
|
||||||
if generation.generation_info
|
generation.generation_info.get("finish_reason")
|
||||||
else None,
|
if generation.generation_info
|
||||||
"logprobs": generation.generation_info.get("logprobs")
|
else None
|
||||||
if generation.generation_info
|
),
|
||||||
else None,
|
"logprobs": (
|
||||||
|
generation.generation_info.get("logprobs")
|
||||||
|
if generation.generation_info
|
||||||
|
else None
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user