mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-15 07:36:08 +00:00
final(?) unit tests
This commit is contained in:
parent
c9a38be6df
commit
2eca8240e2
@ -6,6 +6,7 @@ from typing import Any, cast
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from langchain_core.messages import content_blocks as types
|
from langchain_core.messages import content_blocks as types
|
||||||
|
from langchain_core.messages.ai import UsageMetadata
|
||||||
from langchain_core.messages.content_blocks import (
|
from langchain_core.messages.content_blocks import (
|
||||||
ImageContentBlock,
|
ImageContentBlock,
|
||||||
ReasoningContentBlock,
|
ReasoningContentBlock,
|
||||||
@ -20,6 +21,21 @@ from langchain_core.messages.v1 import SystemMessage as SystemMessageV1
|
|||||||
from langchain_core.messages.v1 import ToolMessage as ToolMessageV1
|
from langchain_core.messages.v1 import ToolMessage as ToolMessageV1
|
||||||
|
|
||||||
|
|
||||||
|
def _get_usage_metadata_from_response(
|
||||||
|
response: dict[str, Any],
|
||||||
|
) -> UsageMetadata | None:
|
||||||
|
"""Extract usage metadata from Ollama response."""
|
||||||
|
input_tokens = response.get("prompt_eval_count")
|
||||||
|
output_tokens = response.get("eval_count")
|
||||||
|
if input_tokens is not None and output_tokens is not None:
|
||||||
|
return UsageMetadata(
|
||||||
|
input_tokens=input_tokens,
|
||||||
|
output_tokens=output_tokens,
|
||||||
|
total_tokens=input_tokens + output_tokens,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _convert_from_v1_to_ollama_format(message: MessageV1) -> dict[str, Any]:
|
def _convert_from_v1_to_ollama_format(message: MessageV1) -> dict[str, Any]:
|
||||||
"""Convert v1 message to Ollama API format."""
|
"""Convert v1 message to Ollama API format."""
|
||||||
if isinstance(message, HumanMessageV1):
|
if isinstance(message, HumanMessageV1):
|
||||||
@ -218,6 +234,7 @@ def _convert_to_v1_from_ollama_format(response: dict[str, Any]) -> AIMessageV1:
|
|||||||
return AIMessageV1(
|
return AIMessageV1(
|
||||||
content=content,
|
content=content,
|
||||||
response_metadata=response_metadata,
|
response_metadata=response_metadata,
|
||||||
|
usage_metadata=_get_usage_metadata_from_response(response),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -280,7 +297,12 @@ def _convert_chunk_to_v1(chunk: dict[str, Any]) -> AIMessageChunkV1:
|
|||||||
if "eval_duration" in chunk:
|
if "eval_duration" in chunk:
|
||||||
response_metadata["eval_duration"] = chunk["eval_duration"] # type: ignore[typeddict-unknown-key]
|
response_metadata["eval_duration"] = chunk["eval_duration"] # type: ignore[typeddict-unknown-key]
|
||||||
|
|
||||||
|
usage_metadata = None
|
||||||
|
if chunk.get("done") is True:
|
||||||
|
usage_metadata = _get_usage_metadata_from_response(chunk)
|
||||||
|
|
||||||
return AIMessageChunkV1(
|
return AIMessageChunkV1(
|
||||||
content=content,
|
content=content,
|
||||||
response_metadata=response_metadata or ResponseMetadata(),
|
response_metadata=response_metadata or ResponseMetadata(),
|
||||||
|
usage_metadata=usage_metadata,
|
||||||
)
|
)
|
||||||
|
@ -25,7 +25,6 @@ from langchain_core.language_models.v1.chat_models import (
|
|||||||
agenerate_from_stream,
|
agenerate_from_stream,
|
||||||
generate_from_stream,
|
generate_from_stream,
|
||||||
)
|
)
|
||||||
from langchain_core.messages.ai import UsageMetadata
|
|
||||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||||
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
||||||
from langchain_core.messages.v1 import MessageV1
|
from langchain_core.messages.v1 import MessageV1
|
||||||
@ -58,21 +57,6 @@ from ._utils import validate_model
|
|||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _get_usage_metadata_from_response(
|
|
||||||
response: dict[str, Any],
|
|
||||||
) -> Optional[UsageMetadata]:
|
|
||||||
"""Extract usage metadata from Ollama response."""
|
|
||||||
input_tokens = response.get("prompt_eval_count")
|
|
||||||
output_tokens = response.get("eval_count")
|
|
||||||
if input_tokens is not None and output_tokens is not None:
|
|
||||||
return UsageMetadata(
|
|
||||||
input_tokens=input_tokens,
|
|
||||||
output_tokens=output_tokens,
|
|
||||||
total_tokens=input_tokens + output_tokens,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_json_string(
|
def _parse_json_string(
|
||||||
json_string: str,
|
json_string: str,
|
||||||
*,
|
*,
|
||||||
@ -578,12 +562,6 @@ class ChatOllamaV1(BaseChatModelV1):
|
|||||||
|
|
||||||
chunk = _convert_chunk_to_v1(part)
|
chunk = _convert_chunk_to_v1(part)
|
||||||
|
|
||||||
# Add usage metadata for final chunks
|
|
||||||
if part.get("done") is True:
|
|
||||||
usage_metadata = _get_usage_metadata_from_response(part)
|
|
||||||
if usage_metadata:
|
|
||||||
chunk.usage_metadata = usage_metadata
|
|
||||||
|
|
||||||
if run_manager:
|
if run_manager:
|
||||||
text_content = "".join(
|
text_content = "".join(
|
||||||
str(block.get("text", ""))
|
str(block.get("text", ""))
|
||||||
@ -599,9 +577,6 @@ class ChatOllamaV1(BaseChatModelV1):
|
|||||||
# Non-streaming case
|
# Non-streaming case
|
||||||
response = self._client.chat(**chat_params)
|
response = self._client.chat(**chat_params)
|
||||||
ai_message = _convert_to_v1_from_ollama_format(response)
|
ai_message = _convert_to_v1_from_ollama_format(response)
|
||||||
usage_metadata = _get_usage_metadata_from_response(response)
|
|
||||||
if usage_metadata:
|
|
||||||
ai_message.usage_metadata = usage_metadata
|
|
||||||
# Convert to chunk for yielding
|
# Convert to chunk for yielding
|
||||||
chunk = AIMessageChunkV1(
|
chunk = AIMessageChunkV1(
|
||||||
content=ai_message.content,
|
content=ai_message.content,
|
||||||
@ -637,12 +612,6 @@ class ChatOllamaV1(BaseChatModelV1):
|
|||||||
|
|
||||||
chunk = _convert_chunk_to_v1(part)
|
chunk = _convert_chunk_to_v1(part)
|
||||||
|
|
||||||
# Add usage metadata for final chunks
|
|
||||||
if part.get("done") is True:
|
|
||||||
usage_metadata = _get_usage_metadata_from_response(part)
|
|
||||||
if usage_metadata:
|
|
||||||
chunk.usage_metadata = usage_metadata
|
|
||||||
|
|
||||||
if run_manager:
|
if run_manager:
|
||||||
text_content = "".join(
|
text_content = "".join(
|
||||||
str(block.get("text", ""))
|
str(block.get("text", ""))
|
||||||
@ -658,9 +627,6 @@ class ChatOllamaV1(BaseChatModelV1):
|
|||||||
# Non-streaming case
|
# Non-streaming case
|
||||||
response = await self._async_client.chat(**chat_params)
|
response = await self._async_client.chat(**chat_params)
|
||||||
ai_message = _convert_to_v1_from_ollama_format(response)
|
ai_message = _convert_to_v1_from_ollama_format(response)
|
||||||
usage_metadata = _get_usage_metadata_from_response(response)
|
|
||||||
if usage_metadata:
|
|
||||||
ai_message.usage_metadata = usage_metadata
|
|
||||||
# Convert to chunk for yielding
|
# Convert to chunk for yielding
|
||||||
chunk = AIMessageChunkV1(
|
chunk = AIMessageChunkV1(
|
||||||
content=ai_message.content,
|
content=ai_message.content,
|
||||||
|
@ -211,23 +211,63 @@ class TestChatOllamaV1(ChatModelV1UnitTests):
|
|||||||
mock_sync_client_class.return_value = mock_sync_client
|
mock_sync_client_class.return_value = mock_sync_client
|
||||||
mock_async_client_class.return_value = mock_async_client
|
mock_async_client_class.return_value = mock_async_client
|
||||||
|
|
||||||
def mock_chat_response(*_args: Any, **_kwargs: Any) -> Iterator[dict[str, Any]]:
|
def mock_chat_response(*args: Any, **kwargs: Any) -> Iterator[dict[str, Any]]:
|
||||||
return iter(
|
# Check request characteristics
|
||||||
[
|
request_data = kwargs.get("messages", [])
|
||||||
{
|
has_tools = "tools" in kwargs
|
||||||
"model": MODEL_NAME,
|
|
||||||
"created_at": "2024-01-01T00:00:00Z",
|
# Check if this is a reasoning request
|
||||||
"message": {"role": "assistant", "content": "Test response"},
|
is_reasoning_request = any(
|
||||||
"done": True,
|
isinstance(msg, dict)
|
||||||
"done_reason": "stop",
|
and "Think step by step" in str(msg.get("content", ""))
|
||||||
}
|
for msg in request_data
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Basic response structure
|
||||||
|
base_response = {
|
||||||
|
"model": MODEL_NAME,
|
||||||
|
"created_at": "2024-01-01T00:00:00Z",
|
||||||
|
"done": True,
|
||||||
|
"done_reason": "stop",
|
||||||
|
"prompt_eval_count": 10,
|
||||||
|
"eval_count": 20,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Generate appropriate response based on request type
|
||||||
|
if has_tools:
|
||||||
|
# Mock tool call response
|
||||||
|
base_response["message"] = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"name": "sample_tool",
|
||||||
|
"arguments": '{"query": "test"}',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
elif is_reasoning_request:
|
||||||
|
# Mock response with reasoning content block
|
||||||
|
base_response["message"] = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "The answer is 4.",
|
||||||
|
"thinking": "Let me think step by step: 2 + 2 = 4",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Regular text response
|
||||||
|
base_response["message"] = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Test response",
|
||||||
|
}
|
||||||
|
|
||||||
|
return iter([base_response])
|
||||||
|
|
||||||
async def mock_async_chat_iterator(
|
async def mock_async_chat_iterator(
|
||||||
*_args: Any, **_kwargs: Any
|
*args: Any, **kwargs: Any
|
||||||
) -> AsyncIterator[dict[str, Any]]:
|
) -> AsyncIterator[dict[str, Any]]:
|
||||||
for item in mock_chat_response(*_args, **_kwargs):
|
for item in mock_chat_response(*args, **kwargs):
|
||||||
yield item
|
yield item
|
||||||
|
|
||||||
mock_sync_client.chat.side_effect = mock_chat_response
|
mock_sync_client.chat.side_effect = mock_chat_response
|
||||||
|
Loading…
Reference in New Issue
Block a user