mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
mistralai[patch]: add missing _combine_llm_outputs implementation in ChatMistralAI (#18603)
# Description Implementing `_combine_llm_outputs` to `ChatMistralAI` to override the default implementation in `BaseChatModel` returning `{}`. The implementation is inspired by the one in `ChatOpenAI` from package `langchain-openai`. # Issue None # Dependencies None # Twitter handle None --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
0175906437
commit
ace7b66261
@ -250,6 +250,22 @@ class ChatMistralAI(BaseChatModel):
|
||||
rtn = _completion_with_retry(**kwargs)
|
||||
return rtn
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
overall_token_usage: dict = {}
|
||||
for output in llm_outputs:
|
||||
if output is None:
|
||||
# Happens in streaming
|
||||
continue
|
||||
token_usage = output["token_usage"]
|
||||
if token_usage is not None:
|
||||
for k, v in token_usage.items():
|
||||
if k in overall_token_usage:
|
||||
overall_token_usage[k] += v
|
||||
else:
|
||||
overall_token_usage[k] = v
|
||||
combined = {"token_usage": overall_token_usage, "model_name": self.model}
|
||||
return combined
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate api key, python package exists, temperature, and top_p."""
|
||||
|
@ -3,7 +3,7 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import AIMessageChunk, HumanMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
from langchain_mistralai.chat_models import ChatMistralAI
|
||||
@ -70,6 +70,50 @@ def test_invoke() -> None:
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
|
||||
def test_chat_mistralai_llm_output_contains_model_name() -> None:
|
||||
"""Test llm_output contains model_name."""
|
||||
chat = ChatMistralAI(max_tokens=10)
|
||||
message = HumanMessage(content="Hello")
|
||||
llm_result = chat.generate([[message]])
|
||||
assert llm_result.llm_output is not None
|
||||
assert llm_result.llm_output["model_name"] == chat.model
|
||||
|
||||
|
||||
def test_chat_mistralai_streaming_llm_output_contains_model_name() -> None:
|
||||
"""Test llm_output contains model_name."""
|
||||
chat = ChatMistralAI(max_tokens=10, streaming=True)
|
||||
message = HumanMessage(content="Hello")
|
||||
llm_result = chat.generate([[message]])
|
||||
assert llm_result.llm_output is not None
|
||||
assert llm_result.llm_output["model_name"] == chat.model
|
||||
|
||||
|
||||
def test_chat_mistralai_llm_output_contains_token_usage() -> None:
|
||||
"""Test llm_output contains model_name."""
|
||||
chat = ChatMistralAI(max_tokens=10)
|
||||
message = HumanMessage(content="Hello")
|
||||
llm_result = chat.generate([[message]])
|
||||
assert llm_result.llm_output is not None
|
||||
assert "token_usage" in llm_result.llm_output
|
||||
token_usage = llm_result.llm_output["token_usage"]
|
||||
assert "prompt_tokens" in token_usage
|
||||
assert "completion_tokens" in token_usage
|
||||
assert "total_tokens" in token_usage
|
||||
|
||||
|
||||
def test_chat_mistralai_streaming_llm_output_contains_token_usage() -> None:
|
||||
"""Test llm_output contains model_name."""
|
||||
chat = ChatMistralAI(max_tokens=10, streaming=True)
|
||||
message = HumanMessage(content="Hello")
|
||||
llm_result = chat.generate([[message]])
|
||||
assert llm_result.llm_output is not None
|
||||
assert "token_usage" in llm_result.llm_output
|
||||
token_usage = llm_result.llm_output["token_usage"]
|
||||
assert "prompt_tokens" in token_usage
|
||||
assert "completion_tokens" in token_usage
|
||||
assert "total_tokens" in token_usage
|
||||
|
||||
|
||||
def test_structured_output() -> None:
|
||||
llm = ChatMistralAI(model="mistral-large-latest", temperature=0)
|
||||
schema = {
|
||||
|
Loading…
Reference in New Issue
Block a user