import { apiInterceptors, createModel, getSupportModels } from '@/client/api'; import { renderModelIcon } from '@/components/chat/header/model-selector'; import { ConfigurableParams } from '@/types/common'; import { StartModelParams, SupportModel } from '@/types/model'; import { AutoComplete, Button, Form, Select, Tooltip, message } from 'antd'; import { useEffect, useState } from 'react'; import { useTranslation } from 'react-i18next'; import ReactMarkdown from 'react-markdown'; import ConfigurableForm from '../common/configurable-form'; const { Option } = Select; const FormItem = Form.Item; // The supported worker types const WORKER_TYPES = ['llm', 'text2vec', 'reranker']; function ModelForm({ onCancel, onSuccess }: { onCancel: () => void; onSuccess: () => void }) { const { t } = useTranslation(); const [_, setModels] = useState | null>([]); const [selectedWorkerType, setSelectedWorkerType] = useState(); const [selectedProvider, setSelectedProvider] = useState(); const [params, setParams] = useState | null>(null); const [loading, setLoading] = useState(false); const [form] = Form.useForm(); const [groupedModels, setGroupedModels] = useState<{ [key: string]: SupportModel[] }>({}); const [providers, setProviders] = useState([]); async function getModels() { const [, res] = await apiInterceptors(getSupportModels()); if (res && res.length) { const sortedModels = res.sort((a: SupportModel, b: SupportModel) => { if (a.enabled && !b.enabled) return -1; if (!a.enabled && b.enabled) return 1; return a.model.localeCompare(b.model); }); setModels(sortedModels); const grouped = sortedModels.reduce((acc: { [key: string]: SupportModel[] }, model) => { const provider = model.provider; if (!acc[provider]) acc[provider] = []; acc[provider].push(model); return acc; }, {}); setGroupedModels(grouped); // Note: Initially do not set providers, wait for worker_type selection before setting setProviders([]); } } useEffect(() => { getModels(); }, []); // Filter and set available providers based on worker_type function updateProvidersByWorkerType(workerType: string) { const availableProviders = new Set(); Object.entries(groupedModels).forEach(([provider, models]) => { if (models.some(model => model.worker_type === workerType)) { availableProviders.add(provider); } }); setProviders(Array.from(availableProviders).sort()); } function handleWorkerTypeChange(value: string) { setSelectedWorkerType(value); setSelectedProvider(undefined); form.resetFields(); form.setFieldValue('worker_type', value); updateProvidersByWorkerType(value); } function handleProviderChange(value: string) { setSelectedProvider(value); form.setFieldValue('provider', value); // Get the params of the first model that matches the selected worker_type under the current provider as the default params const providerModels = groupedModels[value] || []; const filteredModels = providerModels.filter(m => m.worker_type === selectedWorkerType); if (filteredModels.length > 0) { const firstModel = filteredModels[0]; if (firstModel?.params) { setParams(Array.isArray(firstModel.params) ? firstModel.params : [firstModel.params]); } } } async function onFinish(values: any) { if (!selectedProvider || !selectedWorkerType) return; const processFormValues = (formValues: any) => { const processed = { ...formValues }; params?.forEach(param => { if (param.nested_fields && processed[param.param_name]) { const nestedValue = processed[param.param_name]; // Make sure to keep all field values if (nestedValue.type) { const typeFields = param.nested_fields[nestedValue.type] || []; const fieldValues = {}; // Collect values of all fields typeFields.forEach(field => { if (nestedValue[field.param_name] !== undefined) { fieldValues[field.param_name] = nestedValue[field.param_name]; } }); processed[param.param_name] = { ...fieldValues, type: nestedValue.type, }; } } }); return processed; }; setLoading(true); try { const processedValues = processFormValues(values); const selectedModel = groupedModels[selectedProvider]?.find(m => m.model === processedValues.name); const params: StartModelParams = { host: selectedModel?.host || '', port: selectedModel?.port || 0, model: processedValues.name, worker_type: selectedWorkerType, params: processedValues, }; const [, , data] = await apiInterceptors(createModel(params)); if (data?.success) { message.success(t('start_model_success')); form.resetFields(); onSuccess?.(); } } catch (_error) { message.error(t('start_model_failed')); } finally { setLoading(false); } } const renderTooltipContent = (model: SupportModel) => (
{model.description || model.model}
{model.enabled ? `${model.host}:${model.port}` : t('download_model_tip')}
); return (
{selectedWorkerType && ( )} {selectedProvider && selectedWorkerType && params && ( <> model.worker_type === selectedWorkerType) .map(model => ({ value: model.model, label: (
{renderModelIcon(model.model)} {model.model}
), }))} filterOption={(inputValue, option) => option!.value.toUpperCase().indexOf(inputValue.toUpperCase()) !== -1 } />
p.param_name !== 'name')} form={form} /> )}
); } export default ModelForm;