diff --git a/.dockerignore b/.dockerignore index 823dcbd59..7e4266596 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,3 +1,6 @@ +.env +.git/ +./.mypy_cache/ models/ plugins/ pilot/data @@ -5,6 +8,8 @@ pilot/message logs/ venv/ web/node_modules/ +web/.next/ +web/.env docs/node_modules/ build/ docs/build/ diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index cd7936a60..eaf1e953c 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -332,6 +332,10 @@ class Config(metaclass=Singleton): os.getenv("MULTI_INSTANCE", "False").lower() == "true" ) + self.SCHEDULER_ENABLED = ( + os.getenv("SCHEDULER_ENABLED", "True").lower() == "true" + ) + @property def local_db_manager(self) -> "ConnectorManager": from dbgpt.datasource.manages import ConnectorManager diff --git a/dbgpt/app/initialization/scheduler.py b/dbgpt/app/initialization/scheduler.py index 70b7bb71a..36a3107db 100644 --- a/dbgpt/app/initialization/scheduler.py +++ b/dbgpt/app/initialization/scheduler.py @@ -19,12 +19,14 @@ class DefaultScheduler(BaseComponent): system_app: SystemApp, scheduler_delay_ms: int = 5000, scheduler_interval_ms: int = 1000, + scheduler_enable: bool = True, ): super().__init__(system_app) self.system_app = system_app self._scheduler_interval_ms = scheduler_interval_ms self._scheduler_delay_ms = scheduler_delay_ms self._stop_event = threading.Event() + self._scheduler_enable = scheduler_enable def init_app(self, system_app: SystemApp): self.system_app = system_app @@ -39,7 +41,7 @@ class DefaultScheduler(BaseComponent): def _scheduler(self): time.sleep(self._scheduler_delay_ms / 1000) - while not self._stop_event.is_set(): + while self._scheduler_enable and not self._stop_event.is_set(): try: schedule.run_pending() except Exception as e: diff --git a/dbgpt/core/awel/dag/base.py b/dbgpt/core/awel/dag/base.py index 2f3521d24..2c12b3bad 100644 --- a/dbgpt/core/awel/dag/base.py +++ b/dbgpt/core/awel/dag/base.py @@ -145,6 +145,9 @@ class DAGVar: _executor: Optional[Executor] = None _variables_provider: Optional["VariablesProvider"] = None + # Whether check serializable for AWEL, it will be set to True when running AWEL + # operator in remote environment + _check_serializable: Optional[bool] = None @classmethod def enter_dag(cls, dag) -> None: @@ -257,6 +260,24 @@ class DAGVar: """ cls._variables_provider = variables_provider + @classmethod + def get_check_serializable(cls) -> Optional[bool]: + """Get the check serializable flag. + + Returns: + Optional[bool]: The check serializable flag + """ + return cls._check_serializable + + @classmethod + def set_check_serializable(cls, check_serializable: bool) -> None: + """Set the check serializable flag. + + Args: + check_serializable (bool): The check serializable flag to set + """ + cls._check_serializable = check_serializable + class DAGLifecycle: """The lifecycle of DAG.""" @@ -286,6 +307,7 @@ class DAGNode(DAGLifecycle, DependencyMixin, ViewMixin, ABC): node_name: Optional[str] = None, system_app: Optional[SystemApp] = None, executor: Optional[Executor] = None, + check_serializable: Optional[bool] = None, **kwargs, ) -> None: """Initialize a DAGNode. @@ -311,6 +333,7 @@ class DAGNode(DAGLifecycle, DependencyMixin, ViewMixin, ABC): node_id = self._dag._new_node_id() self._node_id: Optional[str] = node_id self._node_name: Optional[str] = node_name + self._check_serializable = check_serializable if self._dag: self._dag._append_node(self) @@ -486,6 +509,20 @@ class DAGNode(DAGLifecycle, DependencyMixin, ViewMixin, ABC): """Return the string of current DAGNode.""" return self.__repr__() + @classmethod + def _do_check_serializable(cls, obj: Any, obj_name: str = "Object"): + """Check whether the current DAGNode is serializable.""" + from dbgpt.util.serialization.check import check_serializable + + check_serializable(obj, obj_name) + + @property + def check_serializable(self) -> bool: + """Whether check serializable for current DAGNode.""" + if self._check_serializable is not None: + return self._check_serializable or False + return DAGVar.get_check_serializable() or False + def _build_task_key(task_name: str, key: str) -> str: return f"{task_name}___$$$$$$___{key}" diff --git a/dbgpt/core/awel/operators/base.py b/dbgpt/core/awel/operators/base.py index da82d2856..7c66c0adc 100644 --- a/dbgpt/core/awel/operators/base.py +++ b/dbgpt/core/awel/operators/base.py @@ -193,12 +193,29 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta): self.incremental_output = bool(kwargs["incremental_output"]) if "output_format" in kwargs: self.output_format = kwargs["output_format"] - self._runner: WorkflowRunner = runner self._dag_ctx: Optional[DAGContext] = None self._can_skip_in_branch = can_skip_in_branch self._variables_provider = variables_provider + def __getstate__(self): + """Customize the pickling process.""" + state = self.__dict__.copy() + if "_runner" in state: + del state["_runner"] + if "_executor" in state: + del state["_executor"] + if "_system_app" in state: + del state["_system_app"] + return state + + def __setstate__(self, state): + """Customize the unpickling process.""" + self.__dict__.update(state) + self._runner = default_runner + self._system_app = DAGVar.get_current_system_app() + self._executor = DAGVar.get_executor() + @property def current_dag_context(self) -> DAGContext: """Return the current DAG context.""" diff --git a/dbgpt/core/awel/operators/common_operator.py b/dbgpt/core/awel/operators/common_operator.py index f8bc25370..763992323 100644 --- a/dbgpt/core/awel/operators/common_operator.py +++ b/dbgpt/core/awel/operators/common_operator.py @@ -41,6 +41,12 @@ class JoinOperator(BaseOperator, Generic[OUT]): super().__init__(can_skip_in_branch=can_skip_in_branch, **kwargs) if not callable(combine_function): raise ValueError("combine_function must be callable") + + if self.check_serializable: + super()._do_check_serializable( + combine_function, + f"JoinOperator: {self}, combine_function: {combine_function}", + ) self.combine_function = combine_function async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: @@ -83,6 +89,11 @@ class ReduceStreamOperator(BaseOperator, Generic[IN, OUT]): super().__init__(**kwargs) if reduce_function and not callable(reduce_function): raise ValueError("reduce_function must be callable") + if reduce_function and self.check_serializable: + super()._do_check_serializable( + reduce_function, f"Operator: {self}, reduce_function: {reduce_function}" + ) + self.reduce_function = reduce_function async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: @@ -133,6 +144,12 @@ class MapOperator(BaseOperator, Generic[IN, OUT]): super().__init__(**kwargs) if map_function and not callable(map_function): raise ValueError("map_function must be callable") + + if map_function and self.check_serializable: + super()._do_check_serializable( + map_function, f"Operator: {self}, map_function: {map_function}" + ) + self.map_function = map_function async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: diff --git a/dbgpt/model/proxy/base.py b/dbgpt/model/proxy/base.py index 129dcf11e..2a1a3b6b8 100644 --- a/dbgpt/model/proxy/base.py +++ b/dbgpt/model/proxy/base.py @@ -94,6 +94,17 @@ class ProxyLLMClient(LLMClient): self.executor = executor or ThreadPoolExecutor() self.proxy_tokenizer = proxy_tokenizer or TiktokenProxyTokenizer() + def __getstate__(self): + """Customize the serialization of the object""" + state = self.__dict__.copy() + state.pop("executor") + return state + + def __setstate__(self, state): + """Customize the deserialization of the object""" + self.__dict__.update(state) + self.executor = ThreadPoolExecutor() + @classmethod @abstractmethod def new_client( diff --git a/dbgpt/serve/flow/api/variables_provider.py b/dbgpt/serve/flow/api/variables_provider.py index 27ed63bf5..10bb8a656 100644 --- a/dbgpt/serve/flow/api/variables_provider.py +++ b/dbgpt/serve/flow/api/variables_provider.py @@ -341,7 +341,7 @@ class BuiltinAgentsVariablesProvider(BuiltinVariablesProvider): StorageVariables( key=key, name=agent["name"], - label=agent["desc"], + label=agent["name"], value=agent["name"], scope=scope, scope_key=scope_key, diff --git a/dbgpt/storage/metadata/_base_dao.py b/dbgpt/storage/metadata/_base_dao.py index 04b74e81c..199b3eabc 100644 --- a/dbgpt/storage/metadata/_base_dao.py +++ b/dbgpt/storage/metadata/_base_dao.py @@ -285,6 +285,9 @@ class BaseDao(Generic[T, REQ, RES]): else model_to_dict(query_request) ) for key, value in query_dict.items(): + if value and isinstance(value, (list, tuple, dict, set)): + # Skip the list, tuple, dict, set + continue if value is not None and hasattr(model_cls, key): if isinstance(value, list): if len(value) > 0: diff --git a/dbgpt/util/net_utils.py b/dbgpt/util/net_utils.py index fc9fb3f86..ce41ba781 100644 --- a/dbgpt/util/net_utils.py +++ b/dbgpt/util/net_utils.py @@ -1,5 +1,6 @@ import errno import socket +from typing import Set, Tuple def _get_ip_address(address: str = "10.254.254.254:1") -> str: @@ -22,3 +23,34 @@ def _get_ip_address(address: str = "10.254.254.254:1") -> str: finally: s.close() return curr_address + + +async def _async_get_free_port( + port_range: Tuple[int, int], timeout: int, used_ports: Set[int] +): + import asyncio + + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, _get_free_port, port_range, timeout, used_ports + ) + + +def _get_free_port(port_range: Tuple[int, int], timeout: int, used_ports: Set[int]): + import random + + available_ports = set(range(port_range[0], port_range[1] + 1)) - used_ports + if not available_ports: + raise RuntimeError("No available ports in the specified range") + + while available_ports: + port = random.choice(list(available_ports)) + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", port)) + used_ports.add(port) + return port + except OSError: + available_ports.remove(port) + + raise RuntimeError("No available ports in the specified range") diff --git a/dbgpt/util/serialization/check.py b/dbgpt/util/serialization/check.py new file mode 100644 index 000000000..10a86edb2 --- /dev/null +++ b/dbgpt/util/serialization/check.py @@ -0,0 +1,85 @@ +import inspect +from io import StringIO +from typing import Any, Dict, Optional, TextIO + +import cloudpickle + + +def check_serializable( + obj: Any, obj_name: str = "Object", error_msg: str = "Object is not serializable" +): + try: + cloudpickle.dumps(obj) + except Exception as e: + inspect_info = inspect_serializability(obj, obj_name) + msg = f"{error_msg}\n{inspect_info['report']}" + raise TypeError(msg) from e + + +class SerializabilityInspector: + def __init__(self, stream: Optional[TextIO] = None): + self.stream = stream or StringIO() + self.failures = {} + self.indent_level = 0 + + def log(self, message: str): + indent = " " * self.indent_level + self.stream.write(f"{indent}{message}\n") + + def inspect(self, obj: Any, name: str, depth: int = 3) -> bool: + self.log(f"Inspecting '{name}'") + self.indent_level += 1 + + try: + cloudpickle.dumps(obj) + self.indent_level -= 1 + return True + except Exception as e: + self.failures[name] = str(e) + self.log(f"Failure: {str(e)}") + + if depth > 0: + if inspect.isfunction(obj) or inspect.ismethod(obj): + self._inspect_function(obj, depth - 1) + elif hasattr(obj, "__dict__"): + self._inspect_object(obj, depth - 1) + + self.indent_level -= 1 + return False + + def _inspect_function(self, func, depth): + closure = inspect.getclosurevars(func) + for name, value in closure.nonlocals.items(): + self.inspect(value, f"{func.__name__}.{name}", depth) + for name, value in closure.globals.items(): + self.inspect(value, f"global:{name}", depth) + + def _inspect_object(self, obj, depth): + for name, value in inspect.getmembers(obj): + if not name.startswith("__"): + self.inspect(value, f"{type(obj).__name__}.{name}", depth) + + def get_report(self) -> str: + summary = "\nSummary of Serialization Failures:\n" + if not self.failures: + summary += "All components are serializable.\n" + else: + for name, error in self.failures.items(): + summary += f" - {name}: {error}\n" + + return self.stream.getvalue() + summary + + +def inspect_serializability( + obj: Any, + name: Optional[str] = None, + depth: int = 5, + stream: Optional[TextIO] = None, +) -> Dict[str, Any]: + inspector = SerializabilityInspector(stream) + success = inspector.inspect(obj, name or type(obj).__name__, depth) + return { + "success": success, + "failures": inspector.failures, + "report": inspector.get_report(), + } diff --git a/docker/base/build_image.sh b/docker/base/build_image.sh index 028dcc809..08cd0b549 100755 --- a/docker/base/build_image.sh +++ b/docker/base/build_image.sh @@ -20,16 +20,21 @@ LOAD_EXAMPLES="true" BUILD_NETWORK="" DB_GPT_INSTALL_MODEL="default" +DOCKERFILE="Dockerfile" +IMAGE_NAME_SUFFIX="" + usage () { echo "USAGE: $0 [--base-image nvidia/cuda:12.1.0-runtime-ubuntu22.04] [--image-name db-gpt]" echo " [-b|--base-image base image name] Base image name" echo " [-n|--image-name image name] Current image name, default: db-gpt" + echo " [--image-name-suffix image name suffix] Image name suffix" echo " [-i|--pip-index-url pip index url] Pip index url, default: https://pypi.org/simple" echo " [--language en or zh] You language, default: en" echo " [--build-local-code true or false] Whether to use the local project code to package the image, default: true" echo " [--load-examples true or false] Whether to load examples to default database default: true" echo " [--network network name] The network of docker build" echo " [--install-mode mode name] Installation mode name, default: default, If you completely use openai's service, you can set the mode name to 'openai'" + echo " [-f|--dockerfile dockerfile] Dockerfile name, default: Dockerfile" echo " [-h|--help] Usage message" } @@ -46,6 +51,11 @@ while [[ $# -gt 0 ]]; do shift # past argument shift # past value ;; + --image-name-suffix) + IMAGE_NAME_SUFFIX="$2" + shift # past argument + shift # past value + ;; -i|--pip-index-url) PIP_INDEX_URL="$2" shift @@ -80,6 +90,11 @@ while [[ $# -gt 0 ]]; do shift # past argument shift # past value ;; + -f|--dockerfile) + DOCKERFILE="$2" + shift # past argument + shift # past value + ;; -h|--help) help="true" shift @@ -111,6 +126,10 @@ else BASE_IMAGE=$IMAGE_NAME_ARGS fi +if [ -n "$IMAGE_NAME_SUFFIX" ]; then + IMAGE_NAME="$IMAGE_NAME-$IMAGE_NAME_SUFFIX" +fi + echo "Begin build docker image, base image: ${BASE_IMAGE}, target image name: ${IMAGE_NAME}" docker build $BUILD_NETWORK \ @@ -120,5 +139,5 @@ docker build $BUILD_NETWORK \ --build-arg BUILD_LOCAL_CODE=$BUILD_LOCAL_CODE \ --build-arg LOAD_EXAMPLES=$LOAD_EXAMPLES \ --build-arg DB_GPT_INSTALL_MODEL=$DB_GPT_INSTALL_MODEL \ - -f Dockerfile \ + -f $DOCKERFILE \ -t $IMAGE_NAME $WORK_DIR/../../ diff --git a/web/client/api/flow/index.ts b/web/client/api/flow/index.ts index 25bdd98be..2d5d2d1dd 100644 --- a/web/client/api/flow/index.ts +++ b/web/client/api/flow/index.ts @@ -6,6 +6,10 @@ import { IFlowRefreshParams, IFlowResponse, IFlowUpdateParam, + IGetKeysRequestParams, + IGetKeysResponseData, + IGetVariablesByKeyRequestParams, + IGetVariablesByKeyResponseData, IUploadFileRequestParams, IUploadFileResponse, } from '@/types/flow'; @@ -63,17 +67,22 @@ export const downloadFile = (fileId: string) => { return GET(`/api/v2/serve/file/files/dbgpt/${fileId}`); }; -// TODO:wait for interface update -export const getFlowTemplateList = () => { - return GET>('/api/v2/serve/awel/flow/templates'); -}; - export const getFlowTemplateById = (id: string) => { return GET(`/api/v2/serve/awel/flow/templates/${id}`); }; + export const getFlowTemplates = () => { return GET(`/api/v2/serve/awel/flow/templates`); }; + +export const getKeys = (data?: IGetKeysRequestParams) => { + return GET>('/api/v2/serve/awel/variables/keys', data); +}; + +export const getVariablesByKey = (data: IGetVariablesByKeyRequestParams) => { + return GET('/api/v2/serve/awel/variables', data); +}; + export const metadataBatch = (data: IUploadFileRequestParams) => { return POST>('/api/v2/serve/file/files/metadata/batch', data); -}; \ No newline at end of file +}; diff --git a/web/components/flow/add-nodes-sider.tsx b/web/components/flow/add-nodes-sider.tsx index d71a12707..2db8f45c5 100644 --- a/web/components/flow/add-nodes-sider.tsx +++ b/web/components/flow/add-nodes-sider.tsx @@ -4,7 +4,8 @@ import { IFlowNode } from '@/types/flow'; import { FLOW_NODES_KEY } from '@/utils'; import { CaretLeftOutlined, CaretRightOutlined } from '@ant-design/icons'; import type { CollapseProps } from 'antd'; -import { Badge, Collapse, Input, Layout, Space, Tag } from 'antd'; +import { Badge, Collapse, Input, Layout, Space, Switch } from 'antd'; +import classnames from 'classnames'; import React, { useContext, useEffect, useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; import StaticNodes from './static-nodes'; @@ -199,9 +200,13 @@ const AddNodesSider: React.FC = () => { {t('add_node')}

