mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-03 05:34:01 +00:00
experimental[patch]: Pass enum only to openai in llm graph transformer (#21860)
Some models like Groq return bad request if you pass in `enum` parameter in tool definition
This commit is contained in:
parent
aab9cb666f
commit
a43515ca65
@ -150,16 +150,24 @@ def optional_enum_field(
|
||||
enum_values: Optional[List[str]] = None,
|
||||
description: str = "",
|
||||
input_type: str = "node",
|
||||
llm_type: Optional[str] = None,
|
||||
**field_kwargs: Any,
|
||||
) -> Any:
|
||||
"""Utility function to conditionally create a field with an enum constraint."""
|
||||
if enum_values:
|
||||
# Only openai supports enum param
|
||||
if enum_values and llm_type == "openai-chat":
|
||||
return Field(
|
||||
...,
|
||||
enum=enum_values,
|
||||
description=f"{description}. Available options are {enum_values}",
|
||||
**field_kwargs,
|
||||
)
|
||||
elif enum_values:
|
||||
return Field(
|
||||
...,
|
||||
description=f"{description}. Available options are {enum_values}",
|
||||
**field_kwargs,
|
||||
)
|
||||
else:
|
||||
additional_info = _get_additional_info(input_type)
|
||||
return Field(..., description=description + additional_info, **field_kwargs)
|
||||
@ -271,6 +279,7 @@ def create_simple_model(
|
||||
node_labels: Optional[List[str]] = None,
|
||||
rel_types: Optional[List[str]] = None,
|
||||
node_properties: Union[bool, List[str]] = False,
|
||||
llm_type: Optional[str] = None,
|
||||
) -> Type[_Graph]:
|
||||
"""
|
||||
Simple model allows to limit node and/or relationship types.
|
||||
@ -288,6 +297,7 @@ def create_simple_model(
|
||||
node_labels,
|
||||
description="The type or label of the node.",
|
||||
input_type="node",
|
||||
llm_type=llm_type,
|
||||
),
|
||||
),
|
||||
}
|
||||
@ -325,6 +335,7 @@ def create_simple_model(
|
||||
node_labels,
|
||||
description="The type or label of the source node.",
|
||||
input_type="node",
|
||||
llm_type=llm_type,
|
||||
)
|
||||
target_node_id: str = Field(
|
||||
description="Name or human-readable unique identifier of target node"
|
||||
@ -333,11 +344,13 @@ def create_simple_model(
|
||||
node_labels,
|
||||
description="The type or label of the target node.",
|
||||
input_type="node",
|
||||
llm_type=llm_type,
|
||||
)
|
||||
type: str = optional_enum_field(
|
||||
rel_types,
|
||||
description="The type of the relationship.",
|
||||
input_type="relationship",
|
||||
llm_type=llm_type,
|
||||
)
|
||||
|
||||
class DynamicGraph(_Graph):
|
||||
@ -572,8 +585,12 @@ class LLMGraphTransformer:
|
||||
self.chain = prompt | llm
|
||||
else:
|
||||
# Define chain
|
||||
try:
|
||||
llm_type = llm._llm_type # type: ignore
|
||||
except AttributeError:
|
||||
llm_type = None
|
||||
schema = create_simple_model(
|
||||
allowed_nodes, allowed_relationships, node_properties
|
||||
allowed_nodes, allowed_relationships, node_properties, llm_type
|
||||
)
|
||||
structured_llm = llm.with_structured_output(schema, include_raw=True)
|
||||
prompt = prompt or default_prompt
|
||||
|
Loading…
Reference in New Issue
Block a user