mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 17:13:22 +00:00
refactor(langchain): refactor unit test stub classes (#32209)
See https://github.com/langchain-ai/langchain/pull/32098#discussion_r2225961563
This commit is contained in:
parent
6f3169eb49
commit
0b34be4ce5
@ -1,48 +1,23 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AnyStr(str):
|
||||
__slots__ = ()
|
||||
|
||||
class _AnyIDMixin(BaseModel):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, str)
|
||||
if isinstance(other, BaseModel):
|
||||
dump = self.model_dump()
|
||||
dump.pop("id")
|
||||
other_dump = other.model_dump()
|
||||
other_dump.pop("id")
|
||||
return dump == other_dump
|
||||
return False
|
||||
|
||||
__hash__ = str.__hash__
|
||||
__hash__ = None # type: ignore[assignment]
|
||||
|
||||
|
||||
# The code below creates version of pydantic models
|
||||
# that will work in unit tests with AnyStr as id field
|
||||
# Please note that the `id` field is assigned AFTER the model is created
|
||||
# to workaround an issue with pydantic ignoring the __eq__ method on
|
||||
# subclassed strings.
|
||||
class _AnyIdAIMessage(AIMessage, _AnyIDMixin):
|
||||
"""AIMessage with any ID."""
|
||||
|
||||
|
||||
def _AnyIdDocument(**kwargs: Any) -> Document:
|
||||
"""Create a document with an id field."""
|
||||
message = Document(**kwargs)
|
||||
message.id = AnyStr()
|
||||
return message
|
||||
|
||||
|
||||
def _AnyIdAIMessage(**kwargs: Any) -> AIMessage:
|
||||
"""Create ai message with an any id field."""
|
||||
message = AIMessage(**kwargs)
|
||||
message.id = AnyStr()
|
||||
return message
|
||||
|
||||
|
||||
def _AnyIdAIMessageChunk(**kwargs: Any) -> AIMessageChunk:
|
||||
"""Create ai message with an any id field."""
|
||||
message = AIMessageChunk(**kwargs)
|
||||
message.id = AnyStr()
|
||||
return message
|
||||
|
||||
|
||||
def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage:
|
||||
"""Create a human with an any id field."""
|
||||
message = HumanMessage(**kwargs)
|
||||
message.id = AnyStr()
|
||||
return message
|
||||
class _AnyIdAIMessageChunk(AIMessageChunk, _AnyIDMixin):
|
||||
"""AIMessageChunk with any ID."""
|
||||
|
Loading…
Reference in New Issue
Block a user