initial commit
This commit is contained in:
371
ios_client/EncryptedChat/Crypto/DoubleRatchet.swift
Normal file
371
ios_client/EncryptedChat/Crypto/DoubleRatchet.swift
Normal file
@@ -0,0 +1,371 @@
|
||||
import Foundation
|
||||
import CryptoKit
|
||||
|
||||
/// Ratchet header sent with each message
|
||||
struct RatchetHeader {
|
||||
let dhPub: Data // sender's current ratchet public key (32 bytes)
|
||||
let n: Int // message number in current sending chain
|
||||
let pn: Int // number of messages in previous sending chain
|
||||
|
||||
/// Serialize header to JSON bytes for use as AAD.
|
||||
/// Matches Python: RatchetHeader.serialize()
|
||||
func serialize() -> Data {
|
||||
let dict: [String: Any] = [
|
||||
"dh_pub": dhPub.hexString,
|
||||
"n": n,
|
||||
"pn": pn,
|
||||
]
|
||||
// Must produce consistent JSON — sorted keys for determinism
|
||||
return try! JSONSerialization.data(withJSONObject: dict, options: .sortedKeys)
|
||||
}
|
||||
|
||||
/// Convert to dictionary for protocol.
|
||||
/// Matches Python: RatchetHeader.to_dict()
|
||||
func toDict() -> [String: Any] {
|
||||
[
|
||||
"dh_pub": dhPub.hexString,
|
||||
"n": n,
|
||||
"pn": pn,
|
||||
]
|
||||
}
|
||||
|
||||
/// Parse from dictionary.
|
||||
/// Matches Python: RatchetHeader.from_dict(d)
|
||||
static func fromDict(_ d: [String: Any]) throws -> RatchetHeader {
|
||||
guard let dhPubHex = d["dh_pub"] as? String,
|
||||
let dhPub = Data(hexString: dhPubHex),
|
||||
let n = d["n"] as? Int,
|
||||
let pn = d["pn"] as? Int else {
|
||||
throw CryptoError.invalidHeader("Missing or invalid header fields")
|
||||
}
|
||||
return RatchetHeader(dhPub: dhPub, n: n, pn: pn)
|
||||
}
|
||||
}
|
||||
|
||||
/// Signal Double Ratchet implementation.
|
||||
/// Matches Python: DoubleRatchet class in crypto_utils.py
|
||||
class DoubleRatchet {
|
||||
|
||||
private(set) var dhPair: (privateKey: Curve25519.KeyAgreement.PrivateKey,
|
||||
publicKey: Curve25519.KeyAgreement.PublicKey)?
|
||||
private(set) var dhRemote: Curve25519.KeyAgreement.PublicKey?
|
||||
private(set) var rootKey: Data = Data()
|
||||
private(set) var sendChainKey: Data?
|
||||
private(set) var recvChainKey: Data?
|
||||
private(set) var sendN: Int = 0
|
||||
private(set) var recvN: Int = 0
|
||||
private(set) var prevSendN: Int = 0
|
||||
// Skipped message keys: "dh_pub_hex:n" → message_key
|
||||
private(set) var skipped: [String: Data] = [:]
|
||||
|
||||
/// Attached X3DH header — set when creating a new session, consumed on first send.
|
||||
/// Matches Python: ratchet._x3dh_header
|
||||
var x3dhHeader: [String: Any]?
|
||||
|
||||
init() {}
|
||||
|
||||
// MARK: - Initialization
|
||||
|
||||
/// Initialize as initiator (Alice) after X3DH.
|
||||
/// Matches Python: DoubleRatchet.init_alice(shared_secret, bob_spk_pub)
|
||||
static func initAlice(sharedSecret: Data, bobSpkPub: Curve25519.KeyAgreement.PublicKey) throws -> DoubleRatchet {
|
||||
let ratchet = DoubleRatchet()
|
||||
let (priv, pub) = X25519Crypto.generateKeypair()
|
||||
ratchet.dhPair = (priv, pub)
|
||||
ratchet.dhRemote = bobSpkPub
|
||||
|
||||
// Perform DH ratchet to derive send chain
|
||||
let dhOutput = try X25519Crypto.dh(priv, bobSpkPub)
|
||||
let (newRK, sendCK) = CryptoUtils.kdfRK(rootKey: sharedSecret, dhOutput: dhOutput)
|
||||
ratchet.rootKey = newRK
|
||||
ratchet.sendChainKey = sendCK
|
||||
ratchet.recvChainKey = nil
|
||||
ratchet.sendN = 0
|
||||
ratchet.recvN = 0
|
||||
ratchet.prevSendN = 0
|
||||
return ratchet
|
||||
}
|
||||
|
||||
/// Initialize as responder (Bob) after X3DH.
|
||||
/// Matches Python: DoubleRatchet.init_bob(shared_secret, spk_pair)
|
||||
static func initBob(
|
||||
sharedSecret: Data,
|
||||
spkPair: (privateKey: Curve25519.KeyAgreement.PrivateKey, publicKey: Curve25519.KeyAgreement.PublicKey)
|
||||
) -> DoubleRatchet {
|
||||
let ratchet = DoubleRatchet()
|
||||
ratchet.dhPair = spkPair
|
||||
ratchet.rootKey = sharedSecret
|
||||
ratchet.sendChainKey = nil
|
||||
ratchet.recvChainKey = nil
|
||||
ratchet.sendN = 0
|
||||
ratchet.recvN = 0
|
||||
ratchet.prevSendN = 0
|
||||
return ratchet
|
||||
}
|
||||
|
||||
// MARK: - Encrypt
|
||||
|
||||
/// Encrypt a message.
|
||||
/// Returns (header dict, ciphertext with tag, nonce).
|
||||
/// Matches Python: DoubleRatchet.encrypt(plaintext)
|
||||
func encrypt(_ plaintext: Data) throws -> (header: [String: Any], ciphertext: Data, nonce: Data) {
|
||||
guard sendChainKey != nil else {
|
||||
throw CryptoError.ratchetError("Send chain not initialized")
|
||||
}
|
||||
guard let dhPair = dhPair else {
|
||||
throw CryptoError.ratchetError("DH pair not set")
|
||||
}
|
||||
|
||||
let (newCK, messageKey) = CryptoUtils.kdfCK(chainKey: sendChainKey!)
|
||||
sendChainKey = newCK
|
||||
|
||||
let header = RatchetHeader(
|
||||
dhPub: X25519Crypto.serializePublic(dhPair.publicKey),
|
||||
n: sendN,
|
||||
pn: prevSendN
|
||||
)
|
||||
|
||||
let nonce = Data.randomBytes(12)
|
||||
let aad = header.serialize()
|
||||
let ctWithTag = try CryptoUtils.aesGcmEncrypt(plaintext, key: messageKey, nonce: nonce, aad: aad)
|
||||
|
||||
sendN += 1
|
||||
|
||||
return (header.toDict(), ctWithTag, nonce)
|
||||
}
|
||||
|
||||
// MARK: - Decrypt
|
||||
|
||||
/// Decrypt a message. Handles DH ratchet step if new dh_pub.
|
||||
/// State is snapshotted before modification and restored on failure (M9 fix).
|
||||
/// Matches Python: DoubleRatchet.decrypt(header_dict, ciphertext, nonce)
|
||||
func decrypt(headerDict: [String: Any], ciphertext: Data, nonce: Data) throws -> Data {
|
||||
let header = try RatchetHeader.fromDict(headerDict)
|
||||
let remoteDhPubBytes = header.dhPub
|
||||
|
||||
// Check if this is from a skipped message
|
||||
let skipKey = "\(remoteDhPubBytes.hexString):\(header.n)"
|
||||
if let mk = skipped[skipKey] {
|
||||
skipped.removeValue(forKey: skipKey)
|
||||
let aad = header.serialize()
|
||||
do {
|
||||
return try CryptoUtils.aesGcmDecrypt(ciphertext, key: mk, nonce: nonce, aad: aad)
|
||||
} catch {
|
||||
// Restore skipped key on failure
|
||||
skipped[skipKey] = mk
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// Snapshot state before modifications
|
||||
let snap = snapshot()
|
||||
|
||||
do {
|
||||
let remoteDhPub = try X25519Crypto.loadPublic(remoteDhPubBytes)
|
||||
let currentRemoteBytes: Data? = dhRemote.map { X25519Crypto.serializePublic($0) }
|
||||
|
||||
if currentRemoteBytes == nil || remoteDhPubBytes != currentRemoteBytes {
|
||||
// New DH ratchet step
|
||||
try skipMessages(until: header.pn)
|
||||
try dhRatchet(remoteDhPub: remoteDhPub)
|
||||
}
|
||||
|
||||
try skipMessages(until: header.n)
|
||||
|
||||
// Derive message key from receive chain
|
||||
guard recvChainKey != nil else {
|
||||
throw CryptoError.ratchetError("Receive chain key is nil")
|
||||
}
|
||||
let (newCK, mk) = CryptoUtils.kdfCK(chainKey: recvChainKey!)
|
||||
recvChainKey = newCK
|
||||
recvN += 1
|
||||
|
||||
let aad = header.serialize()
|
||||
return try CryptoUtils.aesGcmDecrypt(ciphertext, key: mk, nonce: nonce, aad: aad)
|
||||
} catch {
|
||||
restore(snap)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - State Snapshot/Restore (M9)
|
||||
|
||||
private struct Snapshot {
|
||||
let dhPairPriv: Data?
|
||||
let dhPairPub: Data?
|
||||
let dhRemote: Data?
|
||||
let rootKey: Data
|
||||
let sendChainKey: Data?
|
||||
let recvChainKey: Data?
|
||||
let sendN: Int
|
||||
let recvN: Int
|
||||
let prevSendN: Int
|
||||
let skipped: [String: Data]
|
||||
}
|
||||
|
||||
private func snapshot() -> Snapshot {
|
||||
Snapshot(
|
||||
dhPairPriv: dhPair.map { X25519Crypto.serializePrivate($0.privateKey) },
|
||||
dhPairPub: dhPair.map { X25519Crypto.serializePublic($0.publicKey) },
|
||||
dhRemote: dhRemote.map { X25519Crypto.serializePublic($0) },
|
||||
rootKey: rootKey,
|
||||
sendChainKey: sendChainKey,
|
||||
recvChainKey: recvChainKey,
|
||||
sendN: sendN,
|
||||
recvN: recvN,
|
||||
prevSendN: prevSendN,
|
||||
skipped: skipped
|
||||
)
|
||||
}
|
||||
|
||||
private func restore(_ snap: Snapshot) {
|
||||
if let privData = snap.dhPairPriv, let pubData = snap.dhPairPub,
|
||||
let priv = try? X25519Crypto.loadPrivate(privData),
|
||||
let pub = try? X25519Crypto.loadPublic(pubData) {
|
||||
dhPair = (priv, pub)
|
||||
} else {
|
||||
dhPair = nil
|
||||
}
|
||||
if let remoteData = snap.dhRemote, let remote = try? X25519Crypto.loadPublic(remoteData) {
|
||||
dhRemote = remote
|
||||
} else {
|
||||
dhRemote = nil
|
||||
}
|
||||
rootKey = snap.rootKey
|
||||
sendChainKey = snap.sendChainKey
|
||||
recvChainKey = snap.recvChainKey
|
||||
sendN = snap.sendN
|
||||
recvN = snap.recvN
|
||||
prevSendN = snap.prevSendN
|
||||
skipped = snap.skipped
|
||||
}
|
||||
|
||||
// MARK: - Internal Ratchet Operations
|
||||
|
||||
private func skipMessages(until: Int) throws {
|
||||
guard recvChainKey != nil else { return }
|
||||
if until - recvN > Constants.maxSkip {
|
||||
throw CryptoError.maxSkipExceeded
|
||||
}
|
||||
while recvN < until {
|
||||
let (newCK, mk) = CryptoUtils.kdfCK(chainKey: recvChainKey!)
|
||||
recvChainKey = newCK
|
||||
let remoteHex = dhRemote.map { X25519Crypto.serializePublic($0).hexString } ?? ""
|
||||
skipped["\(remoteHex):\(recvN)"] = mk
|
||||
recvN += 1
|
||||
}
|
||||
}
|
||||
|
||||
private func dhRatchet(remoteDhPub: Curve25519.KeyAgreement.PublicKey) throws {
|
||||
prevSendN = sendN
|
||||
sendN = 0
|
||||
recvN = 0
|
||||
dhRemote = remoteDhPub
|
||||
|
||||
// Derive new receive chain key
|
||||
guard let dhPair = dhPair else {
|
||||
throw CryptoError.ratchetError("DH pair not set")
|
||||
}
|
||||
let dhOutput1 = try X25519Crypto.dh(dhPair.privateKey, remoteDhPub)
|
||||
let (newRK1, recvCK) = CryptoUtils.kdfRK(rootKey: rootKey, dhOutput: dhOutput1)
|
||||
rootKey = newRK1
|
||||
recvChainKey = recvCK
|
||||
|
||||
// Generate new DH pair and derive new send chain key
|
||||
let (newPriv, newPub) = X25519Crypto.generateKeypair()
|
||||
self.dhPair = (newPriv, newPub)
|
||||
let dhOutput2 = try X25519Crypto.dh(newPriv, remoteDhPub)
|
||||
let (newRK2, sendCK) = CryptoUtils.kdfRK(rootKey: rootKey, dhOutput: dhOutput2)
|
||||
rootKey = newRK2
|
||||
sendChainKey = sendCK
|
||||
}
|
||||
|
||||
// MARK: - State Export/Import
|
||||
|
||||
/// Serialize full ratchet state for persistent storage.
|
||||
/// Produces JSON matching Python's DoubleRatchet.export_state() exactly.
|
||||
func exportState() throws -> Data {
|
||||
var state: [String: Any] = [:]
|
||||
|
||||
if let pair = dhPair {
|
||||
state["dh_priv"] = X25519Crypto.serializePrivate(pair.privateKey).hexString
|
||||
state["dh_pub"] = X25519Crypto.serializePublic(pair.publicKey).hexString
|
||||
} else {
|
||||
state["dh_priv"] = NSNull()
|
||||
state["dh_pub"] = NSNull()
|
||||
}
|
||||
|
||||
if let remote = dhRemote {
|
||||
state["dh_remote"] = X25519Crypto.serializePublic(remote).hexString
|
||||
} else {
|
||||
state["dh_remote"] = NSNull()
|
||||
}
|
||||
|
||||
state["root_key"] = rootKey.hexString
|
||||
state["send_ck"] = sendChainKey?.hexString ?? NSNull()
|
||||
state["recv_ck"] = recvChainKey?.hexString ?? NSNull()
|
||||
state["send_n"] = sendN
|
||||
state["recv_n"] = recvN
|
||||
state["prev_send_n"] = prevSendN
|
||||
|
||||
// Skipped keys: Python format is "dh_pub_hex:n" -> message_key_hex
|
||||
var skippedDict: [String: String] = [:]
|
||||
for (key, value) in skipped {
|
||||
skippedDict[key] = value.hexString
|
||||
}
|
||||
state["skipped"] = skippedDict
|
||||
|
||||
return try JSONSerialization.data(withJSONObject: state)
|
||||
}
|
||||
|
||||
/// Deserialize ratchet state.
|
||||
/// Matches Python: DoubleRatchet.import_state(data)
|
||||
static func importState(_ data: Data) throws -> DoubleRatchet {
|
||||
guard let state = try JSONSerialization.jsonObject(with: data) as? [String: Any] else {
|
||||
throw CryptoError.stateImportFailed("Invalid JSON")
|
||||
}
|
||||
|
||||
let r = DoubleRatchet()
|
||||
|
||||
if let dhPrivHex = state["dh_priv"] as? String,
|
||||
let dhPubHex = state["dh_pub"] as? String,
|
||||
let privData = Data(hexString: dhPrivHex),
|
||||
let pubData = Data(hexString: dhPubHex) {
|
||||
let priv = try X25519Crypto.loadPrivate(privData)
|
||||
let pub = try X25519Crypto.loadPublic(pubData)
|
||||
r.dhPair = (priv, pub)
|
||||
}
|
||||
|
||||
if let dhRemoteHex = state["dh_remote"] as? String,
|
||||
let remoteData = Data(hexString: dhRemoteHex) {
|
||||
r.dhRemote = try X25519Crypto.loadPublic(remoteData)
|
||||
}
|
||||
|
||||
guard let rootKeyHex = state["root_key"] as? String,
|
||||
let rootKey = Data(hexString: rootKeyHex) else {
|
||||
throw CryptoError.stateImportFailed("Missing root_key")
|
||||
}
|
||||
r.rootKey = rootKey
|
||||
|
||||
if let sendCKHex = state["send_ck"] as? String, let ck = Data(hexString: sendCKHex) {
|
||||
r.sendChainKey = ck
|
||||
}
|
||||
if let recvCKHex = state["recv_ck"] as? String, let ck = Data(hexString: recvCKHex) {
|
||||
r.recvChainKey = ck
|
||||
}
|
||||
|
||||
r.sendN = state["send_n"] as? Int ?? 0
|
||||
r.recvN = state["recv_n"] as? Int ?? 0
|
||||
r.prevSendN = state["prev_send_n"] as? Int ?? 0
|
||||
|
||||
if let skippedDict = state["skipped"] as? [String: String] {
|
||||
for (key, valueHex) in skippedDict {
|
||||
if let value = Data(hexString: valueHex) {
|
||||
r.skipped[key] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user