Compare commits

...

1 Commits

Author SHA1 Message Date
Bagatur
03e96b363f rfc: stream tools pydantic 2024-03-11 22:03:05 -07:00
2 changed files with 20 additions and 6 deletions

View File

@@ -135,7 +135,8 @@ class PydanticToolsParser(JsonOutputToolsParser):
f"Tool arguments must be specified as a dict, received: "
f"{res['args']}"
)
pydantic_objects.append(name_dict[res["type"]](**res["args"]))
pydantic_cls = name_dict[res["type"]]
pydantic_objects.append(pydantic_cls.construct(**res["args"]))
except (ValidationError, ValueError) as e:
if partial:
continue

View File

@@ -6,7 +6,7 @@ from langchain_core.output_parsers.openai_tools import (
JsonOutputToolsParser,
PydanticToolsParser,
)
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.pydantic_v1 import BaseModel
STREAMED_MESSAGES: list = [
AIMessageChunk(content=""),
@@ -434,14 +434,27 @@ class Person(BaseModel):
class NameCollector(BaseModel):
"""record names of all people mentioned"""
names: List[str] = Field(..., description="all names mentioned")
person: Person = Field(..., description="info about the main subject")
names: List[str]
person: Person
# Expected to change when we support more granular pydantic streaming.
EXPECTED_STREAMED_PYDANTIC = [
NameCollector.construct(),
NameCollector.construct(names=["suz"]),
NameCollector.construct(names=["suzy"]),
NameCollector.construct(names=["suzy", "jerm"]),
NameCollector.construct(names=["suzy", "jermaine"]),
NameCollector.construct(names=["suzy", "jermaine", "al"]),
NameCollector.construct(names=["suzy", "jermaine", "alex"]),
NameCollector.construct(names=["suzy", "jermaine", "alex"], person={}),
NameCollector.construct(names=["suzy", "jermaine", "alex"], person={"age": 39}),
NameCollector.construct(
names=["suzy", "jermaine", "alex"], person={"age": 39, "hair_color": "br"}
),
NameCollector.construct(
names=["suzy", "jermaine", "alex"], person={"age": 39, "hair_color": "brown"}
),
NameCollector(
names=["suzy", "jermaine", "alex"],
person=Person(age=39, hair_color="brown", job="c"),