import os import base64 from cryptography.hazmat.primitives.asymmetric import x25519 from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.kdf.hkdf import HKDF from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.ciphers.aead import AESGCM # Constants PRIVATE_KEY_FILE = "private_key.pem" PUBLIC_KEY_FILE = "public_key.pem" class CryptoEngine: def __init__(self): self.private_key = None self.public_key = None self.load_or_generate_keys() def load_or_generate_keys(self): """Loads keys from disk or generates a new pair if they don't exist.""" if os.path.exists(PRIVATE_KEY_FILE) and os.path.exists(PUBLIC_KEY_FILE): with open(PRIVATE_KEY_FILE, "rb") as f: self.private_key = serialization.load_pem_private_key( f.read(), password=None ) with open(PUBLIC_KEY_FILE, "rb") as f: self.public_key = serialization.load_pem_public_key(f.read()) else: self.generate_keypair() def generate_keypair(self): """Generates a new X25519 keypair and saves it to disk.""" self.private_key = x25519.X25519PrivateKey.generate() self.public_key = self.private_key.public_key() # Save private key with open(PRIVATE_KEY_FILE, "wb") as f: f.write(self.private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption() )) # Save public key with open(PUBLIC_KEY_FILE, "wb") as f: f.write(self.public_key.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo )) def get_my_public_key_pem(self): """Returns my public key as string.""" return self.public_key.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo ).decode('utf-8') def encrypt_message(self, message: str, peer_public_key_pem: str) -> str: """ Encrypts a message using X25519 exchange + HKDF + AESGCM. Returns a base64 encoded string containing ephemeral public key, IV, and ciphertext. """ if not peer_public_key_pem: raise ValueError("Peer public key is empty.") peer_public_key = serialization.load_pem_public_key(peer_public_key_pem.encode('utf-8')) # Generate an ephemeral keypair for this message to provide forward secrecy ephemeral_private_key = x25519.X25519PrivateKey.generate() ephemeral_public_key = ephemeral_private_key.public_key() ephemeral_pub_bytes = ephemeral_public_key.public_bytes( encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw ) # Perform key exchange shared_key = ephemeral_private_key.exchange(peer_public_key) # Derive a symmetric key using HKDF derived_key = HKDF( algorithm=hashes.SHA256(), length=32, salt=None, info=b'sms-secure-encryption' ).derive(shared_key) # Encrypt with AES-GCM aesgcm = AESGCM(derived_key) nonce = os.urandom(12) ciphertext = aesgcm.encrypt(nonce, message.encode('utf-8'), None) # Construct final payload: [Ephemeral Pub Key (32)] + [Nonce (12)] + [Ciphertext + Tag] payload = ephemeral_pub_bytes + nonce + ciphertext # We prepend a marker to easily identify secure messages marker = "SEC:" b64_payload = base64.b64encode(payload).decode('ascii') return marker + b64_payload def decrypt_message(self, secure_payload: str) -> str: """ Decrypts a secure payload. Assumes it starts with 'SEC:'. """ if not secure_payload.startswith("SEC:"): raise ValueError("Not a secure message format.") b64_payload = secure_payload[4:] try: payload = base64.b64decode(b64_payload) except Exception: raise ValueError("Invalid Base64 payload.") if len(payload) < 32 + 12 + 16: # Length of pub key + nonce + tag raise ValueError("Payload too short.") ephemeral_pub_bytes = payload[:32] nonce = payload[32:44] ciphertext = payload[44:] ephemeral_public_key = x25519.X25519PublicKey.from_public_bytes(ephemeral_pub_bytes) # Key exchange shared_key = self.private_key.exchange(ephemeral_public_key) # Derive symmetric key derived_key = HKDF( algorithm=hashes.SHA256(), length=32, salt=None, info=b'sms-secure-encryption' ).derive(shared_key) # Decrypt aesgcm = AESGCM(derived_key) try: plaintext = aesgcm.decrypt(nonce, ciphertext, None) return plaintext.decode('utf-8') except Exception as e: raise ValueError("Decryption failed. Invalid key or message tampered.") from e if __name__ == "__main__": # Test alice = CryptoEngine() alice_pub = alice.get_my_public_key_pem() bob = CryptoEngine() bob_pub = bob.get_my_public_key_pem() msg = "This is a highly secret message!" # Alice sends to Bob encrypted = alice.encrypt_message(msg, bob_pub) print("Encrypted:", encrypted) # Bob decrypts decrypted = bob.decrypt_message(encrypted) print("Decrypted:", decrypted) assert msg == decrypted print("Crypto Engine OK.")