openai[patch]: expose model request payload (#23287)

![Screenshot 2024-06-21 at 3 12 12
PM](https://github.com/langchain-ai/langchain/assets/22008038/6243a01f-1ef6-4085-9160-2844d9f2b683)
This commit is contained in:
Bagatur 2024-07-02 17:43:55 -04:00 committed by GitHub
parent ed200bf2c4
commit cb9812593f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -481,11 +481,10 @@ class BaseChatOpenAI(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop) kwargs["stream"] = True
params = {**params, **kwargs, "stream": True} payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
with self.client.create(messages=message_dicts, **params) as response: with self.client.create(**payload) as response:
for chunk in response: for chunk in response:
if not isinstance(chunk, dict): if not isinstance(chunk, dict):
chunk = chunk.model_dump() chunk = chunk.model_dump()
@ -544,19 +543,25 @@ class BaseChatOpenAI(BaseChatModel):
messages, stop=stop, run_manager=run_manager, **kwargs messages, stop=stop, run_manager=run_manager, **kwargs
) )
return generate_from_stream(stream_iter) return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop) payload = self._get_request_payload(messages, stop=stop, **kwargs)
params = {**params, **kwargs} response = self.client.create(**payload)
response = self.client.create(messages=message_dicts, **params)
return self._create_chat_result(response) return self._create_chat_result(response)
def _create_message_dicts( def _get_request_payload(
self, messages: List[BaseMessage], stop: Optional[List[str]] self,
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: input_: LanguageModelInput,
params = self._default_params *,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> dict:
messages = self._convert_input(input_).to_messages()
if stop is not None: if stop is not None:
params["stop"] = stop kwargs["stop"] = stop
message_dicts = [_convert_message_to_dict(m) for m in messages] return {
return message_dicts, params "messages": [_convert_message_to_dict(m) for m in messages],
**self._default_params,
**kwargs,
}
def _create_chat_result( def _create_chat_result(
self, response: Union[dict, openai.BaseModel] self, response: Union[dict, openai.BaseModel]
@ -600,11 +605,10 @@ class BaseChatOpenAI(BaseChatModel):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]: ) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop) kwargs["stream"] = True
params = {**params, **kwargs, "stream": True} payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
response = await self.async_client.create(messages=message_dicts, **params) response = await self.async_client.create(**payload)
async with response: async with response:
async for chunk in response: async for chunk in response:
if not isinstance(chunk, dict): if not isinstance(chunk, dict):
@ -666,10 +670,8 @@ class BaseChatOpenAI(BaseChatModel):
messages, stop=stop, run_manager=run_manager, **kwargs messages, stop=stop, run_manager=run_manager, **kwargs
) )
return await agenerate_from_stream(stream_iter) return await agenerate_from_stream(stream_iter)
payload = self._get_request_payload(messages, stop=stop, **kwargs)
message_dicts, params = self._create_message_dicts(messages, stop) response = await self.async_client.create(**payload)
params = {**params, **kwargs}
response = await self.async_client.create(messages=message_dicts, **params)
return self._create_chat_result(response) return self._create_chat_result(response)
@property @property