Files
Hermes-ui/packages/server/src/routes/hermes/proxy-handler.ts
T

282 lines
8.4 KiB
TypeScript
Raw Normal View History

import type { Context } from 'koa'
import { config } from '../../config'
import { getGatewayManagerInstance } from '../../services/gateway-bootstrap'
import { updateUsage } from '../../db/hermes/usage-store'
function getGatewayManager() { return getGatewayManagerInstance() }
// --- run_id → session_id mapping (in-memory, ephemeral) ---
const runSessionMap = new Map<string, string>()
export function setRunSession(runId: string, sessionId: string): void {
runSessionMap.set(runId, sessionId)
// Auto-cleanup after 30 minutes
setTimeout(() => runSessionMap.delete(runId), 30 * 60 * 1000)
}
function getSessionForRun(runId: string): string | undefined {
return runSessionMap.get(runId)
}
// --- Helpers ---
function isTransientGatewayError(err: any): boolean {
const msg = String(err?.message || '')
const causeCode = String(err?.cause?.code || '')
return (
causeCode === 'ECONNREFUSED' ||
causeCode === 'ECONNRESET' ||
/ECONNREFUSED|ECONNRESET|fetch failed|socket hang up/i.test(msg)
)
}
async function waitForGatewayReady(upstream: string, timeoutMs: number = 5000): Promise<boolean> {
const deadline = Date.now() + timeoutMs
const healthUrl = `${upstream}/health`
while (Date.now() < deadline) {
try {
const res = await fetch(healthUrl, {
method: 'GET',
signal: AbortSignal.timeout(1200),
})
if (res.ok) return true
} catch { }
await new Promise(resolve => setTimeout(resolve, 250))
}
return false
}
/** Resolve profile name from request */
function resolveProfile(ctx: Context): string {
return ctx.get('x-hermes-profile') || (ctx.query.profile as string) || 'default'
}
/** Resolve upstream URL for a request based on profile header/query */
function resolveUpstream(ctx: Context): string {
const mgr = getGatewayManager()
if (mgr) {
const profile = resolveProfile(ctx)
if (profile && profile !== 'default') {
return mgr.getUpstream(profile)
}
return mgr.getUpstream()
}
return config.upstream.replace(/\/$/, '')
}
function buildProxyHeaders(ctx: Context, upstream: string): Record<string, string> {
const headers: Record<string, string> = {}
for (const [key, value] of Object.entries(ctx.headers)) {
if (value == null) continue
const lower = key.toLowerCase()
if (lower === 'host') {
headers['host'] = new URL(upstream).host
} else if (lower === 'origin' || lower === 'referer' || lower === 'connection' || lower === 'authorization') {
continue
} else {
const v = Array.isArray(value) ? value[0] : value
if (v) headers[key] = v
}
}
const mgr = getGatewayManager()
if (mgr) {
const apiKey = mgr.getApiKey(resolveProfile(ctx))
if (apiKey) {
headers['authorization'] = `Bearer ${apiKey}`
}
}
return headers
}
// --- SSE stream interception ---
const SSE_EVENTS_PATH = /^\/v1\/runs\/([^/]+)\/events$/
/**
* Parse one complete SSE event block and record usage for run.completed.
* The public stream is forwarded elsewhere; parser failures are accounting-only
* and must never abort the client stream.
*/
function extractRunCompletedFromBlock(block: string): string | null {
const dataLines: string[] = []
for (const rawLine of block.split(/\r?\n/)) {
if (!rawLine.startsWith('data:')) continue
let data = rawLine.slice(5)
if (data.startsWith(' ')) data = data.slice(1)
dataLines.push(data)
}
if (dataLines.length === 0) return null
try {
const data = JSON.parse(dataLines.join('\n'))
if (data.event === 'run.completed' && data.usage && data.run_id) {
const sessionId = getSessionForRun(data.run_id)
if (sessionId) {
updateUsage(sessionId, data.usage.input_tokens, data.usage.output_tokens)
return data.run_id
}
}
} catch { /* not JSON or usage accounting failed; skip */ }
return null
}
function takeSSEBlock(buffer: string): { block: string; rest: string } | null {
const lf = buffer.indexOf('\n\n')
const crlf = buffer.indexOf('\r\n\r\n')
if (lf === -1 && crlf === -1) return null
if (crlf !== -1 && (lf === -1 || crlf < lf)) {
return { block: buffer.slice(0, crlf), rest: buffer.slice(crlf + 4) }
}
return { block: buffer.slice(0, lf), rest: buffer.slice(lf + 2) }
}
/**
* Stream an SSE response while intercepting run.completed events.
*/
async function streamSSE(ctx: Context, res: Response): Promise<void> {
if (!res.body) {
ctx.res.end()
return
}
const reader = res.body.getReader()
const decoder = new TextDecoder()
let buffer = ''
try {
while (true) {
const { done, value } = await reader.read()
if (done) break
// Forward raw bytes to client immediately
ctx.res.write(value)
// Also decode for interception
buffer += decoder.decode(value, { stream: true })
// Process complete SSE event blocks (LF or CRLF blank-line delimiters).
let next: { block: string; rest: string } | null
while ((next = takeSSEBlock(buffer)) !== null) {
buffer = next.rest
extractRunCompletedFromBlock(next.block)
}
}
buffer += decoder.decode()
// Process remaining buffer
if (buffer.trim()) {
extractRunCompletedFromBlock(buffer)
}
} finally {
ctx.res.end()
}
}
// --- Main proxy function ---
export async function proxy(ctx: Context) {
const profile = resolveProfile(ctx)
const upstream = resolveUpstream(ctx)
const upstreamPath = ctx.path.replace(/^\/api\/hermes\/v1/, '/v1').replace(/^\/api\/hermes/, '/api')
const params = new URLSearchParams(ctx.search || '')
params.delete('token')
const search = params.toString()
const url = `${upstream}${upstreamPath}${search ? `?${search}` : ''}`
const headers = buildProxyHeaders(ctx, upstream)
try {
let body: string | undefined
if (ctx.req.method !== 'GET' && ctx.req.method !== 'HEAD') {
// @koa/bodyparser parses JSON into ctx.request.body but doesn't store rawBody
// by default. Re-serialize the parsed body to get the string form.
const parsed = (ctx as any).request.body
if (typeof parsed === 'string') {
body = parsed
} else if (parsed && typeof parsed === 'object') {
body = JSON.stringify(parsed)
}
}
const requestInit: RequestInit = { method: ctx.req.method, headers, body }
let res: Response
try {
res = await fetch(url, requestInit)
} catch (err: any) {
if (isTransientGatewayError(err) && await waitForGatewayReady(upstream)) {
res = await fetch(url, requestInit)
} else {
throw err
}
}
// Set response headers
res.headers.forEach((value, key) => {
const lower = key.toLowerCase()
if (lower !== 'transfer-encoding' && lower !== 'connection') {
ctx.set(key, value)
}
})
ctx.status = res.status
// Intercept POST /v1/runs to capture run_id → session_id mapping
if (ctx.req.method === 'POST' && /\/v1\/runs$/.test(upstreamPath) && body) {
try {
const parsed = JSON.parse(body)
if (parsed.session_id) {
const resBody = await res.text()
ctx.res.write(resBody)
ctx.res.end()
try {
const result = JSON.parse(resBody)
if (result.run_id) {
setRunSession(result.run_id, parsed.session_id)
}
} catch { /* response not JSON, ignore */ }
return
}
} catch { /* body not JSON, fall through to normal stream */ }
// No session_id in body — fall through to normal response handling below
}
// Intercept SSE streams for /v1/runs/{id}/events
const sseMatch = upstreamPath.match(SSE_EVENTS_PATH)
if (sseMatch) {
ctx.set('Content-Type', 'text/event-stream')
ctx.set('Cache-Control', 'no-cache, no-transform')
ctx.set('X-Accel-Buffering', 'no')
await streamSSE(ctx, res)
return
}
// Default: pipe response body directly
if (res.body) {
const reader = res.body.getReader()
const pump = async () => {
while (true) {
const { done, value } = await reader.read()
if (done) break
ctx.res.write(value)
}
ctx.res.end()
}
await pump()
} else {
ctx.res.end()
}
} catch (err: any) {
if (!ctx.res.headersSent) {
ctx.status = 502
ctx.set('Content-Type', 'application/json')
ctx.body = { error: { message: `Proxy error: ${err.message}` } }
} else {
ctx.res.end()
}
}
}