anthropic[patch]: expose payload (#23291)

![Screenshot 2024-06-21 at 4 56 02
PM](https://github.com/langchain-ai/langchain/assets/22008038/a2c6224f-3741-4502-9607-1a726a0551c9)
This commit is contained in:
Bagatur 2024-07-02 17:43:47 -04:00 committed by GitHub
parent 7a3d8e5a99
commit ed200bf2c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -636,31 +636,28 @@ class ChatAnthropic(BaseChatModel):
values["_async_client"] = anthropic.AsyncClient(**client_params) values["_async_client"] = anthropic.AsyncClient(**client_params)
return values return values
def _format_params( def _get_request_payload(
self, self,
input_: LanguageModelInput,
*, *,
messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs: Dict, **kwargs: Dict,
) -> Dict: ) -> Dict:
# get system prompt if any messages = self._convert_input(input_).to_messages()
system, formatted_messages = _format_messages(messages) system, formatted_messages = _format_messages(messages)
stop_sequences = stop or self.stop_sequences payload = {
rtn = {
"model": self.model, "model": self.model,
"max_tokens": self.max_tokens, "max_tokens": self.max_tokens,
"messages": formatted_messages, "messages": formatted_messages,
"temperature": self.temperature, "temperature": self.temperature,
"top_k": self.top_k, "top_k": self.top_k,
"top_p": self.top_p, "top_p": self.top_p,
"stop_sequences": stop_sequences, "stop_sequences": stop or self.stop_sequences,
"system": system, "system": system,
**self.model_kwargs, **self.model_kwargs,
**kwargs, **kwargs,
} }
rtn = {k: v for k, v in rtn.items() if v is not None} return {k: v for k, v in payload.items() if v is not None}
return rtn
def _stream( def _stream(
self, self,
@ -673,9 +670,10 @@ class ChatAnthropic(BaseChatModel):
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
if stream_usage is None: if stream_usage is None:
stream_usage = self.stream_usage stream_usage = self.stream_usage
params = self._format_params(messages=messages, stop=stop, **kwargs) kwargs["stream"] = True
stream = self._client.messages.create(**params, stream=True) payload = self._get_request_payload(messages, stop=stop, **kwargs)
coerce_content_to_string = not _tools_in_params(params) stream = self._client.messages.create(**payload)
coerce_content_to_string = not _tools_in_params(payload)
for event in stream: for event in stream:
msg = _make_message_chunk_from_anthropic_event( msg = _make_message_chunk_from_anthropic_event(
event, event,
@ -699,9 +697,10 @@ class ChatAnthropic(BaseChatModel):
) -> AsyncIterator[ChatGenerationChunk]: ) -> AsyncIterator[ChatGenerationChunk]:
if stream_usage is None: if stream_usage is None:
stream_usage = self.stream_usage stream_usage = self.stream_usage
params = self._format_params(messages=messages, stop=stop, **kwargs) kwargs["stream"] = True
stream = await self._async_client.messages.create(**params, stream=True) payload = self._get_request_payload(messages, stop=stop, **kwargs)
coerce_content_to_string = not _tools_in_params(params) stream = await self._async_client.messages.create(**payload)
coerce_content_to_string = not _tools_in_params(payload)
async for event in stream: async for event in stream:
msg = _make_message_chunk_from_anthropic_event( msg = _make_message_chunk_from_anthropic_event(
event, event,
@ -748,13 +747,13 @@ class ChatAnthropic(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
params = self._format_params(messages=messages, stop=stop, **kwargs)
if self.streaming: if self.streaming:
stream_iter = self._stream( stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs messages, stop=stop, run_manager=run_manager, **kwargs
) )
return generate_from_stream(stream_iter) return generate_from_stream(stream_iter)
data = self._client.messages.create(**params) payload = self._get_request_payload(messages, stop=stop, **kwargs)
data = self._client.messages.create(**payload)
return self._format_output(data, **kwargs) return self._format_output(data, **kwargs)
async def _agenerate( async def _agenerate(
@ -764,13 +763,13 @@ class ChatAnthropic(BaseChatModel):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
params = self._format_params(messages=messages, stop=stop, **kwargs)
if self.streaming: if self.streaming:
stream_iter = self._astream( stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs messages, stop=stop, run_manager=run_manager, **kwargs
) )
return await agenerate_from_stream(stream_iter) return await agenerate_from_stream(stream_iter)
data = await self._async_client.messages.create(**params) payload = self._get_request_payload(messages, stop=stop, **kwargs)
data = await self._async_client.messages.create(**payload)
return self._format_output(data, **kwargs) return self._format_output(data, **kwargs)
def bind_tools( def bind_tools(