fix: model arguments fixed
This commit is contained in:
@@ -2,9 +2,10 @@ from typing import Optional, Dict
|
||||
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
from app.core.patched_openai_compat_provider import PatchedOpenAICompatProvider
|
||||
|
||||
|
||||
def normalize_provider_name(provider: Optional[str]) -> Optional[str]:
|
||||
if not provider:
|
||||
@@ -51,7 +52,7 @@ def build_llm_provider(
|
||||
extra_headers=extra_headers,
|
||||
)
|
||||
|
||||
return OpenAICompatProvider(
|
||||
return PatchedOpenAICompatProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
default_model=model,
|
||||
|
||||
@@ -19,7 +19,6 @@ from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.config.loader import load_config
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
from nanobot.providers.base import GenerationSettings
|
||||
@@ -32,6 +31,7 @@ from nanobot.config.schema import Config
|
||||
# or just import here if we are confident.
|
||||
# Given the structure, importing here should be fine as long as skills.py doesn't import nanobot.py.
|
||||
from app.api.skills import load_skills
|
||||
from app.core.patched_openai_compat_provider import PatchedOpenAICompatProvider
|
||||
from app.services.llm_cache import get_llm_configs, get_active_llm_config
|
||||
|
||||
from app.core.data_root import get_workspace_root
|
||||
@@ -45,6 +45,7 @@ class NanobotIntegration:
|
||||
self._started = False
|
||||
self._model_agent_cache: Dict[tuple[str | None, int | None], AgentLoop] = {}
|
||||
self._model_agent_lock = asyncio.Lock()
|
||||
self._last_usage_by_session: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
@staticmethod
|
||||
def _normalize_config_value(value: Any) -> Any:
|
||||
@@ -80,6 +81,28 @@ class NanobotIntegration:
|
||||
return content
|
||||
return str(response)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_usage(usage: Any) -> Dict[str, int] | None:
|
||||
if not isinstance(usage, dict):
|
||||
return None
|
||||
normalized: Dict[str, int] = {}
|
||||
prompt = int(usage.get("prompt_tokens", 0) or 0)
|
||||
completion = int(usage.get("completion_tokens", 0) or 0)
|
||||
total = int(usage.get("total_tokens", 0) or 0)
|
||||
|
||||
# If total_tokens is missing or zero, calculate it
|
||||
if total == 0:
|
||||
total = prompt + completion
|
||||
|
||||
normalized["prompt_tokens"] = prompt
|
||||
normalized["completion_tokens"] = completion
|
||||
normalized["total_tokens"] = total
|
||||
return normalized if (prompt > 0 or completion > 0) else None
|
||||
|
||||
def get_last_usage(self, session_id: str) -> Dict[str, int] | None:
|
||||
usage = self._last_usage_by_session.get(session_id)
|
||||
return dict(usage) if usage else None
|
||||
|
||||
def _need_custom_agent_for_target(self, target_config: Dict[str, Any]) -> bool:
|
||||
if not self.agent:
|
||||
return False
|
||||
@@ -220,7 +243,7 @@ class NanobotIntegration:
|
||||
extra_headers=extra_headers,
|
||||
)
|
||||
|
||||
return OpenAICompatProvider(
|
||||
return PatchedOpenAICompatProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
default_model=model,
|
||||
@@ -407,6 +430,9 @@ class NanobotIntegration:
|
||||
on_progress=on_progress,
|
||||
on_stream=on_stream,
|
||||
)
|
||||
usage = self._normalize_usage(getattr(agent_to_use, "_last_usage", None))
|
||||
if usage:
|
||||
self._last_usage_by_session[session_id] = usage
|
||||
return self._extract_response_text(response)
|
||||
|
||||
def _normalize_session_messages(self, messages: List[Any]) -> List[dict[str, Any]]:
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
|
||||
class PatchedOpenAICompatProvider(OpenAICompatProvider):
|
||||
_MAX_COMPLETION_TOKEN_MODELS = ("gpt-5", "o1", "o3", "o4")
|
||||
|
||||
def _build_kwargs(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None,
|
||||
model: str | None,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
reasoning_effort: str | None,
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
kwargs = super()._build_kwargs(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
reasoning_effort=reasoning_effort,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
model_name = (model or self.default_model or "").lower()
|
||||
spec = self._spec
|
||||
supports_max_completion_tokens = bool(
|
||||
spec and getattr(spec, "supports_max_completion_tokens", False)
|
||||
)
|
||||
should_use_max_completion_tokens = supports_max_completion_tokens or any(
|
||||
token in model_name for token in self._MAX_COMPLETION_TOKEN_MODELS
|
||||
)
|
||||
|
||||
if should_use_max_completion_tokens and "max_tokens" in kwargs:
|
||||
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
|
||||
|
||||
return kwargs
|
||||
+31
-2
@@ -268,6 +268,7 @@ def _persist_assistant_enrichment(
|
||||
session_id: str,
|
||||
viz_payload: Optional[Dict[str, Any]] = None,
|
||||
artifacts: Optional[List[Dict[str, Any]]] = None,
|
||||
usage: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
if not nanobot_service.agent:
|
||||
return
|
||||
@@ -281,9 +282,25 @@ def _persist_assistant_enrichment(
|
||||
if artifacts:
|
||||
session.messages[-1]["artifacts"] = artifacts
|
||||
changed = True
|
||||
if usage:
|
||||
session.messages[-1]["usage"] = usage
|
||||
changed = True
|
||||
if changed:
|
||||
nanobot_service.agent.sessions.save(session)
|
||||
|
||||
|
||||
def _extract_reasoning_content(session_messages: List[Dict[str, Any]]) -> str:
|
||||
for message in reversed(session_messages):
|
||||
if not isinstance(message, dict):
|
||||
continue
|
||||
if message.get("role") != "assistant":
|
||||
continue
|
||||
reasoning_content = message.get("reasoning_content")
|
||||
if isinstance(reasoning_content, str) and reasoning_content.strip():
|
||||
return reasoning_content
|
||||
break
|
||||
return ""
|
||||
|
||||
@app.post("/nanobot/chat")
|
||||
async def nanobot_chat(request: ChatRequest):
|
||||
try:
|
||||
@@ -321,10 +338,12 @@ async def nanobot_chat(request: ChatRequest):
|
||||
artifacts = extract_artifacts(text, session_messages)
|
||||
|
||||
viz_payload = current_viz_data.get()
|
||||
usage = nanobot_service.get_last_usage(request.session_id)
|
||||
_persist_assistant_enrichment(
|
||||
session_id=request.session_id,
|
||||
viz_payload=viz_payload if isinstance(viz_payload, dict) else None,
|
||||
artifacts=artifacts,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
payload = {
|
||||
@@ -334,6 +353,8 @@ async def nanobot_chat(request: ChatRequest):
|
||||
}
|
||||
if artifacts:
|
||||
payload["artifacts"] = artifacts
|
||||
if usage:
|
||||
payload["usage"] = usage
|
||||
return payload
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -356,7 +377,9 @@ async def nanobot_chat_stream(request: ChatRequest):
|
||||
|
||||
async def _on_progress(content: str, **kwargs: Any) -> None:
|
||||
if content:
|
||||
await progress_queue.put(content)
|
||||
payload: Dict[str, Any] = {"type": "progress", "content": content}
|
||||
payload.update(kwargs)
|
||||
await progress_queue.put(payload)
|
||||
|
||||
async def _on_stream(delta: str) -> None:
|
||||
if delta:
|
||||
@@ -427,9 +450,10 @@ async def nanobot_chat_stream(request: ChatRequest):
|
||||
session = nanobot_service.agent.sessions.get_or_create(request.session_id)
|
||||
session_messages = session.messages
|
||||
artifacts = extract_artifacts(text, session_messages)
|
||||
reasoning_content = _extract_reasoning_content(session_messages)
|
||||
|
||||
# Check again for viz payload after task completes if not sent yet
|
||||
viz_payload = current_viz_data.get()
|
||||
usage = nanobot_service.get_last_usage(request.session_id)
|
||||
if viz_payload:
|
||||
try:
|
||||
current_hash = hash((
|
||||
@@ -447,11 +471,16 @@ async def nanobot_chat_stream(request: ChatRequest):
|
||||
session_id=request.session_id,
|
||||
viz_payload=viz_payload if isinstance(viz_payload, dict) else None,
|
||||
artifacts=artifacts,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
final_payload = {"type": "final", "content": text}
|
||||
if reasoning_content:
|
||||
final_payload["reasoning_content"] = reasoning_content
|
||||
if artifacts:
|
||||
final_payload["artifacts"] = artifacts
|
||||
if usage:
|
||||
final_payload["usage"] = usage
|
||||
yield f"data: {json.dumps(final_payload, ensure_ascii=False)}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
except asyncio.CancelledError:
|
||||
|
||||
@@ -98,3 +98,48 @@ def test_nanobot_chat_stream_syncs_project_id(monkeypatch) -> None:
|
||||
assert "stream-complete" in content
|
||||
assert calls == [{"session_key": "api:test-3", "project_id": 202}]
|
||||
assert process_kwargs and process_kwargs[0]["project_id"] == 202
|
||||
|
||||
|
||||
def test_nanobot_chat_stream_emits_reasoning_flags_and_final_reasoning(monkeypatch) -> None:
|
||||
async def fake_process_message(*args, **kwargs):
|
||||
on_progress = kwargs.get("on_progress")
|
||||
on_stream = kwargs.get("on_stream")
|
||||
if on_progress:
|
||||
await on_progress("模型正在拆解问题", is_reasoning=True)
|
||||
await on_progress("开始执行工具", tool_hint=True)
|
||||
if on_stream:
|
||||
await on_stream("answer-token")
|
||||
return "final-answer"
|
||||
|
||||
class _DummySession:
|
||||
def __init__(self):
|
||||
self.metadata = {}
|
||||
self.messages = [
|
||||
{"role": "assistant", "content": "final-answer", "reasoning_content": "完整思考过程"}
|
||||
]
|
||||
|
||||
class _DummySessions:
|
||||
def get_or_create(self, _key):
|
||||
return _DummySession()
|
||||
|
||||
class _DummyAgent:
|
||||
def __init__(self):
|
||||
self.sessions = _DummySessions()
|
||||
|
||||
async def collect_stream_chunks(response) -> list[str]:
|
||||
chunks: list[str] = []
|
||||
async for chunk in response.body_iterator:
|
||||
chunks.append(chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk)
|
||||
return chunks
|
||||
|
||||
monkeypatch.setattr(main.nanobot_service, "process_message", fake_process_message)
|
||||
monkeypatch.setattr(main.nanobot_service, "agent", _DummyAgent())
|
||||
|
||||
request = main.ChatRequest(message="hello", session_id="api:test-4", project_id=303)
|
||||
response = asyncio.run(main.nanobot_chat_stream(request))
|
||||
chunks = asyncio.run(collect_stream_chunks(response))
|
||||
content = "".join(chunks)
|
||||
|
||||
assert '"type": "progress", "content": "模型正在拆解问题", "is_reasoning": true' in content
|
||||
assert '"type": "progress", "content": "开始执行工具", "tool_hint": true' in content
|
||||
assert '"type": "final", "content": "final-answer", "reasoning_content": "完整思考过程"' in content
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
||||
REPO_ROOT = BACKEND_ROOT.parent
|
||||
NANOBOT_ROOT = REPO_ROOT / "nanobot"
|
||||
if str(BACKEND_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(BACKEND_ROOT))
|
||||
if str(NANOBOT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(NANOBOT_ROOT))
|
||||
|
||||
from app.core.llm_provider import build_llm_provider
|
||||
from app.core.nanobot import NanobotIntegration
|
||||
from app.core.patched_openai_compat_provider import PatchedOpenAICompatProvider
|
||||
|
||||
|
||||
def test_build_llm_provider_uses_max_completion_tokens_for_gpt5() -> None:
|
||||
provider = build_llm_provider(
|
||||
model="gpt-5.4-nano",
|
||||
provider="openai",
|
||||
api_key="test-key",
|
||||
api_base="https://example.com/v1",
|
||||
)
|
||||
|
||||
assert isinstance(provider, PatchedOpenAICompatProvider)
|
||||
kwargs = provider._build_kwargs(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
tools=None,
|
||||
model="gpt-5.4-nano",
|
||||
max_tokens=5,
|
||||
temperature=0,
|
||||
reasoning_effort=None,
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
assert kwargs["max_completion_tokens"] == 5
|
||||
assert "max_tokens" not in kwargs
|
||||
|
||||
|
||||
def test_nanobot_provider_keeps_max_tokens_for_legacy_models() -> None:
|
||||
integration = NanobotIntegration()
|
||||
provider = integration._build_provider(
|
||||
model="gpt-4o-mini",
|
||||
provider_name="openai",
|
||||
api_key="test-key",
|
||||
api_base="https://example.com/v1",
|
||||
extra_headers=None,
|
||||
)
|
||||
|
||||
assert isinstance(provider, PatchedOpenAICompatProvider)
|
||||
kwargs = provider._build_kwargs(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
tools=None,
|
||||
model="gpt-4o-mini",
|
||||
max_tokens=5,
|
||||
temperature=0,
|
||||
reasoning_effort=None,
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
assert kwargs["max_tokens"] == 5
|
||||
assert "max_completion_tokens" not in kwargs
|
||||
Reference in New Issue
Block a user