mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-12 14:23:58 +00:00
core: Fix issue 31035 alias fields in base tool langchain core (#31112)
**Description**: The 'inspect' package in python skips over the aliases set in the schema of a pydantic model. This is a workound to include the aliases from the original input. **issue**: #31035 Cc: @ccurme @eyurtsev --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
92af7b0933
commit
1e56c66f86
@ -1077,16 +1077,18 @@ def get_all_basemodel_annotations(
|
||||
"""
|
||||
# cls has no subscript: cls = FooBar
|
||||
if isinstance(cls, type):
|
||||
# Gather pydantic field objects (v2: model_fields / v1: __fields__)
|
||||
fields = getattr(cls, "model_fields", {}) or getattr(cls, "__fields__", {})
|
||||
alias_map = {field.alias: name for name, field in fields.items() if field.alias}
|
||||
|
||||
annotations: dict[str, type] = {}
|
||||
for name, param in inspect.signature(cls).parameters.items():
|
||||
# Exclude hidden init args added by pydantic Config. For example if
|
||||
# BaseModel(extra="allow") then "extra_data" will part of init sig.
|
||||
if (
|
||||
fields := getattr(cls, "model_fields", {}) # pydantic v2+
|
||||
or getattr(cls, "__fields__", {}) # pydantic v1
|
||||
) and name not in fields:
|
||||
if fields and name not in fields and name not in alias_map:
|
||||
continue
|
||||
annotations[name] = param.annotation
|
||||
field_name = alias_map.get(name, name)
|
||||
annotations[field_name] = param.annotation
|
||||
orig_bases: tuple = getattr(cls, "__orig_bases__", ())
|
||||
# cls has subscript: cls = FooBar[int]
|
||||
else:
|
||||
|
@ -2146,6 +2146,15 @@ def test__get_all_basemodel_annotations_v1() -> None:
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_get_all_basemodel_annotations_aliases() -> None:
|
||||
class CalculatorInput(BaseModel):
|
||||
a: int = Field(description="first number", alias="A")
|
||||
b: int = Field(description="second number")
|
||||
|
||||
actual = get_all_basemodel_annotations(CalculatorInput)
|
||||
assert actual == {"a": int, "b": int}
|
||||
|
||||
|
||||
def test_tool_annotations_preserved() -> None:
|
||||
"""Test that annotations are preserved when creating a tool."""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user