openai[patch]: expose model request payload (#23287)

![Screenshot 2024-06-21 at 3 12 12
PM](https://github.com/langchain-ai/langchain/assets/22008038/6243a01f-1ef6-4085-9160-2844d9f2b683)
This commit is contained in:
Bagatur 2024-07-02 17:43:55 -04:00 committed by GitHub
parent ed200bf2c4
commit cb9812593f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -481,11 +481,10 @@ class BaseChatOpenAI(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
with self.client.create(messages=message_dicts, **params) as response:
with self.client.create(**payload) as response:
for chunk in response:
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
@ -544,19 +543,25 @@ class BaseChatOpenAI(BaseChatModel):
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.client.create(messages=message_dicts, **params)
payload = self._get_request_payload(messages, stop=stop, **kwargs)
response = self.client.create(**payload)
return self._create_chat_result(response)
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = self._default_params
def _get_request_payload(
self,
input_: LanguageModelInput,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> dict:
messages = self._convert_input(input_).to_messages()
if stop is not None:
params["stop"] = stop
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params
kwargs["stop"] = stop
return {
"messages": [_convert_message_to_dict(m) for m in messages],
**self._default_params,
**kwargs,
}
def _create_chat_result(
self, response: Union[dict, openai.BaseModel]
@ -600,11 +605,10 @@ class BaseChatOpenAI(BaseChatModel):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
response = await self.async_client.create(messages=message_dicts, **params)
response = await self.async_client.create(**payload)
async with response:
async for chunk in response:
if not isinstance(chunk, dict):
@ -666,10 +670,8 @@ class BaseChatOpenAI(BaseChatModel):
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = await self.async_client.create(messages=message_dicts, **params)
payload = self._get_request_payload(messages, stop=stop, **kwargs)
response = await self.async_client.create(**payload)
return self._create_chat_result(response)
@property