import sqlite3 from datetime import datetime from pathlib import Path from typing import Optional from secure_sms.core.security import SecurityMetadata, StorageCipher DB_FILE = "secure_sms_v2.db" def utc_now() -> str: return datetime.utcnow().replace(microsecond=0).isoformat() + "Z" class Database: def __init__(self, db_path: str = DB_FILE): self.db_path = Path(db_path) self._initialize() def _connect(self) -> sqlite3.Connection: conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row return conn def _initialize(self): with self._connect() as conn: cursor = conn.cursor() cursor.execute( """ CREATE TABLE IF NOT EXISTS app_config ( key TEXT PRIMARY KEY, value TEXT NOT NULL ) """ ) cursor.execute( """ CREATE TABLE IF NOT EXISTS identity ( id INTEGER PRIMARY KEY CHECK(id = 1), private_key_enc TEXT NOT NULL, public_key_enc TEXT NOT NULL, fingerprint TEXT NOT NULL, created_at TEXT NOT NULL ) """ ) cursor.execute( """ CREATE TABLE IF NOT EXISTS contacts ( phone TEXT PRIMARY KEY, name_enc TEXT NOT NULL, mode TEXT NOT NULL DEFAULT 'normal', secure_state TEXT NOT NULL DEFAULT 'none', peer_public_key_enc TEXT, peer_fingerprint TEXT, symmetric_key_enc TEXT, last_secure_at TEXT, created_at TEXT NOT NULL, updated_at TEXT NOT NULL ) """ ) cursor.execute( """ CREATE TABLE IF NOT EXISTS messages ( id INTEGER PRIMARY KEY AUTOINCREMENT, phone TEXT NOT NULL, direction TEXT NOT NULL, body_enc TEXT NOT NULL, mode TEXT NOT NULL, transport_state TEXT NOT NULL, metadata_enc TEXT, created_at TEXT NOT NULL ) """ ) cursor.execute( """ CREATE TABLE IF NOT EXISTS packet_fragments ( id INTEGER PRIMARY KEY AUTOINCREMENT, phone TEXT NOT NULL, packet_id TEXT NOT NULL, packet_kind TEXT NOT NULL, packet_mode TEXT, part_no INTEGER NOT NULL, total_parts INTEGER NOT NULL, chunk TEXT NOT NULL, created_at TEXT NOT NULL, UNIQUE(phone, packet_id, part_no) ) """ ) cursor.execute( """ CREATE TABLE IF NOT EXISTS secure_events ( id INTEGER PRIMARY KEY AUTOINCREMENT, phone TEXT, event_type TEXT NOT NULL, details_enc TEXT, created_at TEXT NOT NULL ) """ ) conn.commit() # --- Migrations --- # Ensure columns exist in older databases cursor = conn.cursor() cursor.execute("PRAGMA table_info(contacts)") cols = [c["name"] for c in cursor.fetchall()] if "symmetric_key_enc" not in cols: conn.execute("ALTER TABLE contacts ADD COLUMN symmetric_key_enc TEXT") if "last_secure_at" not in cols: conn.execute("ALTER TABLE contacts ADD COLUMN last_secure_at TEXT") cursor.execute("PRAGMA table_info(messages)") cols = [c["name"] for c in cursor.fetchall()] if "metadata_enc" not in cols: conn.execute("ALTER TABLE messages ADD COLUMN metadata_enc TEXT") conn.commit() def is_bootstrapped(self) -> bool: return self.get_security_metadata() is not None def get_security_metadata(self) -> Optional[SecurityMetadata]: with self._connect() as conn: cursor = conn.cursor() cursor.execute("SELECT value FROM app_config WHERE key = 'password_salt'") salt_row = cursor.fetchone() cursor.execute("SELECT value FROM app_config WHERE key = 'password_verifier'") verifier_row = cursor.fetchone() if not salt_row or not verifier_row: return None return SecurityMetadata(salt=salt_row["value"], verifier=verifier_row["value"]) def set_security_metadata(self, meta: SecurityMetadata): self.set_config("password_salt", meta.salt) self.set_config("password_verifier", meta.verifier) def set_config(self, key: str, value: str): with self._connect() as conn: conn.execute( "INSERT INTO app_config(key, value) VALUES(?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value", (key, value), ) conn.commit() def get_config(self, key: str, default: Optional[str] = None) -> Optional[str]: with self._connect() as conn: cursor = conn.cursor() cursor.execute("SELECT value FROM app_config WHERE key = ?", (key,)) row = cursor.fetchone() return row["value"] if row else default def get_connection_settings(self) -> tuple[str, int]: import os default_port = "/dev/serial0" if os.name != "nt" else "COM1" port = self.get_config("gsm_port") if not port or port == "COM1": port = default_port # Strictly default to 9600 raw_baud = self.get_config("gsm_baudrate", "9600") baudrate = int(raw_baud if raw_baud else "9600") return port, baudrate def set_connection_settings(self, port: str, baudrate: int): self.set_config("gsm_port", port) self.set_config("gsm_baudrate", str(baudrate)) def save_identity(self, private_key_enc: str, public_key_enc: str, fingerprint: str): with self._connect() as conn: conn.execute( """ INSERT INTO identity(id, private_key_enc, public_key_enc, fingerprint, created_at) VALUES(1, ?, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET private_key_enc = excluded.private_key_enc, public_key_enc = excluded.public_key_enc, fingerprint = excluded.fingerprint """, (private_key_enc, public_key_enc, fingerprint, utc_now()), ) conn.commit() def get_identity_row(self) -> Optional[sqlite3.Row]: with self._connect() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM identity WHERE id = 1") return cursor.fetchone() def upsert_contact(self, phone: str, name_enc: str): now = utc_now() with self._connect() as conn: conn.execute( """ INSERT INTO contacts(phone, name_enc, mode, secure_state, created_at, updated_at) VALUES(?, ?, 'normal', 'none', ?, ?) ON CONFLICT(phone) DO UPDATE SET name_enc = excluded.name_enc, updated_at = excluded.updated_at """, (phone, name_enc, now, now), ) conn.commit() def ensure_contact_exists(self, phone: str, name_enc: str): now = utc_now() with self._connect() as conn: conn.execute( """ INSERT INTO contacts(phone, name_enc, mode, secure_state, created_at, updated_at) VALUES(?, ?, 'normal', 'none', ?, ?) ON CONFLICT(phone) DO NOTHING """, (phone, name_enc, now, now), ) conn.commit() def get_contact_row(self, phone: str) -> Optional[sqlite3.Row]: with self._connect() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM contacts WHERE phone = ?", (phone,)) return cursor.fetchone() def delete_contact(self, phone: str): with self._connect() as conn: conn.execute("DELETE FROM contacts WHERE phone = ?", (phone,)) conn.execute("DELETE FROM messages WHERE phone = ?", (phone,)) conn.execute("DELETE FROM packet_fragments WHERE phone = ?", (phone,)) conn.execute("DELETE FROM secure_events WHERE phone = ?", (phone,)) conn.commit() def rename_contact_phone(self, old_phone: str, new_phone: str): with self._connect() as conn: conn.execute("UPDATE contacts SET phone = ? WHERE phone = ?", (new_phone, old_phone)) def migrate_messages_phone(self, old_phone: str, new_phone: str): with self._connect() as conn: conn.execute("UPDATE messages SET phone = ? WHERE phone = ?", (new_phone, old_phone)) def migrate_fragments_phone(self, old_phone: str, new_phone: str): with self._connect() as conn: conn.execute("UPDATE packet_fragments SET phone = ? WHERE phone = ?", (new_phone, old_phone)) def migrate_events_phone(self, old_phone: str, new_phone: str): with self._connect() as conn: conn.execute("UPDATE secure_events SET phone = ? WHERE phone = ?", (new_phone, old_phone)) def list_contact_rows(self) -> list[sqlite3.Row]: with self._connect() as conn: cursor = conn.cursor() cursor.execute( """ SELECT c.*, m.body_enc AS last_body_enc FROM contacts c LEFT JOIN messages m ON m.id = ( SELECT id FROM messages WHERE phone = c.phone ORDER BY id DESC LIMIT 1 ) ORDER BY COALESCE(m.id, 0) DESC, c.updated_at DESC """ ) return cursor.fetchall() def update_contact_security( self, phone: str, *, mode: Optional[str] = None, secure_state: Optional[str] = None, peer_public_key_enc: Optional[str] = None, peer_fingerprint: Optional[str] = None, symmetric_key_enc: Optional[str] = None, last_secure_at: Optional[str] = None, ): updates = [] values = [] if mode is not None: updates.append("mode = ?") values.append(mode) if secure_state is not None: updates.append("secure_state = ?") values.append(secure_state) if peer_public_key_enc is not None: updates.append("peer_public_key_enc = ?") values.append(peer_public_key_enc) if peer_fingerprint is not None: updates.append("peer_fingerprint = ?") values.append(peer_fingerprint) if symmetric_key_enc is not None: updates.append("symmetric_key_enc = ?") values.append(symmetric_key_enc) if last_secure_at is not None: updates.append("last_secure_at = ?") values.append(last_secure_at) updates.append("updated_at = ?") values.append(utc_now()) values.append(phone) with self._connect() as conn: conn.execute( f"UPDATE contacts SET {', '.join(updates)} WHERE phone = ?", values, ) conn.commit() def add_message( self, phone: str, direction: str, body_enc: str, mode: str, transport_state: str, metadata_enc: Optional[str] = None, ) -> int: with self._connect() as conn: cursor = conn.cursor() cursor.execute( """ INSERT INTO messages(phone, direction, body_enc, mode, transport_state, metadata_enc, created_at) VALUES(?, ?, ?, ?, ?, ?, ?) """, (phone, direction, body_enc, mode, transport_state, metadata_enc, utc_now()), ) conn.commit() return int(cursor.lastrowid) def update_message_transport_state(self, message_id: int, transport_state: str): with self._connect() as conn: conn.execute( "UPDATE messages SET transport_state = ? WHERE id = ?", (transport_state, message_id), ) conn.commit() def update_message_body(self, message_id: int, body_enc: str, transport_state: str): with self._connect() as conn: conn.execute( "UPDATE messages SET body_enc = ?, transport_state = ? WHERE id = ?", (body_enc, transport_state, message_id), ) conn.commit() def get_message_row(self, message_id: int) -> Optional[sqlite3.Row]: with self._connect() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM messages WHERE id = ?", (message_id,)) return cursor.fetchone() def list_message_rows(self, phone: str) -> list[sqlite3.Row]: with self._connect() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM messages WHERE phone = ? ORDER BY id ASC", (phone,)) return cursor.fetchall() def log_secure_event(self, phone: Optional[str], event_type: str, details_enc: Optional[str]): with self._connect() as conn: conn.execute( "INSERT INTO secure_events(phone, event_type, details_enc, created_at) VALUES(?, ?, ?, ?)", (phone, event_type, details_enc, utc_now()), ) conn.commit() def list_secure_event_rows(self, limit: int = 50) -> list[sqlite3.Row]: with self._connect() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM secure_events ORDER BY id DESC LIMIT ?", (limit,)) return cursor.fetchall() def save_fragment( self, phone: str, packet_id: str, packet_kind: str, packet_mode: Optional[str], part_no: int, total_parts: int, chunk: str, ): with self._connect() as conn: conn.execute( """ INSERT INTO packet_fragments(phone, packet_id, packet_kind, packet_mode, part_no, total_parts, chunk, created_at) VALUES(?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(phone, packet_id, part_no) DO NOTHING """, (phone, packet_id, packet_kind, packet_mode, part_no, total_parts, chunk, utc_now()), ) conn.commit() def get_packet_fragments(self, phone: str, packet_id: str) -> list[sqlite3.Row]: with self._connect() as conn: cursor = conn.cursor() cursor.execute( """ SELECT * FROM packet_fragments WHERE phone = ? AND packet_id = ? ORDER BY part_no ASC """, (phone, packet_id), ) return cursor.fetchall() def delete_packet_fragments(self, phone: str, packet_id: str): with self._connect() as conn: conn.execute("DELETE FROM packet_fragments WHERE phone = ? AND packet_id = ?", (phone, packet_id)) conn.commit() def list_pending_packets(self) -> list[sqlite3.Row]: with self._connect() as conn: cursor = conn.cursor() cursor.execute( """ SELECT phone, packet_id, packet_kind, packet_mode, COUNT(*) AS received_parts, MAX(total_parts) AS total_parts, MIN(created_at) AS first_seen FROM packet_fragments GROUP BY phone, packet_id, packet_kind, packet_mode ORDER BY MIN(created_at) DESC """ ) return cursor.fetchall() def collect_stats(self) -> dict: with self._connect() as conn: cursor = conn.cursor() stats = {} cursor.execute("SELECT COUNT(*) AS count FROM contacts") stats["contacts"] = cursor.fetchone()["count"] cursor.execute("SELECT COUNT(*) AS count FROM contacts WHERE mode = 'secure'") stats["secure_contacts"] = cursor.fetchone()["count"] cursor.execute("SELECT COUNT(*) AS count FROM contacts WHERE secure_state = 'pending'") stats["pending_contacts"] = cursor.fetchone()["count"] cursor.execute("SELECT COUNT(*) AS count FROM messages") stats["messages"] = cursor.fetchone()["count"] cursor.execute("SELECT COUNT(*) AS count FROM messages WHERE mode = 'secure'") stats["secure_messages"] = cursor.fetchone()["count"] cursor.execute("SELECT COUNT(DISTINCT packet_id) AS count FROM packet_fragments") stats["incomplete_packets"] = cursor.fetchone()["count"] cursor.execute("SELECT COUNT(*) AS count FROM secure_events WHERE event_type = 'secure_established'") stats["secure_connections"] = cursor.fetchone()["count"] return stats def rotate_encrypted_payloads(self, old_cipher: StorageCipher, new_cipher: StorageCipher): table_map = { "contacts": ("phone", ["name_enc", "peer_public_key_enc", "symmetric_key_enc"]), "messages": ("id", ["body_enc", "metadata_enc"]), "secure_events": ("id", ["details_enc"]), "identity": ("id", ["private_key_enc", "public_key_enc"]), } with self._connect() as conn: cursor = conn.cursor() for table_name, (pk_column, encrypted_columns) in table_map.items(): cursor.execute(f"SELECT * FROM {table_name}") rows = cursor.fetchall() for row in rows: assignments = [] values = [] for column in encrypted_columns: current_value = row[column] if current_value is None: continue decrypted = old_cipher.decrypt_text(current_value) assignments.append(f"{column} = ?") values.append(new_cipher.encrypt_text(decrypted)) if not assignments: continue values.append(row[pk_column]) cursor.execute( f"UPDATE {table_name} SET {', '.join(assignments)} WHERE {pk_column} = ?", values, ) conn.commit()