Initial commit: Kecalek Android client
Complete Android client for encrypted chat platform. 78+ Kotlin files: crypto (X3DH, Double Ratchet, AES-GCM, Ed25519, X25519, RSA-PSS), network (TCP/TLS, 50 endpoints), Hilt DI, Room+SQLCipher DB, Jetpack Compose UI with Catppuccin Mocha theme. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
396
app/src/main/java/com/kecalek/chat/crypto/DoubleRatchet.kt
Normal file
396
app/src/main/java/com/kecalek/chat/crypto/DoubleRatchet.kt
Normal file
@@ -0,0 +1,396 @@
|
||||
package com.kecalek.chat.crypto
|
||||
|
||||
import org.bouncycastle.crypto.params.X25519PrivateKeyParameters
|
||||
import org.bouncycastle.crypto.params.X25519PublicKeyParameters
|
||||
import org.json.JSONObject
|
||||
import java.security.SecureRandom
|
||||
|
||||
/**
|
||||
* Double Ratchet algorithm for end-to-end encrypted messaging.
|
||||
* Provides forward secrecy and break-in recovery.
|
||||
*
|
||||
* Compatible with Python DoubleRatchet class from crypto_utils.py.
|
||||
*
|
||||
* State:
|
||||
* - dh_pair: current ratchet X25519 keypair
|
||||
* - dh_remote: remote's current ratchet public key
|
||||
* - root_key: 32-byte root key
|
||||
* - send_chain_key / recv_chain_key: current chain keys
|
||||
* - send_n / recv_n: message counters
|
||||
* - prev_send_n: previous sending chain length
|
||||
* - skipped: map of (dh_hex, n) -> message_key for out-of-order delivery
|
||||
*/
|
||||
class DoubleRatchet private constructor() {
|
||||
|
||||
private lateinit var dhPrivate: X25519PrivateKeyParameters
|
||||
private lateinit var dhPublic: X25519PublicKeyParameters
|
||||
private var dhRemote: X25519PublicKeyParameters? = null
|
||||
private lateinit var rootKey: ByteArray
|
||||
private var sendChainKey: ByteArray? = null
|
||||
private var recvChainKey: ByteArray? = null
|
||||
private var sendN: Int = 0
|
||||
private var recvN: Int = 0
|
||||
private var prevSendN: Int = 0
|
||||
|
||||
// skipped[(remotePublicHex, messageNumber)] = messageKey
|
||||
private val skipped = mutableMapOf<String, ByteArray>()
|
||||
|
||||
companion object {
|
||||
private const val MAX_SKIP = 256
|
||||
|
||||
/**
|
||||
* Initialize as Alice (initiator).
|
||||
* Called after X3DH produces a shared secret.
|
||||
*
|
||||
* @param sharedSecret X3DH shared secret
|
||||
* @param bobSpkPub Bob's signed pre-key public (used as initial remote ratchet key)
|
||||
*/
|
||||
fun initAlice(sharedSecret: ByteArray, bobSpkPub: X25519PublicKeyParameters): DoubleRatchet {
|
||||
val ratchet = DoubleRatchet()
|
||||
|
||||
// Generate initial ratchet keypair
|
||||
val (dhPriv, dhPub) = X25519Crypto.generateKeypair()
|
||||
ratchet.dhPrivate = dhPriv
|
||||
ratchet.dhPublic = dhPub
|
||||
ratchet.dhRemote = bobSpkPub
|
||||
|
||||
// Initial DH ratchet step
|
||||
val dhOutput = X25519Crypto.dh(dhPriv, bobSpkPub)
|
||||
val (newRootKey, sendChainKey) = HkdfUtils.kdfRk(sharedSecret, dhOutput)
|
||||
ratchet.rootKey = newRootKey
|
||||
ratchet.sendChainKey = sendChainKey
|
||||
ratchet.recvChainKey = null
|
||||
ratchet.sendN = 0
|
||||
ratchet.recvN = 0
|
||||
ratchet.prevSendN = 0
|
||||
|
||||
return ratchet
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize as Bob (responder).
|
||||
* Uses SPK pair as initial ratchet key.
|
||||
*
|
||||
* @param sharedSecret X3DH shared secret
|
||||
* @param spkPair Bob's signed pre-key pair (private, public)
|
||||
*/
|
||||
fun initBob(
|
||||
sharedSecret: ByteArray,
|
||||
spkPair: Pair<X25519PrivateKeyParameters, X25519PublicKeyParameters>,
|
||||
): DoubleRatchet {
|
||||
val ratchet = DoubleRatchet()
|
||||
ratchet.dhPrivate = spkPair.first
|
||||
ratchet.dhPublic = spkPair.second
|
||||
ratchet.rootKey = sharedSecret
|
||||
ratchet.sendChainKey = null
|
||||
ratchet.recvChainKey = null
|
||||
ratchet.sendN = 0
|
||||
ratchet.recvN = 0
|
||||
ratchet.prevSendN = 0
|
||||
return ratchet
|
||||
}
|
||||
|
||||
/**
|
||||
* Import ratchet state from JSON bytes.
|
||||
*/
|
||||
fun importState(data: ByteArray): DoubleRatchet {
|
||||
val json = JSONObject(String(data))
|
||||
val ratchet = DoubleRatchet()
|
||||
|
||||
ratchet.dhPrivate = X25519Crypto.loadPrivate(json.getString("dh_priv").hexToBytes())
|
||||
ratchet.dhPublic = X25519Crypto.loadPublic(json.getString("dh_pub").hexToBytes())
|
||||
|
||||
if (json.has("dh_remote") && !json.isNull("dh_remote")) {
|
||||
ratchet.dhRemote = X25519Crypto.loadPublic(json.getString("dh_remote").hexToBytes())
|
||||
}
|
||||
|
||||
ratchet.rootKey = json.getString("root_key").hexToBytes()
|
||||
ratchet.sendChainKey = json.optString("send_ck", "").takeIf { it.isNotEmpty() }?.hexToBytes()
|
||||
ratchet.recvChainKey = json.optString("recv_ck", "").takeIf { it.isNotEmpty() }?.hexToBytes()
|
||||
ratchet.sendN = json.getInt("send_n")
|
||||
ratchet.recvN = json.getInt("recv_n")
|
||||
ratchet.prevSendN = json.getInt("prev_send_n")
|
||||
|
||||
// Import skipped keys
|
||||
if (json.has("skipped")) {
|
||||
val skippedJson = json.getJSONObject("skipped")
|
||||
for (key in skippedJson.keys()) {
|
||||
ratchet.skipped[key] = skippedJson.getString(key).hexToBytes()
|
||||
}
|
||||
}
|
||||
|
||||
return ratchet
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Encrypt plaintext message.
|
||||
* @return RatchetMessage with header dict, ciphertext+tag, nonce
|
||||
*/
|
||||
fun encrypt(plaintext: ByteArray): RatchetMessage {
|
||||
val ck = sendChainKey ?: throw CryptoException.DecryptionFailed("Send chain not initialized")
|
||||
|
||||
val (newChainKey, messageKey) = HkdfUtils.kdfCk(ck)
|
||||
sendChainKey = newChainKey
|
||||
|
||||
val header = RatchetHeader(
|
||||
dhPub = X25519Crypto.serializePublic(dhPublic),
|
||||
n = sendN,
|
||||
pn = prevSendN,
|
||||
)
|
||||
|
||||
val aad = header.serialize()
|
||||
val (nonce, ctWithTag) = AesGcmCrypto.encryptCombined(
|
||||
plaintext = plaintext,
|
||||
key = messageKey,
|
||||
aad = aad,
|
||||
)
|
||||
|
||||
sendN++
|
||||
|
||||
return RatchetMessage(
|
||||
header = header,
|
||||
ciphertext = ctWithTag,
|
||||
nonce = nonce,
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Decrypt received message.
|
||||
* Handles out-of-order delivery via skipped message keys.
|
||||
* Full state rollback on failure.
|
||||
*/
|
||||
fun decrypt(header: RatchetHeader, ciphertext: ByteArray, nonce: ByteArray): ByteArray {
|
||||
val aad = header.serialize()
|
||||
|
||||
// Check skipped message keys first (no state change)
|
||||
val skippedKey = makeSkippedKey(header.dhPub.toHex(), header.n)
|
||||
skipped.remove(skippedKey)?.let { messageKey ->
|
||||
return AesGcmCrypto.decryptCombined(
|
||||
key = messageKey,
|
||||
nonce = nonce,
|
||||
ctWithTag = ciphertext,
|
||||
aad = aad,
|
||||
)
|
||||
}
|
||||
|
||||
// Take snapshot for rollback
|
||||
val snapshot = snapshot()
|
||||
|
||||
try {
|
||||
// New DH ratchet step if remote key changed
|
||||
val remoteHex = header.dhPub.toHex()
|
||||
val currentRemoteHex = dhRemote?.let { X25519Crypto.serializePublic(it).toHex() }
|
||||
|
||||
if (remoteHex != currentRemoteHex) {
|
||||
skipMessages(header.pn)
|
||||
dhRatchet(X25519Crypto.loadPublic(header.dhPub))
|
||||
}
|
||||
|
||||
skipMessages(header.n)
|
||||
|
||||
val ck = recvChainKey ?: throw CryptoException.DecryptionFailed("Receive chain not initialized")
|
||||
val (newChainKey, messageKey) = HkdfUtils.kdfCk(ck)
|
||||
recvChainKey = newChainKey
|
||||
recvN++
|
||||
|
||||
return AesGcmCrypto.decryptCombined(
|
||||
key = messageKey,
|
||||
nonce = nonce,
|
||||
ctWithTag = ciphertext,
|
||||
aad = aad,
|
||||
)
|
||||
} catch (e: Exception) {
|
||||
// Rollback on any failure
|
||||
restore(snapshot)
|
||||
throw if (e is CryptoException) e
|
||||
else CryptoException.DecryptionFailed("Decryption failed", e)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Export full ratchet state as JSON bytes.
|
||||
*/
|
||||
fun exportState(): ByteArray {
|
||||
val json = JSONObject()
|
||||
json.put("dh_priv", X25519Crypto.serializePrivate(dhPrivate).toHex())
|
||||
json.put("dh_pub", X25519Crypto.serializePublic(dhPublic).toHex())
|
||||
json.put("dh_remote", dhRemote?.let { X25519Crypto.serializePublic(it).toHex() })
|
||||
json.put("root_key", rootKey.toHex())
|
||||
json.put("send_ck", sendChainKey?.toHex())
|
||||
json.put("recv_ck", recvChainKey?.toHex())
|
||||
json.put("send_n", sendN)
|
||||
json.put("recv_n", recvN)
|
||||
json.put("prev_send_n", prevSendN)
|
||||
|
||||
val skippedJson = JSONObject()
|
||||
for ((key, value) in skipped) {
|
||||
skippedJson.put(key, value.toHex())
|
||||
}
|
||||
json.put("skipped", skippedJson)
|
||||
|
||||
return json.toString().toByteArray()
|
||||
}
|
||||
|
||||
// --- Private helpers ---
|
||||
|
||||
private fun skipMessages(until: Int) {
|
||||
if (recvChainKey == null) return
|
||||
if (until - recvN > MAX_SKIP) {
|
||||
throw CryptoException.MaxSkipExceeded("Cannot skip more than $MAX_SKIP messages")
|
||||
}
|
||||
|
||||
var ck = recvChainKey!!
|
||||
while (recvN < until) {
|
||||
val (newCk, messageKey) = HkdfUtils.kdfCk(ck)
|
||||
ck = newCk
|
||||
val remoteHex = dhRemote?.let { X25519Crypto.serializePublic(it).toHex() } ?: ""
|
||||
skipped[makeSkippedKey(remoteHex, recvN)] = messageKey
|
||||
recvN++
|
||||
}
|
||||
recvChainKey = ck
|
||||
}
|
||||
|
||||
private fun dhRatchet(remotePublic: X25519PublicKeyParameters) {
|
||||
prevSendN = sendN
|
||||
sendN = 0
|
||||
recvN = 0
|
||||
dhRemote = remotePublic
|
||||
|
||||
// Derive receive chain
|
||||
val dhOutput1 = X25519Crypto.dh(dhPrivate, remotePublic)
|
||||
val (rk1, recvCk) = HkdfUtils.kdfRk(rootKey, dhOutput1)
|
||||
rootKey = rk1
|
||||
recvChainKey = recvCk
|
||||
|
||||
// Generate new DH keypair and derive send chain
|
||||
val (newPriv, newPub) = X25519Crypto.generateKeypair()
|
||||
dhPrivate = newPriv
|
||||
dhPublic = newPub
|
||||
|
||||
val dhOutput2 = X25519Crypto.dh(newPriv, remotePublic)
|
||||
val (rk2, sendCk) = HkdfUtils.kdfRk(rootKey, dhOutput2)
|
||||
rootKey = rk2
|
||||
sendChainKey = sendCk
|
||||
}
|
||||
|
||||
private data class Snapshot(
|
||||
val dhPriv: ByteArray,
|
||||
val dhPub: ByteArray,
|
||||
val dhRemote: ByteArray?,
|
||||
val rootKey: ByteArray,
|
||||
val sendCk: ByteArray?,
|
||||
val recvCk: ByteArray?,
|
||||
val sendN: Int,
|
||||
val recvN: Int,
|
||||
val prevSendN: Int,
|
||||
val skipped: Map<String, ByteArray>,
|
||||
)
|
||||
|
||||
private fun snapshot(): Snapshot {
|
||||
return Snapshot(
|
||||
dhPriv = X25519Crypto.serializePrivate(dhPrivate),
|
||||
dhPub = X25519Crypto.serializePublic(dhPublic),
|
||||
dhRemote = dhRemote?.let { X25519Crypto.serializePublic(it) },
|
||||
rootKey = rootKey.copyOf(),
|
||||
sendCk = sendChainKey?.copyOf(),
|
||||
recvCk = recvChainKey?.copyOf(),
|
||||
sendN = sendN,
|
||||
recvN = recvN,
|
||||
prevSendN = prevSendN,
|
||||
skipped = skipped.toMap(),
|
||||
)
|
||||
}
|
||||
|
||||
private fun restore(s: Snapshot) {
|
||||
dhPrivate = X25519Crypto.loadPrivate(s.dhPriv)
|
||||
dhPublic = X25519Crypto.loadPublic(s.dhPub)
|
||||
dhRemote = s.dhRemote?.let { X25519Crypto.loadPublic(it) }
|
||||
rootKey = s.rootKey
|
||||
sendChainKey = s.sendCk
|
||||
recvChainKey = s.recvCk
|
||||
sendN = s.sendN
|
||||
recvN = s.recvN
|
||||
prevSendN = s.prevSendN
|
||||
skipped.clear()
|
||||
skipped.putAll(s.skipped)
|
||||
}
|
||||
|
||||
private fun makeSkippedKey(dhHex: String, n: Int): String = "$dhHex:$n"
|
||||
}
|
||||
|
||||
/**
|
||||
* Ratchet message header.
|
||||
* Serialized as JSON: {"dh_pub": hex, "n": int, "pn": int}
|
||||
*/
|
||||
data class RatchetHeader(
|
||||
val dhPub: ByteArray,
|
||||
val n: Int,
|
||||
val pn: Int,
|
||||
) {
|
||||
fun serialize(): ByteArray {
|
||||
val json = JSONObject()
|
||||
json.put("dh_pub", dhPub.toHex())
|
||||
json.put("n", n)
|
||||
json.put("pn", pn)
|
||||
return json.toString().toByteArray()
|
||||
}
|
||||
|
||||
fun toMap(): Map<String, Any> = mapOf(
|
||||
"dh_pub" to dhPub.toHex(),
|
||||
"n" to n,
|
||||
"pn" to pn,
|
||||
)
|
||||
|
||||
companion object {
|
||||
fun fromMap(map: Map<String, Any>): RatchetHeader {
|
||||
return RatchetHeader(
|
||||
dhPub = (map["dh_pub"] as String).hexToBytes(),
|
||||
n = (map["n"] as Number).toInt(),
|
||||
pn = (map["pn"] as Number).toInt(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other !is RatchetHeader) return false
|
||||
return dhPub.contentEquals(other.dhPub) && n == other.n && pn == other.pn
|
||||
}
|
||||
|
||||
override fun hashCode(): Int {
|
||||
var result = dhPub.contentHashCode()
|
||||
result = 31 * result + n
|
||||
result = 31 * result + pn
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
data class RatchetMessage(
|
||||
val header: RatchetHeader,
|
||||
val ciphertext: ByteArray,
|
||||
val nonce: ByteArray,
|
||||
) {
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other !is RatchetMessage) return false
|
||||
return header == other.header && ciphertext.contentEquals(other.ciphertext) &&
|
||||
nonce.contentEquals(other.nonce)
|
||||
}
|
||||
|
||||
override fun hashCode(): Int {
|
||||
var result = header.hashCode()
|
||||
result = 31 * result + ciphertext.contentHashCode()
|
||||
result = 31 * result + nonce.contentHashCode()
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
// --- Hex extension functions ---
|
||||
|
||||
internal fun ByteArray.toHex(): String = joinToString("") { "%02x".format(it) }
|
||||
|
||||
internal fun String.hexToBytes(): ByteArray {
|
||||
require(length % 2 == 0) { "Hex string must have even length" }
|
||||
return chunked(2).map { it.toInt(16).toByte() }.toByteArray()
|
||||
}
|
||||
Reference in New Issue
Block a user