anthropic[patch]: expose payload (#23291)

![Screenshot 2024-06-21 at 4 56 02
PM](https://github.com/langchain-ai/langchain/assets/22008038/a2c6224f-3741-4502-9607-1a726a0551c9)
This commit is contained in:
Bagatur 2024-07-02 17:43:47 -04:00 committed by GitHub
parent 7a3d8e5a99
commit ed200bf2c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -636,31 +636,28 @@ class ChatAnthropic(BaseChatModel):
values["_async_client"] = anthropic.AsyncClient(**client_params)
return values
def _format_params(
def _get_request_payload(
self,
input_: LanguageModelInput,
*,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Dict,
) -> Dict:
# get system prompt if any
messages = self._convert_input(input_).to_messages()
system, formatted_messages = _format_messages(messages)
stop_sequences = stop or self.stop_sequences
rtn = {
payload = {
"model": self.model,
"max_tokens": self.max_tokens,
"messages": formatted_messages,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"stop_sequences": stop_sequences,
"stop_sequences": stop or self.stop_sequences,
"system": system,
**self.model_kwargs,
**kwargs,
}
rtn = {k: v for k, v in rtn.items() if v is not None}
return rtn
return {k: v for k, v in payload.items() if v is not None}
def _stream(
self,
@ -673,9 +670,10 @@ class ChatAnthropic(BaseChatModel):
) -> Iterator[ChatGenerationChunk]:
if stream_usage is None:
stream_usage = self.stream_usage
params = self._format_params(messages=messages, stop=stop, **kwargs)
stream = self._client.messages.create(**params, stream=True)
coerce_content_to_string = not _tools_in_params(params)
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
stream = self._client.messages.create(**payload)
coerce_content_to_string = not _tools_in_params(payload)
for event in stream:
msg = _make_message_chunk_from_anthropic_event(
event,
@ -699,9 +697,10 @@ class ChatAnthropic(BaseChatModel):
) -> AsyncIterator[ChatGenerationChunk]:
if stream_usage is None:
stream_usage = self.stream_usage
params = self._format_params(messages=messages, stop=stop, **kwargs)
stream = await self._async_client.messages.create(**params, stream=True)
coerce_content_to_string = not _tools_in_params(params)
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
stream = await self._async_client.messages.create(**payload)
coerce_content_to_string = not _tools_in_params(payload)
async for event in stream:
msg = _make_message_chunk_from_anthropic_event(
event,
@ -748,13 +747,13 @@ class ChatAnthropic(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = self._format_params(messages=messages, stop=stop, **kwargs)
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
data = self._client.messages.create(**params)
payload = self._get_request_payload(messages, stop=stop, **kwargs)
data = self._client.messages.create(**payload)
return self._format_output(data, **kwargs)
async def _agenerate(
@ -764,13 +763,13 @@ class ChatAnthropic(BaseChatModel):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = self._format_params(messages=messages, stop=stop, **kwargs)
if self.streaming:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
data = await self._async_client.messages.create(**params)
payload = self._get_request_payload(messages, stop=stop, **kwargs)
data = await self._async_client.messages.create(**payload)
return self._format_output(data, **kwargs)
def bind_tools(