mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-16 20:42:01 +00:00
anthropic[patch]: expose payload (#23291)

This commit is contained in:
parent
7a3d8e5a99
commit
ed200bf2c4
@ -636,31 +636,28 @@ class ChatAnthropic(BaseChatModel):
|
||||
values["_async_client"] = anthropic.AsyncClient(**client_params)
|
||||
return values
|
||||
|
||||
def _format_params(
|
||||
def _get_request_payload(
|
||||
self,
|
||||
input_: LanguageModelInput,
|
||||
*,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Dict,
|
||||
) -> Dict:
|
||||
# get system prompt if any
|
||||
messages = self._convert_input(input_).to_messages()
|
||||
system, formatted_messages = _format_messages(messages)
|
||||
stop_sequences = stop or self.stop_sequences
|
||||
rtn = {
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"max_tokens": self.max_tokens,
|
||||
"messages": formatted_messages,
|
||||
"temperature": self.temperature,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
"stop_sequences": stop_sequences,
|
||||
"stop_sequences": stop or self.stop_sequences,
|
||||
"system": system,
|
||||
**self.model_kwargs,
|
||||
**kwargs,
|
||||
}
|
||||
rtn = {k: v for k, v in rtn.items() if v is not None}
|
||||
|
||||
return rtn
|
||||
return {k: v for k, v in payload.items() if v is not None}
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
@ -673,9 +670,10 @@ class ChatAnthropic(BaseChatModel):
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
if stream_usage is None:
|
||||
stream_usage = self.stream_usage
|
||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
||||
stream = self._client.messages.create(**params, stream=True)
|
||||
coerce_content_to_string = not _tools_in_params(params)
|
||||
kwargs["stream"] = True
|
||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||
stream = self._client.messages.create(**payload)
|
||||
coerce_content_to_string = not _tools_in_params(payload)
|
||||
for event in stream:
|
||||
msg = _make_message_chunk_from_anthropic_event(
|
||||
event,
|
||||
@ -699,9 +697,10 @@ class ChatAnthropic(BaseChatModel):
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
if stream_usage is None:
|
||||
stream_usage = self.stream_usage
|
||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
||||
stream = await self._async_client.messages.create(**params, stream=True)
|
||||
coerce_content_to_string = not _tools_in_params(params)
|
||||
kwargs["stream"] = True
|
||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||
stream = await self._async_client.messages.create(**payload)
|
||||
coerce_content_to_string = not _tools_in_params(payload)
|
||||
async for event in stream:
|
||||
msg = _make_message_chunk_from_anthropic_event(
|
||||
event,
|
||||
@ -748,13 +747,13 @@ class ChatAnthropic(BaseChatModel):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
||||
if self.streaming:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
data = self._client.messages.create(**params)
|
||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||
data = self._client.messages.create(**payload)
|
||||
return self._format_output(data, **kwargs)
|
||||
|
||||
async def _agenerate(
|
||||
@ -764,13 +763,13 @@ class ChatAnthropic(BaseChatModel):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
||||
if self.streaming:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
data = await self._async_client.messages.create(**params)
|
||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||
data = await self._async_client.messages.create(**payload)
|
||||
return self._format_output(data, **kwargs)
|
||||
|
||||
def bind_tools(
|
||||
|
Loading…
Reference in New Issue
Block a user