Saba-python/secure_sms/core/security.py

134 lines
4.8 KiB
Python

import base64
import hashlib
import os
from dataclasses import dataclass
from typing import Optional
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import x25519
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.kdf.scrypt import Scrypt
def b64u_encode(data: bytes) -> str:
return base64.urlsafe_b64encode(data).decode("ascii").rstrip("=")
def b64u_decode(value: str) -> bytes:
padding = "=" * (-len(value) % 4)
return base64.urlsafe_b64decode((value + padding).encode("ascii"))
@dataclass
class SecurityMetadata:
salt: str
verifier: str
class PasswordManager:
def create_metadata(self, password: str) -> SecurityMetadata:
salt = os.urandom(16)
key = self.derive_key(password, b64u_encode(salt))
return SecurityMetadata(
salt=b64u_encode(salt),
verifier=hashlib.sha256(key).hexdigest(),
)
def derive_key(self, password: str, salt_b64: str) -> bytes:
kdf = Scrypt(
salt=b64u_decode(salt_b64),
length=32,
n=2**14,
r=8,
p=1,
)
return kdf.derive(password.encode("utf-8"))
def verify_password(self, password: str, meta: SecurityMetadata) -> bool:
key = self.derive_key(password, meta.salt)
return hashlib.sha256(key).hexdigest() == meta.verifier
class StorageCipher:
def __init__(self, key: bytes):
self._aes = AESGCM(key)
def encrypt_text(self, value: Optional[str]) -> Optional[str]:
if value is None:
return None
nonce = os.urandom(12)
payload = self._aes.encrypt(nonce, value.encode("utf-8"), None)
return "enc1:" + b64u_encode(nonce + payload)
def decrypt_text(self, value: Optional[str]) -> Optional[str]:
if value in (None, ""):
return value
if not value.startswith("enc1:"):
return value
raw = b64u_decode(value[5:])
nonce = raw[:12]
ciphertext = raw[12:]
plaintext = self._aes.decrypt(nonce, ciphertext, None)
return plaintext.decode("utf-8")
class ECCCryptoService:
INFO = b"sms-secure-channel-v2"
def generate_identity(self) -> tuple[str, str, str]:
private_key = x25519.X25519PrivateKey.generate()
public_key = private_key.public_key()
private_raw = private_key.private_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PrivateFormat.Raw,
encryption_algorithm=serialization.NoEncryption(),
)
public_raw = public_key.public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
)
public_b64 = b64u_encode(public_raw)
return b64u_encode(private_raw), public_b64, self.fingerprint_public_key(public_b64)
def fingerprint_public_key(self, public_key_b64: str) -> str:
digest = hashlib.sha256(b64u_decode(public_key_b64)).hexdigest()
return digest[:16].upper()
def encrypt_for_peer(self, message: str, peer_public_key_b64: str) -> str:
peer_public = x25519.X25519PublicKey.from_public_bytes(b64u_decode(peer_public_key_b64))
ephemeral_private = x25519.X25519PrivateKey.generate()
ephemeral_public = ephemeral_private.public_key().public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
)
shared_key = ephemeral_private.exchange(peer_public)
derived_key = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=None,
info=self.INFO,
).derive(shared_key)
nonce = os.urandom(12)
ciphertext = AESGCM(derived_key).encrypt(nonce, message.encode("utf-8"), None)
return b64u_encode(ephemeral_public + nonce + ciphertext)
def decrypt_from_peer(self, payload_b64: str, private_key_b64: str) -> str:
payload = b64u_decode(payload_b64)
if len(payload) < 60:
raise ValueError("Secure payload is too short.")
ephemeral_public_raw = payload[:32]
nonce = payload[32:44]
ciphertext = payload[44:]
private_key = x25519.X25519PrivateKey.from_private_bytes(b64u_decode(private_key_b64))
ephemeral_public = x25519.X25519PublicKey.from_public_bytes(ephemeral_public_raw)
shared_key = private_key.exchange(ephemeral_public)
derived_key = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=None,
info=self.INFO,
).derive(shared_key)
plaintext = AESGCM(derived_key).decrypt(nonce, ciphertext, None)
return plaintext.decode("utf-8")