Complete Android client for encrypted chat platform. 78+ Kotlin files: crypto (X3DH, Double Ratchet, AES-GCM, Ed25519, X25519, RSA-PSS), network (TCP/TLS, 50 endpoints), Hilt DI, Room+SQLCipher DB, Jetpack Compose UI with Catppuccin Mocha theme. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
397 lines
13 KiB
Kotlin
397 lines
13 KiB
Kotlin
package com.kecalek.chat.crypto
|
|
|
|
import org.bouncycastle.crypto.params.X25519PrivateKeyParameters
|
|
import org.bouncycastle.crypto.params.X25519PublicKeyParameters
|
|
import org.json.JSONObject
|
|
import java.security.SecureRandom
|
|
|
|
/**
|
|
* Double Ratchet algorithm for end-to-end encrypted messaging.
|
|
* Provides forward secrecy and break-in recovery.
|
|
*
|
|
* Compatible with Python DoubleRatchet class from crypto_utils.py.
|
|
*
|
|
* State:
|
|
* - dh_pair: current ratchet X25519 keypair
|
|
* - dh_remote: remote's current ratchet public key
|
|
* - root_key: 32-byte root key
|
|
* - send_chain_key / recv_chain_key: current chain keys
|
|
* - send_n / recv_n: message counters
|
|
* - prev_send_n: previous sending chain length
|
|
* - skipped: map of (dh_hex, n) -> message_key for out-of-order delivery
|
|
*/
|
|
class DoubleRatchet private constructor() {
|
|
|
|
private lateinit var dhPrivate: X25519PrivateKeyParameters
|
|
private lateinit var dhPublic: X25519PublicKeyParameters
|
|
private var dhRemote: X25519PublicKeyParameters? = null
|
|
private lateinit var rootKey: ByteArray
|
|
private var sendChainKey: ByteArray? = null
|
|
private var recvChainKey: ByteArray? = null
|
|
private var sendN: Int = 0
|
|
private var recvN: Int = 0
|
|
private var prevSendN: Int = 0
|
|
|
|
// skipped[(remotePublicHex, messageNumber)] = messageKey
|
|
private val skipped = mutableMapOf<String, ByteArray>()
|
|
|
|
companion object {
|
|
private const val MAX_SKIP = 256
|
|
|
|
/**
|
|
* Initialize as Alice (initiator).
|
|
* Called after X3DH produces a shared secret.
|
|
*
|
|
* @param sharedSecret X3DH shared secret
|
|
* @param bobSpkPub Bob's signed pre-key public (used as initial remote ratchet key)
|
|
*/
|
|
fun initAlice(sharedSecret: ByteArray, bobSpkPub: X25519PublicKeyParameters): DoubleRatchet {
|
|
val ratchet = DoubleRatchet()
|
|
|
|
// Generate initial ratchet keypair
|
|
val (dhPriv, dhPub) = X25519Crypto.generateKeypair()
|
|
ratchet.dhPrivate = dhPriv
|
|
ratchet.dhPublic = dhPub
|
|
ratchet.dhRemote = bobSpkPub
|
|
|
|
// Initial DH ratchet step
|
|
val dhOutput = X25519Crypto.dh(dhPriv, bobSpkPub)
|
|
val (newRootKey, sendChainKey) = HkdfUtils.kdfRk(sharedSecret, dhOutput)
|
|
ratchet.rootKey = newRootKey
|
|
ratchet.sendChainKey = sendChainKey
|
|
ratchet.recvChainKey = null
|
|
ratchet.sendN = 0
|
|
ratchet.recvN = 0
|
|
ratchet.prevSendN = 0
|
|
|
|
return ratchet
|
|
}
|
|
|
|
/**
|
|
* Initialize as Bob (responder).
|
|
* Uses SPK pair as initial ratchet key.
|
|
*
|
|
* @param sharedSecret X3DH shared secret
|
|
* @param spkPair Bob's signed pre-key pair (private, public)
|
|
*/
|
|
fun initBob(
|
|
sharedSecret: ByteArray,
|
|
spkPair: Pair<X25519PrivateKeyParameters, X25519PublicKeyParameters>,
|
|
): DoubleRatchet {
|
|
val ratchet = DoubleRatchet()
|
|
ratchet.dhPrivate = spkPair.first
|
|
ratchet.dhPublic = spkPair.second
|
|
ratchet.rootKey = sharedSecret
|
|
ratchet.sendChainKey = null
|
|
ratchet.recvChainKey = null
|
|
ratchet.sendN = 0
|
|
ratchet.recvN = 0
|
|
ratchet.prevSendN = 0
|
|
return ratchet
|
|
}
|
|
|
|
/**
|
|
* Import ratchet state from JSON bytes.
|
|
*/
|
|
fun importState(data: ByteArray): DoubleRatchet {
|
|
val json = JSONObject(String(data))
|
|
val ratchet = DoubleRatchet()
|
|
|
|
ratchet.dhPrivate = X25519Crypto.loadPrivate(json.getString("dh_priv").hexToBytes())
|
|
ratchet.dhPublic = X25519Crypto.loadPublic(json.getString("dh_pub").hexToBytes())
|
|
|
|
if (json.has("dh_remote") && !json.isNull("dh_remote")) {
|
|
ratchet.dhRemote = X25519Crypto.loadPublic(json.getString("dh_remote").hexToBytes())
|
|
}
|
|
|
|
ratchet.rootKey = json.getString("root_key").hexToBytes()
|
|
ratchet.sendChainKey = json.optString("send_ck", "").takeIf { it.isNotEmpty() }?.hexToBytes()
|
|
ratchet.recvChainKey = json.optString("recv_ck", "").takeIf { it.isNotEmpty() }?.hexToBytes()
|
|
ratchet.sendN = json.getInt("send_n")
|
|
ratchet.recvN = json.getInt("recv_n")
|
|
ratchet.prevSendN = json.getInt("prev_send_n")
|
|
|
|
// Import skipped keys
|
|
if (json.has("skipped")) {
|
|
val skippedJson = json.getJSONObject("skipped")
|
|
for (key in skippedJson.keys()) {
|
|
ratchet.skipped[key] = skippedJson.getString(key).hexToBytes()
|
|
}
|
|
}
|
|
|
|
return ratchet
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Encrypt plaintext message.
|
|
* @return RatchetMessage with header dict, ciphertext+tag, nonce
|
|
*/
|
|
fun encrypt(plaintext: ByteArray): RatchetMessage {
|
|
val ck = sendChainKey ?: throw CryptoException.DecryptionFailed("Send chain not initialized")
|
|
|
|
val (newChainKey, messageKey) = HkdfUtils.kdfCk(ck)
|
|
sendChainKey = newChainKey
|
|
|
|
val header = RatchetHeader(
|
|
dhPub = X25519Crypto.serializePublic(dhPublic),
|
|
n = sendN,
|
|
pn = prevSendN,
|
|
)
|
|
|
|
val aad = header.serialize()
|
|
val (nonce, ctWithTag) = AesGcmCrypto.encryptCombined(
|
|
plaintext = plaintext,
|
|
key = messageKey,
|
|
aad = aad,
|
|
)
|
|
|
|
sendN++
|
|
|
|
return RatchetMessage(
|
|
header = header,
|
|
ciphertext = ctWithTag,
|
|
nonce = nonce,
|
|
)
|
|
}
|
|
|
|
/**
|
|
* Decrypt received message.
|
|
* Handles out-of-order delivery via skipped message keys.
|
|
* Full state rollback on failure.
|
|
*/
|
|
fun decrypt(header: RatchetHeader, ciphertext: ByteArray, nonce: ByteArray): ByteArray {
|
|
val aad = header.serialize()
|
|
|
|
// Check skipped message keys first (no state change)
|
|
val skippedKey = makeSkippedKey(header.dhPub.toHex(), header.n)
|
|
skipped.remove(skippedKey)?.let { messageKey ->
|
|
return AesGcmCrypto.decryptCombined(
|
|
key = messageKey,
|
|
nonce = nonce,
|
|
ctWithTag = ciphertext,
|
|
aad = aad,
|
|
)
|
|
}
|
|
|
|
// Take snapshot for rollback
|
|
val snapshot = snapshot()
|
|
|
|
try {
|
|
// New DH ratchet step if remote key changed
|
|
val remoteHex = header.dhPub.toHex()
|
|
val currentRemoteHex = dhRemote?.let { X25519Crypto.serializePublic(it).toHex() }
|
|
|
|
if (remoteHex != currentRemoteHex) {
|
|
skipMessages(header.pn)
|
|
dhRatchet(X25519Crypto.loadPublic(header.dhPub))
|
|
}
|
|
|
|
skipMessages(header.n)
|
|
|
|
val ck = recvChainKey ?: throw CryptoException.DecryptionFailed("Receive chain not initialized")
|
|
val (newChainKey, messageKey) = HkdfUtils.kdfCk(ck)
|
|
recvChainKey = newChainKey
|
|
recvN++
|
|
|
|
return AesGcmCrypto.decryptCombined(
|
|
key = messageKey,
|
|
nonce = nonce,
|
|
ctWithTag = ciphertext,
|
|
aad = aad,
|
|
)
|
|
} catch (e: Exception) {
|
|
// Rollback on any failure
|
|
restore(snapshot)
|
|
throw if (e is CryptoException) e
|
|
else CryptoException.DecryptionFailed("Decryption failed", e)
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Export full ratchet state as JSON bytes.
|
|
*/
|
|
fun exportState(): ByteArray {
|
|
val json = JSONObject()
|
|
json.put("dh_priv", X25519Crypto.serializePrivate(dhPrivate).toHex())
|
|
json.put("dh_pub", X25519Crypto.serializePublic(dhPublic).toHex())
|
|
json.put("dh_remote", dhRemote?.let { X25519Crypto.serializePublic(it).toHex() })
|
|
json.put("root_key", rootKey.toHex())
|
|
json.put("send_ck", sendChainKey?.toHex())
|
|
json.put("recv_ck", recvChainKey?.toHex())
|
|
json.put("send_n", sendN)
|
|
json.put("recv_n", recvN)
|
|
json.put("prev_send_n", prevSendN)
|
|
|
|
val skippedJson = JSONObject()
|
|
for ((key, value) in skipped) {
|
|
skippedJson.put(key, value.toHex())
|
|
}
|
|
json.put("skipped", skippedJson)
|
|
|
|
return json.toString().toByteArray()
|
|
}
|
|
|
|
// --- Private helpers ---
|
|
|
|
private fun skipMessages(until: Int) {
|
|
if (recvChainKey == null) return
|
|
if (until - recvN > MAX_SKIP) {
|
|
throw CryptoException.MaxSkipExceeded("Cannot skip more than $MAX_SKIP messages")
|
|
}
|
|
|
|
var ck = recvChainKey!!
|
|
while (recvN < until) {
|
|
val (newCk, messageKey) = HkdfUtils.kdfCk(ck)
|
|
ck = newCk
|
|
val remoteHex = dhRemote?.let { X25519Crypto.serializePublic(it).toHex() } ?: ""
|
|
skipped[makeSkippedKey(remoteHex, recvN)] = messageKey
|
|
recvN++
|
|
}
|
|
recvChainKey = ck
|
|
}
|
|
|
|
private fun dhRatchet(remotePublic: X25519PublicKeyParameters) {
|
|
prevSendN = sendN
|
|
sendN = 0
|
|
recvN = 0
|
|
dhRemote = remotePublic
|
|
|
|
// Derive receive chain
|
|
val dhOutput1 = X25519Crypto.dh(dhPrivate, remotePublic)
|
|
val (rk1, recvCk) = HkdfUtils.kdfRk(rootKey, dhOutput1)
|
|
rootKey = rk1
|
|
recvChainKey = recvCk
|
|
|
|
// Generate new DH keypair and derive send chain
|
|
val (newPriv, newPub) = X25519Crypto.generateKeypair()
|
|
dhPrivate = newPriv
|
|
dhPublic = newPub
|
|
|
|
val dhOutput2 = X25519Crypto.dh(newPriv, remotePublic)
|
|
val (rk2, sendCk) = HkdfUtils.kdfRk(rootKey, dhOutput2)
|
|
rootKey = rk2
|
|
sendChainKey = sendCk
|
|
}
|
|
|
|
private data class Snapshot(
|
|
val dhPriv: ByteArray,
|
|
val dhPub: ByteArray,
|
|
val dhRemote: ByteArray?,
|
|
val rootKey: ByteArray,
|
|
val sendCk: ByteArray?,
|
|
val recvCk: ByteArray?,
|
|
val sendN: Int,
|
|
val recvN: Int,
|
|
val prevSendN: Int,
|
|
val skipped: Map<String, ByteArray>,
|
|
)
|
|
|
|
private fun snapshot(): Snapshot {
|
|
return Snapshot(
|
|
dhPriv = X25519Crypto.serializePrivate(dhPrivate),
|
|
dhPub = X25519Crypto.serializePublic(dhPublic),
|
|
dhRemote = dhRemote?.let { X25519Crypto.serializePublic(it) },
|
|
rootKey = rootKey.copyOf(),
|
|
sendCk = sendChainKey?.copyOf(),
|
|
recvCk = recvChainKey?.copyOf(),
|
|
sendN = sendN,
|
|
recvN = recvN,
|
|
prevSendN = prevSendN,
|
|
skipped = skipped.toMap(),
|
|
)
|
|
}
|
|
|
|
private fun restore(s: Snapshot) {
|
|
dhPrivate = X25519Crypto.loadPrivate(s.dhPriv)
|
|
dhPublic = X25519Crypto.loadPublic(s.dhPub)
|
|
dhRemote = s.dhRemote?.let { X25519Crypto.loadPublic(it) }
|
|
rootKey = s.rootKey
|
|
sendChainKey = s.sendCk
|
|
recvChainKey = s.recvCk
|
|
sendN = s.sendN
|
|
recvN = s.recvN
|
|
prevSendN = s.prevSendN
|
|
skipped.clear()
|
|
skipped.putAll(s.skipped)
|
|
}
|
|
|
|
private fun makeSkippedKey(dhHex: String, n: Int): String = "$dhHex:$n"
|
|
}
|
|
|
|
/**
|
|
* Ratchet message header.
|
|
* Serialized as JSON: {"dh_pub": hex, "n": int, "pn": int}
|
|
*/
|
|
data class RatchetHeader(
|
|
val dhPub: ByteArray,
|
|
val n: Int,
|
|
val pn: Int,
|
|
) {
|
|
fun serialize(): ByteArray {
|
|
val json = JSONObject()
|
|
json.put("dh_pub", dhPub.toHex())
|
|
json.put("n", n)
|
|
json.put("pn", pn)
|
|
return json.toString().toByteArray()
|
|
}
|
|
|
|
fun toMap(): Map<String, Any> = mapOf(
|
|
"dh_pub" to dhPub.toHex(),
|
|
"n" to n,
|
|
"pn" to pn,
|
|
)
|
|
|
|
companion object {
|
|
fun fromMap(map: Map<String, Any>): RatchetHeader {
|
|
return RatchetHeader(
|
|
dhPub = (map["dh_pub"] as String).hexToBytes(),
|
|
n = (map["n"] as Number).toInt(),
|
|
pn = (map["pn"] as Number).toInt(),
|
|
)
|
|
}
|
|
}
|
|
|
|
override fun equals(other: Any?): Boolean {
|
|
if (this === other) return true
|
|
if (other !is RatchetHeader) return false
|
|
return dhPub.contentEquals(other.dhPub) && n == other.n && pn == other.pn
|
|
}
|
|
|
|
override fun hashCode(): Int {
|
|
var result = dhPub.contentHashCode()
|
|
result = 31 * result + n
|
|
result = 31 * result + pn
|
|
return result
|
|
}
|
|
}
|
|
|
|
data class RatchetMessage(
|
|
val header: RatchetHeader,
|
|
val ciphertext: ByteArray,
|
|
val nonce: ByteArray,
|
|
) {
|
|
override fun equals(other: Any?): Boolean {
|
|
if (this === other) return true
|
|
if (other !is RatchetMessage) return false
|
|
return header == other.header && ciphertext.contentEquals(other.ciphertext) &&
|
|
nonce.contentEquals(other.nonce)
|
|
}
|
|
|
|
override fun hashCode(): Int {
|
|
var result = header.hashCode()
|
|
result = 31 * result + ciphertext.contentHashCode()
|
|
result = 31 * result + nonce.contentHashCode()
|
|
return result
|
|
}
|
|
}
|
|
|
|
// --- Hex extension functions ---
|
|
|
|
internal fun ByteArray.toHex(): String = joinToString("") { "%02x".format(it) }
|
|
|
|
internal fun String.hexToBytes(): ByteArray {
|
|
require(length % 2 == 0) { "Hex string must have even length" }
|
|
return chunked(2).map { it.toInt(16).toByte() }.toByteArray()
|
|
}
|