mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-22 19:09:57 +00:00
Add CI check that integration tests compile (#12090)
This commit is contained in:
@@ -1,10 +1,8 @@
|
||||
from __future__ import annotations # allows pydantic model to reference itself
|
||||
|
||||
import re
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import duckdb
|
||||
import pandas as pd
|
||||
from langchain.graphs.networkx_graph import NetworkxEntityGraph
|
||||
|
||||
from langchain_experimental.cpal.constants import Constant
|
||||
@@ -38,7 +36,7 @@ class EntityModel(BaseModel):
|
||||
name: str = Field(description="entity name")
|
||||
code: str = Field(description="entity actions")
|
||||
value: float = Field(description="entity initial value")
|
||||
depends_on: list[str] = Field(default=[], description="ancestor entities")
|
||||
depends_on: List[str] = Field(default=[], description="ancestor entities")
|
||||
|
||||
# TODO: generalize to multivariate math
|
||||
# TODO: acyclic graph
|
||||
@@ -54,7 +52,7 @@ class EntityModel(BaseModel):
|
||||
|
||||
class CausalModel(BaseModel):
|
||||
attribute: str = Field(description="name of the attribute to be calculated")
|
||||
entities: list[EntityModel] = Field(description="entities in the story")
|
||||
entities: List[EntityModel] = Field(description="entities in the story")
|
||||
|
||||
# TODO: root validate each `entity.depends_on` using system's entity names
|
||||
|
||||
@@ -101,8 +99,8 @@ class InterventionModel(BaseModel):
|
||||
}
|
||||
"""
|
||||
|
||||
entity_settings: list[EntitySettingModel]
|
||||
system_settings: Optional[list[SystemSettingModel]] = None
|
||||
entity_settings: List[EntitySettingModel]
|
||||
system_settings: Optional[List[SystemSettingModel]] = None
|
||||
|
||||
@validator("system_settings")
|
||||
def lower_case_name(cls, v: str) -> Union[str, None]:
|
||||
@@ -129,7 +127,7 @@ class StoryModel(BaseModel):
|
||||
causal_operations: Any = Field(required=True)
|
||||
intervention: Any = Field(required=True)
|
||||
query: Any = Field(required=True)
|
||||
_outcome_table: pd.DataFrame = PrivateAttr(default=None)
|
||||
_outcome_table: Any = PrivateAttr(default=None)
|
||||
_networkx_wrapper: Any = PrivateAttr(default=None)
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
@@ -190,6 +188,12 @@ class StoryModel(BaseModel):
|
||||
self.causal_operations.entities.sort(key=lambda x: sorted_nodes.index(x.name))
|
||||
|
||||
def _forward_propagate(self) -> None:
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import pandas, please install with `pip install pandas`."
|
||||
) from e
|
||||
entity_scope = {
|
||||
entity.name: entity for entity in self.causal_operations.entities
|
||||
}
|
||||
@@ -217,11 +221,17 @@ class StoryModel(BaseModel):
|
||||
|
||||
if self.query.llm_error_msg == "":
|
||||
try:
|
||||
import duckdb
|
||||
|
||||
df = self._outcome_table # noqa
|
||||
query_result = duckdb.sql(self.query.expression).df()
|
||||
self.query._result_table = query_result
|
||||
except duckdb.BinderException as e:
|
||||
self.query._result_table = humanize_sql_error_msg(str(e))
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import duckdb, please install with `pip install duckdb`."
|
||||
) from e
|
||||
except Exception as e:
|
||||
self.query._result_table = str(e)
|
||||
else:
|
||||
|
Reference in New Issue
Block a user