import Foundation import CryptoKit /// Ratchet header sent with each message struct RatchetHeader { let dhPub: Data // sender's current ratchet public key (32 bytes) let n: Int // message number in current sending chain let pn: Int // number of messages in previous sending chain /// Serialize header to JSON bytes for use as AAD. /// Matches Python: RatchetHeader.serialize() /// IMPORTANT: Must match Python's json.dumps() format exactly (with spaces after : and ,) func serialize() -> Data { // Python json.dumps produces: {"dh_pub": "...", "n": 0, "pn": 0} // Note the spaces after colons and commas - this is critical for AAD matching let jsonString = "{\"dh_pub\": \"\(dhPub.hexString)\", \"n\": \(n), \"pn\": \(pn)}" return jsonString.data(using: .utf8)! } /// Convert to dictionary for protocol. /// Matches Python: RatchetHeader.to_dict() func toDict() -> [String: Any] { [ "dh_pub": dhPub.hexString, "n": n, "pn": pn, ] } /// Parse from dictionary. /// Matches Python: RatchetHeader.from_dict(d) static func fromDict(_ d: [String: Any]) throws -> RatchetHeader { guard let dhPubHex = d["dh_pub"] as? String, let dhPub = Data(hexString: dhPubHex), let n = d["n"] as? Int, let pn = d["pn"] as? Int else { throw CryptoError.invalidHeader("Missing or invalid header fields") } return RatchetHeader(dhPub: dhPub, n: n, pn: pn) } } /// Signal Double Ratchet implementation. /// Matches Python: DoubleRatchet class in crypto_utils.py class DoubleRatchet { private(set) var dhPair: (privateKey: Curve25519.KeyAgreement.PrivateKey, publicKey: Curve25519.KeyAgreement.PublicKey)? private(set) var dhRemote: Curve25519.KeyAgreement.PublicKey? private(set) var rootKey: Data = Data() private(set) var sendChainKey: Data? private(set) var recvChainKey: Data? private(set) var sendN: Int = 0 private(set) var recvN: Int = 0 private(set) var prevSendN: Int = 0 // Skipped message keys: "dh_pub_hex:n" → message_key private(set) var skipped: [String: Data] = [:] /// Attached X3DH header — set when creating a new session, consumed on first send. /// Matches Python: ratchet._x3dh_header var x3dhHeader: [String: Any]? init() {} // MARK: - Initialization /// Initialize as initiator (Alice) after X3DH. /// Matches Python: DoubleRatchet.init_alice(shared_secret, bob_spk_pub) static func initAlice(sharedSecret: Data, bobSpkPub: Curve25519.KeyAgreement.PublicKey) throws -> DoubleRatchet { let ratchet = DoubleRatchet() let (priv, pub) = X25519Crypto.generateKeypair() ratchet.dhPair = (priv, pub) ratchet.dhRemote = bobSpkPub // Debug: print ratchet inputs (matching Python _dh_ratchet) #if DEBUG print("DEBUG initAlice: shared_secret (root_key) = \(sharedSecret.hexString)") print("DEBUG initAlice: my_dh_pub = \(X25519Crypto.serializePublic(pub).hexString)") print("DEBUG initAlice: remote_dh_pub (bob_spk) = \(X25519Crypto.serializePublic(bobSpkPub).hexString)") #endif // Perform DH ratchet to derive send chain let dhOutput = try X25519Crypto.dh(priv, bobSpkPub) let (newRK, sendCK) = CryptoUtils.kdfRK(rootKey: sharedSecret, dhOutput: dhOutput) #if DEBUG print("DEBUG initAlice: dh_output = \(dhOutput.hexString)") print("DEBUG initAlice: new_root_key = \(newRK.hexString)") print("DEBUG initAlice: send_chain_key = \(sendCK.hexString)") #endif ratchet.rootKey = newRK ratchet.sendChainKey = sendCK ratchet.recvChainKey = nil ratchet.sendN = 0 ratchet.recvN = 0 ratchet.prevSendN = 0 return ratchet } /// Initialize as responder (Bob) after X3DH. /// Matches Python: DoubleRatchet.init_bob(shared_secret, spk_pair) static func initBob( sharedSecret: Data, spkPair: (privateKey: Curve25519.KeyAgreement.PrivateKey, publicKey: Curve25519.KeyAgreement.PublicKey) ) -> DoubleRatchet { let ratchet = DoubleRatchet() ratchet.dhPair = spkPair ratchet.rootKey = sharedSecret ratchet.sendChainKey = nil ratchet.recvChainKey = nil ratchet.sendN = 0 ratchet.recvN = 0 ratchet.prevSendN = 0 return ratchet } // MARK: - Encrypt /// Encrypt a message. /// Returns (header dict, ciphertext with tag, nonce). /// Matches Python: DoubleRatchet.encrypt(plaintext) func encrypt(_ plaintext: Data) throws -> (header: [String: Any], ciphertext: Data, nonce: Data) { guard sendChainKey != nil else { throw CryptoError.ratchetError("Send chain not initialized") } guard let dhPair = dhPair else { throw CryptoError.ratchetError("DH pair not set") } let (newCK, messageKey) = CryptoUtils.kdfCK(chainKey: sendChainKey!) sendChainKey = newCK let header = RatchetHeader( dhPub: X25519Crypto.serializePublic(dhPair.publicKey), n: sendN, pn: prevSendN ) let nonce = Data.randomBytes(12) let aad = header.serialize() // Debug: print encrypt values (matching Python decrypt) #if DEBUG print("DEBUG encrypt: message_key = \(messageKey.hexString)") print("DEBUG encrypt: aad = \(aad.hexString)") print("DEBUG encrypt: aad_str = \(String(data: aad, encoding: .utf8) ?? "nil")") print("DEBUG encrypt: nonce = \(nonce.hexString)") #endif let ctWithTag = try CryptoUtils.aesGcmEncrypt(plaintext, key: messageKey, nonce: nonce, aad: aad) #if DEBUG print("DEBUG encrypt: ciphertext_len = \(ctWithTag.count)") #endif sendN += 1 return (header.toDict(), ctWithTag, nonce) } // MARK: - Decrypt /// Decrypt a message. Handles DH ratchet step if new dh_pub. /// State is snapshotted before modification and restored on failure (M9 fix). /// Matches Python: DoubleRatchet.decrypt(header_dict, ciphertext, nonce) func decrypt(headerDict: [String: Any], ciphertext: Data, nonce: Data) throws -> Data { let header = try RatchetHeader.fromDict(headerDict) let remoteDhPubBytes = header.dhPub // Check if this is from a skipped message let skipKey = "\(remoteDhPubBytes.hexString):\(header.n)" if let mk = skipped[skipKey] { skipped.removeValue(forKey: skipKey) let aad = header.serialize() do { return try CryptoUtils.aesGcmDecrypt(ciphertext, key: mk, nonce: nonce, aad: aad) } catch { // Restore skipped key on failure skipped[skipKey] = mk throw error } } // Snapshot state before modifications let snap = snapshot() do { let remoteDhPub = try X25519Crypto.loadPublic(remoteDhPubBytes) let currentRemoteBytes: Data? = dhRemote.map { X25519Crypto.serializePublic($0) } if currentRemoteBytes == nil || remoteDhPubBytes != currentRemoteBytes { // New DH ratchet step try skipMessages(until: header.pn) try dhRatchet(remoteDhPub: remoteDhPub) } try skipMessages(until: header.n) // Derive message key from receive chain guard recvChainKey != nil else { throw CryptoError.ratchetError("Receive chain key is nil") } let (newCK, mk) = CryptoUtils.kdfCK(chainKey: recvChainKey!) recvChainKey = newCK recvN += 1 let aad = header.serialize() return try CryptoUtils.aesGcmDecrypt(ciphertext, key: mk, nonce: nonce, aad: aad) } catch { restore(snap) throw error } } // MARK: - State Snapshot/Restore (M9) private struct Snapshot { let dhPairPriv: Data? let dhPairPub: Data? let dhRemote: Data? let rootKey: Data let sendChainKey: Data? let recvChainKey: Data? let sendN: Int let recvN: Int let prevSendN: Int let skipped: [String: Data] } private func snapshot() -> Snapshot { Snapshot( dhPairPriv: dhPair.map { X25519Crypto.serializePrivate($0.privateKey) }, dhPairPub: dhPair.map { X25519Crypto.serializePublic($0.publicKey) }, dhRemote: dhRemote.map { X25519Crypto.serializePublic($0) }, rootKey: rootKey, sendChainKey: sendChainKey, recvChainKey: recvChainKey, sendN: sendN, recvN: recvN, prevSendN: prevSendN, skipped: skipped ) } private func restore(_ snap: Snapshot) { if let privData = snap.dhPairPriv, let pubData = snap.dhPairPub, let priv = try? X25519Crypto.loadPrivate(privData), let pub = try? X25519Crypto.loadPublic(pubData) { dhPair = (priv, pub) } else { dhPair = nil } if let remoteData = snap.dhRemote, let remote = try? X25519Crypto.loadPublic(remoteData) { dhRemote = remote } else { dhRemote = nil } rootKey = snap.rootKey sendChainKey = snap.sendChainKey recvChainKey = snap.recvChainKey sendN = snap.sendN recvN = snap.recvN prevSendN = snap.prevSendN skipped = snap.skipped } // MARK: - Internal Ratchet Operations private func skipMessages(until: Int) throws { guard recvChainKey != nil else { return } if until - recvN > Constants.maxSkip { throw CryptoError.maxSkipExceeded } while recvN < until { let (newCK, mk) = CryptoUtils.kdfCK(chainKey: recvChainKey!) recvChainKey = newCK let remoteHex = dhRemote.map { X25519Crypto.serializePublic($0).hexString } ?? "" skipped["\(remoteHex):\(recvN)"] = mk recvN += 1 } } private func dhRatchet(remoteDhPub: Curve25519.KeyAgreement.PublicKey) throws { prevSendN = sendN sendN = 0 recvN = 0 dhRemote = remoteDhPub // Derive new receive chain key guard let dhPair = dhPair else { throw CryptoError.ratchetError("DH pair not set") } let dhOutput1 = try X25519Crypto.dh(dhPair.privateKey, remoteDhPub) let (newRK1, recvCK) = CryptoUtils.kdfRK(rootKey: rootKey, dhOutput: dhOutput1) rootKey = newRK1 recvChainKey = recvCK // Generate new DH pair and derive new send chain key let (newPriv, newPub) = X25519Crypto.generateKeypair() self.dhPair = (newPriv, newPub) let dhOutput2 = try X25519Crypto.dh(newPriv, remoteDhPub) let (newRK2, sendCK) = CryptoUtils.kdfRK(rootKey: rootKey, dhOutput: dhOutput2) rootKey = newRK2 sendChainKey = sendCK } // MARK: - State Export/Import /// Serialize full ratchet state for persistent storage. /// Produces JSON matching Python's DoubleRatchet.export_state() exactly. func exportState() throws -> Data { var state: [String: Any] = [:] if let pair = dhPair { state["dh_priv"] = X25519Crypto.serializePrivate(pair.privateKey).hexString state["dh_pub"] = X25519Crypto.serializePublic(pair.publicKey).hexString } else { state["dh_priv"] = NSNull() state["dh_pub"] = NSNull() } if let remote = dhRemote { state["dh_remote"] = X25519Crypto.serializePublic(remote).hexString } else { state["dh_remote"] = NSNull() } state["root_key"] = rootKey.hexString state["send_ck"] = sendChainKey?.hexString ?? NSNull() state["recv_ck"] = recvChainKey?.hexString ?? NSNull() state["send_n"] = sendN state["recv_n"] = recvN state["prev_send_n"] = prevSendN // Skipped keys: Python format is "dh_pub_hex:n" -> message_key_hex var skippedDict: [String: String] = [:] for (key, value) in skipped { skippedDict[key] = value.hexString } state["skipped"] = skippedDict return try JSONSerialization.data(withJSONObject: state) } /// Deserialize ratchet state. /// Matches Python: DoubleRatchet.import_state(data) static func importState(_ data: Data) throws -> DoubleRatchet { guard let state = try JSONSerialization.jsonObject(with: data) as? [String: Any] else { throw CryptoError.stateImportFailed("Invalid JSON") } let r = DoubleRatchet() if let dhPrivHex = state["dh_priv"] as? String, let dhPubHex = state["dh_pub"] as? String, let privData = Data(hexString: dhPrivHex), let pubData = Data(hexString: dhPubHex) { let priv = try X25519Crypto.loadPrivate(privData) let pub = try X25519Crypto.loadPublic(pubData) r.dhPair = (priv, pub) } if let dhRemoteHex = state["dh_remote"] as? String, let remoteData = Data(hexString: dhRemoteHex) { r.dhRemote = try X25519Crypto.loadPublic(remoteData) } guard let rootKeyHex = state["root_key"] as? String, let rootKey = Data(hexString: rootKeyHex) else { throw CryptoError.stateImportFailed("Missing root_key") } r.rootKey = rootKey if let sendCKHex = state["send_ck"] as? String, let ck = Data(hexString: sendCKHex) { r.sendChainKey = ck } if let recvCKHex = state["recv_ck"] as? String, let ck = Data(hexString: recvCKHex) { r.recvChainKey = ck } r.sendN = state["send_n"] as? Int ?? 0 r.recvN = state["recv_n"] as? Int ?? 0 r.prevSendN = state["prev_send_n"] as? Int ?? 0 if let skippedDict = state["skipped"] as? [String: String] { for (key, valueHex) in skippedDict { if let value = Data(hexString: valueHex) { r.skipped[key] = value } } } return r } }