fix: model arguments fixed

This commit is contained in:
qixinbo
2026-03-28 14:46:50 +08:00
parent 00e5587e75
commit bd731660ac
12 changed files with 361 additions and 35 deletions
+45
View File
@@ -98,3 +98,48 @@ def test_nanobot_chat_stream_syncs_project_id(monkeypatch) -> None:
assert "stream-complete" in content
assert calls == [{"session_key": "api:test-3", "project_id": 202}]
assert process_kwargs and process_kwargs[0]["project_id"] == 202
def test_nanobot_chat_stream_emits_reasoning_flags_and_final_reasoning(monkeypatch) -> None:
async def fake_process_message(*args, **kwargs):
on_progress = kwargs.get("on_progress")
on_stream = kwargs.get("on_stream")
if on_progress:
await on_progress("模型正在拆解问题", is_reasoning=True)
await on_progress("开始执行工具", tool_hint=True)
if on_stream:
await on_stream("answer-token")
return "final-answer"
class _DummySession:
def __init__(self):
self.metadata = {}
self.messages = [
{"role": "assistant", "content": "final-answer", "reasoning_content": "完整思考过程"}
]
class _DummySessions:
def get_or_create(self, _key):
return _DummySession()
class _DummyAgent:
def __init__(self):
self.sessions = _DummySessions()
async def collect_stream_chunks(response) -> list[str]:
chunks: list[str] = []
async for chunk in response.body_iterator:
chunks.append(chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk)
return chunks
monkeypatch.setattr(main.nanobot_service, "process_message", fake_process_message)
monkeypatch.setattr(main.nanobot_service, "agent", _DummyAgent())
request = main.ChatRequest(message="hello", session_id="api:test-4", project_id=303)
response = asyncio.run(main.nanobot_chat_stream(request))
chunks = asyncio.run(collect_stream_chunks(response))
content = "".join(chunks)
assert '"type": "progress", "content": "模型正在拆解问题", "is_reasoning": true' in content
assert '"type": "progress", "content": "开始执行工具", "tool_hint": true' in content
assert '"type": "final", "content": "final-answer", "reasoning_content": "完整思考过程"' in content
@@ -0,0 +1,62 @@
import sys
from pathlib import Path
BACKEND_ROOT = Path(__file__).resolve().parents[1]
REPO_ROOT = BACKEND_ROOT.parent
NANOBOT_ROOT = REPO_ROOT / "nanobot"
if str(BACKEND_ROOT) not in sys.path:
sys.path.insert(0, str(BACKEND_ROOT))
if str(NANOBOT_ROOT) not in sys.path:
sys.path.insert(0, str(NANOBOT_ROOT))
from app.core.llm_provider import build_llm_provider
from app.core.nanobot import NanobotIntegration
from app.core.patched_openai_compat_provider import PatchedOpenAICompatProvider
def test_build_llm_provider_uses_max_completion_tokens_for_gpt5() -> None:
provider = build_llm_provider(
model="gpt-5.4-nano",
provider="openai",
api_key="test-key",
api_base="https://example.com/v1",
)
assert isinstance(provider, PatchedOpenAICompatProvider)
kwargs = provider._build_kwargs(
messages=[{"role": "user", "content": "hello"}],
tools=None,
model="gpt-5.4-nano",
max_tokens=5,
temperature=0,
reasoning_effort=None,
tool_choice=None,
)
assert kwargs["max_completion_tokens"] == 5
assert "max_tokens" not in kwargs
def test_nanobot_provider_keeps_max_tokens_for_legacy_models() -> None:
integration = NanobotIntegration()
provider = integration._build_provider(
model="gpt-4o-mini",
provider_name="openai",
api_key="test-key",
api_base="https://example.com/v1",
extra_headers=None,
)
assert isinstance(provider, PatchedOpenAICompatProvider)
kwargs = provider._build_kwargs(
messages=[{"role": "user", "content": "hello"}],
tools=None,
model="gpt-4o-mini",
max_tokens=5,
temperature=0,
reasoning_effort=None,
tool_choice=None,
)
assert kwargs["max_tokens"] == 5
assert "max_completion_tokens" not in kwargs