Add type to Generation and sub-classes, handle root validator (#12220)

* Add a type literal for the generation and sub-classes for serialization purposes.
* Fix the root validator of ChatGeneration to return ValueError instead of KeyError or Attribute error if intialized improperly.
* This change is done for langserve to make sure that llm related callbacks can be serialized/deserialized properly.
This commit is contained in:
Eugene Yurtsev 2023-10-24 16:21:00 -04:00 committed by GitHub
parent 81052ee18e
commit 583dc49477
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 51 additions and 13 deletions

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from copy import deepcopy from copy import deepcopy
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Literal, Optional
from uuid import UUID from uuid import UUID
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
@ -19,6 +19,8 @@ class Generation(Serializable):
"""Raw response from the provider. May include things like the """Raw response from the provider. May include things like the
reason for finishing or token log probabilities. reason for finishing or token log probabilities.
""" """
type: Literal["Generation"] = "Generation"
"""Type is used exclusively for serialization purposes."""
# TODO: add log probs as separate attribute # TODO: add log probs as separate attribute
@classmethod @classmethod
@ -54,11 +56,17 @@ class ChatGeneration(Generation):
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message.""" """*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
message: BaseMessage message: BaseMessage
"""The message output by the chat model.""" """The message output by the chat model."""
# Override type to be ChatGeneration, ignore mypy error as this is intentional
type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment]
"""Type is used exclusively for serialization purposes."""
@root_validator @root_validator
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]: def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Set the text attribute to be the contents of the message.""" """Set the text attribute to be the contents of the message."""
values["text"] = values["message"].content try:
values["text"] = values["message"].content
except (KeyError, AttributeError) as e:
raise ValueError("Error while initializing ChatGeneration") from e
return values return values
@ -71,6 +79,9 @@ class ChatGenerationChunk(ChatGeneration):
""" """
message: BaseMessageChunk message: BaseMessageChunk
# Override type to be ChatGeneration, ignore mypy error as this is intentional
type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment] # noqa: E501
"""Type is used exclusively for serialization purposes."""
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk: def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
if isinstance(other, ChatGenerationChunk): if isinstance(other, ChatGenerationChunk):

File diff suppressed because one or more lines are too long

View File

@ -1719,7 +1719,9 @@ async def test_prompt_with_llm(
"op": "add", "op": "add",
"path": "/logs/FakeListLLM/final_output", "path": "/logs/FakeListLLM/final_output",
"value": { "value": {
"generations": [[{"generation_info": None, "text": "foo"}]], "generations": [
[{"generation_info": None, "text": "foo", "type": "Generation"}]
],
"llm_output": None, "llm_output": None,
"run": None, "run": None,
}, },

View File

@ -1,12 +1,19 @@
"""Test formatting functionality.""" """Test formatting functionality."""
import unittest import unittest
from typing import Union from typing import Union
import pytest
from langchain.prompts.base import StringPromptValue from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValueConcrete from langchain.prompts.chat import ChatPromptValueConcrete
from langchain.pydantic_v1 import BaseModel from langchain.pydantic_v1 import BaseModel, ValidationError
from langchain.schema import AgentAction, AgentFinish, Document from langchain.schema import (
AgentAction,
AgentFinish,
ChatGeneration,
Document,
Generation,
)
from langchain.schema.agent import AgentActionMessageLog from langchain.schema.agent import AgentActionMessageLog
from langchain.schema.messages import ( from langchain.schema.messages import (
AIMessage, AIMessage,
@ -23,6 +30,7 @@ from langchain.schema.messages import (
messages_from_dict, messages_from_dict,
messages_to_dict, messages_to_dict,
) )
from langchain.schema.output import ChatGenerationChunk
class TestGetBufferString(unittest.TestCase): class TestGetBufferString(unittest.TestCase):
@ -108,6 +116,9 @@ def test_serialization_of_wellknown_objects() -> None:
AgentFinish, AgentFinish,
AgentAction, AgentAction,
AgentActionMessageLog, AgentActionMessageLog,
ChatGeneration,
Generation,
ChatGenerationChunk,
] ]
lc_objects = [ lc_objects = [
@ -144,6 +155,16 @@ def test_serialization_of_wellknown_objects() -> None:
log="", log="",
message_log=[HumanMessage(content="human")], message_log=[HumanMessage(content="human")],
), ),
Generation(
text="hello",
generation_info={"info": "info"},
),
ChatGeneration(
message=HumanMessage(content="human"),
),
ChatGenerationChunk(
message=HumanMessageChunk(content="cat"),
),
] ]
for lc_object in lc_objects: for lc_object in lc_objects:
@ -151,3 +172,7 @@ def test_serialization_of_wellknown_objects() -> None:
assert "type" in d, f"Missing key `type` for {type(lc_object)}" assert "type" in d, f"Missing key `type` for {type(lc_object)}"
obj1 = WellKnownLCObject.parse_obj(d) obj1 = WellKnownLCObject.parse_obj(d)
assert type(obj1.__root__) == type(lc_object), f"failed for {type(lc_object)}" assert type(obj1.__root__) == type(lc_object), f"failed for {type(lc_object)}"
with pytest.raises(ValidationError):
# Make sure that specifically validation error is raised
WellKnownLCObject.parse_obj({})