Update 2026-05-13 16:43:53
This commit is contained in:
@@ -0,0 +1,704 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import ssl
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.a2a import (
|
||||
A2AAuditLog,
|
||||
A2AProjectConfig,
|
||||
A2ARemoteAgent,
|
||||
A2ATask,
|
||||
A2ATaskEvent,
|
||||
A2ATaskWebhook,
|
||||
A2AWebhookDelivery,
|
||||
)
|
||||
from app.models.project import Project
|
||||
from app.schemas.a2a import (
|
||||
A2AArtifactSchema,
|
||||
A2APartSchema,
|
||||
A2ATaskStatusSchema,
|
||||
TaskArtifactUpdateEvent,
|
||||
TaskStatusUpdateEvent,
|
||||
)
|
||||
from app.trace import build_error_attributes, trace_service
|
||||
|
||||
|
||||
def _json_loads(raw: Optional[str], default: Any) -> Any:
|
||||
if not raw:
|
||||
return default
|
||||
try:
|
||||
return json.loads(raw)
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
|
||||
def _json_dumps(raw: Any) -> str:
|
||||
return json.dumps(raw, ensure_ascii=False)
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _mask_error(message: str) -> str:
|
||||
if not message:
|
||||
return "internal_error"
|
||||
return "request_failed"
|
||||
|
||||
|
||||
class SharedSecretAuth:
|
||||
@staticmethod
|
||||
def generate_signature(secret: str, payload: bytes, timestamp: Optional[int] = None) -> Tuple[str, int]:
|
||||
if timestamp is None:
|
||||
timestamp = int(time.time())
|
||||
message = f"{timestamp}".encode() + payload
|
||||
signature = hmac.new(secret.encode(), message, hashlib.sha256).hexdigest()
|
||||
return f"sha256={signature}", timestamp
|
||||
|
||||
@staticmethod
|
||||
def verify_signature(secret: str, payload: bytes, signature: str, timestamp: int, max_age_seconds: int = 300) -> bool:
|
||||
if abs(time.time() - timestamp) > max_age_seconds:
|
||||
return False
|
||||
expected_sig, _ = SharedSecretAuth.generate_signature(secret, payload, timestamp)
|
||||
return hmac.compare_digest(signature, expected_sig)
|
||||
|
||||
@staticmethod
|
||||
def sign_request(secret: str, method: str, path: str, body: Optional[bytes] = None) -> Dict[str, str]:
|
||||
timestamp = int(time.time())
|
||||
payload = body or b""
|
||||
message = f"{timestamp}.{method.upper()}.{path}".encode() + payload
|
||||
signature = hmac.new(secret.encode(), message, hashlib.sha256).hexdigest()
|
||||
return {
|
||||
"X-A2A-Signature": f"sha256={signature}",
|
||||
"X-A2A-Timestamp": str(timestamp),
|
||||
}
|
||||
|
||||
|
||||
class MtlsConfig:
|
||||
def __init__(
|
||||
self,
|
||||
ca_cert: Optional[str] = None,
|
||||
client_cert: Optional[str] = None,
|
||||
client_key: Optional[str] = None,
|
||||
):
|
||||
self.ca_cert = ca_cert
|
||||
self.client_cert = client_cert
|
||||
self.client_key = client_key
|
||||
|
||||
def create_ssl_context(self) -> Optional[ssl.SSLContext]:
|
||||
if not self.client_cert or not self.client_key:
|
||||
return None
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
ctx.load_cert_chain(self.client_cert, self.client_key)
|
||||
if self.ca_cert:
|
||||
ctx.load_verify_locations(self.ca_cert)
|
||||
ctx.verify_mode = ssl.CERT_REQUIRED
|
||||
else:
|
||||
ctx.verify_mode = ssl.CERT_NONE
|
||||
return ctx
|
||||
|
||||
|
||||
class OAuth2TokenStore:
|
||||
def __init__(self):
|
||||
self._tokens: Dict[str, Tuple[str, datetime]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get_token(self, key: str) -> Optional[str]:
|
||||
async with self._lock:
|
||||
if key in self._tokens:
|
||||
token, expires_at = self._tokens[key]
|
||||
if expires_at > _utc_now() + timedelta(minutes=1):
|
||||
return token
|
||||
return None
|
||||
|
||||
async def set_token(self, key: str, token: str, expires_in: int = 3600) -> None:
|
||||
async with self._lock:
|
||||
self._tokens[key] = (token, _utc_now() + timedelta(seconds=expires_in))
|
||||
|
||||
|
||||
class OAuth2Auth:
|
||||
def __init__(
|
||||
self,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
token_url: str,
|
||||
scopes: Optional[List[str]] = None,
|
||||
):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.token_url = token_url
|
||||
self.scopes = scopes or []
|
||||
self._token_store = OAuth2TokenStore()
|
||||
|
||||
def _get_cache_key(self) -> str:
|
||||
return f"{self.client_id}:{self.token_url}:{':'.join(self.scopes)}"
|
||||
|
||||
async def get_access_token(self, grant_type: str = "client_credentials") -> str:
|
||||
cache_key = self._get_cache_key()
|
||||
cached = await self._token_store.get_token(cache_key)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"grant_type": grant_type,
|
||||
}
|
||||
if self.scopes:
|
||||
data["scope"] = " ".join(self.scopes)
|
||||
resp = await client.post(self.token_url, data=data)
|
||||
resp.raise_for_status()
|
||||
token_data = resp.json()
|
||||
token = token_data["access_token"]
|
||||
expires_in = token_data.get("expires_in", 3600)
|
||||
await self._token_store.set_token(cache_key, token, expires_in)
|
||||
return token
|
||||
|
||||
async def authorize_request(self, method: str, url: str, **kwargs) -> Dict[str, str]:
|
||||
token = await self.get_access_token()
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
class OIDCAuth:
|
||||
def __init__(
|
||||
self,
|
||||
issuer_url: str,
|
||||
client_id: str,
|
||||
client_secret: Optional[str] = None,
|
||||
scopes: Optional[List[str]] = None,
|
||||
):
|
||||
self.issuer_url = issuer_url.rstrip("/")
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.scopes = scopes or ["openid", "profile"]
|
||||
self._oauth2: Optional[OAuth2Auth] = None
|
||||
self._discovery_cache: Optional[Dict[str, Any]] = None
|
||||
|
||||
async def _get_discovery(self) -> Dict[str, Any]:
|
||||
if self._discovery_cache:
|
||||
return self._discovery_cache
|
||||
discovery_url = f"{self.issuer_url}/.well-known/openid-configuration"
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(discovery_url)
|
||||
resp.raise_for_status()
|
||||
self._discovery_cache = resp.json()
|
||||
return self._discovery_cache
|
||||
|
||||
async def get_access_token(self) -> str:
|
||||
discovery = await self._get_discovery()
|
||||
token_url = discovery.get("token_endpoint")
|
||||
if not token_url:
|
||||
raise RuntimeError("OIDC discovery missing token_endpoint")
|
||||
if not self._oauth2:
|
||||
self._oauth2 = OAuth2Auth(
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret or "",
|
||||
token_url=token_url,
|
||||
scopes=self.scopes,
|
||||
)
|
||||
return await self._oauth2.get_access_token()
|
||||
|
||||
async def authorize_request(self, method: str, url: str, **kwargs) -> Dict[str, str]:
|
||||
token = await self.get_access_token()
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
class RemoteAgentSecuritySelector:
|
||||
def __init__(self, agent: A2ARemoteAgent):
|
||||
self.agent = agent
|
||||
self._card_security_schemes: Optional[Dict[str, Any]] = None
|
||||
|
||||
def load_security_from_card(self) -> None:
|
||||
card = _json_loads(self.agent.card_json, {})
|
||||
if card:
|
||||
self._card_security_schemes = card.get("securitySchemes", {})
|
||||
|
||||
def get_preferred_auth_scheme(self) -> str:
|
||||
card = _json_loads(self.agent.card_json, {})
|
||||
security_reqs = card.get("security", [])
|
||||
if security_reqs:
|
||||
first_req = security_reqs[0]
|
||||
if isinstance(first_req, dict):
|
||||
for scheme_name in first_req.keys():
|
||||
return scheme_name
|
||||
return self.agent.auth_scheme or "bearer"
|
||||
|
||||
def get_auth_headers(self, user_token: Optional[str] = None) -> Dict[str, str]:
|
||||
headers: Dict[str, str] = {}
|
||||
preferred = self.get_preferred_auth_scheme()
|
||||
|
||||
if preferred == "bearer" or self.agent.auth_scheme == "bearer":
|
||||
if self.agent.auth_token:
|
||||
headers["Authorization"] = f"Bearer {self.agent.auth_token}"
|
||||
elif user_token:
|
||||
headers["Authorization"] = f"Bearer {user_token}"
|
||||
|
||||
elif preferred == "shared_secret" or self.agent.auth_scheme == "shared_secret":
|
||||
pass
|
||||
|
||||
elif preferred in ("oauth2", "oauth2_authorizationcode", "oauth2_clientcredentials"):
|
||||
pass
|
||||
|
||||
elif preferred == "openIdConnect":
|
||||
pass
|
||||
|
||||
elif preferred == "mutualTLS":
|
||||
pass
|
||||
|
||||
return headers
|
||||
|
||||
def get_mtls_context(self) -> Optional[ssl.SSLContext]:
|
||||
if self.agent.auth_scheme == "mutualTLS" or self.get_preferred_auth_scheme() == "mutualTLS":
|
||||
if self.agent.mtls_client_cert and self.agent.mtls_client_key:
|
||||
config = MtlsConfig(
|
||||
ca_cert=self.agent.mtls_ca_cert,
|
||||
client_cert=self.agent.mtls_client_cert,
|
||||
client_key=self.agent.mtls_client_key,
|
||||
)
|
||||
return config.create_ssl_context()
|
||||
return None
|
||||
|
||||
def create_signed_request_headers(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
body: Optional[bytes] = None,
|
||||
) -> Dict[str, str]:
|
||||
headers: Dict[str, str] = {}
|
||||
preferred = self.get_preferred_auth_scheme()
|
||||
|
||||
if preferred == "shared_secret" and self.agent.shared_secret:
|
||||
sig_headers = SharedSecretAuth.sign_request(
|
||||
self.agent.shared_secret,
|
||||
method,
|
||||
path,
|
||||
body,
|
||||
)
|
||||
headers.update(sig_headers)
|
||||
|
||||
elif self.agent.auth_scheme == "bearer" and self.agent.auth_token:
|
||||
headers["Authorization"] = f"Bearer {self.agent.auth_token}"
|
||||
|
||||
return headers
|
||||
|
||||
async def get_oauth2_auth(self) -> Optional[OAuth2Auth]:
|
||||
if self.agent.oauth2_client_id and self.agent.oauth2_token_url:
|
||||
scopes = self.agent.oauth2_scopes.split() if self.agent.oauth2_scopes else []
|
||||
return OAuth2Auth(
|
||||
client_id=self.agent.oauth2_client_id,
|
||||
client_secret=self.agent.oauth2_client_secret or "",
|
||||
token_url=self.agent.oauth2_token_url,
|
||||
scopes=scopes,
|
||||
)
|
||||
return None
|
||||
|
||||
async def get_oidc_auth(self) -> Optional[OIDCAuth]:
|
||||
if self.agent.oidc_issuer_url:
|
||||
return OIDCAuth(
|
||||
issuer_url=self.agent.oidc_issuer_url,
|
||||
client_id=self.agent.oidc_client_id or "",
|
||||
client_secret=self.agent.oidc_client_secret,
|
||||
scopes=self.agent.oauth2_scopes.split() if self.agent.oauth2_scopes else [],
|
||||
)
|
||||
return None
|
||||
|
||||
async def authorize_request(self, method: str, url: str, user_token: Optional[str] = None) -> Dict[str, str]:
|
||||
headers = self.get_auth_headers(user_token)
|
||||
preferred = self.get_preferred_auth_scheme()
|
||||
|
||||
if preferred in ("oauth2", "oauth2_authorizationcode", "oauth2_clientcredentials"):
|
||||
oauth2_auth = await self.get_oauth2_auth()
|
||||
if oauth2_auth:
|
||||
headers.update(await oauth2_auth.authorize_request(method, url))
|
||||
|
||||
elif preferred == "openIdConnect":
|
||||
oidc_auth = await self.get_oidc_auth()
|
||||
if oidc_auth:
|
||||
headers.update(await oidc_auth.authorize_request(method, url))
|
||||
|
||||
return headers
|
||||
|
||||
_STATE_TRANSITIONS = {
|
||||
"SUBMITTED": {"WORKING", "FAILED", "CANCELED", "REJECTED", "AUTH_REQUIRED", "INPUT_REQUIRED", "COMPLETED"},
|
||||
"WORKING": {"COMPLETED", "FAILED", "CANCELED", "INPUT_REQUIRED", "AUTH_REQUIRED"},
|
||||
"INPUT_REQUIRED": {"WORKING", "FAILED", "CANCELED"},
|
||||
"AUTH_REQUIRED": {"WORKING", "FAILED", "CANCELED", "REJECTED"},
|
||||
"REJECTED": set(),
|
||||
"FAILED": set(),
|
||||
"COMPLETED": set(),
|
||||
"CANCELED": set(),
|
||||
}
|
||||
_TERMINAL_STATES = {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class A2AResolvedRoute:
|
||||
selected: str
|
||||
fallback_chain: List[str]
|
||||
canary_hit: bool
|
||||
reason: str
|
||||
|
||||
|
||||
class A2AMetrics:
|
||||
def __init__(self) -> None:
|
||||
self._lock = asyncio.Lock()
|
||||
self._counters: Dict[str, int] = defaultdict(int)
|
||||
self._latency_ms: Dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=2000))
|
||||
|
||||
async def incr(self, key: str, value: int = 1) -> None:
|
||||
async with self._lock:
|
||||
self._counters[key] += value
|
||||
|
||||
async def observe_latency(self, key: str, elapsed_ms: float) -> None:
|
||||
async with self._lock:
|
||||
self._latency_ms[key].append(float(elapsed_ms))
|
||||
|
||||
async def snapshot(self) -> Dict[str, Any]:
|
||||
async with self._lock:
|
||||
counters = dict(self._counters)
|
||||
p95 = {}
|
||||
for key, values in self._latency_ms.items():
|
||||
series = sorted(values)
|
||||
if not series:
|
||||
p95[f"{key}.p95_ms"] = 0.0
|
||||
continue
|
||||
idx = int(0.95 * (len(series) - 1))
|
||||
p95[f"{key}.p95_ms"] = round(series[idx], 2)
|
||||
total = counters.get("a2a.requests.total", 0)
|
||||
errors = counters.get("a2a.requests.error", 0)
|
||||
retries = counters.get("a2a.requests.retry", 0)
|
||||
breakers = counters.get("a2a.circuit.open", 0)
|
||||
return {
|
||||
"counters": counters,
|
||||
"derived": {
|
||||
"error_rate": round(errors / total, 4) if total else 0.0,
|
||||
"retry_rate": round(retries / total, 4) if total else 0.0,
|
||||
"circuit_open_rate": round(breakers / total, 4) if total else 0.0,
|
||||
},
|
||||
"latency": p95,
|
||||
}
|
||||
|
||||
|
||||
class A2ARuntime:
|
||||
def __init__(self) -> None:
|
||||
self._subscribers: Dict[str, List[asyncio.Queue[Dict[str, Any]]]] = defaultdict(list)
|
||||
self.metrics = A2AMetrics()
|
||||
self.protocol_version = "1.0"
|
||||
self._circuit_state: Dict[int, datetime] = {}
|
||||
|
||||
async def publish(self, task_id: str, event: Dict[str, Any]) -> None:
|
||||
queues = list(self._subscribers.get(task_id, []))
|
||||
for queue in queues:
|
||||
await queue.put(event)
|
||||
|
||||
async def subscribe(self, task_id: str) -> AsyncIterator[Dict[str, Any]]:
|
||||
queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=200)
|
||||
self._subscribers[task_id].append(queue)
|
||||
try:
|
||||
while True:
|
||||
payload = await queue.get()
|
||||
yield payload
|
||||
finally:
|
||||
self._subscribers[task_id] = [q for q in self._subscribers.get(task_id, []) if q is not queue]
|
||||
if not self._subscribers[task_id]:
|
||||
self._subscribers.pop(task_id, None)
|
||||
|
||||
def get_project_config(self, db: Session, project_id: int, user_id: int) -> A2AProjectConfig:
|
||||
item = db.query(A2AProjectConfig).filter(A2AProjectConfig.project_id == project_id).first()
|
||||
if item:
|
||||
return item
|
||||
config = A2AProjectConfig(project_id=project_id, updated_by=user_id)
|
||||
db.add(config)
|
||||
db.commit()
|
||||
db.refresh(config)
|
||||
return config
|
||||
|
||||
def resolve_route(self, *, project_config: A2AProjectConfig, session_id: str, requested_mode: str, requested_fallback: Optional[List[str]]) -> A2AResolvedRoute:
|
||||
selected = requested_mode or project_config.route_mode_default or "local_first"
|
||||
fallback = requested_fallback or _json_loads(project_config.fallback_chain_json, ["local"])
|
||||
fallback_chain = [item for item in fallback if item in {"a2a", "local", "mcp"}]
|
||||
if not fallback_chain:
|
||||
fallback_chain = ["local"]
|
||||
canary_hit = False
|
||||
if project_config.canary_enabled and project_config.canary_percent > 0:
|
||||
digest = hashlib.sha256(f"{project_config.project_id}:{session_id}".encode()).hexdigest()
|
||||
bucket = int(digest[:8], 16) % 100
|
||||
canary_hit = bucket < project_config.canary_percent
|
||||
if selected in {"a2a_first", "a2a"} and not canary_hit:
|
||||
return A2AResolvedRoute(
|
||||
selected="local",
|
||||
fallback_chain=fallback_chain,
|
||||
canary_hit=False,
|
||||
reason="canary_not_hit_fallback_local",
|
||||
)
|
||||
if selected in {"a2a_first", "a2a"}:
|
||||
return A2AResolvedRoute(selected="a2a", fallback_chain=fallback_chain, canary_hit=canary_hit, reason="a2a_selected")
|
||||
if selected in {"mcp_first", "mcp"}:
|
||||
return A2AResolvedRoute(selected="mcp", fallback_chain=fallback_chain, canary_hit=canary_hit, reason="mcp_selected")
|
||||
return A2AResolvedRoute(selected="local", fallback_chain=fallback_chain, canary_hit=canary_hit, reason="local_selected")
|
||||
|
||||
def can_transition(self, from_state: str, to_state: str) -> bool:
|
||||
if from_state == to_state:
|
||||
return True
|
||||
return to_state in _STATE_TRANSITIONS.get(from_state, set())
|
||||
|
||||
def create_task(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
project_id: int,
|
||||
tenant_id: int,
|
||||
source: str,
|
||||
input_text: str,
|
||||
idempotency_key: Optional[str],
|
||||
remote_agent_id: Optional[int],
|
||||
compatibility_mode: bool,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
context_id: Optional[str] = None,
|
||||
) -> A2ATask:
|
||||
if idempotency_key:
|
||||
existing = (
|
||||
db.query(A2ATask)
|
||||
.filter(
|
||||
A2ATask.project_id == project_id,
|
||||
A2ATask.tenant_id == tenant_id,
|
||||
A2ATask.idempotency_key == idempotency_key,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
return existing
|
||||
task = A2ATask(
|
||||
id=f"task_{uuid.uuid4().hex}",
|
||||
project_id=project_id,
|
||||
tenant_id=tenant_id,
|
||||
source=source,
|
||||
remote_agent_id=remote_agent_id,
|
||||
state="SUBMITTED",
|
||||
input_text=input_text,
|
||||
idempotency_key=idempotency_key,
|
||||
compatibility_mode=compatibility_mode,
|
||||
metadata_json=_json_dumps(metadata or {}),
|
||||
context_id=context_id,
|
||||
)
|
||||
db.add(task)
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return task
|
||||
|
||||
def append_event(self, db: Session, task: A2ATask, event_type: str, payload: Dict[str, Any]) -> A2ATaskEvent:
|
||||
event = A2ATaskEvent(task_id=task.id, event_type=event_type, payload_json=_json_dumps(payload))
|
||||
db.add(event)
|
||||
db.commit()
|
||||
db.refresh(event)
|
||||
return event
|
||||
|
||||
def transition_task(
|
||||
self,
|
||||
db: Session,
|
||||
task: A2ATask,
|
||||
*,
|
||||
to_state: str,
|
||||
output_text: Optional[str] = None,
|
||||
error_message: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> A2ATask:
|
||||
if not self.can_transition(task.state, to_state):
|
||||
raise ValueError(f"Invalid task transition: {task.state} -> {to_state}")
|
||||
task.state = to_state
|
||||
if output_text is not None:
|
||||
task.output_text = output_text
|
||||
if error_message is not None:
|
||||
task.error_message = error_message
|
||||
if metadata is not None:
|
||||
task.metadata_json = _json_dumps(metadata)
|
||||
if to_state in _TERMINAL_STATES:
|
||||
task.finished_at = _utc_now()
|
||||
db.add(task)
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return task
|
||||
|
||||
def record_audit(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
actor_user_id: int,
|
||||
action: str,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
result: str,
|
||||
project_id: Optional[int] = None,
|
||||
task_id: Optional[str] = None,
|
||||
detail: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
audit = A2AAuditLog(
|
||||
actor_user_id=actor_user_id,
|
||||
action=action,
|
||||
target_type=target_type,
|
||||
target_id=target_id,
|
||||
result=result,
|
||||
project_id=project_id,
|
||||
task_id=task_id,
|
||||
detail_json=_json_dumps(detail or {}),
|
||||
)
|
||||
db.add(audit)
|
||||
db.commit()
|
||||
|
||||
async def fetch_agent_card(self, db: Session, agent: A2ARemoteAgent, *, timeout_s: float = 10.0) -> Dict[str, Any]:
|
||||
if agent.id in self._circuit_state and self._circuit_state[agent.id] > _utc_now():
|
||||
raise RuntimeError("circuit_open")
|
||||
started = time.perf_counter()
|
||||
await self.metrics.incr("a2a.requests.total")
|
||||
headers = {}
|
||||
if agent.auth_scheme == "bearer" and agent.auth_token:
|
||||
headers["Authorization"] = f"Bearer {agent.auth_token}"
|
||||
url = f"{agent.base_url.rstrip('/')}/api/v1/a2a/agent-card"
|
||||
with trace_service.start_span("a2a.card.fetch", attributes={"agent_id": agent.id, "url": url}) as span:
|
||||
for attempt in range(3):
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout_s, verify=True) as client:
|
||||
resp = await client.get(url, headers=headers)
|
||||
if resp.status_code >= 400:
|
||||
raise RuntimeError(f"http_{resp.status_code}")
|
||||
payload = resp.json()
|
||||
elapsed_ms = (time.perf_counter() - started) * 1000
|
||||
await self.metrics.observe_latency("a2a.card.fetch", elapsed_ms)
|
||||
agent.card_json = _json_dumps(payload)
|
||||
agent.protocol_version = str(payload.get("protocol_version") or "")
|
||||
agent.capabilities_json = _json_dumps(payload.get("capabilities") or [])
|
||||
agent.card_fetched_at = _utc_now()
|
||||
agent.healthy = True
|
||||
agent.failure_count = 0
|
||||
agent.circuit_open_until = None
|
||||
db.add(agent)
|
||||
db.commit()
|
||||
db.refresh(agent)
|
||||
return payload
|
||||
except Exception as exc:
|
||||
span.set_attributes(build_error_attributes(exc, stage="a2a_card_fetch"))
|
||||
await self.metrics.incr("a2a.requests.error")
|
||||
if attempt < 2:
|
||||
await self.metrics.incr("a2a.requests.retry")
|
||||
await asyncio.sleep(0.2 * (2 ** attempt))
|
||||
continue
|
||||
agent.failure_count = (agent.failure_count or 0) + 1
|
||||
if agent.failure_count >= 3:
|
||||
reopen_at = _utc_now() + timedelta(seconds=90)
|
||||
agent.circuit_open_until = reopen_at
|
||||
self._circuit_state[agent.id] = reopen_at
|
||||
await self.metrics.incr("a2a.circuit.open")
|
||||
agent.healthy = False
|
||||
db.add(agent)
|
||||
db.commit()
|
||||
raise
|
||||
|
||||
async def notify_webhooks(self, db: Session, task: A2ATask, event: A2ATaskEvent) -> None:
|
||||
webhooks = db.query(A2ATaskWebhook).filter(A2ATaskWebhook.task_id == task.id, A2ATaskWebhook.enabled == True).all()
|
||||
if not webhooks:
|
||||
return
|
||||
for hook in webhooks:
|
||||
delivery = A2AWebhookDelivery(task_id=task.id, webhook_id=hook.id, event_id=event.id, attempt=0, status="PENDING")
|
||||
db.add(delivery)
|
||||
db.commit()
|
||||
db.refresh(delivery)
|
||||
await self._deliver_once(db, hook, event, delivery)
|
||||
|
||||
async def _deliver_once(self, db: Session, hook: A2ATaskWebhook, event: A2ATaskEvent, delivery: A2AWebhookDelivery) -> None:
|
||||
event_payload = _json_loads(event.payload_json, {})
|
||||
stream_response_payload = self._build_stream_response_payload(event, event_payload)
|
||||
body = _json_dumps(stream_response_payload).encode("utf-8")
|
||||
|
||||
for attempt in range(1, 5):
|
||||
delivery.attempt = attempt
|
||||
db.add(delivery)
|
||||
db.commit()
|
||||
|
||||
headers = {"Content-Type": "application/json", "X-A2A-Event-Id": str(event.id)}
|
||||
if hook.secret:
|
||||
digest = hmac.new(hook.secret.encode("utf-8"), body, hashlib.sha256).hexdigest()
|
||||
headers["X-A2A-Signature"] = f"sha256={digest}"
|
||||
if hook.auth_header:
|
||||
headers["Authorization"] = hook.auth_header
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=8.0, verify=True) as client:
|
||||
resp = await client.post(hook.target_url, content=body, headers=headers)
|
||||
delivery.response_code = resp.status_code
|
||||
delivery.response_body = (resp.text or "")[:1000]
|
||||
if 200 <= resp.status_code < 300:
|
||||
delivery.status = "DELIVERED"
|
||||
delivery.dead_letter = False
|
||||
delivery.delivered_at = _utc_now()
|
||||
db.add(delivery)
|
||||
db.commit()
|
||||
return
|
||||
raise RuntimeError(f"http_{resp.status_code}")
|
||||
except Exception as exc:
|
||||
delivery.error_message = str(exc)[:500]
|
||||
if attempt < 4:
|
||||
backoff_seconds = 2 ** attempt
|
||||
delivery.status = "RETRYING"
|
||||
delivery.next_retry_at = _utc_now() + timedelta(seconds=backoff_seconds)
|
||||
db.add(delivery)
|
||||
db.commit()
|
||||
await asyncio.sleep(backoff_seconds)
|
||||
continue
|
||||
delivery.status = "FAILED"
|
||||
delivery.dead_letter = True
|
||||
db.add(delivery)
|
||||
db.commit()
|
||||
return
|
||||
|
||||
def _build_stream_response_payload(self, event: A2ATaskEvent, event_payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
event_type = event.event_type
|
||||
task_id = event_payload.get("task_id", event.task_id)
|
||||
|
||||
if event_type == "TaskStatusUpdateEvent":
|
||||
status_state = event_payload.get("task_status", "WORKING")
|
||||
status_timestamp = event_payload.get("timestamp", _utc_now().isoformat())
|
||||
status_schema = A2ATaskStatusSchema(
|
||||
state=status_state,
|
||||
timestamp=datetime.fromisoformat(status_timestamp) if isinstance(status_timestamp, str) else status_timestamp,
|
||||
)
|
||||
return {
|
||||
"statusUpdate": TaskStatusUpdateEvent(
|
||||
taskId=task_id,
|
||||
contextId=event_payload.get("context_id"),
|
||||
status=status_schema,
|
||||
metadata=event_payload.get("metadata", {}),
|
||||
).model_dump()
|
||||
}
|
||||
elif event_type == "TaskArtifactUpdateEvent":
|
||||
artifact_content = event_payload.get("artifact", {}).get("content", "")
|
||||
artifact_schema = A2AArtifactSchema(
|
||||
artifactId=f"artifact-{event.id}",
|
||||
parts=[A2APartSchema(part_type="text", text=artifact_content)],
|
||||
)
|
||||
return {
|
||||
"artifactUpdate": TaskArtifactUpdateEvent(
|
||||
taskId=task_id,
|
||||
contextId=event_payload.get("context_id"),
|
||||
artifact=artifact_schema,
|
||||
append=False,
|
||||
lastChunk=True,
|
||||
).model_dump()
|
||||
}
|
||||
else:
|
||||
return {"message": event_payload}
|
||||
|
||||
|
||||
a2a_runtime = A2ARuntime()
|
||||
@@ -0,0 +1,77 @@
|
||||
import json
|
||||
import threading
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.data_root import get_data_root
|
||||
|
||||
class EmbeddingModelStore:
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.RLock()
|
||||
|
||||
@staticmethod
|
||||
def _file_path() -> Path:
|
||||
return get_data_root() / "embedding_models.json"
|
||||
|
||||
def _read(self) -> List[Dict[str, Any]]:
|
||||
file_path = self._file_path()
|
||||
if not file_path.exists():
|
||||
return []
|
||||
try:
|
||||
with file_path.open("r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return []
|
||||
if not isinstance(data, list):
|
||||
return []
|
||||
return data
|
||||
|
||||
def _write(self, data: List[Dict[str, Any]]) -> None:
|
||||
file_path = self._file_path()
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with file_path.open("w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
def list_models(self) -> List[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
return self._read()
|
||||
|
||||
def get_model(self, model_id: str) -> Optional[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
data = self._read()
|
||||
for item in data:
|
||||
if item.get("id") == model_id:
|
||||
return item
|
||||
return None
|
||||
|
||||
def create_model(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
with self._lock:
|
||||
data = self._read()
|
||||
new_model = payload.copy()
|
||||
new_model["id"] = uuid.uuid4().hex
|
||||
data.append(new_model)
|
||||
self._write(data)
|
||||
return new_model
|
||||
|
||||
def update_model(self, model_id: str, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
data = self._read()
|
||||
for item in data:
|
||||
if item.get("id") == model_id:
|
||||
item.update(payload)
|
||||
self._write(data)
|
||||
return item
|
||||
return None
|
||||
|
||||
def delete_model(self, model_id: str) -> bool:
|
||||
with self._lock:
|
||||
data = self._read()
|
||||
initial_len = len(data)
|
||||
data = [item for item in data if item.get("id") != model_id]
|
||||
if len(data) < initial_len:
|
||||
self._write(data)
|
||||
return True
|
||||
return False
|
||||
|
||||
embedding_model_store = EmbeddingModelStore()
|
||||
@@ -0,0 +1,188 @@
|
||||
import json
|
||||
import threading
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.data_root import get_data_root
|
||||
|
||||
|
||||
def _utcnow_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
class KnowledgeBaseStore:
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.RLock()
|
||||
|
||||
@staticmethod
|
||||
def _file_path() -> Path:
|
||||
return get_data_root() / "knowledge_bases.json"
|
||||
|
||||
def _read(self) -> List[Dict[str, Any]]:
|
||||
file_path = self._file_path()
|
||||
if not file_path.exists():
|
||||
return []
|
||||
try:
|
||||
with file_path.open("r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return []
|
||||
if not isinstance(data, list):
|
||||
return []
|
||||
return data
|
||||
|
||||
def _write(self, data: List[Dict[str, Any]]) -> None:
|
||||
file_path = self._file_path()
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with file_path.open("w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_documents(item: Dict[str, Any]) -> None:
|
||||
docs = item.get("documents")
|
||||
if not isinstance(docs, list):
|
||||
item["documents"] = []
|
||||
return
|
||||
normalized: List[Dict[str, Any]] = []
|
||||
for doc in docs:
|
||||
if not isinstance(doc, dict):
|
||||
continue
|
||||
if not doc.get("id"):
|
||||
doc["id"] = str(uuid.uuid4())
|
||||
now = _utcnow_iso()
|
||||
doc.setdefault("created_at", now)
|
||||
doc.setdefault("updated_at", now)
|
||||
doc.setdefault("metadata", {})
|
||||
normalized.append(doc)
|
||||
item["documents"] = normalized
|
||||
|
||||
def list(self, project_id: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
data = self._read()
|
||||
for item in data:
|
||||
self._normalize_documents(item)
|
||||
if project_id is None:
|
||||
return data
|
||||
return [item for item in data if item.get("project_id") == project_id]
|
||||
|
||||
def get(self, kb_id: str) -> Optional[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
for item in self._read():
|
||||
if item.get("id") == kb_id:
|
||||
self._normalize_documents(item)
|
||||
return item
|
||||
return None
|
||||
|
||||
def create(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
with self._lock:
|
||||
data = self._read()
|
||||
now = _utcnow_iso()
|
||||
item = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": payload["name"],
|
||||
"description": payload.get("description"),
|
||||
"project_id": payload.get("project_id"),
|
||||
"embedding_model": payload.get("embedding_model"),
|
||||
"chunk_size": payload.get("chunk_size", 512),
|
||||
"chunk_overlap": payload.get("chunk_overlap", 50),
|
||||
"top_k": payload.get("top_k", 3),
|
||||
"is_active": payload.get("is_active", True),
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"documents": [],
|
||||
}
|
||||
data.append(item)
|
||||
self._write(data)
|
||||
return item
|
||||
|
||||
def update(self, kb_id: str, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
data = self._read()
|
||||
for idx, item in enumerate(data):
|
||||
if item.get("id") != kb_id:
|
||||
continue
|
||||
for key, value in payload.items():
|
||||
item[key] = value
|
||||
item["updated_at"] = _utcnow_iso()
|
||||
self._normalize_documents(item)
|
||||
data[idx] = item
|
||||
self._write(data)
|
||||
return item
|
||||
return None
|
||||
|
||||
def delete(self, kb_id: str) -> bool:
|
||||
with self._lock:
|
||||
data = self._read()
|
||||
filtered = [item for item in data if item.get("id") != kb_id]
|
||||
if len(filtered) == len(data):
|
||||
return False
|
||||
self._write(filtered)
|
||||
return True
|
||||
|
||||
def create_document(self, kb_id: str, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
data = self._read()
|
||||
for idx, item in enumerate(data):
|
||||
if item.get("id") != kb_id:
|
||||
continue
|
||||
now = _utcnow_iso()
|
||||
doc = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"title": payload["title"],
|
||||
"content": payload["content"],
|
||||
"metadata": payload.get("metadata", {}),
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
self._normalize_documents(item)
|
||||
item["documents"].append(doc)
|
||||
item["updated_at"] = now
|
||||
data[idx] = item
|
||||
self._write(data)
|
||||
return doc
|
||||
return None
|
||||
|
||||
def update_document(self, kb_id: str, doc_id: str, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
data = self._read()
|
||||
for kb_idx, item in enumerate(data):
|
||||
if item.get("id") != kb_id:
|
||||
continue
|
||||
self._normalize_documents(item)
|
||||
docs = item["documents"]
|
||||
for doc_idx, doc in enumerate(docs):
|
||||
if doc.get("id") != doc_id:
|
||||
continue
|
||||
for key, value in payload.items():
|
||||
doc[key] = value
|
||||
doc["updated_at"] = _utcnow_iso()
|
||||
docs[doc_idx] = doc
|
||||
item["updated_at"] = _utcnow_iso()
|
||||
data[kb_idx] = item
|
||||
self._write(data)
|
||||
return doc
|
||||
return None
|
||||
return None
|
||||
|
||||
def delete_document(self, kb_id: str, doc_id: str) -> bool:
|
||||
with self._lock:
|
||||
data = self._read()
|
||||
for kb_idx, item in enumerate(data):
|
||||
if item.get("id") != kb_id:
|
||||
continue
|
||||
self._normalize_documents(item)
|
||||
docs = item["documents"]
|
||||
filtered = [doc for doc in docs if doc.get("id") != doc_id]
|
||||
if len(filtered) == len(docs):
|
||||
return False
|
||||
item["documents"] = filtered
|
||||
item["updated_at"] = _utcnow_iso()
|
||||
data[kb_idx] = item
|
||||
self._write(data)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
knowledge_base_store = KnowledgeBaseStore()
|
||||
@@ -0,0 +1,58 @@
|
||||
import json
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
from app.core.data_root import get_data_root
|
||||
|
||||
|
||||
class KnowledgeGlobalConfigStore:
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.RLock()
|
||||
|
||||
@staticmethod
|
||||
def _file_path() -> Path:
|
||||
return get_data_root() / "knowledge_global_config.json"
|
||||
|
||||
def _read(self) -> Dict[str, Any]:
|
||||
file_path = self._file_path()
|
||||
if not file_path.exists():
|
||||
return {}
|
||||
try:
|
||||
with file_path.open("r", encoding="utf-8") as file_obj:
|
||||
data = json.load(file_obj)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return {}
|
||||
if not isinstance(data, dict):
|
||||
return {}
|
||||
return data
|
||||
|
||||
def _write(self, data: Dict[str, Any]) -> None:
|
||||
file_path = self._file_path()
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with file_path.open("w", encoding="utf-8") as file_obj:
|
||||
json.dump(data, file_obj, indent=2, ensure_ascii=False)
|
||||
|
||||
def get(self) -> Dict[str, Any]:
|
||||
with self._lock:
|
||||
data = self._read()
|
||||
return {
|
||||
"api_base": data.get("api_base"),
|
||||
"api_key": data.get("api_key"),
|
||||
"default_embedding_model": data.get("default_embedding_model"),
|
||||
}
|
||||
|
||||
def update(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
with self._lock:
|
||||
current = self.get()
|
||||
if "api_base" in payload:
|
||||
current["api_base"] = payload.get("api_base")
|
||||
if "api_key" in payload:
|
||||
current["api_key"] = payload.get("api_key")
|
||||
if "default_embedding_model" in payload:
|
||||
current["default_embedding_model"] = payload.get("default_embedding_model")
|
||||
self._write(current)
|
||||
return current
|
||||
|
||||
|
||||
knowledge_global_config_store = KnowledgeGlobalConfigStore()
|
||||
@@ -0,0 +1,267 @@
|
||||
import math
|
||||
import re
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from app.services.knowledge_base_store import knowledge_base_store
|
||||
from app.services.knowledge_global_config_store import knowledge_global_config_store
|
||||
from app.services.openai_compat import normalize_openai_base_url
|
||||
|
||||
try:
|
||||
from llama_index.core import Document, VectorStoreIndex
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
LLAMAINDEX_AVAILABLE = True
|
||||
except Exception:
|
||||
Document = Any
|
||||
VectorStoreIndex = Any
|
||||
SentenceSplitter = Any
|
||||
LLAMAINDEX_AVAILABLE = False
|
||||
|
||||
|
||||
def _tokenize(text: str) -> List[str]:
|
||||
return re.findall(r"[a-zA-Z0-9]+|[\u4e00-\u9fff]", (text or "").lower())
|
||||
|
||||
|
||||
def _normalize_embedding_api_base(api_base: str) -> str:
|
||||
return normalize_openai_base_url(api_base)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchHit:
|
||||
doc_id: str
|
||||
title: str
|
||||
chunk: str
|
||||
score: float
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class KnowledgeIndexService:
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.RLock()
|
||||
self._cache: Dict[str, Tuple[str, Any, List[Dict[str, Any]]]] = {}
|
||||
|
||||
@staticmethod
|
||||
def _signature(kb: Dict[str, Any]) -> str:
|
||||
doc_parts = []
|
||||
for doc in kb.get("documents", []):
|
||||
doc_parts.append(f"{doc.get('id')}:{doc.get('updated_at')}:{len(doc.get('content', ''))}")
|
||||
return "|".join(
|
||||
[
|
||||
str(kb.get("updated_at")),
|
||||
str(kb.get("chunk_size")),
|
||||
str(kb.get("chunk_overlap")),
|
||||
*doc_parts,
|
||||
]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _fallback_chunks(kb: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
chunks: List[Dict[str, Any]] = []
|
||||
chunk_size = int(kb.get("chunk_size") or 512)
|
||||
overlap = int(kb.get("chunk_overlap") or 50)
|
||||
step = max(1, chunk_size - overlap)
|
||||
for doc in kb.get("documents", []):
|
||||
text = doc.get("content") or ""
|
||||
if not text:
|
||||
continue
|
||||
if len(text) <= chunk_size:
|
||||
chunks.append(
|
||||
{
|
||||
"doc_id": doc.get("id", ""),
|
||||
"title": doc.get("title", ""),
|
||||
"chunk": text,
|
||||
"metadata": doc.get("metadata") or {},
|
||||
}
|
||||
)
|
||||
continue
|
||||
for start in range(0, len(text), step):
|
||||
piece = text[start : start + chunk_size]
|
||||
if not piece:
|
||||
continue
|
||||
chunks.append(
|
||||
{
|
||||
"doc_id": doc.get("id", ""),
|
||||
"title": doc.get("title", ""),
|
||||
"chunk": piece,
|
||||
"metadata": doc.get("metadata") or {},
|
||||
}
|
||||
)
|
||||
return chunks
|
||||
|
||||
def _build_index(self, kb: Dict[str, Any]) -> Tuple[Any, List[Dict[str, Any]]]:
|
||||
fallback_chunks = self._fallback_chunks(kb)
|
||||
if not LLAMAINDEX_AVAILABLE:
|
||||
return None, fallback_chunks
|
||||
chunk_size = int(kb.get("chunk_size") or 512)
|
||||
overlap = int(kb.get("chunk_overlap") or 50)
|
||||
splitter = SentenceSplitter(chunk_size=chunk_size, chunk_overlap=overlap)
|
||||
docs = [
|
||||
Document(
|
||||
text=(doc.get("content") or ""),
|
||||
metadata={
|
||||
"doc_id": doc.get("id", ""),
|
||||
"title": doc.get("title", ""),
|
||||
**(doc.get("metadata") or {}),
|
||||
},
|
||||
)
|
||||
for doc in kb.get("documents", [])
|
||||
if (doc.get("content") or "").strip()
|
||||
]
|
||||
if not docs:
|
||||
return None, fallback_chunks
|
||||
embed_model = self._build_embed_model(kb)
|
||||
if embed_model is not None:
|
||||
index = VectorStoreIndex.from_documents(
|
||||
docs,
|
||||
transformations=[splitter],
|
||||
embed_model=embed_model,
|
||||
)
|
||||
else:
|
||||
index = VectorStoreIndex.from_documents(docs, transformations=[splitter])
|
||||
return index, fallback_chunks
|
||||
|
||||
@staticmethod
|
||||
def _build_embed_model(kb: Dict[str, Any]) -> Any:
|
||||
from app.services.embedding_model_store import embedding_model_store
|
||||
models = embedding_model_store.list_models()
|
||||
if not models:
|
||||
return None
|
||||
|
||||
target_model = None
|
||||
kb_model_val = kb.get("embedding_model")
|
||||
if kb_model_val:
|
||||
# Try matching by ID first, then by model name
|
||||
target_model = next((m for m in models if m.get("id") == kb_model_val), None)
|
||||
if not target_model:
|
||||
target_model = next((m for m in models if m.get("model") == kb_model_val), None)
|
||||
|
||||
if not target_model:
|
||||
# Fallback to the first model
|
||||
target_model = models[0]
|
||||
|
||||
api_base = target_model.get("api_base")
|
||||
api_key = target_model.get("api_key")
|
||||
model_name = target_model.get("model")
|
||||
|
||||
if not api_base or not api_key or not model_name:
|
||||
return None
|
||||
api_base = _normalize_embedding_api_base(api_base)
|
||||
try:
|
||||
from llama_index.embeddings.openai_like import OpenAILikeEmbedding
|
||||
|
||||
return OpenAILikeEmbedding(
|
||||
model_name=model_name,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
embed_batch_size=10,
|
||||
)
|
||||
except Exception:
|
||||
try:
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
return OpenAIEmbedding(
|
||||
model_name=model_name,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
embed_batch_size=10,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def reindex(self, kb_id: str) -> Dict[str, Any]:
|
||||
kb = knowledge_base_store.get(kb_id)
|
||||
if not kb:
|
||||
raise ValueError("Knowledge base not found")
|
||||
with self._lock:
|
||||
signature = self._signature(kb)
|
||||
index, fallback_chunks = self._build_index(kb)
|
||||
self._cache[kb_id] = (signature, index, fallback_chunks)
|
||||
return {
|
||||
"kb_id": kb_id,
|
||||
"status": "ok",
|
||||
"documents": len(kb.get("documents", [])),
|
||||
"engine": "llamaindex" if LLAMAINDEX_AVAILABLE and index is not None else "fallback",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _fallback_search(query: str, chunks: List[Dict[str, Any]], top_k: int) -> List[SearchHit]:
|
||||
q_tokens = _tokenize(query)
|
||||
if not q_tokens:
|
||||
return []
|
||||
q_set = set(q_tokens)
|
||||
scored: List[SearchHit] = []
|
||||
for chunk_item in chunks:
|
||||
c_tokens = _tokenize(chunk_item.get("chunk", ""))
|
||||
if not c_tokens:
|
||||
continue
|
||||
overlap = sum(1 for t in c_tokens if t in q_set)
|
||||
if overlap == 0:
|
||||
continue
|
||||
score = overlap / math.sqrt(len(c_tokens))
|
||||
scored.append(
|
||||
SearchHit(
|
||||
doc_id=chunk_item.get("doc_id", ""),
|
||||
title=chunk_item.get("title", ""),
|
||||
chunk=chunk_item.get("chunk", ""),
|
||||
score=float(score),
|
||||
metadata=chunk_item.get("metadata") or {},
|
||||
)
|
||||
)
|
||||
scored.sort(key=lambda x: x.score, reverse=True)
|
||||
return scored[:top_k]
|
||||
|
||||
def search(self, kb_id: str, query: str, top_k: int | None = None) -> Dict[str, Any]:
|
||||
kb = knowledge_base_store.get(kb_id)
|
||||
if not kb:
|
||||
raise ValueError("Knowledge base not found")
|
||||
if not kb.get("documents"):
|
||||
return {"answer": "", "hits": []}
|
||||
effective_top_k = int(top_k or kb.get("top_k") or 3)
|
||||
with self._lock:
|
||||
signature = self._signature(kb)
|
||||
cached = self._cache.get(kb_id)
|
||||
if not cached or cached[0] != signature:
|
||||
index, fallback_chunks = self._build_index(kb)
|
||||
cached = (signature, index, fallback_chunks)
|
||||
self._cache[kb_id] = cached
|
||||
_, index, fallback_chunks = cached
|
||||
if index is None:
|
||||
hits = self._fallback_search(query=query, chunks=fallback_chunks, top_k=effective_top_k)
|
||||
answer = "\n\n".join(hit.chunk for hit in hits)
|
||||
return {
|
||||
"answer": answer,
|
||||
"hits": [hit.__dict__ for hit in hits],
|
||||
}
|
||||
retriever = index.as_retriever(similarity_top_k=effective_top_k)
|
||||
response_nodes = retriever.retrieve(query)
|
||||
hits: List[Dict[str, Any]] = []
|
||||
for node_with_score in response_nodes:
|
||||
node = getattr(node_with_score, "node", None)
|
||||
metadata = getattr(node, "metadata", {}) if node is not None else {}
|
||||
chunk_text = ""
|
||||
if node is not None and hasattr(node, "get_content"):
|
||||
chunk_text = node.get_content()
|
||||
elif node is not None:
|
||||
chunk_text = str(getattr(node, "text", ""))
|
||||
hits.append(
|
||||
{
|
||||
"doc_id": metadata.get("doc_id", ""),
|
||||
"title": metadata.get("title", ""),
|
||||
"chunk": chunk_text,
|
||||
"score": float(getattr(node_with_score, "score", 0.0) or 0.0),
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
if not hits:
|
||||
fallback_hits = self._fallback_search(query=query, chunks=fallback_chunks, top_k=effective_top_k)
|
||||
return {
|
||||
"answer": "\n\n".join(hit.chunk for hit in fallback_hits),
|
||||
"hits": [hit.__dict__ for hit in fallback_hits],
|
||||
}
|
||||
answer = "\n\n".join(item.get("chunk", "") for item in hits if item.get("chunk"))
|
||||
return {"answer": answer, "hits": hits}
|
||||
|
||||
|
||||
knowledge_index_service = KnowledgeIndexService()
|
||||
@@ -0,0 +1,24 @@
|
||||
import os
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.api.llm import DATA_FILE, _load_data
|
||||
|
||||
_cache_lock = threading.RLock()
|
||||
_cache_mtime: float = -1.0
|
||||
_cache_data: List[Dict[str, Any]] = []
|
||||
|
||||
|
||||
def get_llm_configs() -> List[Dict[str, Any]]:
|
||||
global _cache_mtime, _cache_data
|
||||
current_mtime = os.path.getmtime(DATA_FILE) if os.path.exists(DATA_FILE) else -1.0
|
||||
with _cache_lock:
|
||||
if current_mtime != _cache_mtime:
|
||||
_cache_data = _load_data()
|
||||
_cache_mtime = current_mtime
|
||||
return list(_cache_data)
|
||||
|
||||
|
||||
def get_active_llm_config() -> Optional[Dict[str, Any]]:
|
||||
configs = get_llm_configs()
|
||||
return next((c for c in configs if c.get("is_active")), None)
|
||||
@@ -0,0 +1,169 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List
|
||||
from app.models.datasource import DataSource
|
||||
from app.schemas.mdl import MDLManifest, Model, Column, TableReference
|
||||
from app.connectors.factory import get_connector
|
||||
from app.database import SessionLocal
|
||||
from app.core.data_root import get_data_root
|
||||
|
||||
MDL_STORAGE_PATH = get_data_root() / "mdl"
|
||||
|
||||
class MDLService:
|
||||
@staticmethod
|
||||
def _get_mdl_path(datasource_id: int) -> Path:
|
||||
MDL_STORAGE_PATH.mkdir(parents=True, exist_ok=True)
|
||||
return MDL_STORAGE_PATH / f"{datasource_id}.json"
|
||||
|
||||
@staticmethod
|
||||
def get_raw_schema(datasource: DataSource) -> Dict[str, List[Dict[str, str]]]:
|
||||
connector = get_connector(datasource)
|
||||
try:
|
||||
return connector.get_schema()
|
||||
except Exception as e:
|
||||
print(f"Error fetching schema for DS {datasource.id}: {e}")
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def generate_default_mdl(
|
||||
datasource: DataSource,
|
||||
selected_tables: Optional[List[str]] = None,
|
||||
selected_columns: Optional[Dict[str, List[str]]] = None,
|
||||
) -> MDLManifest:
|
||||
raw_schema = MDLService.get_raw_schema(datasource)
|
||||
|
||||
models = []
|
||||
relationships = []
|
||||
from app.schemas.mdl import Relationship
|
||||
|
||||
# Helper to get columns for a table from the raw schema (which could be a list or a dict)
|
||||
def get_table_info(t_name):
|
||||
data = raw_schema.get(t_name, [])
|
||||
if isinstance(data, dict) and "columns" in data:
|
||||
return data
|
||||
return {"columns": data, "primary_keys": [], "foreign_keys": []}
|
||||
|
||||
for table_name in raw_schema.keys():
|
||||
if selected_tables is not None and table_name not in selected_tables:
|
||||
continue
|
||||
|
||||
table_info = get_table_info(table_name)
|
||||
columns = table_info["columns"]
|
||||
pks = table_info.get("primary_keys", [])
|
||||
|
||||
model_cols = []
|
||||
for col_info in columns:
|
||||
if isinstance(col_info, dict):
|
||||
name = col_info.get("name", "UNKNOWN")
|
||||
type_ = col_info.get("type", "UNKNOWN")
|
||||
elif isinstance(col_info, str):
|
||||
# Fallback for old string format "name (type)"
|
||||
if "(" in col_info and col_info.endswith(")"):
|
||||
parts = col_info.rsplit(" (", 1)
|
||||
if len(parts) == 2:
|
||||
name = parts[0]
|
||||
type_ = parts[1][:-1]
|
||||
else:
|
||||
name = col_info
|
||||
type_ = "UNKNOWN"
|
||||
else:
|
||||
name = col_info
|
||||
type_ = "UNKNOWN"
|
||||
else:
|
||||
name = str(col_info)
|
||||
type_ = "UNKNOWN"
|
||||
|
||||
if selected_columns is not None:
|
||||
allowed = selected_columns.get(table_name, [])
|
||||
if allowed and name not in allowed:
|
||||
continue
|
||||
|
||||
is_pk = name in pks
|
||||
model_cols.append(Column(name=name, type=type_, properties={"is_primary_key": is_pk}))
|
||||
|
||||
if not model_cols:
|
||||
continue
|
||||
|
||||
models.append(Model(
|
||||
name=table_name,
|
||||
tableReference=TableReference(table=table_name),
|
||||
columns=model_cols,
|
||||
primaryKey=pks[0] if pks else None
|
||||
))
|
||||
|
||||
# Extract relationships from foreign keys
|
||||
fks = table_info.get("foreign_keys", [])
|
||||
for fk in fks:
|
||||
referred_table = fk.get("referred_table")
|
||||
if not referred_table:
|
||||
continue
|
||||
# Skip if the referred table is not selected
|
||||
if selected_tables is not None and referred_table not in selected_tables:
|
||||
continue
|
||||
|
||||
constrained_cols = fk.get("constrained_columns", [])
|
||||
referred_cols = fk.get("referred_columns", [])
|
||||
|
||||
if len(constrained_cols) == 1 and len(referred_cols) == 1:
|
||||
# Update column properties for FK
|
||||
fk_col_name = constrained_cols[0]
|
||||
for col in model_cols:
|
||||
if col.name == fk_col_name:
|
||||
col.properties["is_foreign_key"] = True
|
||||
|
||||
# Simple single-column foreign key
|
||||
condition = f"{table_name}.{constrained_cols[0]} = {referred_table}.{referred_cols[0]}"
|
||||
rel_name = f"{table_name}_{constrained_cols[0]}_to_{referred_table}"
|
||||
relationships.append(Relationship(
|
||||
name=rel_name,
|
||||
models=[table_name, referred_table],
|
||||
joinType="MANY_TO_ONE", # typically a foreign key represents many-to-one
|
||||
condition=condition
|
||||
))
|
||||
|
||||
return MDLManifest(
|
||||
catalog="default",
|
||||
schema="public", # Default schema, might need adjustment based on datasource config
|
||||
dataSource=datasource.type.upper(),
|
||||
models=models,
|
||||
relationships=relationships
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_mdl(datasource_id: int) -> Optional[MDLManifest]:
|
||||
path = MDLService._get_mdl_path(datasource_id)
|
||||
if path.exists():
|
||||
try:
|
||||
with open(path, "r") as f:
|
||||
data = json.load(f)
|
||||
# Pydantic v2 compatible
|
||||
return MDLManifest.model_validate(data)
|
||||
except Exception as e:
|
||||
print(f"Error loading MDL for {datasource_id}: {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def save_mdl(datasource_id: int, mdl: MDLManifest):
|
||||
path = MDLService._get_mdl_path(datasource_id)
|
||||
with open(path, "w") as f:
|
||||
f.write(mdl.model_dump_json(indent=2, by_alias=True))
|
||||
|
||||
@staticmethod
|
||||
def get_or_create_mdl(datasource_id: int) -> MDLManifest:
|
||||
mdl = MDLService.get_mdl(datasource_id)
|
||||
if mdl:
|
||||
return mdl
|
||||
|
||||
# Generate new
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
|
||||
if not ds:
|
||||
raise ValueError(f"DataSource {datasource_id} not found")
|
||||
mdl = MDLService.generate_default_mdl(ds)
|
||||
MDLService.save_mdl(datasource_id, mdl)
|
||||
return mdl
|
||||
finally:
|
||||
db.close()
|
||||
@@ -0,0 +1,5 @@
|
||||
def normalize_openai_base_url(api_base: str) -> str:
|
||||
normalized = (api_base or "").strip().rstrip("/")
|
||||
if normalized.lower().endswith("/embeddings"):
|
||||
normalized = normalized[: -len("/embeddings")]
|
||||
return normalized
|
||||
@@ -0,0 +1,51 @@
|
||||
import os
|
||||
import json
|
||||
import threading
|
||||
from typing import Any, Dict
|
||||
|
||||
from app.core.data_root import get_data_root
|
||||
|
||||
_cache_lock = threading.RLock()
|
||||
_cache_mtime: float = -1.0
|
||||
_cache_data: Dict[str, Any] = {}
|
||||
|
||||
def get_config_file_path() -> str:
|
||||
return str(get_data_root() / "web_search_config.json")
|
||||
|
||||
def get_web_search_config() -> Dict[str, Any]:
|
||||
global _cache_mtime, _cache_data
|
||||
config_file = get_config_file_path()
|
||||
current_mtime = os.path.getmtime(config_file) if os.path.exists(config_file) else -1.0
|
||||
|
||||
with _cache_lock:
|
||||
if current_mtime != _cache_mtime:
|
||||
if not os.path.exists(config_file):
|
||||
_cache_data = {
|
||||
"provider": "duckduckgo",
|
||||
"api_key": "",
|
||||
"base_url": "",
|
||||
"max_results": 5
|
||||
}
|
||||
else:
|
||||
try:
|
||||
with open(config_file, "r") as f:
|
||||
_cache_data = json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
_cache_data = {
|
||||
"provider": "duckduckgo",
|
||||
"api_key": "",
|
||||
"base_url": "",
|
||||
"max_results": 5
|
||||
}
|
||||
_cache_mtime = current_mtime
|
||||
return dict(_cache_data)
|
||||
|
||||
def save_web_search_config(config: Dict[str, Any]) -> None:
|
||||
global _cache_mtime, _cache_data
|
||||
config_file = get_config_file_path()
|
||||
os.makedirs(os.path.dirname(config_file), exist_ok=True)
|
||||
with _cache_lock:
|
||||
with open(config_file, "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
_cache_data = dict(config)
|
||||
_cache_mtime = os.path.getmtime(config_file)
|
||||
Reference in New Issue
Block a user