176 lines
6.5 KiB
Swift
176 lines
6.5 KiB
Swift
import Foundation
|
|
import CryptoKit
|
|
|
|
/// Sender key chain for group messaging.
|
|
/// Each sender in a group has their own chain. Others receive the initial key via pairwise ratchet.
|
|
/// Matches Python: SenderKeyState class in crypto_utils.py
|
|
class SenderKeyState {
|
|
|
|
let senderKey: Data
|
|
let chainId: Data
|
|
private(set) var chainKey: Data
|
|
private(set) var n: Int
|
|
private var knownKeys: [Int: Data]
|
|
|
|
/// Initialize with optional sender key (generates random 32B if nil).
|
|
/// Matches Python: SenderKeyState.__init__(sender_key=None)
|
|
init(senderKey: Data? = nil) {
|
|
let key = senderKey ?? Data.randomBytes(32)
|
|
self.senderKey = key
|
|
self.chainId = Data(SHA256.hash(data: key))
|
|
self.chainKey = CryptoUtils.hkdfDerive(
|
|
inputKey: key,
|
|
salt: Data(repeating: 0x00, count: 32),
|
|
info: Data(Constants.senderKeyChainInfo.utf8),
|
|
length: 32
|
|
)
|
|
self.n = 0
|
|
self.knownKeys = [:]
|
|
}
|
|
|
|
/// Private init for import
|
|
private init(senderKey: Data, chainId: Data, chainKey: Data, n: Int, knownKeys: [Int: Data]) {
|
|
self.senderKey = senderKey
|
|
self.chainId = chainId
|
|
self.chainKey = chainKey
|
|
self.n = n
|
|
self.knownKeys = knownKeys
|
|
}
|
|
|
|
// MARK: - Encrypt
|
|
|
|
/// Encrypt with current chain key.
|
|
/// Returns (chainId hex, n, ciphertext with tag, nonce).
|
|
/// Matches Python: SenderKeyState.encrypt(plaintext)
|
|
func encrypt(_ plaintext: Data) throws -> (chainIdHex: String, n: Int, ciphertext: Data, nonce: Data) {
|
|
let (newCK, messageKey) = CryptoUtils.kdfCK(chainKey: chainKey)
|
|
chainKey = newCK
|
|
|
|
let nonce = Data.randomBytes(12)
|
|
// AAD = chainId + bigEndian(UInt32(n))
|
|
let aad = chainId + UInt32(n).bigEndianData
|
|
let ctWithTag = try CryptoUtils.aesGcmEncrypt(plaintext, key: messageKey, nonce: nonce, aad: aad)
|
|
|
|
let result = (chainIdHex: chainId.hexString, n: n, ciphertext: ctWithTag, nonce: nonce)
|
|
n += 1
|
|
return result
|
|
}
|
|
|
|
// MARK: - Decrypt
|
|
|
|
/// Decrypt a group message. Fast-forwards the chain if needed.
|
|
/// State is snapshotted before modification and restored on failure.
|
|
/// Matches Python: SenderKeyState.decrypt(chain_id_hex, n, ciphertext, nonce)
|
|
func decrypt(chainIdHex: String, n: Int, ciphertext: Data, nonce: Data) throws -> Data {
|
|
guard let expectedChainId = Data(hexString: chainIdHex) else {
|
|
throw CryptoError.senderKeyError("Invalid chain ID hex")
|
|
}
|
|
guard expectedChainId == chainId else {
|
|
throw CryptoError.senderKeyError("Chain ID mismatch")
|
|
}
|
|
|
|
if n - self.n > Constants.maxSenderKeySkip {
|
|
throw CryptoError.senderKeyError("Sender key skip too large (\(n - self.n) > \(Constants.maxSenderKeySkip))")
|
|
}
|
|
|
|
// Snapshot before fast-forward
|
|
let snapChainKey = chainKey
|
|
let snapN = self.n
|
|
let snapKnown = knownKeys
|
|
|
|
do {
|
|
// Fast-forward the chain to reach message n
|
|
while self.n <= n {
|
|
let (newCK, mk) = CryptoUtils.kdfCK(chainKey: chainKey)
|
|
chainKey = newCK
|
|
knownKeys[self.n] = mk
|
|
self.n += 1
|
|
}
|
|
|
|
guard let mk = knownKeys.removeValue(forKey: n) else {
|
|
throw CryptoError.senderKeyError("Message key for n=\(n) not available")
|
|
}
|
|
|
|
let aad = chainId + UInt32(n).bigEndianData
|
|
return try CryptoUtils.aesGcmDecrypt(ciphertext, key: mk, nonce: nonce, aad: aad)
|
|
} catch {
|
|
// Restore state on failure
|
|
chainKey = snapChainKey
|
|
self.n = snapN
|
|
knownKeys = snapKnown
|
|
throw error
|
|
}
|
|
}
|
|
|
|
// MARK: - Key Export/Import
|
|
|
|
/// Export sender key for distribution to group members.
|
|
/// Matches Python: SenderKeyState.export_key()
|
|
func exportKey() -> Data {
|
|
let dict: [String: Any] = ["sender_key": senderKey.hexString]
|
|
return try! JSONSerialization.data(withJSONObject: dict)
|
|
}
|
|
|
|
/// Initialize a receiving SenderKeyState from an exported key.
|
|
/// Matches Python: SenderKeyState.from_key(exported_key)
|
|
static func fromKey(_ exportedKey: Data) throws -> SenderKeyState {
|
|
guard let dict = try JSONSerialization.jsonObject(with: exportedKey) as? [String: Any],
|
|
let senderKeyHex = dict["sender_key"] as? String,
|
|
let senderKey = Data(hexString: senderKeyHex) else {
|
|
throw CryptoError.stateImportFailed("Invalid sender key export")
|
|
}
|
|
return SenderKeyState(senderKey: senderKey)
|
|
}
|
|
|
|
// MARK: - Full State Export/Import
|
|
|
|
/// Serialize full state for persistent storage.
|
|
/// Matches Python: SenderKeyState.export_state()
|
|
func exportState() -> Data {
|
|
var knownKeysDict: [String: String] = [:]
|
|
for (k, v) in knownKeys {
|
|
knownKeysDict[String(k)] = v.hexString
|
|
}
|
|
let state: [String: Any] = [
|
|
"sender_key": senderKey.hexString,
|
|
"chain_id": chainId.hexString,
|
|
"chain_key": chainKey.hexString,
|
|
"n": n,
|
|
"known_keys": knownKeysDict,
|
|
]
|
|
return try! JSONSerialization.data(withJSONObject: state)
|
|
}
|
|
|
|
/// Deserialize full state.
|
|
/// Matches Python: SenderKeyState.import_state(data)
|
|
static func importState(_ data: Data) throws -> SenderKeyState {
|
|
guard let state = try JSONSerialization.jsonObject(with: data) as? [String: Any],
|
|
let senderKeyHex = state["sender_key"] as? String,
|
|
let senderKey = Data(hexString: senderKeyHex),
|
|
let chainIdHex = state["chain_id"] as? String,
|
|
let chainId = Data(hexString: chainIdHex),
|
|
let chainKeyHex = state["chain_key"] as? String,
|
|
let chainKey = Data(hexString: chainKeyHex),
|
|
let n = state["n"] as? Int else {
|
|
throw CryptoError.stateImportFailed("Invalid sender key state")
|
|
}
|
|
|
|
var knownKeys: [Int: Data] = [:]
|
|
if let knownKeysDict = state["known_keys"] as? [String: String] {
|
|
for (k, v) in knownKeysDict {
|
|
if let idx = Int(k), let data = Data(hexString: v) {
|
|
knownKeys[idx] = data
|
|
}
|
|
}
|
|
}
|
|
|
|
return SenderKeyState(
|
|
senderKey: senderKey,
|
|
chainId: chainId,
|
|
chainKey: chainKey,
|
|
n: n,
|
|
knownKeys: knownKeys
|
|
)
|
|
}
|
|
}
|