mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
Harrison/chat token usage (#1785)
This commit is contained in:
parent
7de2ada3ea
commit
ef4945af6b
@ -12,7 +12,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 1,
|
||||||
"id": "522686de",
|
"id": "522686de",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -36,7 +36,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 2,
|
||||||
"id": "62e0dbc3",
|
"id": "62e0dbc3",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -56,7 +56,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 3,
|
||||||
"id": "76a6e7b0-e927-4bfb-a414-1332a4149106",
|
"id": "76a6e7b0-e927-4bfb-a414-1332a4149106",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -68,7 +68,7 @@
|
|||||||
"AIMessage(content=\"J'aime programmer.\", additional_kwargs={})"
|
"AIMessage(content=\"J'aime programmer.\", additional_kwargs={})"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 4,
|
"execution_count": 3,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -87,7 +87,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 4,
|
||||||
"id": "ce16ad78-8e6f-48cd-954e-98be75eb5836",
|
"id": "ce16ad78-8e6f-48cd-954e-98be75eb5836",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -99,7 +99,7 @@
|
|||||||
"AIMessage(content=\"J'aime programmer.\", additional_kwargs={})"
|
"AIMessage(content=\"J'aime programmer.\", additional_kwargs={})"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 5,
|
"execution_count": 4,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -122,7 +122,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 5,
|
||||||
"id": "2b21fc52-74b6-4950-ab78-45d12c68fb4d",
|
"id": "2b21fc52-74b6-4950-ab78-45d12c68fb4d",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -131,10 +131,10 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"LLMResult(generations=[[ChatGeneration(text=\"J'aime programmer.\", generation_info=None, message=AIMessage(content=\"J'aime programmer.\", additional_kwargs={}))], [ChatGeneration(text=\"J'aime l'intelligence artificielle.\", generation_info=None, message=AIMessage(content=\"J'aime l'intelligence artificielle.\", additional_kwargs={}))]], llm_output=None)"
|
"LLMResult(generations=[[ChatGeneration(text=\"J'aime programmer.\", generation_info=None, message=AIMessage(content=\"J'aime programmer.\", additional_kwargs={}))], [ChatGeneration(text=\"J'aime l'intelligence artificielle.\", generation_info=None, message=AIMessage(content=\"J'aime l'intelligence artificielle.\", additional_kwargs={}))]], llm_output={'token_usage': {'prompt_tokens': 71, 'completion_tokens': 18, 'total_tokens': 89}})"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 6,
|
"execution_count": 5,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -150,7 +150,39 @@
|
|||||||
" HumanMessage(content=\"Translate this sentence from English to French. I love artificial intelligence.\")\n",
|
" HumanMessage(content=\"Translate this sentence from English to French. I love artificial intelligence.\")\n",
|
||||||
" ],\n",
|
" ],\n",
|
||||||
"]\n",
|
"]\n",
|
||||||
"chat.generate(batch_messages)"
|
"result = chat.generate(batch_messages)\n",
|
||||||
|
"result"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "2960f50f",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"You can recover things like token usage from this LLMResult"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "a6186bee",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"{'token_usage': {'prompt_tokens': 71,\n",
|
||||||
|
" 'completion_tokens': 18,\n",
|
||||||
|
" 'total_tokens': 89}}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"result.llm_output"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -43,19 +43,26 @@ class BaseChatModel(BaseLanguageModel, BaseModel, ABC):
|
|||||||
"""
|
"""
|
||||||
return callback_manager or get_callback_manager()
|
return callback_manager or get_callback_manager()
|
||||||
|
|
||||||
|
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||||
|
return {}
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Top Level call"""
|
"""Top Level call"""
|
||||||
results = [self._generate(m, stop=stop) for m in messages]
|
results = [self._generate(m, stop=stop) for m in messages]
|
||||||
return LLMResult(generations=[res.generations for res in results])
|
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
|
||||||
|
generations = [res.generations for res in results]
|
||||||
|
return LLMResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
async def agenerate(
|
async def agenerate(
|
||||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Top Level call"""
|
"""Top Level call"""
|
||||||
results = [await self._agenerate(m, stop=stop) for m in messages]
|
results = [await self._agenerate(m, stop=stop) for m in messages]
|
||||||
return LLMResult(generations=[res.generations for res in results])
|
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
|
||||||
|
generations = [res.generations for res in results]
|
||||||
|
return LLMResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
def generate_prompt(
|
def generate_prompt(
|
||||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||||
|
@ -97,7 +97,8 @@ def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
|
|||||||
message = _convert_dict_to_message(res["message"])
|
message = _convert_dict_to_message(res["message"])
|
||||||
gen = ChatGeneration(message=message)
|
gen = ChatGeneration(message=message)
|
||||||
generations.append(gen)
|
generations.append(gen)
|
||||||
return ChatResult(generations=generations)
|
llm_output = {"token_usage": response["usage"]}
|
||||||
|
return ChatResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
|
|
||||||
class ChatOpenAI(BaseChatModel, BaseModel):
|
class ChatOpenAI(BaseChatModel, BaseModel):
|
||||||
@ -221,6 +222,19 @@ class ChatOpenAI(BaseChatModel, BaseModel):
|
|||||||
|
|
||||||
return _completion_with_retry(**kwargs)
|
return _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
|
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||||
|
overall_token_usage: dict = {}
|
||||||
|
for output in llm_outputs:
|
||||||
|
if output is None:
|
||||||
|
raise ValueError("Should always be something for OpenAI.")
|
||||||
|
token_usage = output["token_usage"]
|
||||||
|
for k, v in token_usage.items():
|
||||||
|
if k in overall_token_usage:
|
||||||
|
overall_token_usage[k] += v
|
||||||
|
else:
|
||||||
|
overall_token_usage[k] = v
|
||||||
|
return {"token_usage": overall_token_usage}
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
|
Loading…
Reference in New Issue
Block a user