update model card
This commit is contained in:
@@ -6,6 +6,7 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from jose import jwt, JWTError
|
||||
from pydantic import BaseModel, Field
|
||||
from app.core.security import SECRET_KEY, ALGORITHM
|
||||
from litellm import completion
|
||||
|
||||
router = APIRouter()
|
||||
security = HTTPBearer()
|
||||
@@ -56,6 +57,13 @@ class LLMConfigUpdate(BaseModel):
|
||||
extra_headers: Optional[Dict[str, str]] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
class TestConnectionRequest(BaseModel):
|
||||
provider: str
|
||||
model: str
|
||||
api_key: Optional[str] = None
|
||||
api_base: Optional[str] = None
|
||||
extra_headers: Optional[Dict[str, str]] = None
|
||||
|
||||
|
||||
def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)) -> CurrentUser:
|
||||
unauthorized = HTTPException(
|
||||
@@ -152,3 +160,54 @@ def delete_llm_config(config_id: str, _: CurrentUser = Depends(get_admin_user)):
|
||||
raise HTTPException(status_code=404, detail="LLM configuration not found")
|
||||
_save_data(data)
|
||||
return {"message": "LLM configuration deleted successfully"}
|
||||
|
||||
@router.post("/llm/test")
|
||||
def test_connection(request: TestConnectionRequest, _: CurrentUser = Depends(get_admin_user)):
|
||||
try:
|
||||
# Use litellm to test connection
|
||||
# litellm handles many providers
|
||||
kwargs = {
|
||||
"model": request.model,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"max_tokens": 5
|
||||
}
|
||||
|
||||
if request.api_key:
|
||||
kwargs["api_key"] = request.api_key
|
||||
|
||||
if request.api_base:
|
||||
kwargs["api_base"] = request.api_base
|
||||
|
||||
if request.extra_headers:
|
||||
kwargs["extra_headers"] = request.extra_headers
|
||||
|
||||
# For OpenAI-compatible endpoints that are not standard OpenAI (like Local, vLLM etc)
|
||||
# usually user sets provider to "openai" and api_base to their custom URL.
|
||||
# litellm usually works well if we pass custom_llm_provider="openai" if provider is openai but custom url
|
||||
|
||||
# If provider is "local" or "openai", we generally use "openai" format
|
||||
if request.provider == "local":
|
||||
kwargs["custom_llm_provider"] = "openai"
|
||||
elif request.provider:
|
||||
kwargs["custom_llm_provider"] = request.provider
|
||||
|
||||
# If user explicitly selected provider in UI, we might want to respect that
|
||||
# But litellm completion main arg is 'model'.
|
||||
# If the UI 'model' input doesn't have prefix, we might need to add it or pass custom_llm_provider.
|
||||
|
||||
# Simple heuristic: if provider is set, try to pass it if litellm supports it or just rely on env vars/args
|
||||
# For this simple test, we just try to call it.
|
||||
|
||||
try:
|
||||
response = completion(**kwargs)
|
||||
except Exception as first_error:
|
||||
error_text = str(first_error)
|
||||
if request.provider and "Provider NOT provided" in error_text and "/" not in request.model:
|
||||
retry_kwargs = kwargs.copy()
|
||||
retry_kwargs["model"] = f"{request.provider}/{request.model}"
|
||||
response = completion(**retry_kwargs)
|
||||
else:
|
||||
raise first_error
|
||||
return {"success": True, "message": "Connection successful", "details": str(response)}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Connection failed: {str(e)}")
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
[
|
||||
{
|
||||
"id": "m1773487590",
|
||||
"provider": "zhipuai",
|
||||
"provider": "openai",
|
||||
"model": "glm-4-7-251222",
|
||||
"api_key": "secret",
|
||||
"api_key": "4a54896c-dac5-4aa6-b618-558dfbd89e4a",
|
||||
"api_base": "https://ark.cn-beijing.volces.com/api/v3",
|
||||
"extra_headers": null,
|
||||
"is_active": true
|
||||
"extra_headers": {},
|
||||
"is_active": true,
|
||||
"name": "glm-4-7-251222",
|
||||
"model_type": "\u5927\u8bed\u8a00\u6a21\u578b",
|
||||
"base_model": "glm-4-7-251222",
|
||||
"protocol_type": "OpenAI"
|
||||
},
|
||||
{
|
||||
"id": "deny1",
|
||||
|
||||
@@ -2,9 +2,12 @@ import { useState, useRef, useEffect } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||
import { User, Loader2, Sparkles, Search, ArrowUp, ChevronDown, Table, Paperclip } from "lucide-react";
|
||||
import { User, Loader2, Sparkles, Search, ArrowUp, ChevronDown, Table, Paperclip, Check } from "lucide-react";
|
||||
import { api } from "@/lib/api";
|
||||
import { useVisualizationStore } from "@/store/visualizationStore";
|
||||
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
|
||||
import { Command, CommandEmpty, CommandGroup, CommandInput, CommandItem, CommandList } from "@/components/ui/command";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface Message {
|
||||
id: string;
|
||||
@@ -12,6 +15,14 @@ interface Message {
|
||||
content: string;
|
||||
}
|
||||
|
||||
interface ModelConfig {
|
||||
id: string;
|
||||
name?: string;
|
||||
model: string;
|
||||
provider: string;
|
||||
is_active: boolean;
|
||||
}
|
||||
|
||||
export function ChatInterface() {
|
||||
const [messages, setMessages] = useState<Message[]>([
|
||||
{ id: '1', role: 'assistant', content: 'Hello! I am DataClaw. How can I help you analyze your data today?' }
|
||||
@@ -23,6 +34,33 @@ export function ChatInterface() {
|
||||
const scrollRef = useRef<HTMLDivElement>(null);
|
||||
const { setVisualization, setLoading: setVizLoading, setError: setVizError } = useVisualizationStore();
|
||||
|
||||
// Model selection state
|
||||
const [models, setModels] = useState<ModelConfig[]>([]);
|
||||
const [selectedModelId, setSelectedModelId] = useState<string>("");
|
||||
const [modelOpen, setModelOpen] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
fetchModels();
|
||||
}, []);
|
||||
|
||||
const fetchModels = async () => {
|
||||
try {
|
||||
const data = await api.get<ModelConfig[]>("/api/v1/llm");
|
||||
setModels(data);
|
||||
// Set default model if available
|
||||
const active = data.find(m => m.is_active);
|
||||
if (active) {
|
||||
setSelectedModelId(active.id);
|
||||
} else if (data.length > 0) {
|
||||
setSelectedModelId(data[0].id);
|
||||
}
|
||||
} catch (e) {
|
||||
console.error("Failed to fetch models", e);
|
||||
}
|
||||
};
|
||||
|
||||
const currentModel = models.find(m => m.id === selectedModelId);
|
||||
|
||||
const capabilities = [
|
||||
{ icon: Sparkles, label: "智能问答", color: "text-purple-500", bg: "bg-purple-50" },
|
||||
{ icon: Table, label: "表格问答", color: "text-orange-500", bg: "bg-orange-50" },
|
||||
@@ -51,7 +89,8 @@ export function ChatInterface() {
|
||||
const source = selectedDataSource.split('-')[0]; // postgres-main -> postgres
|
||||
const response = await api.post<{sql?: string, result?: unknown, error?: string}>('/api/v1/agent/nl2sql', {
|
||||
query: newMessage.content,
|
||||
source: source
|
||||
source: source,
|
||||
model_id: selectedModelId // Pass selected model ID if backend supports it
|
||||
});
|
||||
|
||||
if (response.error) {
|
||||
@@ -76,7 +115,8 @@ export function ChatInterface() {
|
||||
// General Chat
|
||||
const response = await api.post<{response: string}>('/nanobot/chat', {
|
||||
message: newMessage.content,
|
||||
skill_ids: [selectedSkill]
|
||||
skill_ids: [selectedSkill],
|
||||
model_id: selectedModelId
|
||||
});
|
||||
|
||||
setMessages(prev => [...prev, {
|
||||
@@ -101,14 +141,49 @@ export function ChatInterface() {
|
||||
return (
|
||||
<div className="h-full bg-white relative flex flex-col">
|
||||
{/* Top Bar */}
|
||||
<div className="absolute top-0 left-0 w-full px-6 py-4 z-10">
|
||||
<button className="flex items-center gap-2 text-sm font-medium text-zinc-600 hover:bg-zinc-100 px-3 py-1.5 rounded-lg transition-colors">
|
||||
glm-4-7-251222
|
||||
<ChevronDown className="h-4 w-4 text-zinc-400" />
|
||||
</button>
|
||||
<div className="absolute top-0 left-0 w-full px-6 py-4 z-10 flex justify-between items-center">
|
||||
<Popover open={modelOpen} onOpenChange={setModelOpen}>
|
||||
<PopoverTrigger className="w-[200px] flex justify-between items-center bg-white/80 backdrop-blur-sm border border-zinc-200 rounded-md px-3 py-2 text-sm hover:bg-zinc-50 hover:text-zinc-900 text-zinc-700 font-medium shadow-sm transition-all">
|
||||
{currentModel ? (currentModel.name || currentModel.model) : "选择模型..."}
|
||||
<ChevronDown className="ml-2 h-4 w-4 shrink-0 opacity-50" />
|
||||
</PopoverTrigger>
|
||||
<PopoverContent className="w-[240px] p-0" align="start">
|
||||
<Command>
|
||||
<CommandInput placeholder="搜索模型..." className="h-9" />
|
||||
<CommandList>
|
||||
<CommandEmpty>未找到模型</CommandEmpty>
|
||||
<CommandGroup heading="可用模型">
|
||||
{models.map((model) => (
|
||||
<CommandItem
|
||||
key={model.id}
|
||||
value={model.name || model.model}
|
||||
onSelect={() => {
|
||||
setSelectedModelId(model.id);
|
||||
setModelOpen(false);
|
||||
}}
|
||||
className="cursor-pointer"
|
||||
>
|
||||
<div className="flex flex-col">
|
||||
<span className="font-medium">{model.name || model.model}</span>
|
||||
<span className="text-xs text-zinc-400">{model.provider}</span>
|
||||
</div>
|
||||
<Check
|
||||
className={cn(
|
||||
"ml-auto h-4 w-4",
|
||||
selectedModelId === model.id ? "opacity-100" : "opacity-0"
|
||||
)}
|
||||
/>
|
||||
</CommandItem>
|
||||
))}
|
||||
</CommandGroup>
|
||||
</CommandList>
|
||||
</Command>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
</div>
|
||||
|
||||
<ScrollArea className="flex-1">
|
||||
|
||||
<div className="min-h-full">
|
||||
{messages.length <= 1 ? (
|
||||
<div className="h-full flex flex-col items-center justify-center pt-[20vh] px-4 pb-32">
|
||||
|
||||
@@ -104,6 +104,45 @@ export function ModelConfigs() {
|
||||
setDialogOpen(true);
|
||||
};
|
||||
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
|
||||
const handleTestConnection = async () => {
|
||||
if (!form.model || !form.provider || !form.api_base) {
|
||||
setError("请先填写必要信息(供应商、模型ID、API域名)");
|
||||
return;
|
||||
}
|
||||
setIsTesting(true);
|
||||
setError("");
|
||||
try {
|
||||
let extraHeaders: Record<string, string> = {};
|
||||
if (extraConfigText.trim()) {
|
||||
try {
|
||||
const parsed = JSON.parse(extraConfigText);
|
||||
if (parsed && typeof parsed === "object") extraHeaders = parsed;
|
||||
} catch (err) {
|
||||
setError("额外配置必须是有效的JSON");
|
||||
setIsTesting(false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const payload = {
|
||||
provider: form.provider,
|
||||
model: form.model,
|
||||
api_key: form.api_key,
|
||||
api_base: form.api_base,
|
||||
extra_headers: extraHeaders
|
||||
};
|
||||
|
||||
await api.post("/api/v1/llm/test", payload);
|
||||
alert("连接测试成功!");
|
||||
} catch (e: any) {
|
||||
setError(e.message || "连接测试失败");
|
||||
} finally {
|
||||
setIsTesting(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSave = async (e?: React.FormEvent) => {
|
||||
if (e) e.preventDefault();
|
||||
if (!form.model || !form.provider || !form.api_base) {
|
||||
@@ -285,12 +324,25 @@ export function ModelConfigs() {
|
||||
<Label>供应商 *</Label>
|
||||
<Select value={form.provider} onValueChange={(v) => setForm((p) => ({ ...p, provider: v || "openai" }))}>
|
||||
<SelectTrigger><SelectValue /></SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectContent className="max-h-[300px]">
|
||||
<SelectItem value="openai">OpenAI</SelectItem>
|
||||
<SelectItem value="zhipuai">ZhipuAI</SelectItem>
|
||||
<SelectItem value="anthropic">Anthropic</SelectItem>
|
||||
<SelectItem value="azure">Azure OpenAI</SelectItem>
|
||||
<SelectItem value="local">Local</SelectItem>
|
||||
<SelectItem value="anthropic">Anthropic</SelectItem>
|
||||
<SelectItem value="vertex_ai">Google Vertex AI</SelectItem>
|
||||
<SelectItem value="gemini">Google AI Studio (Gemini)</SelectItem>
|
||||
<SelectItem value="bedrock">AWS Bedrock</SelectItem>
|
||||
<SelectItem value="deepseek">DeepSeek</SelectItem>
|
||||
<SelectItem value="zhipuai">ZhipuAI (智谱)</SelectItem>
|
||||
<SelectItem value="moonshot">Moonshot (Kimi)</SelectItem>
|
||||
<SelectItem value="dashscope">DashScope (通义千问)</SelectItem>
|
||||
<SelectItem value="volcengine">Volcengine (火山引擎)</SelectItem>
|
||||
<SelectItem value="groq">Groq</SelectItem>
|
||||
<SelectItem value="cohere">Cohere</SelectItem>
|
||||
<SelectItem value="mistral">Mistral</SelectItem>
|
||||
<SelectItem value="openrouter">OpenRouter</SelectItem>
|
||||
<SelectItem value="ollama">Ollama</SelectItem>
|
||||
<SelectItem value="huggingface">HuggingFace</SelectItem>
|
||||
<SelectItem value="local">Local (OpenAI Compatible)</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
@@ -298,26 +350,23 @@ export function ModelConfigs() {
|
||||
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div className="space-y-2">
|
||||
<Label>模型标识 *</Label>
|
||||
<Label>模型ID *</Label>
|
||||
<Input value={form.model || ""} onChange={(e) => setForm((p) => ({ ...p, model: e.target.value }))} placeholder="如:gpt-4-turbo" required />
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<Label>基础模型</Label>
|
||||
<Input value={form.base_model || ""} onChange={(e) => setForm((p) => ({ ...p, base_model: e.target.value }))} placeholder="可选" />
|
||||
<Label>模型类型</Label>
|
||||
<Select value={form.model_type || "LLM"} onValueChange={(v) => setForm((p) => ({ ...p, model_type: v || "LLM" }))}>
|
||||
<SelectTrigger><SelectValue /></SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="LLM">LLM</SelectItem>
|
||||
<SelectItem value="Embedding">Embedding</SelectItem>
|
||||
<SelectItem value="Rerank">Rerank</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div className="space-y-2">
|
||||
<Label>模型类型</Label>
|
||||
<Select value={form.model_type || "大语言模型"} onValueChange={(v) => setForm((p) => ({ ...p, model_type: v || "大语言模型" }))}>
|
||||
<SelectTrigger><SelectValue /></SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="大语言模型">大语言模型</SelectItem>
|
||||
<SelectItem value="多模态模型">多模态模型</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<Label>协议类型</Label>
|
||||
<Select value={form.protocol_type || "OpenAI"} onValueChange={(v) => setForm((p) => ({ ...p, protocol_type: v || "OpenAI" }))}>
|
||||
@@ -360,12 +409,18 @@ export function ModelConfigs() {
|
||||
<Textarea value={extraConfigText} onChange={(e) => setExtraConfigText(e.target.value)} className="min-h-[80px] font-mono text-xs" placeholder='{"timeout": "60"}' />
|
||||
</div>
|
||||
</div>
|
||||
<DialogFooter>
|
||||
<Button type="button" variant="outline" onClick={() => setDialogOpen(false)}>取消</Button>
|
||||
<Button type="submit" disabled={isSaving} className="bg-indigo-600 hover:bg-indigo-700 text-white">
|
||||
{isSaving ? <Loader2 className="h-4 w-4 animate-spin mr-2" /> : null}
|
||||
保存
|
||||
<DialogFooter className="flex items-center justify-between gap-2">
|
||||
<Button type="button" variant="outline" onClick={handleTestConnection} disabled={isTesting}>
|
||||
{isTesting ? <Loader2 className="h-4 w-4 animate-spin mr-2" /> : null}
|
||||
测试连接
|
||||
</Button>
|
||||
<div className="flex items-center gap-2">
|
||||
<Button type="button" variant="outline" onClick={() => setDialogOpen(false)}>取消</Button>
|
||||
<Button type="submit" disabled={isSaving} className="bg-indigo-600 hover:bg-indigo-700 text-white">
|
||||
{isSaving ? <Loader2 className="h-4 w-4 animate-spin mr-2" /> : null}
|
||||
保存
|
||||
</Button>
|
||||
</div>
|
||||
</DialogFooter>
|
||||
</form>
|
||||
</DialogContent>
|
||||
|
||||
Reference in New Issue
Block a user