Add CI check that integration tests compile (#12090)

This commit is contained in:
Bagatur
2023-10-21 10:52:18 -04:00
committed by GitHub
parent 5dbe456aae
commit 85302a9ec1
10 changed files with 82 additions and 18 deletions

View File

@@ -1,10 +1,8 @@
from __future__ import annotations # allows pydantic model to reference itself
import re
from typing import Any, Optional, Union
from typing import Any, List, Optional, Union
import duckdb
import pandas as pd
from langchain.graphs.networkx_graph import NetworkxEntityGraph
from langchain_experimental.cpal.constants import Constant
@@ -38,7 +36,7 @@ class EntityModel(BaseModel):
name: str = Field(description="entity name")
code: str = Field(description="entity actions")
value: float = Field(description="entity initial value")
depends_on: list[str] = Field(default=[], description="ancestor entities")
depends_on: List[str] = Field(default=[], description="ancestor entities")
# TODO: generalize to multivariate math
# TODO: acyclic graph
@@ -54,7 +52,7 @@ class EntityModel(BaseModel):
class CausalModel(BaseModel):
attribute: str = Field(description="name of the attribute to be calculated")
entities: list[EntityModel] = Field(description="entities in the story")
entities: List[EntityModel] = Field(description="entities in the story")
# TODO: root validate each `entity.depends_on` using system's entity names
@@ -101,8 +99,8 @@ class InterventionModel(BaseModel):
}
"""
entity_settings: list[EntitySettingModel]
system_settings: Optional[list[SystemSettingModel]] = None
entity_settings: List[EntitySettingModel]
system_settings: Optional[List[SystemSettingModel]] = None
@validator("system_settings")
def lower_case_name(cls, v: str) -> Union[str, None]:
@@ -129,7 +127,7 @@ class StoryModel(BaseModel):
causal_operations: Any = Field(required=True)
intervention: Any = Field(required=True)
query: Any = Field(required=True)
_outcome_table: pd.DataFrame = PrivateAttr(default=None)
_outcome_table: Any = PrivateAttr(default=None)
_networkx_wrapper: Any = PrivateAttr(default=None)
def __init__(self, **kwargs: Any):
@@ -190,6 +188,12 @@ class StoryModel(BaseModel):
self.causal_operations.entities.sort(key=lambda x: sorted_nodes.index(x.name))
def _forward_propagate(self) -> None:
try:
import pandas as pd
except ImportError as e:
raise ImportError(
"Unable to import pandas, please install with `pip install pandas`."
) from e
entity_scope = {
entity.name: entity for entity in self.causal_operations.entities
}
@@ -217,11 +221,17 @@ class StoryModel(BaseModel):
if self.query.llm_error_msg == "":
try:
import duckdb
df = self._outcome_table # noqa
query_result = duckdb.sql(self.query.expression).df()
self.query._result_table = query_result
except duckdb.BinderException as e:
self.query._result_table = humanize_sql_error_msg(str(e))
except ImportError as e:
raise ImportError(
"Unable to import duckdb, please install with `pip install duckdb`."
) from e
except Exception as e:
self.query._result_table = str(e)
else: