mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-17 04:52:00 +00:00
anthropic[patch]: expose payload (#23291)

This commit is contained in:
parent
7a3d8e5a99
commit
ed200bf2c4
@ -636,31 +636,28 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
values["_async_client"] = anthropic.AsyncClient(**client_params)
|
values["_async_client"] = anthropic.AsyncClient(**client_params)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def _format_params(
|
def _get_request_payload(
|
||||||
self,
|
self,
|
||||||
|
input_: LanguageModelInput,
|
||||||
*,
|
*,
|
||||||
messages: List[BaseMessage],
|
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
**kwargs: Dict,
|
**kwargs: Dict,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
# get system prompt if any
|
messages = self._convert_input(input_).to_messages()
|
||||||
system, formatted_messages = _format_messages(messages)
|
system, formatted_messages = _format_messages(messages)
|
||||||
stop_sequences = stop or self.stop_sequences
|
payload = {
|
||||||
rtn = {
|
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"max_tokens": self.max_tokens,
|
"max_tokens": self.max_tokens,
|
||||||
"messages": formatted_messages,
|
"messages": formatted_messages,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"top_k": self.top_k,
|
"top_k": self.top_k,
|
||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
"stop_sequences": stop_sequences,
|
"stop_sequences": stop or self.stop_sequences,
|
||||||
"system": system,
|
"system": system,
|
||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
rtn = {k: v for k, v in rtn.items() if v is not None}
|
return {k: v for k, v in payload.items() if v is not None}
|
||||||
|
|
||||||
return rtn
|
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
@ -673,9 +670,10 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
if stream_usage is None:
|
if stream_usage is None:
|
||||||
stream_usage = self.stream_usage
|
stream_usage = self.stream_usage
|
||||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
kwargs["stream"] = True
|
||||||
stream = self._client.messages.create(**params, stream=True)
|
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||||
coerce_content_to_string = not _tools_in_params(params)
|
stream = self._client.messages.create(**payload)
|
||||||
|
coerce_content_to_string = not _tools_in_params(payload)
|
||||||
for event in stream:
|
for event in stream:
|
||||||
msg = _make_message_chunk_from_anthropic_event(
|
msg = _make_message_chunk_from_anthropic_event(
|
||||||
event,
|
event,
|
||||||
@ -699,9 +697,10 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
) -> AsyncIterator[ChatGenerationChunk]:
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
if stream_usage is None:
|
if stream_usage is None:
|
||||||
stream_usage = self.stream_usage
|
stream_usage = self.stream_usage
|
||||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
kwargs["stream"] = True
|
||||||
stream = await self._async_client.messages.create(**params, stream=True)
|
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||||
coerce_content_to_string = not _tools_in_params(params)
|
stream = await self._async_client.messages.create(**payload)
|
||||||
|
coerce_content_to_string = not _tools_in_params(payload)
|
||||||
async for event in stream:
|
async for event in stream:
|
||||||
msg = _make_message_chunk_from_anthropic_event(
|
msg = _make_message_chunk_from_anthropic_event(
|
||||||
event,
|
event,
|
||||||
@ -748,13 +747,13 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
stream_iter = self._stream(
|
stream_iter = self._stream(
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
)
|
)
|
||||||
return generate_from_stream(stream_iter)
|
return generate_from_stream(stream_iter)
|
||||||
data = self._client.messages.create(**params)
|
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||||
|
data = self._client.messages.create(**payload)
|
||||||
return self._format_output(data, **kwargs)
|
return self._format_output(data, **kwargs)
|
||||||
|
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
@ -764,13 +763,13 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
stream_iter = self._astream(
|
stream_iter = self._astream(
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
)
|
)
|
||||||
return await agenerate_from_stream(stream_iter)
|
return await agenerate_from_stream(stream_iter)
|
||||||
data = await self._async_client.messages.create(**params)
|
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||||
|
data = await self._async_client.messages.create(**payload)
|
||||||
return self._format_output(data, **kwargs)
|
return self._format_output(data, **kwargs)
|
||||||
|
|
||||||
def bind_tools(
|
def bind_tools(
|
||||||
|
Loading…
Reference in New Issue
Block a user