372 lines
13 KiB
Swift
372 lines
13 KiB
Swift
import Foundation
|
|
import CryptoKit
|
|
|
|
/// Ratchet header sent with each message
|
|
struct RatchetHeader {
|
|
let dhPub: Data // sender's current ratchet public key (32 bytes)
|
|
let n: Int // message number in current sending chain
|
|
let pn: Int // number of messages in previous sending chain
|
|
|
|
/// Serialize header to JSON bytes for use as AAD.
|
|
/// Matches Python: RatchetHeader.serialize()
|
|
func serialize() -> Data {
|
|
let dict: [String: Any] = [
|
|
"dh_pub": dhPub.hexString,
|
|
"n": n,
|
|
"pn": pn,
|
|
]
|
|
// Must produce consistent JSON — sorted keys for determinism
|
|
return try! JSONSerialization.data(withJSONObject: dict, options: .sortedKeys)
|
|
}
|
|
|
|
/// Convert to dictionary for protocol.
|
|
/// Matches Python: RatchetHeader.to_dict()
|
|
func toDict() -> [String: Any] {
|
|
[
|
|
"dh_pub": dhPub.hexString,
|
|
"n": n,
|
|
"pn": pn,
|
|
]
|
|
}
|
|
|
|
/// Parse from dictionary.
|
|
/// Matches Python: RatchetHeader.from_dict(d)
|
|
static func fromDict(_ d: [String: Any]) throws -> RatchetHeader {
|
|
guard let dhPubHex = d["dh_pub"] as? String,
|
|
let dhPub = Data(hexString: dhPubHex),
|
|
let n = d["n"] as? Int,
|
|
let pn = d["pn"] as? Int else {
|
|
throw CryptoError.invalidHeader("Missing or invalid header fields")
|
|
}
|
|
return RatchetHeader(dhPub: dhPub, n: n, pn: pn)
|
|
}
|
|
}
|
|
|
|
/// Signal Double Ratchet implementation.
|
|
/// Matches Python: DoubleRatchet class in crypto_utils.py
|
|
class DoubleRatchet {
|
|
|
|
private(set) var dhPair: (privateKey: Curve25519.KeyAgreement.PrivateKey,
|
|
publicKey: Curve25519.KeyAgreement.PublicKey)?
|
|
private(set) var dhRemote: Curve25519.KeyAgreement.PublicKey?
|
|
private(set) var rootKey: Data = Data()
|
|
private(set) var sendChainKey: Data?
|
|
private(set) var recvChainKey: Data?
|
|
private(set) var sendN: Int = 0
|
|
private(set) var recvN: Int = 0
|
|
private(set) var prevSendN: Int = 0
|
|
// Skipped message keys: "dh_pub_hex:n" → message_key
|
|
private(set) var skipped: [String: Data] = [:]
|
|
|
|
/// Attached X3DH header — set when creating a new session, consumed on first send.
|
|
/// Matches Python: ratchet._x3dh_header
|
|
var x3dhHeader: [String: Any]?
|
|
|
|
init() {}
|
|
|
|
// MARK: - Initialization
|
|
|
|
/// Initialize as initiator (Alice) after X3DH.
|
|
/// Matches Python: DoubleRatchet.init_alice(shared_secret, bob_spk_pub)
|
|
static func initAlice(sharedSecret: Data, bobSpkPub: Curve25519.KeyAgreement.PublicKey) throws -> DoubleRatchet {
|
|
let ratchet = DoubleRatchet()
|
|
let (priv, pub) = X25519Crypto.generateKeypair()
|
|
ratchet.dhPair = (priv, pub)
|
|
ratchet.dhRemote = bobSpkPub
|
|
|
|
// Perform DH ratchet to derive send chain
|
|
let dhOutput = try X25519Crypto.dh(priv, bobSpkPub)
|
|
let (newRK, sendCK) = CryptoUtils.kdfRK(rootKey: sharedSecret, dhOutput: dhOutput)
|
|
ratchet.rootKey = newRK
|
|
ratchet.sendChainKey = sendCK
|
|
ratchet.recvChainKey = nil
|
|
ratchet.sendN = 0
|
|
ratchet.recvN = 0
|
|
ratchet.prevSendN = 0
|
|
return ratchet
|
|
}
|
|
|
|
/// Initialize as responder (Bob) after X3DH.
|
|
/// Matches Python: DoubleRatchet.init_bob(shared_secret, spk_pair)
|
|
static func initBob(
|
|
sharedSecret: Data,
|
|
spkPair: (privateKey: Curve25519.KeyAgreement.PrivateKey, publicKey: Curve25519.KeyAgreement.PublicKey)
|
|
) -> DoubleRatchet {
|
|
let ratchet = DoubleRatchet()
|
|
ratchet.dhPair = spkPair
|
|
ratchet.rootKey = sharedSecret
|
|
ratchet.sendChainKey = nil
|
|
ratchet.recvChainKey = nil
|
|
ratchet.sendN = 0
|
|
ratchet.recvN = 0
|
|
ratchet.prevSendN = 0
|
|
return ratchet
|
|
}
|
|
|
|
// MARK: - Encrypt
|
|
|
|
/// Encrypt a message.
|
|
/// Returns (header dict, ciphertext with tag, nonce).
|
|
/// Matches Python: DoubleRatchet.encrypt(plaintext)
|
|
func encrypt(_ plaintext: Data) throws -> (header: [String: Any], ciphertext: Data, nonce: Data) {
|
|
guard sendChainKey != nil else {
|
|
throw CryptoError.ratchetError("Send chain not initialized")
|
|
}
|
|
guard let dhPair = dhPair else {
|
|
throw CryptoError.ratchetError("DH pair not set")
|
|
}
|
|
|
|
let (newCK, messageKey) = CryptoUtils.kdfCK(chainKey: sendChainKey!)
|
|
sendChainKey = newCK
|
|
|
|
let header = RatchetHeader(
|
|
dhPub: X25519Crypto.serializePublic(dhPair.publicKey),
|
|
n: sendN,
|
|
pn: prevSendN
|
|
)
|
|
|
|
let nonce = Data.randomBytes(12)
|
|
let aad = header.serialize()
|
|
let ctWithTag = try CryptoUtils.aesGcmEncrypt(plaintext, key: messageKey, nonce: nonce, aad: aad)
|
|
|
|
sendN += 1
|
|
|
|
return (header.toDict(), ctWithTag, nonce)
|
|
}
|
|
|
|
// MARK: - Decrypt
|
|
|
|
/// Decrypt a message. Handles DH ratchet step if new dh_pub.
|
|
/// State is snapshotted before modification and restored on failure (M9 fix).
|
|
/// Matches Python: DoubleRatchet.decrypt(header_dict, ciphertext, nonce)
|
|
func decrypt(headerDict: [String: Any], ciphertext: Data, nonce: Data) throws -> Data {
|
|
let header = try RatchetHeader.fromDict(headerDict)
|
|
let remoteDhPubBytes = header.dhPub
|
|
|
|
// Check if this is from a skipped message
|
|
let skipKey = "\(remoteDhPubBytes.hexString):\(header.n)"
|
|
if let mk = skipped[skipKey] {
|
|
skipped.removeValue(forKey: skipKey)
|
|
let aad = header.serialize()
|
|
do {
|
|
return try CryptoUtils.aesGcmDecrypt(ciphertext, key: mk, nonce: nonce, aad: aad)
|
|
} catch {
|
|
// Restore skipped key on failure
|
|
skipped[skipKey] = mk
|
|
throw error
|
|
}
|
|
}
|
|
|
|
// Snapshot state before modifications
|
|
let snap = snapshot()
|
|
|
|
do {
|
|
let remoteDhPub = try X25519Crypto.loadPublic(remoteDhPubBytes)
|
|
let currentRemoteBytes: Data? = dhRemote.map { X25519Crypto.serializePublic($0) }
|
|
|
|
if currentRemoteBytes == nil || remoteDhPubBytes != currentRemoteBytes {
|
|
// New DH ratchet step
|
|
try skipMessages(until: header.pn)
|
|
try dhRatchet(remoteDhPub: remoteDhPub)
|
|
}
|
|
|
|
try skipMessages(until: header.n)
|
|
|
|
// Derive message key from receive chain
|
|
guard recvChainKey != nil else {
|
|
throw CryptoError.ratchetError("Receive chain key is nil")
|
|
}
|
|
let (newCK, mk) = CryptoUtils.kdfCK(chainKey: recvChainKey!)
|
|
recvChainKey = newCK
|
|
recvN += 1
|
|
|
|
let aad = header.serialize()
|
|
return try CryptoUtils.aesGcmDecrypt(ciphertext, key: mk, nonce: nonce, aad: aad)
|
|
} catch {
|
|
restore(snap)
|
|
throw error
|
|
}
|
|
}
|
|
|
|
// MARK: - State Snapshot/Restore (M9)
|
|
|
|
private struct Snapshot {
|
|
let dhPairPriv: Data?
|
|
let dhPairPub: Data?
|
|
let dhRemote: Data?
|
|
let rootKey: Data
|
|
let sendChainKey: Data?
|
|
let recvChainKey: Data?
|
|
let sendN: Int
|
|
let recvN: Int
|
|
let prevSendN: Int
|
|
let skipped: [String: Data]
|
|
}
|
|
|
|
private func snapshot() -> Snapshot {
|
|
Snapshot(
|
|
dhPairPriv: dhPair.map { X25519Crypto.serializePrivate($0.privateKey) },
|
|
dhPairPub: dhPair.map { X25519Crypto.serializePublic($0.publicKey) },
|
|
dhRemote: dhRemote.map { X25519Crypto.serializePublic($0) },
|
|
rootKey: rootKey,
|
|
sendChainKey: sendChainKey,
|
|
recvChainKey: recvChainKey,
|
|
sendN: sendN,
|
|
recvN: recvN,
|
|
prevSendN: prevSendN,
|
|
skipped: skipped
|
|
)
|
|
}
|
|
|
|
private func restore(_ snap: Snapshot) {
|
|
if let privData = snap.dhPairPriv, let pubData = snap.dhPairPub,
|
|
let priv = try? X25519Crypto.loadPrivate(privData),
|
|
let pub = try? X25519Crypto.loadPublic(pubData) {
|
|
dhPair = (priv, pub)
|
|
} else {
|
|
dhPair = nil
|
|
}
|
|
if let remoteData = snap.dhRemote, let remote = try? X25519Crypto.loadPublic(remoteData) {
|
|
dhRemote = remote
|
|
} else {
|
|
dhRemote = nil
|
|
}
|
|
rootKey = snap.rootKey
|
|
sendChainKey = snap.sendChainKey
|
|
recvChainKey = snap.recvChainKey
|
|
sendN = snap.sendN
|
|
recvN = snap.recvN
|
|
prevSendN = snap.prevSendN
|
|
skipped = snap.skipped
|
|
}
|
|
|
|
// MARK: - Internal Ratchet Operations
|
|
|
|
private func skipMessages(until: Int) throws {
|
|
guard recvChainKey != nil else { return }
|
|
if until - recvN > Constants.maxSkip {
|
|
throw CryptoError.maxSkipExceeded
|
|
}
|
|
while recvN < until {
|
|
let (newCK, mk) = CryptoUtils.kdfCK(chainKey: recvChainKey!)
|
|
recvChainKey = newCK
|
|
let remoteHex = dhRemote.map { X25519Crypto.serializePublic($0).hexString } ?? ""
|
|
skipped["\(remoteHex):\(recvN)"] = mk
|
|
recvN += 1
|
|
}
|
|
}
|
|
|
|
private func dhRatchet(remoteDhPub: Curve25519.KeyAgreement.PublicKey) throws {
|
|
prevSendN = sendN
|
|
sendN = 0
|
|
recvN = 0
|
|
dhRemote = remoteDhPub
|
|
|
|
// Derive new receive chain key
|
|
guard let dhPair = dhPair else {
|
|
throw CryptoError.ratchetError("DH pair not set")
|
|
}
|
|
let dhOutput1 = try X25519Crypto.dh(dhPair.privateKey, remoteDhPub)
|
|
let (newRK1, recvCK) = CryptoUtils.kdfRK(rootKey: rootKey, dhOutput: dhOutput1)
|
|
rootKey = newRK1
|
|
recvChainKey = recvCK
|
|
|
|
// Generate new DH pair and derive new send chain key
|
|
let (newPriv, newPub) = X25519Crypto.generateKeypair()
|
|
self.dhPair = (newPriv, newPub)
|
|
let dhOutput2 = try X25519Crypto.dh(newPriv, remoteDhPub)
|
|
let (newRK2, sendCK) = CryptoUtils.kdfRK(rootKey: rootKey, dhOutput: dhOutput2)
|
|
rootKey = newRK2
|
|
sendChainKey = sendCK
|
|
}
|
|
|
|
// MARK: - State Export/Import
|
|
|
|
/// Serialize full ratchet state for persistent storage.
|
|
/// Produces JSON matching Python's DoubleRatchet.export_state() exactly.
|
|
func exportState() throws -> Data {
|
|
var state: [String: Any] = [:]
|
|
|
|
if let pair = dhPair {
|
|
state["dh_priv"] = X25519Crypto.serializePrivate(pair.privateKey).hexString
|
|
state["dh_pub"] = X25519Crypto.serializePublic(pair.publicKey).hexString
|
|
} else {
|
|
state["dh_priv"] = NSNull()
|
|
state["dh_pub"] = NSNull()
|
|
}
|
|
|
|
if let remote = dhRemote {
|
|
state["dh_remote"] = X25519Crypto.serializePublic(remote).hexString
|
|
} else {
|
|
state["dh_remote"] = NSNull()
|
|
}
|
|
|
|
state["root_key"] = rootKey.hexString
|
|
state["send_ck"] = sendChainKey?.hexString ?? NSNull()
|
|
state["recv_ck"] = recvChainKey?.hexString ?? NSNull()
|
|
state["send_n"] = sendN
|
|
state["recv_n"] = recvN
|
|
state["prev_send_n"] = prevSendN
|
|
|
|
// Skipped keys: Python format is "dh_pub_hex:n" -> message_key_hex
|
|
var skippedDict: [String: String] = [:]
|
|
for (key, value) in skipped {
|
|
skippedDict[key] = value.hexString
|
|
}
|
|
state["skipped"] = skippedDict
|
|
|
|
return try JSONSerialization.data(withJSONObject: state)
|
|
}
|
|
|
|
/// Deserialize ratchet state.
|
|
/// Matches Python: DoubleRatchet.import_state(data)
|
|
static func importState(_ data: Data) throws -> DoubleRatchet {
|
|
guard let state = try JSONSerialization.jsonObject(with: data) as? [String: Any] else {
|
|
throw CryptoError.stateImportFailed("Invalid JSON")
|
|
}
|
|
|
|
let r = DoubleRatchet()
|
|
|
|
if let dhPrivHex = state["dh_priv"] as? String,
|
|
let dhPubHex = state["dh_pub"] as? String,
|
|
let privData = Data(hexString: dhPrivHex),
|
|
let pubData = Data(hexString: dhPubHex) {
|
|
let priv = try X25519Crypto.loadPrivate(privData)
|
|
let pub = try X25519Crypto.loadPublic(pubData)
|
|
r.dhPair = (priv, pub)
|
|
}
|
|
|
|
if let dhRemoteHex = state["dh_remote"] as? String,
|
|
let remoteData = Data(hexString: dhRemoteHex) {
|
|
r.dhRemote = try X25519Crypto.loadPublic(remoteData)
|
|
}
|
|
|
|
guard let rootKeyHex = state["root_key"] as? String,
|
|
let rootKey = Data(hexString: rootKeyHex) else {
|
|
throw CryptoError.stateImportFailed("Missing root_key")
|
|
}
|
|
r.rootKey = rootKey
|
|
|
|
if let sendCKHex = state["send_ck"] as? String, let ck = Data(hexString: sendCKHex) {
|
|
r.sendChainKey = ck
|
|
}
|
|
if let recvCKHex = state["recv_ck"] as? String, let ck = Data(hexString: recvCKHex) {
|
|
r.recvChainKey = ck
|
|
}
|
|
|
|
r.sendN = state["send_n"] as? Int ?? 0
|
|
r.recvN = state["recv_n"] as? Int ?? 0
|
|
r.prevSendN = state["prev_send_n"] as? Int ?? 0
|
|
|
|
if let skippedDict = state["skipped"] as? [String: String] {
|
|
for (key, valueHex) in skippedDict {
|
|
if let value = Data(hexString: valueHex) {
|
|
r.skipped[key] = value
|
|
}
|
|
}
|
|
}
|
|
|
|
return r
|
|
}
|
|
}
|