"""SQLite chat memory for Telegram users.""" from __future__ import annotations import sqlite3 import time from pathlib import Path from typing import TypedDict class ChatMessage(TypedDict): role: str content: str class SQLiteMemory: def __init__(self, db_path: str = "/workspace/.services/bot/data.db") -> None: self.db_path = Path(db_path) self.init_db() def init_db(self) -> None: self.db_path.parent.mkdir(parents=True, exist_ok=True) with sqlite3.connect(self.db_path) as conn: conn.execute( """ CREATE TABLE IF NOT EXISTS messages ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL, role TEXT NOT NULL CHECK(role IN ('user', 'assistant')), content TEXT NOT NULL, ts INTEGER NOT NULL ) """ ) conn.execute("CREATE INDEX IF NOT EXISTS idx_messages_user_ts ON messages(user_id, ts DESC, id DESC)") conn.commit() def save(self, user_id: int, role: str, content: str) -> None: content = (content or "").strip() if not content: return with sqlite3.connect(self.db_path) as conn: conn.execute( "INSERT INTO messages(user_id, role, content, ts) VALUES (?, ?, ?, ?)", (user_id, role, content, int(time.time())), ) # Storage trim: keep a small bounded tail; prompt context uses the last 10 only. conn.execute( """ DELETE FROM messages WHERE user_id = ? AND id NOT IN ( SELECT id FROM messages WHERE user_id = ? ORDER BY ts DESC, id DESC LIMIT 50 ) """, (user_id, user_id), ) conn.commit() def recent_context(self, user_id: int, limit: int = 10) -> list[ChatMessage]: """Return the newest `limit` messages in chronological order for LLM context.""" with sqlite3.connect(self.db_path) as conn: rows = conn.execute( "SELECT role, content FROM messages WHERE user_id = ? ORDER BY ts DESC, id DESC LIMIT ?", (user_id, limit), ).fetchall() return [{"role": role, "content": content} for role, content in reversed(rows)] def reset(self, user_id: int) -> int: """Delete all stored messages for a user and return affected row count.""" with sqlite3.connect(self.db_path) as conn: cursor = conn.execute("DELETE FROM messages WHERE user_id = ?", (user_id,)) conn.commit() return cursor.rowcount