mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-22 19:09:57 +00:00
Harrison/move experimental (#8084)
This commit is contained in:
244
libs/experimental/langchain_experimental/cpal/models.py
Normal file
244
libs/experimental/langchain_experimental/cpal/models.py
Normal file
@@ -0,0 +1,244 @@
|
||||
from __future__ import annotations # allows pydantic model to reference itself
|
||||
|
||||
import re
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import duckdb
|
||||
import pandas as pd
|
||||
from langchain.experimental.cpal.constants import Constant
|
||||
from langchain.graphs.networkx_graph import NetworkxEntityGraph
|
||||
from pydantic import BaseModel, Field, PrivateAttr, root_validator, validator
|
||||
|
||||
|
||||
class NarrativeModel(BaseModel):
|
||||
"""
|
||||
Represent the narrative input as three story elements.
|
||||
"""
|
||||
|
||||
story_outcome_question: str
|
||||
story_hypothetical: str
|
||||
story_plot: str # causal stack of operations
|
||||
|
||||
@validator("*", pre=True)
|
||||
def empty_str_to_none(cls, v: str) -> Union[str, None]:
|
||||
"""Empty strings are not allowed"""
|
||||
if v == "":
|
||||
return None
|
||||
return v
|
||||
|
||||
|
||||
class EntityModel(BaseModel):
|
||||
name: str = Field(description="entity name")
|
||||
code: str = Field(description="entity actions")
|
||||
value: float = Field(description="entity initial value")
|
||||
depends_on: list[str] = Field(default=[], description="ancestor entities")
|
||||
|
||||
# TODO: generalize to multivariate math
|
||||
# TODO: acyclic graph
|
||||
|
||||
class Config:
|
||||
validate_assignment = True
|
||||
|
||||
@validator("name")
|
||||
def lower_case_name(cls, v: str) -> str:
|
||||
v = v.lower()
|
||||
return v
|
||||
|
||||
|
||||
class CausalModel(BaseModel):
|
||||
attribute: str = Field(description="name of the attribute to be calculated")
|
||||
entities: list[EntityModel] = Field(description="entities in the story")
|
||||
|
||||
# TODO: root validate each `entity.depends_on` using system's entity names
|
||||
|
||||
|
||||
class EntitySettingModel(BaseModel):
|
||||
"""
|
||||
Initial conditions for an entity
|
||||
|
||||
{"name": "bud", "attribute": "pet_count", "value": 12}
|
||||
"""
|
||||
|
||||
name: str = Field(description="name of the entity")
|
||||
attribute: str = Field(description="name of the attribute to be calculated")
|
||||
value: float = Field(description="entity's attribute value (calculated)")
|
||||
|
||||
@validator("name")
|
||||
def lower_case_transform(cls, v: str) -> str:
|
||||
v = v.lower()
|
||||
return v
|
||||
|
||||
|
||||
class SystemSettingModel(BaseModel):
|
||||
"""
|
||||
Initial global conditions for the system.
|
||||
|
||||
{"parameter": "interest_rate", "value": .05}
|
||||
"""
|
||||
|
||||
parameter: str
|
||||
value: float
|
||||
|
||||
|
||||
class InterventionModel(BaseModel):
|
||||
"""
|
||||
aka initial conditions
|
||||
|
||||
>>> intervention.dict()
|
||||
{
|
||||
entity_settings: [
|
||||
{"name": "bud", "attribute": "pet_count", "value": 12},
|
||||
{"name": "pat", "attribute": "pet_count", "value": 0},
|
||||
],
|
||||
system_settings: None,
|
||||
}
|
||||
"""
|
||||
|
||||
entity_settings: list[EntitySettingModel]
|
||||
system_settings: Optional[list[SystemSettingModel]] = None
|
||||
|
||||
@validator("system_settings")
|
||||
def lower_case_name(cls, v: str) -> Union[str, None]:
|
||||
if v is not None:
|
||||
raise NotImplementedError("system_setting is not implemented yet")
|
||||
return v
|
||||
|
||||
|
||||
class QueryModel(BaseModel):
|
||||
"""translate a question about the story outcome into a programmatic expression"""
|
||||
|
||||
question: str = Field(alias=Constant.narrative_input.value) # input
|
||||
expression: str # output, part of llm completion
|
||||
llm_error_msg: str # output, part of llm completion
|
||||
_result_table: str = PrivateAttr() # result of the executed query
|
||||
|
||||
|
||||
class ResultModel(BaseModel):
|
||||
question: str = Field(alias=Constant.narrative_input.value) # input
|
||||
_result_table: str = PrivateAttr() # result of the executed query
|
||||
|
||||
|
||||
class StoryModel(BaseModel):
|
||||
causal_operations: Any = Field(required=True)
|
||||
intervention: Any = Field(required=True)
|
||||
query: Any = Field(required=True)
|
||||
_outcome_table: pd.DataFrame = PrivateAttr(default=None)
|
||||
_networkx_wrapper: Any = PrivateAttr(default=None)
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self._compute()
|
||||
|
||||
# TODO: when langchain adopts pydantic.v2 replace w/ `__post_init__`
|
||||
# misses hints github.com/pydantic/pydantic/issues/1729#issuecomment-1300576214
|
||||
|
||||
@root_validator
|
||||
def check_intervention_is_valid(cls, values: dict) -> dict:
|
||||
valid_names = [e.name for e in values["causal_operations"].entities]
|
||||
for setting in values["intervention"].entity_settings:
|
||||
if setting.name not in valid_names:
|
||||
error_msg = f"""
|
||||
Hypothetical question has an invalid entity name.
|
||||
`{setting.name}` not in `{valid_names}`
|
||||
"""
|
||||
raise ValueError(error_msg)
|
||||
return values
|
||||
|
||||
def _block_back_door_paths(self) -> None:
|
||||
# stop intervention entities from depending on others
|
||||
intervention_entities = [
|
||||
entity_setting.name for entity_setting in self.intervention.entity_settings
|
||||
]
|
||||
for entity in self.causal_operations.entities:
|
||||
if entity.name in intervention_entities:
|
||||
entity.depends_on = []
|
||||
entity.code = "pass"
|
||||
|
||||
def _set_initial_conditions(self) -> None:
|
||||
for entity_setting in self.intervention.entity_settings:
|
||||
for entity in self.causal_operations.entities:
|
||||
if entity.name == entity_setting.name:
|
||||
entity.value = entity_setting.value
|
||||
|
||||
def _make_graph(self) -> None:
|
||||
self._networkx_wrapper = NetworkxEntityGraph()
|
||||
for entity in self.causal_operations.entities:
|
||||
for parent_name in entity.depends_on:
|
||||
self._networkx_wrapper._graph.add_edge(
|
||||
parent_name, entity.name, relation=entity.code
|
||||
)
|
||||
|
||||
# TODO: is it correct to drop entities with no impact on the outcome (?)
|
||||
self.causal_operations.entities = [
|
||||
entity
|
||||
for entity in self.causal_operations.entities
|
||||
if entity.name in self._networkx_wrapper.get_topological_sort()
|
||||
]
|
||||
|
||||
def _sort_entities(self) -> None:
|
||||
# order the sequence of causal actions
|
||||
sorted_nodes = self._networkx_wrapper.get_topological_sort()
|
||||
self.causal_operations.entities.sort(key=lambda x: sorted_nodes.index(x.name))
|
||||
|
||||
def _forward_propagate(self) -> None:
|
||||
entity_scope = {
|
||||
entity.name: entity for entity in self.causal_operations.entities
|
||||
}
|
||||
for entity in self.causal_operations.entities:
|
||||
if entity.code == "pass":
|
||||
continue
|
||||
else:
|
||||
# gist.github.com/dean0x7d/df5ce97e4a1a05be4d56d1378726ff92
|
||||
exec(entity.code, globals(), entity_scope)
|
||||
row_values = [entity.dict() for entity in entity_scope.values()]
|
||||
self._outcome_table = pd.DataFrame(row_values)
|
||||
|
||||
def _run_query(self) -> None:
|
||||
def humanize_sql_error_msg(error: str) -> str:
|
||||
pattern = r"column\s+(.*?)\s+not found"
|
||||
col_match = re.search(pattern, error)
|
||||
if col_match:
|
||||
return (
|
||||
"SQL error: "
|
||||
+ col_match.group(1)
|
||||
+ " is not an attribute in your story!"
|
||||
)
|
||||
else:
|
||||
return str(error)
|
||||
|
||||
if self.query.llm_error_msg == "":
|
||||
try:
|
||||
df = self._outcome_table # noqa
|
||||
query_result = duckdb.sql(self.query.expression).df()
|
||||
self.query._result_table = query_result
|
||||
except duckdb.BinderException as e:
|
||||
self.query._result_table = humanize_sql_error_msg(str(e))
|
||||
except Exception as e:
|
||||
self.query._result_table = str(e)
|
||||
else:
|
||||
msg = "LLM maybe failed to translate question to SQL query."
|
||||
raise ValueError(
|
||||
{
|
||||
"question": self.query.question,
|
||||
"llm_error_msg": self.query.llm_error_msg,
|
||||
"msg": msg,
|
||||
}
|
||||
)
|
||||
|
||||
def _compute(self) -> Any:
|
||||
self._block_back_door_paths()
|
||||
self._set_initial_conditions()
|
||||
self._make_graph()
|
||||
self._sort_entities()
|
||||
self._forward_propagate()
|
||||
self._run_query()
|
||||
|
||||
def print_debug_report(self) -> None:
|
||||
report = {
|
||||
"outcome": self._outcome_table,
|
||||
"query": self.query.dict(),
|
||||
"result": self.query._result_table,
|
||||
}
|
||||
from pprint import pprint
|
||||
|
||||
pprint(report)
|
Reference in New Issue
Block a user