code-review-agent/tests/test_llm_streaming.py

138 lines
4.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Тест стриминга LLM messages от LangGraph
"""
import asyncio
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated
import operator
from langchain_ollama import OllamaLLM
class TestState(TypedDict):
messages: Annotated[list, operator.add]
result: str
async def llm_node(state: TestState) -> TestState:
"""Нода с LLM вызовом"""
print(" [LLM NODE] Вызов LLM...")
llm = OllamaLLM(
model="qwen2.5-coder:3b",
base_url="http://localhost:11434",
temperature=0.7
)
# Простой промпт для быстрого ответа
prompt = "Напиши короткую проверку кода на Python (не более 100 символов)"
response = await llm.ainvoke(prompt)
print(f" [LLM NODE] Ответ получен: {response[:50]}...")
return {
"messages": [{"role": "ai", "content": response}],
"result": response
}
def create_test_graph():
"""Создает тестовый граф с LLM"""
workflow = StateGraph(TestState)
workflow.add_node("llm_call", llm_node)
workflow.set_entry_point("llm_call")
workflow.add_edge("llm_call", END)
return workflow.compile()
async def test_with_llm():
"""Тест стриминга с LLM"""
print("\n" + "="*80)
print("ТЕСТ СТРИМИНГА LLM MESSAGES")
print("="*80)
graph = create_test_graph()
initial_state: TestState = {
"messages": [],
"result": ""
}
# Тест: updates + messages
print(f"\n🔍 Тест: stream_mode=['updates', 'messages']")
print("-" * 80)
event_count = 0
messages_count = 0
async for event in graph.astream(initial_state, stream_mode=["updates", "messages"]):
event_count += 1
if isinstance(event, tuple) and len(event) >= 2:
event_type, event_data = event[0], event[1]
print(f"\n📨 Event #{event_count}")
print(f" Type: {event_type}")
print(f" Data type: {type(event_data)}")
if event_type == 'updates':
print(f" ✅ Node update")
if isinstance(event_data, dict):
for node_name in event_data.keys():
print(f" Node: {node_name}")
elif event_type == 'messages':
messages_count += 1
print(f" 💬 LLM Messages (#{messages_count})")
if isinstance(event_data, (list, tuple)):
for i, msg in enumerate(event_data):
print(f" Message {i+1}:")
# Извлекаем контент
if hasattr(msg, 'content'):
content = msg.content
print(f" Content: {content[:100]}...")
elif isinstance(msg, dict):
print(f" Dict: {msg}")
else:
print(f" Type: {type(msg)}")
print(f" Str: {str(msg)[:100]}...")
print(f"\n" + "="*80)
print(f"Всего событий: {event_count}")
print(f"✅ Messages событий: {messages_count}")
print("="*80)
async def main():
print("\n" + "="*80)
print("ТЕСТИРОВАНИЕ LLM STREAMING В LANGGRAPH")
print("="*80)
print("\nПроверка Ollama...")
try:
# Проверяем что Ollama доступен
from langchain_ollama import OllamaLLM
test_llm = OllamaLLM(model="qwen2.5-coder:3b", base_url="http://localhost:11434")
result = await test_llm.ainvoke("test")
print("✅ Ollama работает!")
except Exception as e:
print(f"❌ Ошибка подключения к Ollama: {e}")
print("\n⚠️ Убедитесь что Ollama запущен: ollama serve")
print("⚠️ И модель загружена: ollama pull qwen2.5-coder:3b\n")
return
await test_with_llm()
print("\n✅ Тестирование завершено\n")
if __name__ == "__main__":
asyncio.run(main())