package com.kecalek.chat.crypto import org.bouncycastle.crypto.params.X25519PrivateKeyParameters import org.bouncycastle.crypto.params.X25519PublicKeyParameters import org.json.JSONObject import java.security.SecureRandom /** * Double Ratchet algorithm for end-to-end encrypted messaging. * Provides forward secrecy and break-in recovery. * * Compatible with Python DoubleRatchet class from crypto_utils.py. * * State: * - dh_pair: current ratchet X25519 keypair * - dh_remote: remote's current ratchet public key * - root_key: 32-byte root key * - send_chain_key / recv_chain_key: current chain keys * - send_n / recv_n: message counters * - prev_send_n: previous sending chain length * - skipped: map of (dh_hex, n) -> message_key for out-of-order delivery */ class DoubleRatchet private constructor() { private lateinit var dhPrivate: X25519PrivateKeyParameters private lateinit var dhPublic: X25519PublicKeyParameters private var dhRemote: X25519PublicKeyParameters? = null private lateinit var rootKey: ByteArray private var sendChainKey: ByteArray? = null private var recvChainKey: ByteArray? = null private var sendN: Int = 0 private var recvN: Int = 0 private var prevSendN: Int = 0 // skipped[(remotePublicHex, messageNumber)] = messageKey private val skipped = mutableMapOf() companion object { private const val MAX_SKIP = 256 /** * Initialize as Alice (initiator). * Called after X3DH produces a shared secret. * * @param sharedSecret X3DH shared secret * @param bobSpkPub Bob's signed pre-key public (used as initial remote ratchet key) */ fun initAlice(sharedSecret: ByteArray, bobSpkPub: X25519PublicKeyParameters): DoubleRatchet { val ratchet = DoubleRatchet() // Generate initial ratchet keypair val (dhPriv, dhPub) = X25519Crypto.generateKeypair() ratchet.dhPrivate = dhPriv ratchet.dhPublic = dhPub ratchet.dhRemote = bobSpkPub // Initial DH ratchet step val dhOutput = X25519Crypto.dh(dhPriv, bobSpkPub) val (newRootKey, sendChainKey) = HkdfUtils.kdfRk(sharedSecret, dhOutput) ratchet.rootKey = newRootKey ratchet.sendChainKey = sendChainKey ratchet.recvChainKey = null ratchet.sendN = 0 ratchet.recvN = 0 ratchet.prevSendN = 0 return ratchet } /** * Initialize as Bob (responder). * Uses SPK pair as initial ratchet key. * * @param sharedSecret X3DH shared secret * @param spkPair Bob's signed pre-key pair (private, public) */ fun initBob( sharedSecret: ByteArray, spkPair: Pair, ): DoubleRatchet { val ratchet = DoubleRatchet() ratchet.dhPrivate = spkPair.first ratchet.dhPublic = spkPair.second ratchet.rootKey = sharedSecret ratchet.sendChainKey = null ratchet.recvChainKey = null ratchet.sendN = 0 ratchet.recvN = 0 ratchet.prevSendN = 0 return ratchet } /** * Import ratchet state from JSON bytes. */ fun importState(data: ByteArray): DoubleRatchet { val json = JSONObject(String(data)) val ratchet = DoubleRatchet() ratchet.dhPrivate = X25519Crypto.loadPrivate(json.getString("dh_priv").hexToBytes()) ratchet.dhPublic = X25519Crypto.loadPublic(json.getString("dh_pub").hexToBytes()) if (json.has("dh_remote") && !json.isNull("dh_remote")) { ratchet.dhRemote = X25519Crypto.loadPublic(json.getString("dh_remote").hexToBytes()) } ratchet.rootKey = json.getString("root_key").hexToBytes() ratchet.sendChainKey = json.optString("send_ck", "").takeIf { it.isNotEmpty() }?.hexToBytes() ratchet.recvChainKey = json.optString("recv_ck", "").takeIf { it.isNotEmpty() }?.hexToBytes() ratchet.sendN = json.getInt("send_n") ratchet.recvN = json.getInt("recv_n") ratchet.prevSendN = json.getInt("prev_send_n") // Import skipped keys if (json.has("skipped")) { val skippedJson = json.getJSONObject("skipped") for (key in skippedJson.keys()) { ratchet.skipped[key] = skippedJson.getString(key).hexToBytes() } } return ratchet } } /** * Encrypt plaintext message. * @return RatchetMessage with header dict, ciphertext+tag, nonce */ fun encrypt(plaintext: ByteArray): RatchetMessage { val ck = sendChainKey ?: throw CryptoException.DecryptionFailed("Send chain not initialized") val (newChainKey, messageKey) = HkdfUtils.kdfCk(ck) sendChainKey = newChainKey val header = RatchetHeader( dhPub = X25519Crypto.serializePublic(dhPublic), n = sendN, pn = prevSendN, ) val aad = header.serialize() val (nonce, ctWithTag) = AesGcmCrypto.encryptCombined( plaintext = plaintext, key = messageKey, aad = aad, ) sendN++ return RatchetMessage( header = header, ciphertext = ctWithTag, nonce = nonce, ) } /** * Decrypt received message. * Handles out-of-order delivery via skipped message keys. * Full state rollback on failure. */ fun decrypt(header: RatchetHeader, ciphertext: ByteArray, nonce: ByteArray): ByteArray { val aad = header.serialize() // Check skipped message keys first (no state change) val skippedKey = makeSkippedKey(header.dhPub.toHex(), header.n) skipped.remove(skippedKey)?.let { messageKey -> return AesGcmCrypto.decryptCombined( key = messageKey, nonce = nonce, ctWithTag = ciphertext, aad = aad, ) } // Take snapshot for rollback val snapshot = snapshot() try { // New DH ratchet step if remote key changed val remoteHex = header.dhPub.toHex() val currentRemoteHex = dhRemote?.let { X25519Crypto.serializePublic(it).toHex() } if (remoteHex != currentRemoteHex) { skipMessages(header.pn) dhRatchet(X25519Crypto.loadPublic(header.dhPub)) } skipMessages(header.n) val ck = recvChainKey ?: throw CryptoException.DecryptionFailed("Receive chain not initialized") val (newChainKey, messageKey) = HkdfUtils.kdfCk(ck) recvChainKey = newChainKey recvN++ return AesGcmCrypto.decryptCombined( key = messageKey, nonce = nonce, ctWithTag = ciphertext, aad = aad, ) } catch (e: Exception) { // Rollback on any failure restore(snapshot) throw if (e is CryptoException) e else CryptoException.DecryptionFailed("Decryption failed", e) } } /** * Export full ratchet state as JSON bytes. */ fun exportState(): ByteArray { val json = JSONObject() json.put("dh_priv", X25519Crypto.serializePrivate(dhPrivate).toHex()) json.put("dh_pub", X25519Crypto.serializePublic(dhPublic).toHex()) json.put("dh_remote", dhRemote?.let { X25519Crypto.serializePublic(it).toHex() }) json.put("root_key", rootKey.toHex()) json.put("send_ck", sendChainKey?.toHex()) json.put("recv_ck", recvChainKey?.toHex()) json.put("send_n", sendN) json.put("recv_n", recvN) json.put("prev_send_n", prevSendN) val skippedJson = JSONObject() for ((key, value) in skipped) { skippedJson.put(key, value.toHex()) } json.put("skipped", skippedJson) return json.toString().toByteArray() } // --- Private helpers --- private fun skipMessages(until: Int) { if (recvChainKey == null) return if (until - recvN > MAX_SKIP) { throw CryptoException.MaxSkipExceeded("Cannot skip more than $MAX_SKIP messages") } var ck = recvChainKey!! while (recvN < until) { val (newCk, messageKey) = HkdfUtils.kdfCk(ck) ck = newCk val remoteHex = dhRemote?.let { X25519Crypto.serializePublic(it).toHex() } ?: "" skipped[makeSkippedKey(remoteHex, recvN)] = messageKey recvN++ } recvChainKey = ck } private fun dhRatchet(remotePublic: X25519PublicKeyParameters) { prevSendN = sendN sendN = 0 recvN = 0 dhRemote = remotePublic // Derive receive chain val dhOutput1 = X25519Crypto.dh(dhPrivate, remotePublic) val (rk1, recvCk) = HkdfUtils.kdfRk(rootKey, dhOutput1) rootKey = rk1 recvChainKey = recvCk // Generate new DH keypair and derive send chain val (newPriv, newPub) = X25519Crypto.generateKeypair() dhPrivate = newPriv dhPublic = newPub val dhOutput2 = X25519Crypto.dh(newPriv, remotePublic) val (rk2, sendCk) = HkdfUtils.kdfRk(rootKey, dhOutput2) rootKey = rk2 sendChainKey = sendCk } private data class Snapshot( val dhPriv: ByteArray, val dhPub: ByteArray, val dhRemote: ByteArray?, val rootKey: ByteArray, val sendCk: ByteArray?, val recvCk: ByteArray?, val sendN: Int, val recvN: Int, val prevSendN: Int, val skipped: Map, ) private fun snapshot(): Snapshot { return Snapshot( dhPriv = X25519Crypto.serializePrivate(dhPrivate), dhPub = X25519Crypto.serializePublic(dhPublic), dhRemote = dhRemote?.let { X25519Crypto.serializePublic(it) }, rootKey = rootKey.copyOf(), sendCk = sendChainKey?.copyOf(), recvCk = recvChainKey?.copyOf(), sendN = sendN, recvN = recvN, prevSendN = prevSendN, skipped = skipped.toMap(), ) } private fun restore(s: Snapshot) { dhPrivate = X25519Crypto.loadPrivate(s.dhPriv) dhPublic = X25519Crypto.loadPublic(s.dhPub) dhRemote = s.dhRemote?.let { X25519Crypto.loadPublic(it) } rootKey = s.rootKey sendChainKey = s.sendCk recvChainKey = s.recvCk sendN = s.sendN recvN = s.recvN prevSendN = s.prevSendN skipped.clear() skipped.putAll(s.skipped) } private fun makeSkippedKey(dhHex: String, n: Int): String = "$dhHex:$n" } /** * Ratchet message header. * Serialized as JSON: {"dh_pub": hex, "n": int, "pn": int} */ data class RatchetHeader( val dhPub: ByteArray, val n: Int, val pn: Int, ) { fun serialize(): ByteArray { val json = JSONObject() json.put("dh_pub", dhPub.toHex()) json.put("n", n) json.put("pn", pn) return json.toString().toByteArray() } fun toMap(): Map = mapOf( "dh_pub" to dhPub.toHex(), "n" to n, "pn" to pn, ) companion object { fun fromMap(map: Map): RatchetHeader { return RatchetHeader( dhPub = (map["dh_pub"] as String).hexToBytes(), n = (map["n"] as Number).toInt(), pn = (map["pn"] as Number).toInt(), ) } } override fun equals(other: Any?): Boolean { if (this === other) return true if (other !is RatchetHeader) return false return dhPub.contentEquals(other.dhPub) && n == other.n && pn == other.pn } override fun hashCode(): Int { var result = dhPub.contentHashCode() result = 31 * result + n result = 31 * result + pn return result } } data class RatchetMessage( val header: RatchetHeader, val ciphertext: ByteArray, val nonce: ByteArray, ) { override fun equals(other: Any?): Boolean { if (this === other) return true if (other !is RatchetMessage) return false return header == other.header && ciphertext.contentEquals(other.ciphertext) && nonce.contentEquals(other.nonce) } override fun hashCode(): Int { var result = header.hashCode() result = 31 * result + ciphertext.contentHashCode() result = 31 * result + nonce.contentHashCode() return result } } // --- Hex extension functions --- internal fun ByteArray.toHex(): String = joinToString("") { "%02x".format(it) } internal fun String.hexToBytes(): ByteArray { require(length % 2 == 0) { "Hex string must have even length" } return chunked(2).map { it.toInt(16).toByte() }.toByteArray() }