mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 23:26:34 +00:00
final(?) unit tests
This commit is contained in:
parent
c9a38be6df
commit
2eca8240e2
@ -6,6 +6,7 @@ from typing import Any, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain_core.messages import content_blocks as types
|
||||
from langchain_core.messages.ai import UsageMetadata
|
||||
from langchain_core.messages.content_blocks import (
|
||||
ImageContentBlock,
|
||||
ReasoningContentBlock,
|
||||
@ -20,6 +21,21 @@ from langchain_core.messages.v1 import SystemMessage as SystemMessageV1
|
||||
from langchain_core.messages.v1 import ToolMessage as ToolMessageV1
|
||||
|
||||
|
||||
def _get_usage_metadata_from_response(
|
||||
response: dict[str, Any],
|
||||
) -> UsageMetadata | None:
|
||||
"""Extract usage metadata from Ollama response."""
|
||||
input_tokens = response.get("prompt_eval_count")
|
||||
output_tokens = response.get("eval_count")
|
||||
if input_tokens is not None and output_tokens is not None:
|
||||
return UsageMetadata(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
total_tokens=input_tokens + output_tokens,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _convert_from_v1_to_ollama_format(message: MessageV1) -> dict[str, Any]:
|
||||
"""Convert v1 message to Ollama API format."""
|
||||
if isinstance(message, HumanMessageV1):
|
||||
@ -218,6 +234,7 @@ def _convert_to_v1_from_ollama_format(response: dict[str, Any]) -> AIMessageV1:
|
||||
return AIMessageV1(
|
||||
content=content,
|
||||
response_metadata=response_metadata,
|
||||
usage_metadata=_get_usage_metadata_from_response(response),
|
||||
)
|
||||
|
||||
|
||||
@ -280,7 +297,12 @@ def _convert_chunk_to_v1(chunk: dict[str, Any]) -> AIMessageChunkV1:
|
||||
if "eval_duration" in chunk:
|
||||
response_metadata["eval_duration"] = chunk["eval_duration"] # type: ignore[typeddict-unknown-key]
|
||||
|
||||
usage_metadata = None
|
||||
if chunk.get("done") is True:
|
||||
usage_metadata = _get_usage_metadata_from_response(chunk)
|
||||
|
||||
return AIMessageChunkV1(
|
||||
content=content,
|
||||
response_metadata=response_metadata or ResponseMetadata(),
|
||||
usage_metadata=usage_metadata,
|
||||
)
|
||||
|
@ -25,7 +25,6 @@ from langchain_core.language_models.v1.chat_models import (
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages.ai import UsageMetadata
|
||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
||||
from langchain_core.messages.v1 import MessageV1
|
||||
@ -58,21 +57,6 @@ from ._utils import validate_model
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_usage_metadata_from_response(
|
||||
response: dict[str, Any],
|
||||
) -> Optional[UsageMetadata]:
|
||||
"""Extract usage metadata from Ollama response."""
|
||||
input_tokens = response.get("prompt_eval_count")
|
||||
output_tokens = response.get("eval_count")
|
||||
if input_tokens is not None and output_tokens is not None:
|
||||
return UsageMetadata(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
total_tokens=input_tokens + output_tokens,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _parse_json_string(
|
||||
json_string: str,
|
||||
*,
|
||||
@ -578,12 +562,6 @@ class ChatOllamaV1(BaseChatModelV1):
|
||||
|
||||
chunk = _convert_chunk_to_v1(part)
|
||||
|
||||
# Add usage metadata for final chunks
|
||||
if part.get("done") is True:
|
||||
usage_metadata = _get_usage_metadata_from_response(part)
|
||||
if usage_metadata:
|
||||
chunk.usage_metadata = usage_metadata
|
||||
|
||||
if run_manager:
|
||||
text_content = "".join(
|
||||
str(block.get("text", ""))
|
||||
@ -599,9 +577,6 @@ class ChatOllamaV1(BaseChatModelV1):
|
||||
# Non-streaming case
|
||||
response = self._client.chat(**chat_params)
|
||||
ai_message = _convert_to_v1_from_ollama_format(response)
|
||||
usage_metadata = _get_usage_metadata_from_response(response)
|
||||
if usage_metadata:
|
||||
ai_message.usage_metadata = usage_metadata
|
||||
# Convert to chunk for yielding
|
||||
chunk = AIMessageChunkV1(
|
||||
content=ai_message.content,
|
||||
@ -637,12 +612,6 @@ class ChatOllamaV1(BaseChatModelV1):
|
||||
|
||||
chunk = _convert_chunk_to_v1(part)
|
||||
|
||||
# Add usage metadata for final chunks
|
||||
if part.get("done") is True:
|
||||
usage_metadata = _get_usage_metadata_from_response(part)
|
||||
if usage_metadata:
|
||||
chunk.usage_metadata = usage_metadata
|
||||
|
||||
if run_manager:
|
||||
text_content = "".join(
|
||||
str(block.get("text", ""))
|
||||
@ -658,9 +627,6 @@ class ChatOllamaV1(BaseChatModelV1):
|
||||
# Non-streaming case
|
||||
response = await self._async_client.chat(**chat_params)
|
||||
ai_message = _convert_to_v1_from_ollama_format(response)
|
||||
usage_metadata = _get_usage_metadata_from_response(response)
|
||||
if usage_metadata:
|
||||
ai_message.usage_metadata = usage_metadata
|
||||
# Convert to chunk for yielding
|
||||
chunk = AIMessageChunkV1(
|
||||
content=ai_message.content,
|
||||
|
@ -211,23 +211,63 @@ class TestChatOllamaV1(ChatModelV1UnitTests):
|
||||
mock_sync_client_class.return_value = mock_sync_client
|
||||
mock_async_client_class.return_value = mock_async_client
|
||||
|
||||
def mock_chat_response(*_args: Any, **_kwargs: Any) -> Iterator[dict[str, Any]]:
|
||||
return iter(
|
||||
[
|
||||
{
|
||||
"model": MODEL_NAME,
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"message": {"role": "assistant", "content": "Test response"},
|
||||
"done": True,
|
||||
"done_reason": "stop",
|
||||
}
|
||||
]
|
||||
def mock_chat_response(*args: Any, **kwargs: Any) -> Iterator[dict[str, Any]]:
|
||||
# Check request characteristics
|
||||
request_data = kwargs.get("messages", [])
|
||||
has_tools = "tools" in kwargs
|
||||
|
||||
# Check if this is a reasoning request
|
||||
is_reasoning_request = any(
|
||||
isinstance(msg, dict)
|
||||
and "Think step by step" in str(msg.get("content", ""))
|
||||
for msg in request_data
|
||||
)
|
||||
|
||||
# Basic response structure
|
||||
base_response = {
|
||||
"model": MODEL_NAME,
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"done": True,
|
||||
"done_reason": "stop",
|
||||
"prompt_eval_count": 10,
|
||||
"eval_count": 20,
|
||||
}
|
||||
|
||||
# Generate appropriate response based on request type
|
||||
if has_tools:
|
||||
# Mock tool call response
|
||||
base_response["message"] = {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "sample_tool",
|
||||
"arguments": '{"query": "test"}',
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
elif is_reasoning_request:
|
||||
# Mock response with reasoning content block
|
||||
base_response["message"] = {
|
||||
"role": "assistant",
|
||||
"content": "The answer is 4.",
|
||||
"thinking": "Let me think step by step: 2 + 2 = 4",
|
||||
}
|
||||
else:
|
||||
# Regular text response
|
||||
base_response["message"] = {
|
||||
"role": "assistant",
|
||||
"content": "Test response",
|
||||
}
|
||||
|
||||
return iter([base_response])
|
||||
|
||||
async def mock_async_chat_iterator(
|
||||
*_args: Any, **_kwargs: Any
|
||||
*args: Any, **kwargs: Any
|
||||
) -> AsyncIterator[dict[str, Any]]:
|
||||
for item in mock_chat_response(*_args, **_kwargs):
|
||||
for item in mock_chat_response(*args, **kwargs):
|
||||
yield item
|
||||
|
||||
mock_sync_client.chat.side_effect = mock_chat_response
|
||||
|
Loading…
Reference in New Issue
Block a user