diff --git a/backend/app/api/llm.py b/backend/app/api/llm.py index 28e0313..ea8661d 100644 --- a/backend/app/api/llm.py +++ b/backend/app/api/llm.py @@ -6,6 +6,7 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from jose import jwt, JWTError from pydantic import BaseModel, Field from app.core.security import SECRET_KEY, ALGORITHM +from litellm import completion router = APIRouter() security = HTTPBearer() @@ -56,6 +57,13 @@ class LLMConfigUpdate(BaseModel): extra_headers: Optional[Dict[str, str]] = None is_active: Optional[bool] = None +class TestConnectionRequest(BaseModel): + provider: str + model: str + api_key: Optional[str] = None + api_base: Optional[str] = None + extra_headers: Optional[Dict[str, str]] = None + def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)) -> CurrentUser: unauthorized = HTTPException( @@ -152,3 +160,54 @@ def delete_llm_config(config_id: str, _: CurrentUser = Depends(get_admin_user)): raise HTTPException(status_code=404, detail="LLM configuration not found") _save_data(data) return {"message": "LLM configuration deleted successfully"} + +@router.post("/llm/test") +def test_connection(request: TestConnectionRequest, _: CurrentUser = Depends(get_admin_user)): + try: + # Use litellm to test connection + # litellm handles many providers + kwargs = { + "model": request.model, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5 + } + + if request.api_key: + kwargs["api_key"] = request.api_key + + if request.api_base: + kwargs["api_base"] = request.api_base + + if request.extra_headers: + kwargs["extra_headers"] = request.extra_headers + + # For OpenAI-compatible endpoints that are not standard OpenAI (like Local, vLLM etc) + # usually user sets provider to "openai" and api_base to their custom URL. + # litellm usually works well if we pass custom_llm_provider="openai" if provider is openai but custom url + + # If provider is "local" or "openai", we generally use "openai" format + if request.provider == "local": + kwargs["custom_llm_provider"] = "openai" + elif request.provider: + kwargs["custom_llm_provider"] = request.provider + + # If user explicitly selected provider in UI, we might want to respect that + # But litellm completion main arg is 'model'. + # If the UI 'model' input doesn't have prefix, we might need to add it or pass custom_llm_provider. + + # Simple heuristic: if provider is set, try to pass it if litellm supports it or just rely on env vars/args + # For this simple test, we just try to call it. + + try: + response = completion(**kwargs) + except Exception as first_error: + error_text = str(first_error) + if request.provider and "Provider NOT provided" in error_text and "/" not in request.model: + retry_kwargs = kwargs.copy() + retry_kwargs["model"] = f"{request.provider}/{request.model}" + response = completion(**retry_kwargs) + else: + raise first_error + return {"success": True, "message": "Connection successful", "details": str(response)} + except Exception as e: + raise HTTPException(status_code=400, detail=f"Connection failed: {str(e)}") diff --git a/backend/data/llm_config.json b/backend/data/llm_config.json index e91756d..4683a01 100644 --- a/backend/data/llm_config.json +++ b/backend/data/llm_config.json @@ -1,12 +1,16 @@ [ { "id": "m1773487590", - "provider": "zhipuai", + "provider": "openai", "model": "glm-4-7-251222", - "api_key": "secret", + "api_key": "4a54896c-dac5-4aa6-b618-558dfbd89e4a", "api_base": "https://ark.cn-beijing.volces.com/api/v3", - "extra_headers": null, - "is_active": true + "extra_headers": {}, + "is_active": true, + "name": "glm-4-7-251222", + "model_type": "\u5927\u8bed\u8a00\u6a21\u578b", + "base_model": "glm-4-7-251222", + "protocol_type": "OpenAI" }, { "id": "deny1", diff --git a/frontend/src/components/ChatInterface.tsx b/frontend/src/components/ChatInterface.tsx index d6e25c0..9fd264a 100644 --- a/frontend/src/components/ChatInterface.tsx +++ b/frontend/src/components/ChatInterface.tsx @@ -2,9 +2,12 @@ import { useState, useRef, useEffect } from "react"; import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; import { ScrollArea } from "@/components/ui/scroll-area"; -import { User, Loader2, Sparkles, Search, ArrowUp, ChevronDown, Table, Paperclip } from "lucide-react"; +import { User, Loader2, Sparkles, Search, ArrowUp, ChevronDown, Table, Paperclip, Check } from "lucide-react"; import { api } from "@/lib/api"; import { useVisualizationStore } from "@/store/visualizationStore"; +import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; +import { Command, CommandEmpty, CommandGroup, CommandInput, CommandItem, CommandList } from "@/components/ui/command"; +import { cn } from "@/lib/utils"; interface Message { id: string; @@ -12,6 +15,14 @@ interface Message { content: string; } +interface ModelConfig { + id: string; + name?: string; + model: string; + provider: string; + is_active: boolean; +} + export function ChatInterface() { const [messages, setMessages] = useState([ { id: '1', role: 'assistant', content: 'Hello! I am DataClaw. How can I help you analyze your data today?' } @@ -23,6 +34,33 @@ export function ChatInterface() { const scrollRef = useRef(null); const { setVisualization, setLoading: setVizLoading, setError: setVizError } = useVisualizationStore(); + // Model selection state + const [models, setModels] = useState([]); + const [selectedModelId, setSelectedModelId] = useState(""); + const [modelOpen, setModelOpen] = useState(false); + + useEffect(() => { + fetchModels(); + }, []); + + const fetchModels = async () => { + try { + const data = await api.get("/api/v1/llm"); + setModels(data); + // Set default model if available + const active = data.find(m => m.is_active); + if (active) { + setSelectedModelId(active.id); + } else if (data.length > 0) { + setSelectedModelId(data[0].id); + } + } catch (e) { + console.error("Failed to fetch models", e); + } + }; + + const currentModel = models.find(m => m.id === selectedModelId); + const capabilities = [ { icon: Sparkles, label: "智能问答", color: "text-purple-500", bg: "bg-purple-50" }, { icon: Table, label: "表格问答", color: "text-orange-500", bg: "bg-orange-50" }, @@ -51,7 +89,8 @@ export function ChatInterface() { const source = selectedDataSource.split('-')[0]; // postgres-main -> postgres const response = await api.post<{sql?: string, result?: unknown, error?: string}>('/api/v1/agent/nl2sql', { query: newMessage.content, - source: source + source: source, + model_id: selectedModelId // Pass selected model ID if backend supports it }); if (response.error) { @@ -76,7 +115,8 @@ export function ChatInterface() { // General Chat const response = await api.post<{response: string}>('/nanobot/chat', { message: newMessage.content, - skill_ids: [selectedSkill] + skill_ids: [selectedSkill], + model_id: selectedModelId }); setMessages(prev => [...prev, { @@ -101,14 +141,49 @@ export function ChatInterface() { return (
{/* Top Bar */} -
- +
+ + + {currentModel ? (currentModel.name || currentModel.model) : "选择模型..."} + + + + + + + 未找到模型 + + {models.map((model) => ( + { + setSelectedModelId(model.id); + setModelOpen(false); + }} + className="cursor-pointer" + > +
+ {model.name || model.model} + {model.provider} +
+ +
+ ))} +
+
+
+
+
+
{messages.length <= 1 ? (
diff --git a/frontend/src/pages/ModelConfigs.tsx b/frontend/src/pages/ModelConfigs.tsx index 267e92b..e91e211 100644 --- a/frontend/src/pages/ModelConfigs.tsx +++ b/frontend/src/pages/ModelConfigs.tsx @@ -104,6 +104,45 @@ export function ModelConfigs() { setDialogOpen(true); }; + const [isTesting, setIsTesting] = useState(false); + + const handleTestConnection = async () => { + if (!form.model || !form.provider || !form.api_base) { + setError("请先填写必要信息(供应商、模型ID、API域名)"); + return; + } + setIsTesting(true); + setError(""); + try { + let extraHeaders: Record = {}; + if (extraConfigText.trim()) { + try { + const parsed = JSON.parse(extraConfigText); + if (parsed && typeof parsed === "object") extraHeaders = parsed; + } catch (err) { + setError("额外配置必须是有效的JSON"); + setIsTesting(false); + return; + } + } + + const payload = { + provider: form.provider, + model: form.model, + api_key: form.api_key, + api_base: form.api_base, + extra_headers: extraHeaders + }; + + await api.post("/api/v1/llm/test", payload); + alert("连接测试成功!"); + } catch (e: any) { + setError(e.message || "连接测试失败"); + } finally { + setIsTesting(false); + } + }; + const handleSave = async (e?: React.FormEvent) => { if (e) e.preventDefault(); if (!form.model || !form.provider || !form.api_base) { @@ -285,12 +324,25 @@ export function ModelConfigs() {
@@ -298,26 +350,23 @@ export function ModelConfigs() {
- + setForm((p) => ({ ...p, model: e.target.value }))} placeholder="如:gpt-4-turbo" required />
- - setForm((p) => ({ ...p, base_model: e.target.value }))} placeholder="可选" /> + +
-
- - -