Compare commits

...

2 Commits

Author SHA1 Message Date
vowelparrot
9dc40823a3 Add Anthropic Token Counting 2023-06-06 17:10:28 -07:00
vowelparrot
036ade7ca9 Add token usgae for anthropic 2023-06-06 07:33:58 -07:00
2 changed files with 33 additions and 5 deletions

View File

@@ -89,6 +89,19 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
text.rstrip()
) # trim off the trailing ' ' that might come from the "Assistant: "
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
for output in llm_outputs:
if output is None:
continue
token_usage = output["token_usage"]
for k, v in token_usage.items():
if k in overall_token_usage:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
return {"token_usage": overall_token_usage, "model_name": self.model}
def _generate(
self,
messages: List[BaseMessage],
@@ -99,7 +112,6 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
params: Dict[str, Any] = {"prompt": prompt, **self._default_params}
if stop:
params["stop_sequences"] = stop
if self.streaming:
completion = ""
stream_resp = self.client.completion_stream(**params)
@@ -110,11 +122,20 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
run_manager.on_llm_new_token(
delta,
)
else:
response = self.client.completion(**params)
completion = response["completion"]
token_usage = {
"prompt_tokens": self.get_num_tokens(prompt),
"completion_tokens": self.get_num_tokens(completion),
}
token_usage["total_tokens"] = sum(token_usage.values())
message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)])
return ChatResult(
generations=[ChatGeneration(message=message)],
llm_output={"token_usage": token_usage},
)
async def _agenerate(
self,
@@ -126,7 +147,6 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
params: Dict[str, Any] = {"prompt": prompt, **self._default_params}
if stop:
params["stop_sequences"] = stop
if self.streaming:
completion = ""
stream_resp = await self.client.acompletion_stream(**params)
@@ -140,8 +160,16 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
else:
response = await self.client.acompletion(**params)
completion = response["completion"]
token_usage = {
"prompt_tokens": self.get_num_tokens(prompt),
"completion_tokens": self.get_num_tokens(completion),
}
token_usage["total_tokens"] = sum(token_usage.values())
message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)])
return ChatResult(
generations=[ChatGeneration(message=message)],
llm_output={"token_usage": token_usage},
)
def get_num_tokens(self, text: str) -> int:
"""Calculate number of tokens."""

View File

@@ -116,7 +116,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
"run_manager"
)
try:
results = await asyncio.gather(
results: List[ChatResult] = await asyncio.gather(
*[
self._agenerate(m, stop=stop, run_manager=run_manager)
if new_arg_supported