138 lines
4.3 KiB
Python
138 lines
4.3 KiB
Python
"""
|
||
Тест стриминга LLM messages от LangGraph
|
||
"""
|
||
|
||
import asyncio
|
||
from langgraph.graph import StateGraph, END
|
||
from typing import TypedDict, Annotated
|
||
import operator
|
||
from langchain_ollama import OllamaLLM
|
||
|
||
|
||
class TestState(TypedDict):
|
||
messages: Annotated[list, operator.add]
|
||
result: str
|
||
|
||
|
||
async def llm_node(state: TestState) -> TestState:
|
||
"""Нода с LLM вызовом"""
|
||
print(" [LLM NODE] Вызов LLM...")
|
||
|
||
llm = OllamaLLM(
|
||
model="qwen2.5-coder:3b",
|
||
base_url="http://localhost:11434",
|
||
temperature=0.7
|
||
)
|
||
|
||
# Простой промпт для быстрого ответа
|
||
prompt = "Напиши короткую проверку кода на Python (не более 100 символов)"
|
||
|
||
response = await llm.ainvoke(prompt)
|
||
|
||
print(f" [LLM NODE] Ответ получен: {response[:50]}...")
|
||
|
||
return {
|
||
"messages": [{"role": "ai", "content": response}],
|
||
"result": response
|
||
}
|
||
|
||
|
||
def create_test_graph():
|
||
"""Создает тестовый граф с LLM"""
|
||
workflow = StateGraph(TestState)
|
||
|
||
workflow.add_node("llm_call", llm_node)
|
||
|
||
workflow.set_entry_point("llm_call")
|
||
workflow.add_edge("llm_call", END)
|
||
|
||
return workflow.compile()
|
||
|
||
|
||
async def test_with_llm():
|
||
"""Тест стриминга с LLM"""
|
||
print("\n" + "="*80)
|
||
print("ТЕСТ СТРИМИНГА LLM MESSAGES")
|
||
print("="*80)
|
||
|
||
graph = create_test_graph()
|
||
|
||
initial_state: TestState = {
|
||
"messages": [],
|
||
"result": ""
|
||
}
|
||
|
||
# Тест: updates + messages
|
||
print(f"\n🔍 Тест: stream_mode=['updates', 'messages']")
|
||
print("-" * 80)
|
||
|
||
event_count = 0
|
||
messages_count = 0
|
||
|
||
async for event in graph.astream(initial_state, stream_mode=["updates", "messages"]):
|
||
event_count += 1
|
||
|
||
if isinstance(event, tuple) and len(event) >= 2:
|
||
event_type, event_data = event[0], event[1]
|
||
|
||
print(f"\n📨 Event #{event_count}")
|
||
print(f" Type: {event_type}")
|
||
print(f" Data type: {type(event_data)}")
|
||
|
||
if event_type == 'updates':
|
||
print(f" ✅ Node update")
|
||
if isinstance(event_data, dict):
|
||
for node_name in event_data.keys():
|
||
print(f" Node: {node_name}")
|
||
|
||
elif event_type == 'messages':
|
||
messages_count += 1
|
||
print(f" 💬 LLM Messages (#{messages_count})")
|
||
|
||
if isinstance(event_data, (list, tuple)):
|
||
for i, msg in enumerate(event_data):
|
||
print(f" Message {i+1}:")
|
||
|
||
# Извлекаем контент
|
||
if hasattr(msg, 'content'):
|
||
content = msg.content
|
||
print(f" Content: {content[:100]}...")
|
||
elif isinstance(msg, dict):
|
||
print(f" Dict: {msg}")
|
||
else:
|
||
print(f" Type: {type(msg)}")
|
||
print(f" Str: {str(msg)[:100]}...")
|
||
|
||
print(f"\n" + "="*80)
|
||
print(f"✅ Всего событий: {event_count}")
|
||
print(f"✅ Messages событий: {messages_count}")
|
||
print("="*80)
|
||
|
||
|
||
async def main():
|
||
print("\n" + "="*80)
|
||
print("ТЕСТИРОВАНИЕ LLM STREAMING В LANGGRAPH")
|
||
print("="*80)
|
||
print("\nПроверка Ollama...")
|
||
|
||
try:
|
||
# Проверяем что Ollama доступен
|
||
from langchain_ollama import OllamaLLM
|
||
test_llm = OllamaLLM(model="qwen2.5-coder:3b", base_url="http://localhost:11434")
|
||
result = await test_llm.ainvoke("test")
|
||
print("✅ Ollama работает!")
|
||
except Exception as e:
|
||
print(f"❌ Ошибка подключения к Ollama: {e}")
|
||
print("\n⚠️ Убедитесь что Ollama запущен: ollama serve")
|
||
print("⚠️ И модель загружена: ollama pull qwen2.5-coder:3b\n")
|
||
return
|
||
|
||
await test_with_llm()
|
||
|
||
print("\n✅ Тестирование завершено\n")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|
||
|