- - {isAllNodesVisible ? t('All_Nodes') : t('Higher_Order_Nodes')} - + diff --git a/web/components/flow/canvas-modal/add-flow-variable-modal.tsx b/web/components/flow/canvas-modal/add-flow-variable-modal.tsx index 961ec76f7..0c76ff19b 100644 --- a/web/components/flow/canvas-modal/add-flow-variable-modal.tsx +++ b/web/components/flow/canvas-modal/add-flow-variable-modal.tsx @@ -1,70 +1,58 @@ -// import { IFlowNode } from '@/types/flow'; +import { apiInterceptors, getKeys, getVariablesByKey } from '@/client/api'; +import { IFlowUpdateParam, IGetKeysResponseData, IVariableItem } from '@/types/flow'; +import { buildVariableString } from '@/utils/flow'; import { MinusCircleOutlined, PlusOutlined } from '@ant-design/icons'; -import { Button, Form, Input, Modal, Select, Space } from 'antd'; -import React, { useState } from 'react'; +import { Button, Cascader, Form, Input, InputNumber, Modal, Select, Space } from 'antd'; +import { DefaultOptionType } from 'antd/es/cascader'; +import { uniqBy } from 'lodash'; +import React, { useEffect, useState } from 'react'; import { useTranslation } from 'react-i18next'; -// ype GroupType = { category: string; categoryLabel: string; nodes: IFlowNode[] }; -type ValueType = 'str' | 'int' | 'float' | 'bool' | 'ref'; - const { Option } = Select; +const VALUE_TYPES = ['str', 'int', 'float', 'bool', 'ref'] as const; -const DAG_PARAM_KEY = 'dbgpt.core.flow.params'; -const DAG_PARAM_SCOPE = 'flow_priv'; +type ValueType = (typeof VALUE_TYPES)[number]; +type Props = { + flowInfo?: IFlowUpdateParam; + setFlowInfo: React.Dispatch>; +}; -export const AddFlowVariableModal: React.FC = () => { +export const AddFlowVariableModal: React.FC = ({ flowInfo, setFlowInfo }) => { const { t } = useTranslation(); - // const [operators, setOperators] = useState>([]); - // const [resources, setResources] = useState>([]); - // const [operatorsGroup, setOperatorsGroup] = useState([]); - // const [resourcesGroup, setResourcesGroup] = useState([]); const [isModalOpen, setIsModalOpen] = useState(false); - const [form] = Form.useForm(); // const [form] = Form.useForm(); + const [form] = Form.useForm(); + const [controlTypes, setControlTypes] = useState(['str']); + const [refVariableOptions, setRefVariableOptions] = useState([]); - const showModal = () => { - setIsModalOpen(true); + useEffect(() => { + getKeysData(); + }, []); + + const getKeysData = async () => { + const [err, res] = await apiInterceptors(getKeys()); + + if (err) return; + + const keyOptions = res?.map(({ key, label, scope }: IGetKeysResponseData) => ({ + value: key, + label, + scope, + isLeaf: false, + })); + + setRefVariableOptions(keyOptions); }; - // TODO: get keys - // useEffect(() => { - // getNodes(); - // }, []); - - // async function getNodes() { - // const [_, data] = await apiInterceptors(getFlowNodes()); - // if (data && data.length > 0) { - // localStorage.setItem(FLOW_NODES_KEY, JSON.stringify(data)); - // const operatorNodes = data.filter(node => node.flow_type === 'operator'); - // const resourceNodes = data.filter(node => node.flow_type === 'resource'); - // setOperators(operatorNodes); - // setResources(resourceNodes); - // setOperatorsGroup(groupNodes(operatorNodes)); - // setResourcesGroup(groupNodes(resourceNodes)); - // } - // } - - // function groupNodes(data: IFlowNode[]) { - // const groups: GroupType[] = []; - // const categoryMap: Record = {}; - // data.forEach(item => { - // const { category, category_label } = item; - // if (!categoryMap[category]) { - // categoryMap[category] = { category, categoryLabel: category_label, nodes: [] }; - // groups.push(categoryMap[category]); - // } - // categoryMap[category].nodes.push(item); - // }); - // return groups; - // } - const onFinish = (values: any) => { - console.log('Received values of form:', values); + const newFlowInfo = { ...flowInfo, variables: values?.parameters || [] } as IFlowUpdateParam; + setFlowInfo(newFlowInfo); + setIsModalOpen(false); }; - function onNameChange(e: React.ChangeEvent, index: number) { + const onNameChange = (e: React.ChangeEvent, index: number) => { const name = e.target.value; - const result = name + const newValue = name ?.split('_') ?.map(word => word.charAt(0).toUpperCase() + word.slice(1)) ?.join(' '); @@ -72,43 +60,105 @@ export const AddFlowVariableModal: React.FC = () => { form.setFields([ { name: ['parameters', index, 'label'], - value: result, + value: newValue, }, ]); + }; - // change value to ref - const type = form.getFieldValue(['parameters', index, 'value_type']); + const onValueTypeChange = (type: ValueType, index: number) => { + const newControlTypes = [...controlTypes]; + newControlTypes[index] = type; + setControlTypes(newControlTypes); + }; - if (type === 'ref') { - const parameters = form.getFieldValue('parameters'); - const param = parameters?.[index]; + const loadData = (selectedOptions: DefaultOptionType[]) => { + const targetOption = selectedOptions[selectedOptions.length - 1]; + const { value, scope } = targetOption as DefaultOptionType & { scope: string }; - if (param) { - const { name = '' } = param; - param.value = `${DAG_PARAM_KEY}:${name}@scope:${DAG_PARAM_SCOPE}`; + setTimeout(async () => { + const [err, res] = await apiInterceptors(getVariablesByKey({ key: value as string, scope })); - form.setFieldsValue({ - parameters: [...parameters], - }); + if (err) return; + if (res?.total_count === 0) { + targetOption.isLeaf = true; + return; } + + const uniqueItems = uniqBy(res?.items, 'name'); + targetOption.children = uniqueItems?.map(item => ({ + value: item?.name, + label: item.label, + item: item, + })); + setRefVariableOptions([...refVariableOptions]); + }, 1000); + }; + + const onRefTypeValueChange = ( + value: (string | number | null)[], + selectedOptions: DefaultOptionType[], + index: number, + ) => { + // when select ref variable, must be select two options(key and variable) + if (value?.length !== 2) return; + + const [selectRefKey, selectedRefVariable] = selectedOptions as DefaultOptionType[]; + const selectedVariable = selectRefKey?.children?.find( + ({ value }) => value === selectedRefVariable?.value, + ) as DefaultOptionType & { item: IVariableItem }; + + // build variable string by rule + const variableStr = buildVariableString(selectedVariable?.item); + const parameters = form.getFieldValue('parameters'); + const param = parameters?.[index]; + if (param) { + param.value = variableStr; + param.category = selectedVariable?.item?.category; + param.value_type = selectedVariable?.item?.value_type; + + form.setFieldsValue({ + parameters: [...parameters], + }); } - } + }; - function onValueTypeChange(type: ValueType, index: number) { - if (type === 'ref') { - const parameters = form.getFieldValue('parameters'); - const param = parameters?.[index]; - - if (param) { - const { name = '' } = param; - param.value = `${DAG_PARAM_KEY}:${name}@scope:${DAG_PARAM_SCOPE}`; - - form.setFieldsValue({ - parameters: [...parameters], - }); - } + // Helper function to render the appropriate control component + const renderVariableValue = (type: string, index: number) => { + switch (type) { + case 'ref': + return ( + onRefTypeValueChange(value, selectedOptions, index)} + changeOnSelect + /> + ); + case 'str': + return ; + case 'int': + return ( + value?.replace(/[^\-?\d]/g, '') || 0} + style={{ width: '100%' }} + /> + ); + case 'float': + return ; + case 'bool': + return ( + + ); + default: + return ; } - } + }; return ( <> @@ -117,24 +167,32 @@ export const AddFlowVariableModal: React.FC = () => { className='flex items-center justify-center rounded-full left-4 top-4' style={{ zIndex: 1050 }} icon={} - onClick={showModal} + onClick={() => setIsModalOpen(true)} /> setIsModalOpen(false)} + onCancel={() => setIsModalOpen(false)} + footer={[ + , + , + ]} >
{ autoComplete='off' layout='vertical' className='mt-8' - initialValues={{ parameters: [{}] }} + initialValues={{ parameters: flowInfo?.variables || [{}] }} > {(fields, { add, remove }) => ( <> {fields.map(({ key, name, ...restField }, index) => ( - + { rules={[{ required: true, message: 'Missing parameter type' }]} > + {renderVariableValue(controlTypes[index], index)} @@ -207,6 +265,10 @@ export const AddFlowVariableModal: React.FC = () => { remove(name)} /> + + ))} @@ -218,15 +280,6 @@ export const AddFlowVariableModal: React.FC = () => { )} - - - - - - -
diff --git a/web/components/flow/canvas-modal/export-flow-modal.tsx b/web/components/flow/canvas-modal/export-flow-modal.tsx index 0d056abac..21f60efbe 100644 --- a/web/components/flow/canvas-modal/export-flow-modal.tsx +++ b/web/components/flow/canvas-modal/export-flow-modal.tsx @@ -1,5 +1,5 @@ import { IFlowData, IFlowUpdateParam } from '@/types/flow'; -import { Button, Form, Input, Modal, Radio, Space, message } from 'antd'; +import { Button, Form, Input, Modal, Radio, message } from 'antd'; import { useTranslation } from 'react-i18next'; import { ReactFlowInstance } from 'reactflow'; @@ -46,8 +46,14 @@ export const ExportFlowModal: React.FC = ({ title={t('Export_Flow')} open={isExportFlowModalOpen} onCancel={() => setIsExportFlowModalOpen(false)} - cancelButtonProps={{ className: 'hidden' }} - okButtonProps={{ className: 'hidden' }} + footer={[ + , + , + ]} >
= ({ - - - - - - -
diff --git a/web/components/flow/canvas-modal/import-flow-modal.tsx b/web/components/flow/canvas-modal/import-flow-modal.tsx index bb6e78f18..3711b78ed 100644 --- a/web/components/flow/canvas-modal/import-flow-modal.tsx +++ b/web/components/flow/canvas-modal/import-flow-modal.tsx @@ -1,11 +1,11 @@ import { apiInterceptors, importFlow } from '@/client/api'; +import CanvasWrapper from '@/pages/construct/flow/canvas/index'; import { UploadOutlined } from '@ant-design/icons'; -import { Button, Form, GetProp, Modal, Radio, Space, Upload, UploadFile, UploadProps, message } from 'antd'; +import { Button, Form, GetProp, Modal, Radio, Upload, UploadFile, UploadProps, message } from 'antd'; import { useEffect, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { Edge, Node } from 'reactflow'; -import CanvasWrapper from '@/pages/construct/flow/canvas/index'; type Props = { isImportModalOpen: boolean; setNodes: React.Dispatch[]>>; @@ -40,10 +40,9 @@ export const ImportFlowModal: React.FC = ({ isImportModalOpen, setIsImpor if (res?.success) { messageApi.success(t('Import_Flow_Success')); localStorage.setItem('importFlowData', JSON.stringify(res?.data)); - CanvasWrapper() + CanvasWrapper(); } else if (res?.err_msg) { messageApi.error(res?.err_msg); - } setIsImportFlowModalOpen(false); }; @@ -68,8 +67,14 @@ export const ImportFlowModal: React.FC = ({ isImportModalOpen, setIsImpor title={t('Import_Flow')} open={isImportModalOpen} onCancel={() => setIsImportFlowModalOpen(false)} - cancelButtonProps={{ className: 'hidden' }} - okButtonProps={{ className: 'hidden' }} + footer={[ + , + , + ]} >
= ({ isImportModalOpen, setIsImpor - + - - - - - - - diff --git a/web/components/flow/canvas-modal/save-flow-modal.tsx b/web/components/flow/canvas-modal/save-flow-modal.tsx index 708f6f768..0434af90f 100644 --- a/web/components/flow/canvas-modal/save-flow-modal.tsx +++ b/web/components/flow/canvas-modal/save-flow-modal.tsx @@ -1,7 +1,7 @@ import { addFlow, apiInterceptors, updateFlowById } from '@/client/api'; import { IFlowData, IFlowUpdateParam } from '@/types/flow'; import { mapHumpToUnderline } from '@/utils/flow'; -import { Button, Checkbox, Form, Input, Modal, Space, message } from 'antd'; +import { Button, Checkbox, Form, Input, Modal, message } from 'antd'; import { useRouter } from 'next/router'; import { useEffect, useState } from 'react'; import { useTranslation } from 'react-i18next'; @@ -58,6 +58,7 @@ export const SaveFlowModal: React.FC = ({ uid: id.toString(), flow_data: reactFlowObject, state, + variables: flowInfo?.variables, }), ); @@ -75,6 +76,7 @@ export const SaveFlowModal: React.FC = ({ editable, flow_data: reactFlowObject, state, + variables: flowInfo?.variables, }), ); @@ -91,11 +93,15 @@ export const SaveFlowModal: React.FC = ({ { - setIsSaveFlowModalOpen(false); - }} - cancelButtonProps={{ className: 'hidden' }} - okButtonProps={{ className: 'hidden' }} + onCancel={() => setIsSaveFlowModalOpen(false)} + footer={[ + , + , + ]} >
= ({ }} /> - - - - - - -
diff --git a/web/locales/en/flow.ts b/web/locales/en/flow.ts index b24879616..45e90e21f 100644 --- a/web/locales/en/flow.ts +++ b/web/locales/en/flow.ts @@ -23,6 +23,6 @@ export const FlowEn = { Please_Add_Nodes_First: 'Please add nodes first', Add_Global_Variable_of_Flow: 'Add global variable of flow', Add_Parameter: 'Add Parameter', - Higher_Order_Nodes: 'Higher Order Nodes', - All_Nodes: 'All Nodes', + Higher_Order_Nodes: 'Higher Order', + All_Nodes: 'All', }; diff --git a/web/locales/zh/flow.ts b/web/locales/zh/flow.ts index f6eb22869..9238b6221 100644 --- a/web/locales/zh/flow.ts +++ b/web/locales/zh/flow.ts @@ -23,6 +23,6 @@ export const FlowZn = { Please_Add_Nodes_First: '请先添加节点', Add_Global_Variable_of_Flow: '添加 Flow 全局变量', Add_Parameter: '添加参数', - Higher_Order_Nodes: '高阶节点', - All_Nodes: '所有节点', + Higher_Order_Nodes: '高阶', + All_Nodes: '所有', }; diff --git a/web/pages/construct/app/index.tsx b/web/pages/construct/app/index.tsx index 1e67c989f..737753589 100644 --- a/web/pages/construct/app/index.tsx +++ b/web/pages/construct/app/index.tsx @@ -308,15 +308,14 @@ export default function AppContent() { className='w-[230px] h-[40px] border-1 border-white backdrop-filter backdrop-blur-lg bg-white bg-opacity-30 dark:border-[#6f7f95] dark:bg-[#6f7f95] dark:bg-opacity-60' /> -
- -
+ +
{apps.map(item => { diff --git a/web/pages/construct/flow/canvas/index.tsx b/web/pages/construct/flow/canvas/index.tsx index 0cd5b285d..370456a3e 100644 --- a/web/pages/construct/flow/canvas/index.tsx +++ b/web/pages/construct/flow/canvas/index.tsx @@ -12,7 +12,7 @@ import { import CanvasNode from '@/components/flow/canvas-node'; import { IFlowData, IFlowUpdateParam } from '@/types/flow'; import { checkFlowDataRequied, getUniqueNodeId, mapUnderlineToHump } from '@/utils/flow'; -import { ExportOutlined, FrownOutlined, ImportOutlined, FileAddOutlined,SaveOutlined } from '@ant-design/icons'; +import { ExportOutlined, FileAddOutlined, FrownOutlined, ImportOutlined, SaveOutlined } from '@ant-design/icons'; import { Divider, Space, Tooltip, message, notification } from 'antd'; import { useSearchParams } from 'next/navigation'; import React, { DragEvent, useCallback, useEffect, useRef, useState } from 'react'; @@ -32,27 +32,27 @@ import 'reactflow/dist/style.css'; const nodeTypes = { customNode: CanvasNode }; const edgeTypes = { buttonedge: ButtonEdge }; + const Canvas: React.FC = () => { - const { t } = useTranslation(); - const [messageApi, contextHolder] = message.useMessage(); - const searchParams = useSearchParams(); const id = searchParams?.get('id') || ''; const reactFlow = useReactFlow(); + const [messageApi, contextHolder] = message.useMessage(); + const reactFlowWrapper = useRef(null); - const [loading, setLoading] = useState(false); const [nodes, setNodes, onNodesChange] = useNodesState([]); const [edges, setEdges, onEdgesChange] = useEdgesState([]); - const reactFlowWrapper = useRef(null); + const [flowInfo, setFlowInfo] = useState(); + const [loading, setLoading] = useState(false); const [isSaveFlowModalOpen, setIsSaveFlowModalOpen] = useState(false); const [isExportFlowModalOpen, setIsExportFlowModalOpen] = useState(false); const [isImportModalOpen, setIsImportFlowModalOpen] = useState(false); const [isTemplateFlowModalOpen, setIsTemplateFlowModalOpen] = useState(false); if (localStorage.getItem('importFlowData')) { - const importFlowData = JSON.parse(localStorage.getItem('importFlowData')); + const importFlowData = JSON.parse(localStorage.getItem('importFlowData') || ''); localStorage.removeItem('importFlowData'); setLoading(true); const flowData = mapUnderlineToHump(importFlowData.flow_data); @@ -61,6 +61,7 @@ const Canvas: React.FC = () => { setEdges(flowData.edges); setLoading(false); } + async function getFlowData() { setLoading(true); const [_, data] = await apiInterceptors(getFlowById(id)); @@ -275,7 +276,7 @@ const Canvas: React.FC = () => { - +
diff --git a/web/pages/mobile/chat/components/InputContainer.tsx b/web/pages/mobile/chat/components/InputContainer.tsx index f9af2ea3e..59a545684 100644 --- a/web/pages/mobile/chat/components/InputContainer.tsx +++ b/web/pages/mobile/chat/components/InputContainer.tsx @@ -6,7 +6,7 @@ import { ClearOutlined, LoadingOutlined, PauseCircleOutlined, RedoOutlined, Send import { EventStreamContentType, fetchEventSource } from '@microsoft/fetch-event-source'; import { useRequest } from 'ahooks'; import { Button, Input, Popover, Spin, Tag } from 'antd'; -import cls from 'classnames'; +import classnames from 'classnames'; import { useSearchParams } from 'next/navigation'; import React, { useContext, useEffect, useMemo, useState } from 'react'; import { MobileChatContext } from '../'; @@ -245,7 +245,7 @@ const InputContainer: React.FC = () => {
{ { ) : ( {
{/* 输入框 */}
{