217 lines
7.5 KiB
Python
217 lines
7.5 KiB
Python
import base64
|
||
import hashlib
|
||
import os
|
||
import re
|
||
import unicodedata
|
||
from dataclasses import dataclass
|
||
from typing import Optional
|
||
|
||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||
from cryptography.hazmat.primitives.kdf.scrypt import Scrypt
|
||
|
||
|
||
_LEGACY_INVISIBLE_CHARS = dict.fromkeys(map(ord, "\u200c\u200d\u200e\u200f\ufeff"), None)
|
||
_ARABIC_VARIANT_TRANSLATION = str.maketrans(
|
||
{
|
||
"ك": "ک",
|
||
"ي": "ی",
|
||
"ى": "ی",
|
||
"ة": "ه",
|
||
"ۀ": "ه",
|
||
"٠": "0",
|
||
"١": "1",
|
||
"٢": "2",
|
||
"٣": "3",
|
||
"٤": "4",
|
||
"٥": "5",
|
||
"٦": "6",
|
||
"٧": "7",
|
||
"٨": "8",
|
||
"٩": "9",
|
||
"۰": "0",
|
||
"۱": "1",
|
||
"۲": "2",
|
||
"۳": "3",
|
||
"۴": "4",
|
||
"۵": "5",
|
||
"۶": "6",
|
||
"۷": "7",
|
||
"۸": "8",
|
||
"۹": "9",
|
||
}
|
||
)
|
||
|
||
|
||
def b64u_encode(data: bytes) -> str:
|
||
"""URL-safe Base64 encoding without padding, matching Flutter's implementation."""
|
||
return base64.urlsafe_b64encode(data).decode("ascii").rstrip("=")
|
||
|
||
|
||
def _decode_transport_payload(value: str) -> bytes:
|
||
"""Decode either legacy Base64URL payloads or the new hex transport format."""
|
||
if value.startswith("h1:"):
|
||
clean_hex = re.sub(r"[^0-9A-Fa-f]", "", value[3:])
|
||
if len(clean_hex) % 2 != 0:
|
||
clean_hex = clean_hex[:-1]
|
||
return bytes.fromhex(clean_hex) if clean_hex else b""
|
||
return b64u_decode(value)
|
||
|
||
|
||
def b64u_decode(value: str) -> bytes:
|
||
"""Universal Base64 decode with robust cleanup and padding repair.
|
||
Handles both standard and URL-safe Base64, ensuring both English (short)
|
||
and Persian (long) messages are decoded correctly.
|
||
"""
|
||
# 1. Strip all whitespace and modem artifacts
|
||
clean = "".join(value.split())
|
||
# 2. Normalize characters (URL-safe to Standard)
|
||
clean = clean.replace("-", "+").replace("_", "/")
|
||
# 3. Strip existing padding to avoid double-padding issues
|
||
clean = clean.split("=")[0]
|
||
# 4. Filter only valid Base64 characters (A-Z, a-z, 0-9, +, /)
|
||
clean = re.sub(r'[^A-Za-z0-9+/]', '', clean)
|
||
|
||
# 5. Add correct padding for decoding
|
||
padding_len = (4 - len(clean) % 4) % 4
|
||
padded = clean + ("=" * padding_len)
|
||
|
||
# DEBUG: Show what happened to the payload
|
||
# print(f"[B64] Original: {value[:16]}... Padded: {padded[:16]}... (len={len(padded)})")
|
||
|
||
try:
|
||
return base64.b64decode(padded.encode("ascii"))
|
||
except Exception as e:
|
||
print(f"[B64] ERROR decoding: {e}")
|
||
return b""
|
||
|
||
|
||
@dataclass
|
||
class SecurityMetadata:
|
||
salt: str
|
||
verifier: str
|
||
|
||
|
||
class PasswordManager:
|
||
def create_metadata(self, password: str) -> SecurityMetadata:
|
||
salt = os.urandom(16)
|
||
key = self.derive_key(password, b64u_encode(salt))
|
||
return SecurityMetadata(
|
||
salt=b64u_encode(salt),
|
||
verifier=hashlib.sha256(key).hexdigest(),
|
||
)
|
||
|
||
def derive_key(self, password: str, salt_b64: str) -> bytes:
|
||
kdf = Scrypt(
|
||
salt=b64u_decode(salt_b64),
|
||
length=32,
|
||
n=2**14,
|
||
r=8,
|
||
p=1,
|
||
)
|
||
return kdf.derive(password.encode("utf-8"))
|
||
|
||
def verify_password(self, password: str, meta: SecurityMetadata) -> bool:
|
||
key = self.derive_key(password, meta.salt)
|
||
return hashlib.sha256(key).hexdigest() == meta.verifier
|
||
|
||
|
||
class StorageCipher:
|
||
def __init__(self, key: bytes):
|
||
self._aes = AESGCM(key)
|
||
|
||
def encrypt_text(self, value: Optional[str]) -> Optional[str]:
|
||
if value is None:
|
||
return None
|
||
nonce = os.urandom(12)
|
||
payload = self._aes.encrypt(nonce, value.encode("utf-8"), None)
|
||
return "enc1:" + b64u_encode(nonce + payload)
|
||
|
||
def decrypt_text(self, value: Optional[str]) -> Optional[str]:
|
||
if value in (None, ""):
|
||
return value
|
||
if not value.startswith("enc1:"):
|
||
return value
|
||
raw = b64u_decode(value[5:])
|
||
nonce = raw[:12]
|
||
ciphertext = raw[12:]
|
||
plaintext = self._aes.decrypt(nonce, ciphertext, None)
|
||
return plaintext.decode("utf-8")
|
||
|
||
|
||
class SymmetricCryptoService:
|
||
"""
|
||
Symmetric encryption service using AES-GCM-256.
|
||
Matches the Flutter implementation while keeping compatibility with
|
||
older Python builds that normalized keys before hashing.
|
||
"""
|
||
|
||
def _clean_key_variants(self, password: str) -> list[tuple[str, str]]:
|
||
raw = password.strip()
|
||
legacy_nfc = unicodedata.normalize("NFC", raw)
|
||
visual_safe = raw.translate(_LEGACY_INVISIBLE_CHARS).translate(_ARABIC_VARIANT_TRANSLATION)
|
||
legacy_nfc_visual_safe = legacy_nfc.translate(_LEGACY_INVISIBLE_CHARS).translate(_ARABIC_VARIANT_TRANSLATION)
|
||
|
||
variants: list[tuple[str, str]] = []
|
||
seen: set[str] = set()
|
||
for label, value in (
|
||
("flutter_raw", raw),
|
||
("legacy_python_nfc", legacy_nfc),
|
||
("visual_safe", visual_safe),
|
||
("legacy_python_nfc_visual_safe", legacy_nfc_visual_safe),
|
||
):
|
||
if value not in seen:
|
||
variants.append((label, value))
|
||
seen.add(value)
|
||
return variants
|
||
|
||
def _derive_symmetric_key_from_text(self, key_text: str, *, debug: bool = False) -> bytes:
|
||
key = hashlib.sha256(key_text.encode("utf-8")).digest()
|
||
if debug:
|
||
print(f"[Crypto] Derived key fingerprint: {key.hex()[:4]}...")
|
||
return key
|
||
|
||
def _derive_symmetric_key(self, password: str) -> bytes:
|
||
"""Derive the primary 32-byte key exactly like Flutter: SHA-256(trimmed text)."""
|
||
raw_trimmed = password.strip()
|
||
return self._derive_symmetric_key_from_text(raw_trimmed, debug=True)
|
||
|
||
def encrypt_symmetric(self, message: str, password: str) -> str:
|
||
"""Encrypt message using AES-GCM with a password-derived key."""
|
||
key = self._derive_symmetric_key(password)
|
||
aesgcm = AESGCM(key)
|
||
nonce = os.urandom(12)
|
||
ciphertext = aesgcm.encrypt(nonce, message.encode("utf-8"), None)
|
||
# SMS-safe transport: explicit hex payload, still backward-compatible on decode.
|
||
return "h1:" + (nonce + ciphertext).hex()
|
||
|
||
def decrypt_symmetric(self, payload_b64: str, password: str) -> str:
|
||
"""Decrypt message using AES-GCM with a password-derived key."""
|
||
payload = _decode_transport_payload(payload_b64)
|
||
|
||
if len(payload) < 28:
|
||
print(f"[Symmetric] ERROR: Payload too short ({len(payload)})")
|
||
raise ValueError("Symmetric payload is too short.")
|
||
|
||
nonce = payload[:12]
|
||
ciphertext_with_tag = payload[12:]
|
||
|
||
last_error: Optional[Exception] = None
|
||
tried_labels: list[str] = []
|
||
for label, key_text in self._clean_key_variants(password):
|
||
tried_labels.append(label)
|
||
key = self._derive_symmetric_key_from_text(key_text, debug=(label == "flutter_raw"))
|
||
aesgcm = AESGCM(key)
|
||
try:
|
||
plaintext = aesgcm.decrypt(nonce, ciphertext_with_tag, None)
|
||
if label != "flutter_raw":
|
||
print(f"[Symmetric] Compatibility decrypt succeeded using: {label}")
|
||
return plaintext.decode("utf-8")
|
||
except Exception as exc:
|
||
last_error = exc
|
||
|
||
print(f"[Symmetric] Decryption FAILED after trying: {', '.join(tried_labels)}")
|
||
print("[Symmetric] Check: Key and payload must match bit-perfectly.")
|
||
if last_error is not None:
|
||
raise last_error
|
||
raise ValueError("Symmetric decryption failed.")
|