Files
Kecalek_python/ios_client/EncryptedChat/Crypto/DoubleRatchet.swift
2026-03-11 16:54:14 +01:00

372 lines
13 KiB
Swift

import Foundation
import CryptoKit
/// Ratchet header sent with each message
struct RatchetHeader {
let dhPub: Data // sender's current ratchet public key (32 bytes)
let n: Int // message number in current sending chain
let pn: Int // number of messages in previous sending chain
/// Serialize header to JSON bytes for use as AAD.
/// Matches Python: RatchetHeader.serialize()
func serialize() -> Data {
let dict: [String: Any] = [
"dh_pub": dhPub.hexString,
"n": n,
"pn": pn,
]
// Must produce consistent JSON sorted keys for determinism
return try! JSONSerialization.data(withJSONObject: dict, options: .sortedKeys)
}
/// Convert to dictionary for protocol.
/// Matches Python: RatchetHeader.to_dict()
func toDict() -> [String: Any] {
[
"dh_pub": dhPub.hexString,
"n": n,
"pn": pn,
]
}
/// Parse from dictionary.
/// Matches Python: RatchetHeader.from_dict(d)
static func fromDict(_ d: [String: Any]) throws -> RatchetHeader {
guard let dhPubHex = d["dh_pub"] as? String,
let dhPub = Data(hexString: dhPubHex),
let n = d["n"] as? Int,
let pn = d["pn"] as? Int else {
throw CryptoError.invalidHeader("Missing or invalid header fields")
}
return RatchetHeader(dhPub: dhPub, n: n, pn: pn)
}
}
/// Signal Double Ratchet implementation.
/// Matches Python: DoubleRatchet class in crypto_utils.py
class DoubleRatchet {
private(set) var dhPair: (privateKey: Curve25519.KeyAgreement.PrivateKey,
publicKey: Curve25519.KeyAgreement.PublicKey)?
private(set) var dhRemote: Curve25519.KeyAgreement.PublicKey?
private(set) var rootKey: Data = Data()
private(set) var sendChainKey: Data?
private(set) var recvChainKey: Data?
private(set) var sendN: Int = 0
private(set) var recvN: Int = 0
private(set) var prevSendN: Int = 0
// Skipped message keys: "dh_pub_hex:n" message_key
private(set) var skipped: [String: Data] = [:]
/// Attached X3DH header set when creating a new session, consumed on first send.
/// Matches Python: ratchet._x3dh_header
var x3dhHeader: [String: Any]?
init() {}
// MARK: - Initialization
/// Initialize as initiator (Alice) after X3DH.
/// Matches Python: DoubleRatchet.init_alice(shared_secret, bob_spk_pub)
static func initAlice(sharedSecret: Data, bobSpkPub: Curve25519.KeyAgreement.PublicKey) throws -> DoubleRatchet {
let ratchet = DoubleRatchet()
let (priv, pub) = X25519Crypto.generateKeypair()
ratchet.dhPair = (priv, pub)
ratchet.dhRemote = bobSpkPub
// Perform DH ratchet to derive send chain
let dhOutput = try X25519Crypto.dh(priv, bobSpkPub)
let (newRK, sendCK) = CryptoUtils.kdfRK(rootKey: sharedSecret, dhOutput: dhOutput)
ratchet.rootKey = newRK
ratchet.sendChainKey = sendCK
ratchet.recvChainKey = nil
ratchet.sendN = 0
ratchet.recvN = 0
ratchet.prevSendN = 0
return ratchet
}
/// Initialize as responder (Bob) after X3DH.
/// Matches Python: DoubleRatchet.init_bob(shared_secret, spk_pair)
static func initBob(
sharedSecret: Data,
spkPair: (privateKey: Curve25519.KeyAgreement.PrivateKey, publicKey: Curve25519.KeyAgreement.PublicKey)
) -> DoubleRatchet {
let ratchet = DoubleRatchet()
ratchet.dhPair = spkPair
ratchet.rootKey = sharedSecret
ratchet.sendChainKey = nil
ratchet.recvChainKey = nil
ratchet.sendN = 0
ratchet.recvN = 0
ratchet.prevSendN = 0
return ratchet
}
// MARK: - Encrypt
/// Encrypt a message.
/// Returns (header dict, ciphertext with tag, nonce).
/// Matches Python: DoubleRatchet.encrypt(plaintext)
func encrypt(_ plaintext: Data) throws -> (header: [String: Any], ciphertext: Data, nonce: Data) {
guard sendChainKey != nil else {
throw CryptoError.ratchetError("Send chain not initialized")
}
guard let dhPair = dhPair else {
throw CryptoError.ratchetError("DH pair not set")
}
let (newCK, messageKey) = CryptoUtils.kdfCK(chainKey: sendChainKey!)
sendChainKey = newCK
let header = RatchetHeader(
dhPub: X25519Crypto.serializePublic(dhPair.publicKey),
n: sendN,
pn: prevSendN
)
let nonce = Data.randomBytes(12)
let aad = header.serialize()
let ctWithTag = try CryptoUtils.aesGcmEncrypt(plaintext, key: messageKey, nonce: nonce, aad: aad)
sendN += 1
return (header.toDict(), ctWithTag, nonce)
}
// MARK: - Decrypt
/// Decrypt a message. Handles DH ratchet step if new dh_pub.
/// State is snapshotted before modification and restored on failure (M9 fix).
/// Matches Python: DoubleRatchet.decrypt(header_dict, ciphertext, nonce)
func decrypt(headerDict: [String: Any], ciphertext: Data, nonce: Data) throws -> Data {
let header = try RatchetHeader.fromDict(headerDict)
let remoteDhPubBytes = header.dhPub
// Check if this is from a skipped message
let skipKey = "\(remoteDhPubBytes.hexString):\(header.n)"
if let mk = skipped[skipKey] {
skipped.removeValue(forKey: skipKey)
let aad = header.serialize()
do {
return try CryptoUtils.aesGcmDecrypt(ciphertext, key: mk, nonce: nonce, aad: aad)
} catch {
// Restore skipped key on failure
skipped[skipKey] = mk
throw error
}
}
// Snapshot state before modifications
let snap = snapshot()
do {
let remoteDhPub = try X25519Crypto.loadPublic(remoteDhPubBytes)
let currentRemoteBytes: Data? = dhRemote.map { X25519Crypto.serializePublic($0) }
if currentRemoteBytes == nil || remoteDhPubBytes != currentRemoteBytes {
// New DH ratchet step
try skipMessages(until: header.pn)
try dhRatchet(remoteDhPub: remoteDhPub)
}
try skipMessages(until: header.n)
// Derive message key from receive chain
guard recvChainKey != nil else {
throw CryptoError.ratchetError("Receive chain key is nil")
}
let (newCK, mk) = CryptoUtils.kdfCK(chainKey: recvChainKey!)
recvChainKey = newCK
recvN += 1
let aad = header.serialize()
return try CryptoUtils.aesGcmDecrypt(ciphertext, key: mk, nonce: nonce, aad: aad)
} catch {
restore(snap)
throw error
}
}
// MARK: - State Snapshot/Restore (M9)
private struct Snapshot {
let dhPairPriv: Data?
let dhPairPub: Data?
let dhRemote: Data?
let rootKey: Data
let sendChainKey: Data?
let recvChainKey: Data?
let sendN: Int
let recvN: Int
let prevSendN: Int
let skipped: [String: Data]
}
private func snapshot() -> Snapshot {
Snapshot(
dhPairPriv: dhPair.map { X25519Crypto.serializePrivate($0.privateKey) },
dhPairPub: dhPair.map { X25519Crypto.serializePublic($0.publicKey) },
dhRemote: dhRemote.map { X25519Crypto.serializePublic($0) },
rootKey: rootKey,
sendChainKey: sendChainKey,
recvChainKey: recvChainKey,
sendN: sendN,
recvN: recvN,
prevSendN: prevSendN,
skipped: skipped
)
}
private func restore(_ snap: Snapshot) {
if let privData = snap.dhPairPriv, let pubData = snap.dhPairPub,
let priv = try? X25519Crypto.loadPrivate(privData),
let pub = try? X25519Crypto.loadPublic(pubData) {
dhPair = (priv, pub)
} else {
dhPair = nil
}
if let remoteData = snap.dhRemote, let remote = try? X25519Crypto.loadPublic(remoteData) {
dhRemote = remote
} else {
dhRemote = nil
}
rootKey = snap.rootKey
sendChainKey = snap.sendChainKey
recvChainKey = snap.recvChainKey
sendN = snap.sendN
recvN = snap.recvN
prevSendN = snap.prevSendN
skipped = snap.skipped
}
// MARK: - Internal Ratchet Operations
private func skipMessages(until: Int) throws {
guard recvChainKey != nil else { return }
if until - recvN > Constants.maxSkip {
throw CryptoError.maxSkipExceeded
}
while recvN < until {
let (newCK, mk) = CryptoUtils.kdfCK(chainKey: recvChainKey!)
recvChainKey = newCK
let remoteHex = dhRemote.map { X25519Crypto.serializePublic($0).hexString } ?? ""
skipped["\(remoteHex):\(recvN)"] = mk
recvN += 1
}
}
private func dhRatchet(remoteDhPub: Curve25519.KeyAgreement.PublicKey) throws {
prevSendN = sendN
sendN = 0
recvN = 0
dhRemote = remoteDhPub
// Derive new receive chain key
guard let dhPair = dhPair else {
throw CryptoError.ratchetError("DH pair not set")
}
let dhOutput1 = try X25519Crypto.dh(dhPair.privateKey, remoteDhPub)
let (newRK1, recvCK) = CryptoUtils.kdfRK(rootKey: rootKey, dhOutput: dhOutput1)
rootKey = newRK1
recvChainKey = recvCK
// Generate new DH pair and derive new send chain key
let (newPriv, newPub) = X25519Crypto.generateKeypair()
self.dhPair = (newPriv, newPub)
let dhOutput2 = try X25519Crypto.dh(newPriv, remoteDhPub)
let (newRK2, sendCK) = CryptoUtils.kdfRK(rootKey: rootKey, dhOutput: dhOutput2)
rootKey = newRK2
sendChainKey = sendCK
}
// MARK: - State Export/Import
/// Serialize full ratchet state for persistent storage.
/// Produces JSON matching Python's DoubleRatchet.export_state() exactly.
func exportState() throws -> Data {
var state: [String: Any] = [:]
if let pair = dhPair {
state["dh_priv"] = X25519Crypto.serializePrivate(pair.privateKey).hexString
state["dh_pub"] = X25519Crypto.serializePublic(pair.publicKey).hexString
} else {
state["dh_priv"] = NSNull()
state["dh_pub"] = NSNull()
}
if let remote = dhRemote {
state["dh_remote"] = X25519Crypto.serializePublic(remote).hexString
} else {
state["dh_remote"] = NSNull()
}
state["root_key"] = rootKey.hexString
state["send_ck"] = sendChainKey?.hexString ?? NSNull()
state["recv_ck"] = recvChainKey?.hexString ?? NSNull()
state["send_n"] = sendN
state["recv_n"] = recvN
state["prev_send_n"] = prevSendN
// Skipped keys: Python format is "dh_pub_hex:n" -> message_key_hex
var skippedDict: [String: String] = [:]
for (key, value) in skipped {
skippedDict[key] = value.hexString
}
state["skipped"] = skippedDict
return try JSONSerialization.data(withJSONObject: state)
}
/// Deserialize ratchet state.
/// Matches Python: DoubleRatchet.import_state(data)
static func importState(_ data: Data) throws -> DoubleRatchet {
guard let state = try JSONSerialization.jsonObject(with: data) as? [String: Any] else {
throw CryptoError.stateImportFailed("Invalid JSON")
}
let r = DoubleRatchet()
if let dhPrivHex = state["dh_priv"] as? String,
let dhPubHex = state["dh_pub"] as? String,
let privData = Data(hexString: dhPrivHex),
let pubData = Data(hexString: dhPubHex) {
let priv = try X25519Crypto.loadPrivate(privData)
let pub = try X25519Crypto.loadPublic(pubData)
r.dhPair = (priv, pub)
}
if let dhRemoteHex = state["dh_remote"] as? String,
let remoteData = Data(hexString: dhRemoteHex) {
r.dhRemote = try X25519Crypto.loadPublic(remoteData)
}
guard let rootKeyHex = state["root_key"] as? String,
let rootKey = Data(hexString: rootKeyHex) else {
throw CryptoError.stateImportFailed("Missing root_key")
}
r.rootKey = rootKey
if let sendCKHex = state["send_ck"] as? String, let ck = Data(hexString: sendCKHex) {
r.sendChainKey = ck
}
if let recvCKHex = state["recv_ck"] as? String, let ck = Data(hexString: recvCKHex) {
r.recvChainKey = ck
}
r.sendN = state["send_n"] as? Int ?? 0
r.recvN = state["recv_n"] as? Int ?? 0
r.prevSendN = state["prev_send_n"] as? Int ?? 0
if let skippedDict = state["skipped"] as? [String: String] {
for (key, valueHex) in skippedDict {
if let value = Data(hexString: valueHex) {
r.skipped[key] = value
}
}
}
return r
}
}