mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-20 03:21:33 +00:00
Add type to Generation and sub-classes, handle root validator (#12220)
* Add a type literal for the generation and sub-classes for serialization purposes. * Fix the root validator of ChatGeneration to return ValueError instead of KeyError or Attribute error if intialized improperly. * This change is done for langserve to make sure that llm related callbacks can be serialized/deserialized properly.
This commit is contained in:
parent
81052ee18e
commit
583dc49477
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
@ -19,6 +19,8 @@ class Generation(Serializable):
|
|||||||
"""Raw response from the provider. May include things like the
|
"""Raw response from the provider. May include things like the
|
||||||
reason for finishing or token log probabilities.
|
reason for finishing or token log probabilities.
|
||||||
"""
|
"""
|
||||||
|
type: Literal["Generation"] = "Generation"
|
||||||
|
"""Type is used exclusively for serialization purposes."""
|
||||||
# TODO: add log probs as separate attribute
|
# TODO: add log probs as separate attribute
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -54,11 +56,17 @@ class ChatGeneration(Generation):
|
|||||||
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
|
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
|
||||||
message: BaseMessage
|
message: BaseMessage
|
||||||
"""The message output by the chat model."""
|
"""The message output by the chat model."""
|
||||||
|
# Override type to be ChatGeneration, ignore mypy error as this is intentional
|
||||||
|
type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment]
|
||||||
|
"""Type is used exclusively for serialization purposes."""
|
||||||
|
|
||||||
@root_validator
|
@root_validator
|
||||||
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Set the text attribute to be the contents of the message."""
|
"""Set the text attribute to be the contents of the message."""
|
||||||
values["text"] = values["message"].content
|
try:
|
||||||
|
values["text"] = values["message"].content
|
||||||
|
except (KeyError, AttributeError) as e:
|
||||||
|
raise ValueError("Error while initializing ChatGeneration") from e
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
@ -71,6 +79,9 @@ class ChatGenerationChunk(ChatGeneration):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
message: BaseMessageChunk
|
message: BaseMessageChunk
|
||||||
|
# Override type to be ChatGeneration, ignore mypy error as this is intentional
|
||||||
|
type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment] # noqa: E501
|
||||||
|
"""Type is used exclusively for serialization purposes."""
|
||||||
|
|
||||||
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
|
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
|
||||||
if isinstance(other, ChatGenerationChunk):
|
if isinstance(other, ChatGenerationChunk):
|
||||||
|
File diff suppressed because one or more lines are too long
@ -1719,7 +1719,9 @@ async def test_prompt_with_llm(
|
|||||||
"op": "add",
|
"op": "add",
|
||||||
"path": "/logs/FakeListLLM/final_output",
|
"path": "/logs/FakeListLLM/final_output",
|
||||||
"value": {
|
"value": {
|
||||||
"generations": [[{"generation_info": None, "text": "foo"}]],
|
"generations": [
|
||||||
|
[{"generation_info": None, "text": "foo", "type": "Generation"}]
|
||||||
|
],
|
||||||
"llm_output": None,
|
"llm_output": None,
|
||||||
"run": None,
|
"run": None,
|
||||||
},
|
},
|
||||||
|
@ -1,12 +1,19 @@
|
|||||||
"""Test formatting functionality."""
|
"""Test formatting functionality."""
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from langchain.prompts.base import StringPromptValue
|
from langchain.prompts.base import StringPromptValue
|
||||||
from langchain.prompts.chat import ChatPromptValueConcrete
|
from langchain.prompts.chat import ChatPromptValueConcrete
|
||||||
from langchain.pydantic_v1 import BaseModel
|
from langchain.pydantic_v1 import BaseModel, ValidationError
|
||||||
from langchain.schema import AgentAction, AgentFinish, Document
|
from langchain.schema import (
|
||||||
|
AgentAction,
|
||||||
|
AgentFinish,
|
||||||
|
ChatGeneration,
|
||||||
|
Document,
|
||||||
|
Generation,
|
||||||
|
)
|
||||||
from langchain.schema.agent import AgentActionMessageLog
|
from langchain.schema.agent import AgentActionMessageLog
|
||||||
from langchain.schema.messages import (
|
from langchain.schema.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
@ -23,6 +30,7 @@ from langchain.schema.messages import (
|
|||||||
messages_from_dict,
|
messages_from_dict,
|
||||||
messages_to_dict,
|
messages_to_dict,
|
||||||
)
|
)
|
||||||
|
from langchain.schema.output import ChatGenerationChunk
|
||||||
|
|
||||||
|
|
||||||
class TestGetBufferString(unittest.TestCase):
|
class TestGetBufferString(unittest.TestCase):
|
||||||
@ -108,6 +116,9 @@ def test_serialization_of_wellknown_objects() -> None:
|
|||||||
AgentFinish,
|
AgentFinish,
|
||||||
AgentAction,
|
AgentAction,
|
||||||
AgentActionMessageLog,
|
AgentActionMessageLog,
|
||||||
|
ChatGeneration,
|
||||||
|
Generation,
|
||||||
|
ChatGenerationChunk,
|
||||||
]
|
]
|
||||||
|
|
||||||
lc_objects = [
|
lc_objects = [
|
||||||
@ -144,6 +155,16 @@ def test_serialization_of_wellknown_objects() -> None:
|
|||||||
log="",
|
log="",
|
||||||
message_log=[HumanMessage(content="human")],
|
message_log=[HumanMessage(content="human")],
|
||||||
),
|
),
|
||||||
|
Generation(
|
||||||
|
text="hello",
|
||||||
|
generation_info={"info": "info"},
|
||||||
|
),
|
||||||
|
ChatGeneration(
|
||||||
|
message=HumanMessage(content="human"),
|
||||||
|
),
|
||||||
|
ChatGenerationChunk(
|
||||||
|
message=HumanMessageChunk(content="cat"),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
for lc_object in lc_objects:
|
for lc_object in lc_objects:
|
||||||
@ -151,3 +172,7 @@ def test_serialization_of_wellknown_objects() -> None:
|
|||||||
assert "type" in d, f"Missing key `type` for {type(lc_object)}"
|
assert "type" in d, f"Missing key `type` for {type(lc_object)}"
|
||||||
obj1 = WellKnownLCObject.parse_obj(d)
|
obj1 = WellKnownLCObject.parse_obj(d)
|
||||||
assert type(obj1.__root__) == type(lc_object), f"failed for {type(lc_object)}"
|
assert type(obj1.__root__) == type(lc_object), f"failed for {type(lc_object)}"
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
# Make sure that specifically validation error is raised
|
||||||
|
WellKnownLCObject.parse_obj({})
|
||||||
|
Loading…
Reference in New Issue
Block a user