From ed200bf2c46c2fa43b8329e966f93899f9e12ae2 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Tue, 2 Jul 2024 17:43:47 -0400 Subject: [PATCH] anthropic[patch]: expose payload (#23291) ![Screenshot 2024-06-21 at 4 56 02 PM](https://github.com/langchain-ai/langchain/assets/22008038/a2c6224f-3741-4502-9607-1a726a0551c9) --- .../langchain_anthropic/chat_models.py | 37 +++++++++---------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index bd18af6c670..61377694be5 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -636,31 +636,28 @@ class ChatAnthropic(BaseChatModel): values["_async_client"] = anthropic.AsyncClient(**client_params) return values - def _format_params( + def _get_request_payload( self, + input_: LanguageModelInput, *, - messages: List[BaseMessage], stop: Optional[List[str]] = None, **kwargs: Dict, ) -> Dict: - # get system prompt if any + messages = self._convert_input(input_).to_messages() system, formatted_messages = _format_messages(messages) - stop_sequences = stop or self.stop_sequences - rtn = { + payload = { "model": self.model, "max_tokens": self.max_tokens, "messages": formatted_messages, "temperature": self.temperature, "top_k": self.top_k, "top_p": self.top_p, - "stop_sequences": stop_sequences, + "stop_sequences": stop or self.stop_sequences, "system": system, **self.model_kwargs, **kwargs, } - rtn = {k: v for k, v in rtn.items() if v is not None} - - return rtn + return {k: v for k, v in payload.items() if v is not None} def _stream( self, @@ -673,9 +670,10 @@ class ChatAnthropic(BaseChatModel): ) -> Iterator[ChatGenerationChunk]: if stream_usage is None: stream_usage = self.stream_usage - params = self._format_params(messages=messages, stop=stop, **kwargs) - stream = self._client.messages.create(**params, stream=True) - coerce_content_to_string = not _tools_in_params(params) + kwargs["stream"] = True + payload = self._get_request_payload(messages, stop=stop, **kwargs) + stream = self._client.messages.create(**payload) + coerce_content_to_string = not _tools_in_params(payload) for event in stream: msg = _make_message_chunk_from_anthropic_event( event, @@ -699,9 +697,10 @@ class ChatAnthropic(BaseChatModel): ) -> AsyncIterator[ChatGenerationChunk]: if stream_usage is None: stream_usage = self.stream_usage - params = self._format_params(messages=messages, stop=stop, **kwargs) - stream = await self._async_client.messages.create(**params, stream=True) - coerce_content_to_string = not _tools_in_params(params) + kwargs["stream"] = True + payload = self._get_request_payload(messages, stop=stop, **kwargs) + stream = await self._async_client.messages.create(**payload) + coerce_content_to_string = not _tools_in_params(payload) async for event in stream: msg = _make_message_chunk_from_anthropic_event( event, @@ -748,13 +747,13 @@ class ChatAnthropic(BaseChatModel): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - params = self._format_params(messages=messages, stop=stop, **kwargs) if self.streaming: stream_iter = self._stream( messages, stop=stop, run_manager=run_manager, **kwargs ) return generate_from_stream(stream_iter) - data = self._client.messages.create(**params) + payload = self._get_request_payload(messages, stop=stop, **kwargs) + data = self._client.messages.create(**payload) return self._format_output(data, **kwargs) async def _agenerate( @@ -764,13 +763,13 @@ class ChatAnthropic(BaseChatModel): run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - params = self._format_params(messages=messages, stop=stop, **kwargs) if self.streaming: stream_iter = self._astream( messages, stop=stop, run_manager=run_manager, **kwargs ) return await agenerate_from_stream(stream_iter) - data = await self._async_client.messages.create(**params) + payload = self._get_request_payload(messages, stop=stop, **kwargs) + data = await self._async_client.messages.create(**payload) return self._format_output(data, **kwargs) def bind_tools(