import Foundation import CryptoKit /// Sender key chain for group messaging. /// Each sender in a group has their own chain. Others receive the initial key via pairwise ratchet. /// Matches Python: SenderKeyState class in crypto_utils.py class SenderKeyState { let senderKey: Data let chainId: Data private(set) var chainKey: Data private(set) var n: Int private var knownKeys: [Int: Data] /// Initialize with optional sender key (generates random 32B if nil). /// Matches Python: SenderKeyState.__init__(sender_key=None) init(senderKey: Data? = nil) { let key = senderKey ?? Data.randomBytes(32) self.senderKey = key self.chainId = Data(SHA256.hash(data: key)) self.chainKey = CryptoUtils.hkdfDerive( inputKey: key, salt: Data(repeating: 0x00, count: 32), info: Data(Constants.senderKeyChainInfo.utf8), length: 32 ) self.n = 0 self.knownKeys = [:] } /// Private init for import private init(senderKey: Data, chainId: Data, chainKey: Data, n: Int, knownKeys: [Int: Data]) { self.senderKey = senderKey self.chainId = chainId self.chainKey = chainKey self.n = n self.knownKeys = knownKeys } // MARK: - Encrypt /// Encrypt with current chain key. /// Returns (chainId hex, n, ciphertext with tag, nonce). /// Matches Python: SenderKeyState.encrypt(plaintext) func encrypt(_ plaintext: Data) throws -> (chainIdHex: String, n: Int, ciphertext: Data, nonce: Data) { let (newCK, messageKey) = CryptoUtils.kdfCK(chainKey: chainKey) chainKey = newCK let nonce = Data.randomBytes(12) // AAD = chainId + bigEndian(UInt32(n)) let aad = chainId + UInt32(n).bigEndianData let ctWithTag = try CryptoUtils.aesGcmEncrypt(plaintext, key: messageKey, nonce: nonce, aad: aad) let result = (chainIdHex: chainId.hexString, n: n, ciphertext: ctWithTag, nonce: nonce) n += 1 return result } // MARK: - Decrypt /// Decrypt a group message. Fast-forwards the chain if needed. /// State is snapshotted before modification and restored on failure. /// Matches Python: SenderKeyState.decrypt(chain_id_hex, n, ciphertext, nonce) func decrypt(chainIdHex: String, n: Int, ciphertext: Data, nonce: Data) throws -> Data { guard let expectedChainId = Data(hexString: chainIdHex) else { throw CryptoError.senderKeyError("Invalid chain ID hex") } guard expectedChainId == chainId else { throw CryptoError.senderKeyError("Chain ID mismatch") } if n - self.n > Constants.maxSenderKeySkip { throw CryptoError.senderKeyError("Sender key skip too large (\(n - self.n) > \(Constants.maxSenderKeySkip))") } // Snapshot before fast-forward let snapChainKey = chainKey let snapN = self.n let snapKnown = knownKeys do { // Fast-forward the chain to reach message n while self.n <= n { let (newCK, mk) = CryptoUtils.kdfCK(chainKey: chainKey) chainKey = newCK knownKeys[self.n] = mk self.n += 1 } guard let mk = knownKeys.removeValue(forKey: n) else { throw CryptoError.senderKeyError("Message key for n=\(n) not available") } let aad = chainId + UInt32(n).bigEndianData return try CryptoUtils.aesGcmDecrypt(ciphertext, key: mk, nonce: nonce, aad: aad) } catch { // Restore state on failure chainKey = snapChainKey self.n = snapN knownKeys = snapKnown throw error } } // MARK: - Key Export/Import /// Export sender key for distribution to group members. /// Matches Python: SenderKeyState.export_key() func exportKey() -> Data { let dict: [String: Any] = ["sender_key": senderKey.hexString] return try! JSONSerialization.data(withJSONObject: dict) } /// Initialize a receiving SenderKeyState from an exported key. /// Matches Python: SenderKeyState.from_key(exported_key) static func fromKey(_ exportedKey: Data) throws -> SenderKeyState { guard let dict = try JSONSerialization.jsonObject(with: exportedKey) as? [String: Any], let senderKeyHex = dict["sender_key"] as? String, let senderKey = Data(hexString: senderKeyHex) else { throw CryptoError.stateImportFailed("Invalid sender key export") } return SenderKeyState(senderKey: senderKey) } // MARK: - Full State Export/Import /// Serialize full state for persistent storage. /// Matches Python: SenderKeyState.export_state() func exportState() -> Data { var knownKeysDict: [String: String] = [:] for (k, v) in knownKeys { knownKeysDict[String(k)] = v.hexString } let state: [String: Any] = [ "sender_key": senderKey.hexString, "chain_id": chainId.hexString, "chain_key": chainKey.hexString, "n": n, "known_keys": knownKeysDict, ] return try! JSONSerialization.data(withJSONObject: state) } /// Deserialize full state. /// Matches Python: SenderKeyState.import_state(data) static func importState(_ data: Data) throws -> SenderKeyState { guard let state = try JSONSerialization.jsonObject(with: data) as? [String: Any], let senderKeyHex = state["sender_key"] as? String, let senderKey = Data(hexString: senderKeyHex), let chainIdHex = state["chain_id"] as? String, let chainId = Data(hexString: chainIdHex), let chainKeyHex = state["chain_key"] as? String, let chainKey = Data(hexString: chainKeyHex), let n = state["n"] as? Int else { throw CryptoError.stateImportFailed("Invalid sender key state") } var knownKeys: [Int: Data] = [:] if let knownKeysDict = state["known_keys"] as? [String: String] { for (k, v) in knownKeysDict { if let idx = Int(k), let data = Data(hexString: v) { knownKeys[idx] = data } } } return SenderKeyState( senderKey: senderKey, chainId: chainId, chainKey: chainKey, n: n, knownKeys: knownKeys ) } }