Initial commit: Kecalek Android client
Complete Android client for encrypted chat platform. 78+ Kotlin files: crypto (X3DH, Double Ratchet, AES-GCM, Ed25519, X25519, RSA-PSS), network (TCP/TLS, 50 endpoints), Hilt DI, Room+SQLCipher DB, Jetpack Compose UI with Catppuccin Mocha theme. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
156
app/src/main/java/com/kecalek/chat/crypto/AesGcmCrypto.kt
Normal file
156
app/src/main/java/com/kecalek/chat/crypto/AesGcmCrypto.kt
Normal file
@@ -0,0 +1,156 @@
|
||||
package com.kecalek.chat.crypto
|
||||
|
||||
import java.security.SecureRandom
|
||||
import javax.crypto.Cipher
|
||||
import javax.crypto.spec.GCMParameterSpec
|
||||
import javax.crypto.spec.SecretKeySpec
|
||||
|
||||
/**
|
||||
* AES-256-GCM encryption/decryption.
|
||||
* Nonce: 12 bytes (96 bits), Tag: 128 bits.
|
||||
* Compatible with Python's AESGCM from cryptography library.
|
||||
*/
|
||||
object AesGcmCrypto {
|
||||
|
||||
private const val KEY_SIZE = 32
|
||||
private const val NONCE_SIZE = 12
|
||||
private const val TAG_BITS = 128
|
||||
private const val ALGORITHM = "AES/GCM/NoPadding"
|
||||
|
||||
private val secureRandom = SecureRandom()
|
||||
|
||||
/**
|
||||
* Encrypt plaintext with AES-256-GCM.
|
||||
* @param plaintext data to encrypt
|
||||
* @param key 32-byte AES key (generated if null)
|
||||
* @param aad optional additional authenticated data
|
||||
* @return AesGcmResult with key, nonce, ciphertext (without tag), tag (16 bytes)
|
||||
*/
|
||||
fun encrypt(
|
||||
plaintext: ByteArray,
|
||||
key: ByteArray? = null,
|
||||
aad: ByteArray? = null,
|
||||
): AesGcmResult {
|
||||
val aesKey = key ?: generateKey()
|
||||
require(aesKey.size == KEY_SIZE) { "AES key must be $KEY_SIZE bytes" }
|
||||
|
||||
val nonce = ByteArray(NONCE_SIZE).also { secureRandom.nextBytes(it) }
|
||||
val cipher = Cipher.getInstance(ALGORITHM)
|
||||
cipher.init(
|
||||
Cipher.ENCRYPT_MODE,
|
||||
SecretKeySpec(aesKey, "AES"),
|
||||
GCMParameterSpec(TAG_BITS, nonce),
|
||||
)
|
||||
if (aad != null) cipher.updateAAD(aad)
|
||||
|
||||
// Java GCM appends tag to ciphertext
|
||||
val ctWithTag = cipher.doFinal(plaintext)
|
||||
val ciphertext = ctWithTag.copyOfRange(0, ctWithTag.size - 16)
|
||||
val tag = ctWithTag.copyOfRange(ctWithTag.size - 16, ctWithTag.size)
|
||||
|
||||
return AesGcmResult(aesKey, nonce, ciphertext, tag)
|
||||
}
|
||||
|
||||
/**
|
||||
* Decrypt AES-256-GCM ciphertext.
|
||||
* @param key 32-byte AES key
|
||||
* @param nonce 12-byte nonce
|
||||
* @param ciphertext encrypted data (without tag)
|
||||
* @param tag 16-byte authentication tag
|
||||
* @param aad optional additional authenticated data
|
||||
* @return decrypted plaintext
|
||||
*/
|
||||
fun decrypt(
|
||||
key: ByteArray,
|
||||
nonce: ByteArray,
|
||||
ciphertext: ByteArray,
|
||||
tag: ByteArray,
|
||||
aad: ByteArray? = null,
|
||||
): ByteArray {
|
||||
require(key.size == KEY_SIZE) { "AES key must be $KEY_SIZE bytes" }
|
||||
require(nonce.size == NONCE_SIZE) { "Nonce must be $NONCE_SIZE bytes" }
|
||||
require(tag.size == 16) { "Tag must be 16 bytes" }
|
||||
|
||||
val cipher = Cipher.getInstance(ALGORITHM)
|
||||
cipher.init(
|
||||
Cipher.DECRYPT_MODE,
|
||||
SecretKeySpec(key, "AES"),
|
||||
GCMParameterSpec(TAG_BITS, nonce),
|
||||
)
|
||||
if (aad != null) cipher.updateAAD(aad)
|
||||
|
||||
// Java expects ciphertext + tag concatenated
|
||||
val ctWithTag = ciphertext + tag
|
||||
return cipher.doFinal(ctWithTag)
|
||||
}
|
||||
|
||||
/**
|
||||
* Encrypt returning ciphertext+tag combined (for internal use by ECP1, Double Ratchet).
|
||||
*/
|
||||
fun encryptCombined(
|
||||
plaintext: ByteArray,
|
||||
key: ByteArray,
|
||||
aad: ByteArray? = null,
|
||||
): Pair<ByteArray, ByteArray> {
|
||||
require(key.size == KEY_SIZE) { "AES key must be $KEY_SIZE bytes" }
|
||||
|
||||
val nonce = ByteArray(NONCE_SIZE).also { secureRandom.nextBytes(it) }
|
||||
val cipher = Cipher.getInstance(ALGORITHM)
|
||||
cipher.init(
|
||||
Cipher.ENCRYPT_MODE,
|
||||
SecretKeySpec(key, "AES"),
|
||||
GCMParameterSpec(TAG_BITS, nonce),
|
||||
)
|
||||
if (aad != null) cipher.updateAAD(aad)
|
||||
|
||||
val ctWithTag = cipher.doFinal(plaintext)
|
||||
return Pair(nonce, ctWithTag)
|
||||
}
|
||||
|
||||
/**
|
||||
* Decrypt ciphertext+tag combined.
|
||||
*/
|
||||
fun decryptCombined(
|
||||
key: ByteArray,
|
||||
nonce: ByteArray,
|
||||
ctWithTag: ByteArray,
|
||||
aad: ByteArray? = null,
|
||||
): ByteArray {
|
||||
require(key.size == KEY_SIZE) { "AES key must be $KEY_SIZE bytes" }
|
||||
require(nonce.size == NONCE_SIZE) { "Nonce must be $NONCE_SIZE bytes" }
|
||||
|
||||
val cipher = Cipher.getInstance(ALGORITHM)
|
||||
cipher.init(
|
||||
Cipher.DECRYPT_MODE,
|
||||
SecretKeySpec(key, "AES"),
|
||||
GCMParameterSpec(TAG_BITS, nonce),
|
||||
)
|
||||
if (aad != null) cipher.updateAAD(aad)
|
||||
|
||||
return cipher.doFinal(ctWithTag)
|
||||
}
|
||||
|
||||
fun generateKey(): ByteArray = ByteArray(KEY_SIZE).also { secureRandom.nextBytes(it) }
|
||||
}
|
||||
|
||||
data class AesGcmResult(
|
||||
val key: ByteArray,
|
||||
val nonce: ByteArray,
|
||||
val ciphertext: ByteArray,
|
||||
val tag: ByteArray,
|
||||
) {
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other !is AesGcmResult) return false
|
||||
return key.contentEquals(other.key) && nonce.contentEquals(other.nonce) &&
|
||||
ciphertext.contentEquals(other.ciphertext) && tag.contentEquals(other.tag)
|
||||
}
|
||||
|
||||
override fun hashCode(): Int {
|
||||
var result = key.contentHashCode()
|
||||
result = 31 * result + nonce.contentHashCode()
|
||||
result = 31 * result + ciphertext.contentHashCode()
|
||||
result = 31 * result + tag.contentHashCode()
|
||||
return result
|
||||
}
|
||||
}
|
||||
130
app/src/main/java/com/kecalek/chat/crypto/ContactVerification.kt
Normal file
130
app/src/main/java/com/kecalek/chat/crypto/ContactVerification.kt
Normal file
@@ -0,0 +1,130 @@
|
||||
package com.kecalek.chat.crypto
|
||||
|
||||
import java.math.BigInteger
|
||||
import java.nio.ByteBuffer
|
||||
import java.security.MessageDigest
|
||||
|
||||
/**
|
||||
* Contact key verification: fingerprints, safety numbers, QR code payloads.
|
||||
* Compatible with Python compute_fingerprint, compute_safety_number,
|
||||
* encode_verification_qr, decode_verification_qr.
|
||||
*
|
||||
* Fingerprint: SHA-512 iterated 5200x over (version + identity_key + user_id).
|
||||
* Safety number: 60 digits (12 groups of 5), derived from both users' fingerprints.
|
||||
*/
|
||||
object ContactVerification {
|
||||
|
||||
private const val FINGERPRINT_VERSION = 0
|
||||
private const val FINGERPRINT_ITERATIONS = 5200
|
||||
private const val QR_VERSION: Byte = 0x01
|
||||
|
||||
/**
|
||||
* Compute fingerprint for a user's identity key.
|
||||
* @param userId user ID string
|
||||
* @param identityKeyBytes 32-byte Ed25519 public key
|
||||
* @param iterations number of SHA-512 iterations (default 5200)
|
||||
* @return 32 bytes (first 32 of final SHA-512 hash)
|
||||
*/
|
||||
fun computeFingerprint(
|
||||
userId: String,
|
||||
identityKeyBytes: ByteArray,
|
||||
iterations: Int = FINGERPRINT_ITERATIONS,
|
||||
): ByteArray {
|
||||
// Seed: version(2B big-endian) + identity_key(32B) + user_id(UTF-8)
|
||||
val versionBytes = ByteBuffer.allocate(2).putShort(FINGERPRINT_VERSION.toShort()).array()
|
||||
val userIdBytes = userId.toByteArray(Charsets.UTF_8)
|
||||
var data = versionBytes + identityKeyBytes + userIdBytes
|
||||
|
||||
val digest = MessageDigest.getInstance("SHA-512")
|
||||
for (i in 0 until iterations) {
|
||||
digest.reset()
|
||||
digest.update(data)
|
||||
digest.update(identityKeyBytes)
|
||||
data = digest.digest()
|
||||
}
|
||||
|
||||
return data.copyOfRange(0, 32)
|
||||
}
|
||||
|
||||
/**
|
||||
* Format fingerprint bytes as 6 groups of 5 digits.
|
||||
* Each group: int.from_bytes(5 bytes, "big") % 100_000, zero-padded.
|
||||
* @return "XXXXX XXXXX XXXXX\nXXXXX XXXXX XXXXX"
|
||||
*/
|
||||
fun formatFingerprint(fpBytes: ByteArray): String {
|
||||
val groups = (0 until 6).map { i ->
|
||||
val chunk = fpBytes.copyOfRange(i * 5, (i + 1) * 5)
|
||||
val num = BigInteger(1, chunk).mod(BigInteger.valueOf(100_000)).toInt()
|
||||
"%05d".format(num)
|
||||
}
|
||||
return "${groups[0]} ${groups[1]} ${groups[2]}\n${groups[3]} ${groups[4]} ${groups[5]}"
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute safety number between two users.
|
||||
* Deterministic ordering: lower user_id fingerprint comes first.
|
||||
* @return 60 digits as "XXXXX XXXXX XXXXX XXXXX\n..." (3 lines of 4 groups)
|
||||
*/
|
||||
fun computeSafetyNumber(
|
||||
myUserId: String,
|
||||
myIdentityKey: ByteArray,
|
||||
theirUserId: String,
|
||||
theirIdentityKey: ByteArray,
|
||||
): String {
|
||||
val myFp = computeFingerprint(myUserId, myIdentityKey)
|
||||
val theirFp = computeFingerprint(theirUserId, theirIdentityKey)
|
||||
|
||||
// Deterministic ordering: lower user_id first
|
||||
val combined = if (myUserId < theirUserId) {
|
||||
myFp + theirFp
|
||||
} else {
|
||||
theirFp + myFp
|
||||
}
|
||||
|
||||
// 12 groups of 5 digits from 64 bytes
|
||||
val groups = (0 until 12).map { i ->
|
||||
val chunk = combined.copyOfRange(i * 5, (i + 1) * 5)
|
||||
val num = BigInteger(1, chunk).mod(BigInteger.valueOf(100_000)).toInt()
|
||||
"%05d".format(num)
|
||||
}
|
||||
|
||||
return "${groups[0]} ${groups[1]} ${groups[2]} ${groups[3]}\n" +
|
||||
"${groups[4]} ${groups[5]} ${groups[6]} ${groups[7]}\n" +
|
||||
"${groups[8]} ${groups[9]} ${groups[10]} ${groups[11]}"
|
||||
}
|
||||
|
||||
/**
|
||||
* Encode verification QR code payload.
|
||||
* Format: 0x01 + uid_len(1B) + uid(UTF-8) + identity_key(32B)
|
||||
*/
|
||||
fun encodeVerificationQR(userId: String, identityKeyBytes: ByteArray): ByteArray {
|
||||
val uidBytes = userId.toByteArray(Charsets.UTF_8)
|
||||
require(uidBytes.size <= 255) { "User ID too long for QR encoding" }
|
||||
|
||||
val result = ByteArray(1 + 1 + uidBytes.size + identityKeyBytes.size)
|
||||
result[0] = QR_VERSION
|
||||
result[1] = uidBytes.size.toByte()
|
||||
System.arraycopy(uidBytes, 0, result, 2, uidBytes.size)
|
||||
System.arraycopy(identityKeyBytes, 0, result, 2 + uidBytes.size, identityKeyBytes.size)
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* Decode verification QR code payload.
|
||||
* @return Pair(userId, identityKeyBytes)
|
||||
* @throws CryptoException.InvalidQRCode on invalid format
|
||||
*/
|
||||
fun decodeVerificationQR(data: ByteArray): Pair<String, ByteArray> {
|
||||
if (data.size < 3) throw CryptoException.InvalidQRCode("QR data too short")
|
||||
if (data[0] != QR_VERSION) throw CryptoException.InvalidQRCode("Unknown QR version: ${data[0]}")
|
||||
|
||||
val uidLen = data[1].toInt() and 0xFF
|
||||
if (data.size < 2 + uidLen + 32) {
|
||||
throw CryptoException.InvalidQRCode("QR data incomplete")
|
||||
}
|
||||
|
||||
val userId = String(data, 2, uidLen, Charsets.UTF_8)
|
||||
val identityKey = data.copyOfRange(2 + uidLen, 2 + uidLen + 32)
|
||||
return Pair(userId, identityKey)
|
||||
}
|
||||
}
|
||||
33
app/src/main/java/com/kecalek/chat/crypto/CryptoErrors.kt
Normal file
33
app/src/main/java/com/kecalek/chat/crypto/CryptoErrors.kt
Normal file
@@ -0,0 +1,33 @@
|
||||
package com.kecalek.chat.crypto
|
||||
|
||||
/**
|
||||
* Error types for cryptographic operations.
|
||||
*/
|
||||
sealed class CryptoException(message: String, cause: Throwable? = null) : Exception(message, cause) {
|
||||
class DecryptionFailed(message: String = "Decryption failed", cause: Throwable? = null) :
|
||||
CryptoException(message, cause)
|
||||
|
||||
class InvalidSignature(message: String = "Signature verification failed") :
|
||||
CryptoException(message)
|
||||
|
||||
class InvalidKey(message: String = "Invalid key format", cause: Throwable? = null) :
|
||||
CryptoException(message, cause)
|
||||
|
||||
class InvalidPassword(message: String = "Invalid password", cause: Throwable? = null) :
|
||||
CryptoException(message, cause)
|
||||
|
||||
class MaxSkipExceeded(message: String = "Maximum message skip exceeded") :
|
||||
CryptoException(message)
|
||||
|
||||
class InvalidHeader(message: String = "Invalid ratchet header") :
|
||||
CryptoException(message)
|
||||
|
||||
class ChainIdMismatch(message: String = "Sender key chain ID mismatch") :
|
||||
CryptoException(message)
|
||||
|
||||
class InvalidQRCode(message: String = "Invalid verification QR code") :
|
||||
CryptoException(message)
|
||||
|
||||
class X3DHFailed(message: String = "X3DH key agreement failed", cause: Throwable? = null) :
|
||||
CryptoException(message, cause)
|
||||
}
|
||||
396
app/src/main/java/com/kecalek/chat/crypto/DoubleRatchet.kt
Normal file
396
app/src/main/java/com/kecalek/chat/crypto/DoubleRatchet.kt
Normal file
@@ -0,0 +1,396 @@
|
||||
package com.kecalek.chat.crypto
|
||||
|
||||
import org.bouncycastle.crypto.params.X25519PrivateKeyParameters
|
||||
import org.bouncycastle.crypto.params.X25519PublicKeyParameters
|
||||
import org.json.JSONObject
|
||||
import java.security.SecureRandom
|
||||
|
||||
/**
|
||||
* Double Ratchet algorithm for end-to-end encrypted messaging.
|
||||
* Provides forward secrecy and break-in recovery.
|
||||
*
|
||||
* Compatible with Python DoubleRatchet class from crypto_utils.py.
|
||||
*
|
||||
* State:
|
||||
* - dh_pair: current ratchet X25519 keypair
|
||||
* - dh_remote: remote's current ratchet public key
|
||||
* - root_key: 32-byte root key
|
||||
* - send_chain_key / recv_chain_key: current chain keys
|
||||
* - send_n / recv_n: message counters
|
||||
* - prev_send_n: previous sending chain length
|
||||
* - skipped: map of (dh_hex, n) -> message_key for out-of-order delivery
|
||||
*/
|
||||
class DoubleRatchet private constructor() {
|
||||
|
||||
private lateinit var dhPrivate: X25519PrivateKeyParameters
|
||||
private lateinit var dhPublic: X25519PublicKeyParameters
|
||||
private var dhRemote: X25519PublicKeyParameters? = null
|
||||
private lateinit var rootKey: ByteArray
|
||||
private var sendChainKey: ByteArray? = null
|
||||
private var recvChainKey: ByteArray? = null
|
||||
private var sendN: Int = 0
|
||||
private var recvN: Int = 0
|
||||
private var prevSendN: Int = 0
|
||||
|
||||
// skipped[(remotePublicHex, messageNumber)] = messageKey
|
||||
private val skipped = mutableMapOf<String, ByteArray>()
|
||||
|
||||
companion object {
|
||||
private const val MAX_SKIP = 256
|
||||
|
||||
/**
|
||||
* Initialize as Alice (initiator).
|
||||
* Called after X3DH produces a shared secret.
|
||||
*
|
||||
* @param sharedSecret X3DH shared secret
|
||||
* @param bobSpkPub Bob's signed pre-key public (used as initial remote ratchet key)
|
||||
*/
|
||||
fun initAlice(sharedSecret: ByteArray, bobSpkPub: X25519PublicKeyParameters): DoubleRatchet {
|
||||
val ratchet = DoubleRatchet()
|
||||
|
||||
// Generate initial ratchet keypair
|
||||
val (dhPriv, dhPub) = X25519Crypto.generateKeypair()
|
||||
ratchet.dhPrivate = dhPriv
|
||||
ratchet.dhPublic = dhPub
|
||||
ratchet.dhRemote = bobSpkPub
|
||||
|
||||
// Initial DH ratchet step
|
||||
val dhOutput = X25519Crypto.dh(dhPriv, bobSpkPub)
|
||||
val (newRootKey, sendChainKey) = HkdfUtils.kdfRk(sharedSecret, dhOutput)
|
||||
ratchet.rootKey = newRootKey
|
||||
ratchet.sendChainKey = sendChainKey
|
||||
ratchet.recvChainKey = null
|
||||
ratchet.sendN = 0
|
||||
ratchet.recvN = 0
|
||||
ratchet.prevSendN = 0
|
||||
|
||||
return ratchet
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize as Bob (responder).
|
||||
* Uses SPK pair as initial ratchet key.
|
||||
*
|
||||
* @param sharedSecret X3DH shared secret
|
||||
* @param spkPair Bob's signed pre-key pair (private, public)
|
||||
*/
|
||||
fun initBob(
|
||||
sharedSecret: ByteArray,
|
||||
spkPair: Pair<X25519PrivateKeyParameters, X25519PublicKeyParameters>,
|
||||
): DoubleRatchet {
|
||||
val ratchet = DoubleRatchet()
|
||||
ratchet.dhPrivate = spkPair.first
|
||||
ratchet.dhPublic = spkPair.second
|
||||
ratchet.rootKey = sharedSecret
|
||||
ratchet.sendChainKey = null
|
||||
ratchet.recvChainKey = null
|
||||
ratchet.sendN = 0
|
||||
ratchet.recvN = 0
|
||||
ratchet.prevSendN = 0
|
||||
return ratchet
|
||||
}
|
||||
|
||||
/**
|
||||
* Import ratchet state from JSON bytes.
|
||||
*/
|
||||
fun importState(data: ByteArray): DoubleRatchet {
|
||||
val json = JSONObject(String(data))
|
||||
val ratchet = DoubleRatchet()
|
||||
|
||||
ratchet.dhPrivate = X25519Crypto.loadPrivate(json.getString("dh_priv").hexToBytes())
|
||||
ratchet.dhPublic = X25519Crypto.loadPublic(json.getString("dh_pub").hexToBytes())
|
||||
|
||||
if (json.has("dh_remote") && !json.isNull("dh_remote")) {
|
||||
ratchet.dhRemote = X25519Crypto.loadPublic(json.getString("dh_remote").hexToBytes())
|
||||
}
|
||||
|
||||
ratchet.rootKey = json.getString("root_key").hexToBytes()
|
||||
ratchet.sendChainKey = json.optString("send_ck", "").takeIf { it.isNotEmpty() }?.hexToBytes()
|
||||
ratchet.recvChainKey = json.optString("recv_ck", "").takeIf { it.isNotEmpty() }?.hexToBytes()
|
||||
ratchet.sendN = json.getInt("send_n")
|
||||
ratchet.recvN = json.getInt("recv_n")
|
||||
ratchet.prevSendN = json.getInt("prev_send_n")
|
||||
|
||||
// Import skipped keys
|
||||
if (json.has("skipped")) {
|
||||
val skippedJson = json.getJSONObject("skipped")
|
||||
for (key in skippedJson.keys()) {
|
||||
ratchet.skipped[key] = skippedJson.getString(key).hexToBytes()
|
||||
}
|
||||
}
|
||||
|
||||
return ratchet
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Encrypt plaintext message.
|
||||
* @return RatchetMessage with header dict, ciphertext+tag, nonce
|
||||
*/
|
||||
fun encrypt(plaintext: ByteArray): RatchetMessage {
|
||||
val ck = sendChainKey ?: throw CryptoException.DecryptionFailed("Send chain not initialized")
|
||||
|
||||
val (newChainKey, messageKey) = HkdfUtils.kdfCk(ck)
|
||||
sendChainKey = newChainKey
|
||||
|
||||
val header = RatchetHeader(
|
||||
dhPub = X25519Crypto.serializePublic(dhPublic),
|
||||
n = sendN,
|
||||
pn = prevSendN,
|
||||
)
|
||||
|
||||
val aad = header.serialize()
|
||||
val (nonce, ctWithTag) = AesGcmCrypto.encryptCombined(
|
||||
plaintext = plaintext,
|
||||
key = messageKey,
|
||||
aad = aad,
|
||||
)
|
||||
|
||||
sendN++
|
||||
|
||||
return RatchetMessage(
|
||||
header = header,
|
||||
ciphertext = ctWithTag,
|
||||
nonce = nonce,
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Decrypt received message.
|
||||
* Handles out-of-order delivery via skipped message keys.
|
||||
* Full state rollback on failure.
|
||||
*/
|
||||
fun decrypt(header: RatchetHeader, ciphertext: ByteArray, nonce: ByteArray): ByteArray {
|
||||
val aad = header.serialize()
|
||||
|
||||
// Check skipped message keys first (no state change)
|
||||
val skippedKey = makeSkippedKey(header.dhPub.toHex(), header.n)
|
||||
skipped.remove(skippedKey)?.let { messageKey ->
|
||||
return AesGcmCrypto.decryptCombined(
|
||||
key = messageKey,
|
||||
nonce = nonce,
|
||||
ctWithTag = ciphertext,
|
||||
aad = aad,
|
||||
)
|
||||
}
|
||||
|
||||
// Take snapshot for rollback
|
||||
val snapshot = snapshot()
|
||||
|
||||
try {
|
||||
// New DH ratchet step if remote key changed
|
||||
val remoteHex = header.dhPub.toHex()
|
||||
val currentRemoteHex = dhRemote?.let { X25519Crypto.serializePublic(it).toHex() }
|
||||
|
||||
if (remoteHex != currentRemoteHex) {
|
||||
skipMessages(header.pn)
|
||||
dhRatchet(X25519Crypto.loadPublic(header.dhPub))
|
||||
}
|
||||
|
||||
skipMessages(header.n)
|
||||
|
||||
val ck = recvChainKey ?: throw CryptoException.DecryptionFailed("Receive chain not initialized")
|
||||
val (newChainKey, messageKey) = HkdfUtils.kdfCk(ck)
|
||||
recvChainKey = newChainKey
|
||||
recvN++
|
||||
|
||||
return AesGcmCrypto.decryptCombined(
|
||||
key = messageKey,
|
||||
nonce = nonce,
|
||||
ctWithTag = ciphertext,
|
||||
aad = aad,
|
||||
)
|
||||
} catch (e: Exception) {
|
||||
// Rollback on any failure
|
||||
restore(snapshot)
|
||||
throw if (e is CryptoException) e
|
||||
else CryptoException.DecryptionFailed("Decryption failed", e)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Export full ratchet state as JSON bytes.
|
||||
*/
|
||||
fun exportState(): ByteArray {
|
||||
val json = JSONObject()
|
||||
json.put("dh_priv", X25519Crypto.serializePrivate(dhPrivate).toHex())
|
||||
json.put("dh_pub", X25519Crypto.serializePublic(dhPublic).toHex())
|
||||
json.put("dh_remote", dhRemote?.let { X25519Crypto.serializePublic(it).toHex() })
|
||||
json.put("root_key", rootKey.toHex())
|
||||
json.put("send_ck", sendChainKey?.toHex())
|
||||
json.put("recv_ck", recvChainKey?.toHex())
|
||||
json.put("send_n", sendN)
|
||||
json.put("recv_n", recvN)
|
||||
json.put("prev_send_n", prevSendN)
|
||||
|
||||
val skippedJson = JSONObject()
|
||||
for ((key, value) in skipped) {
|
||||
skippedJson.put(key, value.toHex())
|
||||
}
|
||||
json.put("skipped", skippedJson)
|
||||
|
||||
return json.toString().toByteArray()
|
||||
}
|
||||
|
||||
// --- Private helpers ---
|
||||
|
||||
private fun skipMessages(until: Int) {
|
||||
if (recvChainKey == null) return
|
||||
if (until - recvN > MAX_SKIP) {
|
||||
throw CryptoException.MaxSkipExceeded("Cannot skip more than $MAX_SKIP messages")
|
||||
}
|
||||
|
||||
var ck = recvChainKey!!
|
||||
while (recvN < until) {
|
||||
val (newCk, messageKey) = HkdfUtils.kdfCk(ck)
|
||||
ck = newCk
|
||||
val remoteHex = dhRemote?.let { X25519Crypto.serializePublic(it).toHex() } ?: ""
|
||||
skipped[makeSkippedKey(remoteHex, recvN)] = messageKey
|
||||
recvN++
|
||||
}
|
||||
recvChainKey = ck
|
||||
}
|
||||
|
||||
private fun dhRatchet(remotePublic: X25519PublicKeyParameters) {
|
||||
prevSendN = sendN
|
||||
sendN = 0
|
||||
recvN = 0
|
||||
dhRemote = remotePublic
|
||||
|
||||
// Derive receive chain
|
||||
val dhOutput1 = X25519Crypto.dh(dhPrivate, remotePublic)
|
||||
val (rk1, recvCk) = HkdfUtils.kdfRk(rootKey, dhOutput1)
|
||||
rootKey = rk1
|
||||
recvChainKey = recvCk
|
||||
|
||||
// Generate new DH keypair and derive send chain
|
||||
val (newPriv, newPub) = X25519Crypto.generateKeypair()
|
||||
dhPrivate = newPriv
|
||||
dhPublic = newPub
|
||||
|
||||
val dhOutput2 = X25519Crypto.dh(newPriv, remotePublic)
|
||||
val (rk2, sendCk) = HkdfUtils.kdfRk(rootKey, dhOutput2)
|
||||
rootKey = rk2
|
||||
sendChainKey = sendCk
|
||||
}
|
||||
|
||||
private data class Snapshot(
|
||||
val dhPriv: ByteArray,
|
||||
val dhPub: ByteArray,
|
||||
val dhRemote: ByteArray?,
|
||||
val rootKey: ByteArray,
|
||||
val sendCk: ByteArray?,
|
||||
val recvCk: ByteArray?,
|
||||
val sendN: Int,
|
||||
val recvN: Int,
|
||||
val prevSendN: Int,
|
||||
val skipped: Map<String, ByteArray>,
|
||||
)
|
||||
|
||||
private fun snapshot(): Snapshot {
|
||||
return Snapshot(
|
||||
dhPriv = X25519Crypto.serializePrivate(dhPrivate),
|
||||
dhPub = X25519Crypto.serializePublic(dhPublic),
|
||||
dhRemote = dhRemote?.let { X25519Crypto.serializePublic(it) },
|
||||
rootKey = rootKey.copyOf(),
|
||||
sendCk = sendChainKey?.copyOf(),
|
||||
recvCk = recvChainKey?.copyOf(),
|
||||
sendN = sendN,
|
||||
recvN = recvN,
|
||||
prevSendN = prevSendN,
|
||||
skipped = skipped.toMap(),
|
||||
)
|
||||
}
|
||||
|
||||
private fun restore(s: Snapshot) {
|
||||
dhPrivate = X25519Crypto.loadPrivate(s.dhPriv)
|
||||
dhPublic = X25519Crypto.loadPublic(s.dhPub)
|
||||
dhRemote = s.dhRemote?.let { X25519Crypto.loadPublic(it) }
|
||||
rootKey = s.rootKey
|
||||
sendChainKey = s.sendCk
|
||||
recvChainKey = s.recvCk
|
||||
sendN = s.sendN
|
||||
recvN = s.recvN
|
||||
prevSendN = s.prevSendN
|
||||
skipped.clear()
|
||||
skipped.putAll(s.skipped)
|
||||
}
|
||||
|
||||
private fun makeSkippedKey(dhHex: String, n: Int): String = "$dhHex:$n"
|
||||
}
|
||||
|
||||
/**
|
||||
* Ratchet message header.
|
||||
* Serialized as JSON: {"dh_pub": hex, "n": int, "pn": int}
|
||||
*/
|
||||
data class RatchetHeader(
|
||||
val dhPub: ByteArray,
|
||||
val n: Int,
|
||||
val pn: Int,
|
||||
) {
|
||||
fun serialize(): ByteArray {
|
||||
val json = JSONObject()
|
||||
json.put("dh_pub", dhPub.toHex())
|
||||
json.put("n", n)
|
||||
json.put("pn", pn)
|
||||
return json.toString().toByteArray()
|
||||
}
|
||||
|
||||
fun toMap(): Map<String, Any> = mapOf(
|
||||
"dh_pub" to dhPub.toHex(),
|
||||
"n" to n,
|
||||
"pn" to pn,
|
||||
)
|
||||
|
||||
companion object {
|
||||
fun fromMap(map: Map<String, Any>): RatchetHeader {
|
||||
return RatchetHeader(
|
||||
dhPub = (map["dh_pub"] as String).hexToBytes(),
|
||||
n = (map["n"] as Number).toInt(),
|
||||
pn = (map["pn"] as Number).toInt(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other !is RatchetHeader) return false
|
||||
return dhPub.contentEquals(other.dhPub) && n == other.n && pn == other.pn
|
||||
}
|
||||
|
||||
override fun hashCode(): Int {
|
||||
var result = dhPub.contentHashCode()
|
||||
result = 31 * result + n
|
||||
result = 31 * result + pn
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
data class RatchetMessage(
|
||||
val header: RatchetHeader,
|
||||
val ciphertext: ByteArray,
|
||||
val nonce: ByteArray,
|
||||
) {
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other !is RatchetMessage) return false
|
||||
return header == other.header && ciphertext.contentEquals(other.ciphertext) &&
|
||||
nonce.contentEquals(other.nonce)
|
||||
}
|
||||
|
||||
override fun hashCode(): Int {
|
||||
var result = header.hashCode()
|
||||
result = 31 * result + ciphertext.contentHashCode()
|
||||
result = 31 * result + nonce.contentHashCode()
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
// --- Hex extension functions ---
|
||||
|
||||
internal fun ByteArray.toHex(): String = joinToString("") { "%02x".format(it) }
|
||||
|
||||
internal fun String.hexToBytes(): ByteArray {
|
||||
require(length % 2 == 0) { "Hex string must have even length" }
|
||||
return chunked(2).map { it.toInt(16).toByte() }.toByteArray()
|
||||
}
|
||||
130
app/src/main/java/com/kecalek/chat/crypto/Ed25519Crypto.kt
Normal file
130
app/src/main/java/com/kecalek/chat/crypto/Ed25519Crypto.kt
Normal file
@@ -0,0 +1,130 @@
|
||||
package com.kecalek.chat.crypto
|
||||
|
||||
import org.bouncycastle.crypto.params.Ed25519PrivateKeyParameters
|
||||
import org.bouncycastle.crypto.params.Ed25519PublicKeyParameters
|
||||
import org.bouncycastle.crypto.signers.Ed25519Signer
|
||||
import org.bouncycastle.math.ec.rfc7748.X25519Field
|
||||
import java.math.BigInteger
|
||||
import java.security.MessageDigest
|
||||
|
||||
/**
|
||||
* Ed25519 signing and key operations using Bouncy Castle.
|
||||
* Includes Ed25519 -> X25519 conversion for X3DH.
|
||||
* Compatible with Python's Ed25519PrivateKey/PublicKey from cryptography library.
|
||||
*/
|
||||
object Ed25519Crypto {
|
||||
|
||||
/**
|
||||
* Generate Ed25519 keypair.
|
||||
* @return (privateKey, publicKey) as Bouncy Castle parameters
|
||||
*/
|
||||
fun generateKeypair(): Pair<Ed25519PrivateKeyParameters, Ed25519PublicKeyParameters> {
|
||||
val privateKey = Ed25519PrivateKeyParameters(java.security.SecureRandom())
|
||||
return Pair(privateKey, privateKey.generatePublicKey())
|
||||
}
|
||||
|
||||
/**
|
||||
* Get 32-byte raw private key (seed).
|
||||
*/
|
||||
fun serializePrivate(key: Ed25519PrivateKeyParameters): ByteArray {
|
||||
return key.encoded // 32-byte seed
|
||||
}
|
||||
|
||||
/**
|
||||
* Get 32-byte raw public key.
|
||||
*/
|
||||
fun serializePublic(key: Ed25519PublicKeyParameters): ByteArray {
|
||||
return key.encoded // 32 bytes
|
||||
}
|
||||
|
||||
/**
|
||||
* Load Ed25519 private key from 32-byte seed.
|
||||
*/
|
||||
fun loadPrivate(data: ByteArray): Ed25519PrivateKeyParameters {
|
||||
require(data.size == 32) { "Ed25519 private key must be 32 bytes" }
|
||||
return Ed25519PrivateKeyParameters(data, 0)
|
||||
}
|
||||
|
||||
/**
|
||||
* Load Ed25519 public key from 32 bytes.
|
||||
*/
|
||||
fun loadPublic(data: ByteArray): Ed25519PublicKeyParameters {
|
||||
require(data.size == 32) { "Ed25519 public key must be 32 bytes" }
|
||||
return Ed25519PublicKeyParameters(data, 0)
|
||||
}
|
||||
|
||||
/**
|
||||
* Sign data with Ed25519.
|
||||
* @return 64-byte signature
|
||||
*/
|
||||
fun sign(privateKey: Ed25519PrivateKeyParameters, data: ByteArray): ByteArray {
|
||||
val signer = Ed25519Signer()
|
||||
signer.init(true, privateKey)
|
||||
signer.update(data, 0, data.size)
|
||||
return signer.generateSignature()
|
||||
}
|
||||
|
||||
/**
|
||||
* Verify Ed25519 signature.
|
||||
*/
|
||||
fun verify(publicKey: Ed25519PublicKeyParameters, signature: ByteArray, data: ByteArray): Boolean {
|
||||
return try {
|
||||
val verifier = Ed25519Signer()
|
||||
verifier.init(false, publicKey)
|
||||
verifier.update(data, 0, data.size)
|
||||
verifier.verifySignature(signature)
|
||||
} catch (_: Exception) {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert Ed25519 private key to X25519 private key.
|
||||
* Process: SHA-512(seed), take first 32 bytes, clamp per RFC 7748.
|
||||
* Compatible with Python ed25519_private_to_x25519.
|
||||
*/
|
||||
fun privateToX25519(edPrivate: Ed25519PrivateKeyParameters): ByteArray {
|
||||
val seed = edPrivate.encoded // 32-byte seed
|
||||
val hash = MessageDigest.getInstance("SHA-512").digest(seed)
|
||||
val clamped = hash.copyOfRange(0, 32)
|
||||
|
||||
// RFC 7748 clamping
|
||||
clamped[0] = (clamped[0].toInt() and 248).toByte()
|
||||
clamped[31] = (clamped[31].toInt() and 127).toByte()
|
||||
clamped[31] = (clamped[31].toInt() or 64).toByte()
|
||||
|
||||
return clamped
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert Ed25519 public key to X25519 public key.
|
||||
* Uses Montgomery form conversion: u = (1 + y) / (1 - y) mod p
|
||||
* where p = 2^255 - 19, y is the Ed25519 y-coordinate.
|
||||
* Compatible with Python ed25519_public_to_x25519.
|
||||
*/
|
||||
fun publicToX25519(edPublic: Ed25519PublicKeyParameters): ByteArray {
|
||||
val edPubBytes = edPublic.encoded // 32 bytes, little-endian
|
||||
|
||||
// Interpret as little-endian integer
|
||||
var y = BigInteger(1, edPubBytes.reversedArray())
|
||||
// Clear sign bit (bit 255)
|
||||
y = y.and(BigInteger.ONE.shiftLeft(255).subtract(BigInteger.ONE))
|
||||
|
||||
val p = BigInteger.ONE.shiftLeft(255).subtract(BigInteger.valueOf(19))
|
||||
|
||||
val one = BigInteger.ONE
|
||||
val onePlusY = one.add(y).mod(p)
|
||||
val oneMinusY = one.subtract(y).mod(p)
|
||||
|
||||
// Modular inverse via Fermat's little theorem
|
||||
val oneMinusYInv = oneMinusY.modPow(p.subtract(BigInteger.TWO), p)
|
||||
val u = onePlusY.multiply(oneMinusYInv).mod(p)
|
||||
|
||||
// Convert to 32-byte little-endian
|
||||
val uBytes = u.toByteArray().reversedArray()
|
||||
val result = ByteArray(32)
|
||||
System.arraycopy(uBytes, 0, result, 0, minOf(uBytes.size, 32))
|
||||
|
||||
return result
|
||||
}
|
||||
}
|
||||
122
app/src/main/java/com/kecalek/chat/crypto/HkdfUtils.kt
Normal file
122
app/src/main/java/com/kecalek/chat/crypto/HkdfUtils.kt
Normal file
@@ -0,0 +1,122 @@
|
||||
package com.kecalek.chat.crypto
|
||||
|
||||
import javax.crypto.Mac
|
||||
import javax.crypto.spec.SecretKeySpec
|
||||
|
||||
/**
|
||||
* HKDF-SHA256 (RFC 5869) and related KDF functions.
|
||||
* Compatible with Python's HKDF from cryptography.hazmat.primitives.kdf.hkdf.
|
||||
*
|
||||
* Chain key derivation uses HMAC-SHA256 directly (Signal Protocol spec).
|
||||
*/
|
||||
object HkdfUtils {
|
||||
|
||||
// Info strings matching Python/iOS constants
|
||||
const val X3DH_INFO = "EncryptedChat_X3DH"
|
||||
const val ROOT_KEY_INFO = "EncryptedChat_RootKey"
|
||||
const val SELF_ENCRYPTION_SALT = "self_encryption"
|
||||
const val SELF_ENCRYPTION_INFO = "EncryptedChat_SelfKey"
|
||||
const val LOCAL_STORAGE_SALT = "local_storage"
|
||||
const val LOCAL_STORAGE_INFO = "EncryptedChat_LocalStorage"
|
||||
const val SENDER_KEY_CHAIN_INFO = "SenderKeyChain"
|
||||
|
||||
// Chain key KDF constants (Signal spec)
|
||||
private val CK_MSG_INFO = byteArrayOf(0x01) // chain key -> message key
|
||||
private val CK_NEXT_INFO = byteArrayOf(0x02) // chain key -> next chain key
|
||||
|
||||
/**
|
||||
* HKDF-SHA256: Extract + Expand (RFC 5869).
|
||||
* @param inputKey input keying material
|
||||
* @param salt optional salt (if null, uses zeros of hash length)
|
||||
* @param info context/application-specific info
|
||||
* @param length output key length in bytes
|
||||
*/
|
||||
fun derive(
|
||||
inputKey: ByteArray,
|
||||
salt: ByteArray? = null,
|
||||
info: ByteArray,
|
||||
length: Int = 32,
|
||||
): ByteArray {
|
||||
// Extract
|
||||
val prk = hmacSha256(salt ?: ByteArray(32), inputKey)
|
||||
// Expand
|
||||
return expand(prk, info, length)
|
||||
}
|
||||
|
||||
/**
|
||||
* HKDF-Expand (used when PRK is already extracted).
|
||||
*/
|
||||
private fun expand(prk: ByteArray, info: ByteArray, length: Int): ByteArray {
|
||||
val hashLen = 32
|
||||
val n = (length + hashLen - 1) / hashLen
|
||||
require(n <= 255) { "HKDF output too long" }
|
||||
|
||||
val output = ByteArray(n * hashLen)
|
||||
var t = ByteArray(0)
|
||||
for (i in 1..n) {
|
||||
val input = t + info + byteArrayOf(i.toByte())
|
||||
t = hmacSha256(prk, input)
|
||||
System.arraycopy(t, 0, output, (i - 1) * hashLen, hashLen)
|
||||
}
|
||||
return output.copyOfRange(0, length)
|
||||
}
|
||||
|
||||
/**
|
||||
* Root key KDF: derives new root key + chain key from DH output.
|
||||
* kdf_rk(root_key, dh_output) -> (new_root_key, chain_key)
|
||||
*/
|
||||
fun kdfRk(rootKey: ByteArray, dhOutput: ByteArray): Pair<ByteArray, ByteArray> {
|
||||
val derived = derive(
|
||||
inputKey = dhOutput,
|
||||
salt = rootKey,
|
||||
info = ROOT_KEY_INFO.toByteArray(),
|
||||
length = 64,
|
||||
)
|
||||
val newRootKey = derived.copyOfRange(0, 32)
|
||||
val chainKey = derived.copyOfRange(32, 64)
|
||||
return Pair(newRootKey, chainKey)
|
||||
}
|
||||
|
||||
/**
|
||||
* Chain key KDF: derives message key + next chain key.
|
||||
* kdf_ck(chain_key) -> (new_chain_key, message_key)
|
||||
*/
|
||||
fun kdfCk(chainKey: ByteArray): Pair<ByteArray, ByteArray> {
|
||||
val messageKey = hmacSha256(chainKey, CK_MSG_INFO)
|
||||
val newChainKey = hmacSha256(chainKey, CK_NEXT_INFO)
|
||||
return Pair(newChainKey, messageKey)
|
||||
}
|
||||
|
||||
/**
|
||||
* Derive self-encryption key from identity private key.
|
||||
*/
|
||||
fun deriveSelfEncryptionKey(identityPrivateRaw: ByteArray): ByteArray {
|
||||
return derive(
|
||||
inputKey = identityPrivateRaw,
|
||||
salt = SELF_ENCRYPTION_SALT.toByteArray(),
|
||||
info = SELF_ENCRYPTION_INFO.toByteArray(),
|
||||
length = 32,
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Derive local storage encryption key from identity private key.
|
||||
*/
|
||||
fun deriveLocalStorageKey(identityPrivateRaw: ByteArray): ByteArray {
|
||||
return derive(
|
||||
inputKey = identityPrivateRaw,
|
||||
salt = LOCAL_STORAGE_SALT.toByteArray(),
|
||||
info = LOCAL_STORAGE_INFO.toByteArray(),
|
||||
length = 32,
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* HMAC-SHA256.
|
||||
*/
|
||||
fun hmacSha256(key: ByteArray, data: ByteArray): ByteArray {
|
||||
val mac = Mac.getInstance("HmacSHA256")
|
||||
mac.init(SecretKeySpec(key, "HmacSHA256"))
|
||||
return mac.doFinal(data)
|
||||
}
|
||||
}
|
||||
94
app/src/main/java/com/kecalek/chat/crypto/KeyEncryption.kt
Normal file
94
app/src/main/java/com/kecalek/chat/crypto/KeyEncryption.kt
Normal file
@@ -0,0 +1,94 @@
|
||||
package com.kecalek.chat.crypto
|
||||
|
||||
import java.security.SecureRandom
|
||||
import javax.crypto.SecretKeyFactory
|
||||
import javax.crypto.spec.PBEKeySpec
|
||||
|
||||
/**
|
||||
* ECP1 format: Password-based key encryption using PBKDF2-HMAC-SHA256 + AES-256-GCM.
|
||||
* Format: ECP1(4B magic) + salt(16B) + nonce(12B) + ciphertext+tag(N+16B)
|
||||
*
|
||||
* Compatible with Python _encrypt_private_key / _decrypt_private_key.
|
||||
* AAD for AES-GCM = ECP1_MAGIC bytes.
|
||||
*/
|
||||
object KeyEncryption {
|
||||
|
||||
private val ECP1_MAGIC = byteArrayOf(0x45, 0x43, 0x50, 0x31) // "ECP1"
|
||||
private const val PBKDF2_ITERATIONS = 600_000
|
||||
private const val SALT_SIZE = 16
|
||||
private const val NONCE_SIZE = 12
|
||||
private const val KEY_SIZE = 32
|
||||
|
||||
private val secureRandom = SecureRandom()
|
||||
|
||||
/**
|
||||
* Encrypt raw key bytes with password using ECP1 format.
|
||||
* @param rawBytes the private key bytes to encrypt
|
||||
* @param password the password to derive encryption key from
|
||||
* @return ECP1 formatted encrypted data
|
||||
*/
|
||||
fun encrypt(rawBytes: ByteArray, password: String): ByteArray {
|
||||
val salt = ByteArray(SALT_SIZE).also { secureRandom.nextBytes(it) }
|
||||
val aesKey = deriveKey(password, salt)
|
||||
|
||||
val (nonce, ctWithTag) = AesGcmCrypto.encryptCombined(
|
||||
plaintext = rawBytes,
|
||||
key = aesKey,
|
||||
aad = ECP1_MAGIC,
|
||||
)
|
||||
|
||||
// ECP1 format: magic(4) + salt(16) + nonce(12) + ct+tag
|
||||
return ECP1_MAGIC + salt + nonce + ctWithTag
|
||||
}
|
||||
|
||||
/**
|
||||
* Decrypt ECP1-encrypted key bytes with password.
|
||||
* @param data ECP1 formatted encrypted data
|
||||
* @param password the password
|
||||
* @return decrypted raw key bytes
|
||||
* @throws CryptoException.InvalidPassword if password is wrong or data is corrupted
|
||||
*/
|
||||
fun decrypt(data: ByteArray, password: String): ByteArray {
|
||||
if (data.size < 4 + SALT_SIZE + NONCE_SIZE + 16) {
|
||||
throw CryptoException.InvalidKey("Data too short for ECP1 format")
|
||||
}
|
||||
|
||||
// Verify magic
|
||||
if (!data.copyOfRange(0, 4).contentEquals(ECP1_MAGIC)) {
|
||||
throw CryptoException.InvalidKey("Invalid ECP1 magic bytes")
|
||||
}
|
||||
|
||||
val salt = data.copyOfRange(4, 4 + SALT_SIZE)
|
||||
val nonce = data.copyOfRange(4 + SALT_SIZE, 4 + SALT_SIZE + NONCE_SIZE)
|
||||
val ctWithTag = data.copyOfRange(4 + SALT_SIZE + NONCE_SIZE, data.size)
|
||||
|
||||
val aesKey = deriveKey(password, salt)
|
||||
|
||||
return try {
|
||||
AesGcmCrypto.decryptCombined(
|
||||
key = aesKey,
|
||||
nonce = nonce,
|
||||
ctWithTag = ctWithTag,
|
||||
aad = ECP1_MAGIC,
|
||||
)
|
||||
} catch (e: Exception) {
|
||||
throw CryptoException.InvalidPassword("Failed to decrypt: wrong password or corrupted data", e)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if data starts with ECP1 magic bytes.
|
||||
*/
|
||||
fun isEcp1Format(data: ByteArray): Boolean {
|
||||
return data.size >= 4 && data.copyOfRange(0, 4).contentEquals(ECP1_MAGIC)
|
||||
}
|
||||
|
||||
/**
|
||||
* Derive 32-byte AES key from password using PBKDF2-HMAC-SHA256.
|
||||
*/
|
||||
private fun deriveKey(password: String, salt: ByteArray): ByteArray {
|
||||
val spec = PBEKeySpec(password.toCharArray(), salt, PBKDF2_ITERATIONS, KEY_SIZE * 8)
|
||||
val factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256")
|
||||
return factory.generateSecret(spec).encoded
|
||||
}
|
||||
}
|
||||
82
app/src/main/java/com/kecalek/chat/crypto/MessagePadding.kt
Normal file
82
app/src/main/java/com/kecalek/chat/crypto/MessagePadding.kt
Normal file
@@ -0,0 +1,82 @@
|
||||
package com.kecalek.chat.crypto
|
||||
|
||||
import java.nio.ByteBuffer
|
||||
import java.security.SecureRandom
|
||||
|
||||
/**
|
||||
* Bucket-based message padding for metadata privacy.
|
||||
* Pads messages to fixed bucket sizes to prevent message-length analysis.
|
||||
*
|
||||
* Format: 0x01 + plaintext + random_padding + pad_length(4 bytes big-endian)
|
||||
* Compatible with Python pad_plaintext/unpad_plaintext.
|
||||
*/
|
||||
object MessagePadding {
|
||||
|
||||
private const val PAD_MAGIC: Byte = 0x01
|
||||
private val PAD_BUCKETS = intArrayOf(64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536)
|
||||
private val secureRandom = SecureRandom()
|
||||
|
||||
/**
|
||||
* Pad plaintext to nearest bucket size.
|
||||
* @param plaintext raw plaintext bytes
|
||||
* @return padded bytes: 0x01 + plaintext + random_padding + pad_length(4B)
|
||||
*/
|
||||
fun pad(plaintext: ByteArray): ByteArray {
|
||||
// content = magic + plaintext
|
||||
val content = ByteArray(1 + plaintext.size)
|
||||
content[0] = PAD_MAGIC
|
||||
System.arraycopy(plaintext, 0, content, 1, plaintext.size)
|
||||
|
||||
// minimum total size = content + 4 bytes for pad_length
|
||||
val minSize = content.size + 4
|
||||
|
||||
// find nearest bucket
|
||||
var targetSize = minSize
|
||||
for (bucket in PAD_BUCKETS) {
|
||||
if (bucket >= minSize) {
|
||||
targetSize = bucket
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// pad_length includes itself (4 bytes) + random padding bytes
|
||||
val padLength = targetSize - content.size
|
||||
val randomPadSize = padLength - 4
|
||||
|
||||
val result = ByteArray(targetSize)
|
||||
System.arraycopy(content, 0, result, 0, content.size)
|
||||
|
||||
// fill random padding
|
||||
if (randomPadSize > 0) {
|
||||
val randomBytes = ByteArray(randomPadSize)
|
||||
secureRandom.nextBytes(randomBytes)
|
||||
System.arraycopy(randomBytes, 0, result, content.size, randomPadSize)
|
||||
}
|
||||
|
||||
// write pad_length as big-endian uint32 at the end
|
||||
val lenBytes = ByteBuffer.allocate(4).putInt(padLength).array()
|
||||
System.arraycopy(lenBytes, 0, result, targetSize - 4, 4)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove padding from padded message.
|
||||
* @param data padded message bytes
|
||||
* @return original plaintext
|
||||
*/
|
||||
fun unpad(data: ByteArray): ByteArray {
|
||||
// Legacy unpadded messages (JSON starting with '{')
|
||||
if (data.isEmpty() || data[0] != PAD_MAGIC) return data
|
||||
if (data.size < 5) return data
|
||||
|
||||
// Read pad_length from last 4 bytes
|
||||
val padLength = ByteBuffer.wrap(data, data.size - 4, 4).int
|
||||
|
||||
// Validate
|
||||
if (padLength < 4 || padLength > data.size - 1) return data
|
||||
|
||||
// Strip magic prefix (1 byte) and padding (padLength bytes from end)
|
||||
return data.copyOfRange(1, data.size - padLength)
|
||||
}
|
||||
}
|
||||
136
app/src/main/java/com/kecalek/chat/crypto/RSACrypto.kt
Normal file
136
app/src/main/java/com/kecalek/chat/crypto/RSACrypto.kt
Normal file
@@ -0,0 +1,136 @@
|
||||
package com.kecalek.chat.crypto
|
||||
|
||||
import java.security.KeyFactory
|
||||
import java.security.KeyPairGenerator
|
||||
import java.security.Signature
|
||||
import java.security.interfaces.RSAPrivateKey
|
||||
import java.security.interfaces.RSAPublicKey
|
||||
import java.security.spec.MGF1ParameterSpec
|
||||
import java.security.spec.PKCS8EncodedKeySpec
|
||||
import java.security.spec.PSSParameterSpec
|
||||
import java.security.spec.X509EncodedKeySpec
|
||||
|
||||
/**
|
||||
* RSA-4096 for login challenge-response only.
|
||||
* Uses RSA-PSS with SHA-256, MGF1-SHA256.
|
||||
*
|
||||
* Private key storage: DER PKCS8 raw bytes encrypted via ECP1.
|
||||
* Public key: DER SubjectPublicKeyInfo (X.509).
|
||||
*
|
||||
* Compatible with Python generate_rsa_keypair, rsa_sign, rsa_verify.
|
||||
* Sign uses PSS with salt_length=MAX. Verify accepts MAX or hash-length salt.
|
||||
*/
|
||||
object RSACrypto {
|
||||
|
||||
private const val KEY_SIZE = 4096
|
||||
|
||||
/**
|
||||
* Generate RSA-4096 keypair.
|
||||
*/
|
||||
fun generateKeypair(): Pair<RSAPrivateKey, RSAPublicKey> {
|
||||
val kpg = KeyPairGenerator.getInstance("RSA")
|
||||
kpg.initialize(KEY_SIZE)
|
||||
val kp = kpg.generateKeyPair()
|
||||
return Pair(kp.private as RSAPrivateKey, kp.public as RSAPublicKey)
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize private key to DER PKCS8 format.
|
||||
* Optionally encrypt with password using ECP1.
|
||||
*/
|
||||
fun serializePrivate(key: RSAPrivateKey, password: String? = null): ByteArray {
|
||||
val der = key.encoded // PKCS8 DER
|
||||
return if (password != null) {
|
||||
KeyEncryption.encrypt(der, password)
|
||||
} else {
|
||||
der
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize public key to DER X.509 format.
|
||||
*/
|
||||
fun serializePublic(key: RSAPublicKey): ByteArray {
|
||||
return key.encoded // X.509 DER
|
||||
}
|
||||
|
||||
/**
|
||||
* Load private key from DER bytes (optionally ECP1-encrypted).
|
||||
*/
|
||||
fun loadPrivate(data: ByteArray, password: String? = null): RSAPrivateKey {
|
||||
val der = if (KeyEncryption.isEcp1Format(data) && password != null) {
|
||||
KeyEncryption.decrypt(data, password)
|
||||
} else {
|
||||
data
|
||||
}
|
||||
val keyFactory = KeyFactory.getInstance("RSA")
|
||||
return keyFactory.generatePrivate(PKCS8EncodedKeySpec(der)) as RSAPrivateKey
|
||||
}
|
||||
|
||||
/**
|
||||
* Load public key from DER X.509 bytes.
|
||||
*/
|
||||
fun loadPublic(data: ByteArray): RSAPublicKey {
|
||||
val keyFactory = KeyFactory.getInstance("RSA")
|
||||
return keyFactory.generatePublic(X509EncodedKeySpec(data)) as RSAPublicKey
|
||||
}
|
||||
|
||||
/**
|
||||
* Sign data with RSA-PSS (SHA-256, MGF1-SHA256, max salt length).
|
||||
* Compatible with Python rsa_sign.
|
||||
*/
|
||||
fun sign(privateKey: RSAPrivateKey, data: ByteArray): ByteArray {
|
||||
// Max salt length = key size in bytes - hash size - 2
|
||||
val maxSaltLen = privateKey.modulus.bitLength() / 8 - 32 - 2
|
||||
val pssSpec = PSSParameterSpec(
|
||||
"SHA-256",
|
||||
"MGF1",
|
||||
MGF1ParameterSpec.SHA256,
|
||||
maxSaltLen,
|
||||
1, // trailer field
|
||||
)
|
||||
val sig = Signature.getInstance("RSASSA-PSS")
|
||||
sig.setParameter(pssSpec)
|
||||
sig.initSign(privateKey)
|
||||
sig.update(data)
|
||||
return sig.sign()
|
||||
}
|
||||
|
||||
/**
|
||||
* Verify RSA-PSS signature.
|
||||
* Uses salt_length = max for verification (Java PSS handles this internally).
|
||||
* For cross-platform compat, we try max salt first, then hash-length salt.
|
||||
*/
|
||||
fun verify(publicKey: RSAPublicKey, signature: ByteArray, data: ByteArray): Boolean {
|
||||
// Try with max salt length first (Python's default for signing)
|
||||
val maxSaltLen = publicKey.modulus.bitLength() / 8 - 32 - 2
|
||||
if (verifyWithSaltLen(publicKey, signature, data, maxSaltLen)) return true
|
||||
// Try with hash-length salt (iOS compatibility)
|
||||
if (verifyWithSaltLen(publicKey, signature, data, 32)) return true
|
||||
return false
|
||||
}
|
||||
|
||||
private fun verifyWithSaltLen(
|
||||
publicKey: RSAPublicKey,
|
||||
signature: ByteArray,
|
||||
data: ByteArray,
|
||||
saltLen: Int,
|
||||
): Boolean {
|
||||
return try {
|
||||
val pssSpec = PSSParameterSpec(
|
||||
"SHA-256",
|
||||
"MGF1",
|
||||
MGF1ParameterSpec.SHA256,
|
||||
saltLen,
|
||||
1,
|
||||
)
|
||||
val sig = Signature.getInstance("RSASSA-PSS")
|
||||
sig.setParameter(pssSpec)
|
||||
sig.initVerify(publicKey)
|
||||
sig.update(data)
|
||||
sig.verify(signature)
|
||||
} catch (_: Exception) {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
214
app/src/main/java/com/kecalek/chat/crypto/SenderKeyState.kt
Normal file
214
app/src/main/java/com/kecalek/chat/crypto/SenderKeyState.kt
Normal file
@@ -0,0 +1,214 @@
|
||||
package com.kecalek.chat.crypto
|
||||
|
||||
import org.json.JSONObject
|
||||
import java.nio.ByteBuffer
|
||||
import java.security.MessageDigest
|
||||
import java.security.SecureRandom
|
||||
|
||||
/**
|
||||
* Sender Key state for group messaging.
|
||||
* Each sender has their own chain that group members can decrypt.
|
||||
*
|
||||
* Compatible with Python SenderKeyState from crypto_utils.py.
|
||||
*
|
||||
* Chain: HKDF(sender_key, salt=0x00*32, info="SenderKeyChain") -> chain_key
|
||||
* Chain ID: SHA-256(sender_key) -> 32 bytes
|
||||
* Message key: kdf_ck(chain_key) -> (new_chain_key, message_key)
|
||||
* AAD: chain_id(32B) + message_number(4B big-endian)
|
||||
*/
|
||||
class SenderKeyState private constructor(
|
||||
private val senderKey: ByteArray,
|
||||
private var chainId: ByteArray,
|
||||
private var chainKey: ByteArray,
|
||||
private var n: Int,
|
||||
private val knownKeys: MutableMap<Int, ByteArray> = mutableMapOf(),
|
||||
) {
|
||||
|
||||
companion object {
|
||||
private const val MAX_SENDER_KEY_SKIP = 256
|
||||
private val ZERO_SALT = ByteArray(32)
|
||||
|
||||
/**
|
||||
* Create new sender key state (for sending).
|
||||
* @param senderKey optional 32-byte sender key (generated if null)
|
||||
*/
|
||||
fun create(senderKey: ByteArray? = null): SenderKeyState {
|
||||
val key = senderKey ?: ByteArray(32).also { SecureRandom().nextBytes(it) }
|
||||
val chainId = MessageDigest.getInstance("SHA-256").digest(key)
|
||||
val chainKey = HkdfUtils.derive(
|
||||
inputKey = key,
|
||||
salt = ZERO_SALT,
|
||||
info = HkdfUtils.SENDER_KEY_CHAIN_INFO.toByteArray(),
|
||||
length = 32,
|
||||
)
|
||||
return SenderKeyState(key, chainId, chainKey, 0)
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize from received sender key (for receiving/decrypting).
|
||||
* @param exportedKey JSON bytes from exportKey()
|
||||
*/
|
||||
fun fromKey(exportedKey: ByteArray): SenderKeyState {
|
||||
val json = JSONObject(String(exportedKey))
|
||||
val key = json.getString("sender_key").hexToBytes()
|
||||
return create(key)
|
||||
}
|
||||
|
||||
/**
|
||||
* Import full state from JSON bytes.
|
||||
*/
|
||||
fun importState(data: ByteArray): SenderKeyState {
|
||||
val json = JSONObject(String(data))
|
||||
val senderKey = json.getString("sender_key").hexToBytes()
|
||||
val chainId = json.getString("chain_id").hexToBytes()
|
||||
val chainKey = json.getString("chain_key").hexToBytes()
|
||||
val n = json.getInt("n")
|
||||
|
||||
val knownKeys = mutableMapOf<Int, ByteArray>()
|
||||
if (json.has("known_keys")) {
|
||||
val knownJson = json.getJSONObject("known_keys")
|
||||
for (key in knownJson.keys()) {
|
||||
knownKeys[key.toInt()] = knownJson.getString(key).hexToBytes()
|
||||
}
|
||||
}
|
||||
|
||||
return SenderKeyState(senderKey, chainId, chainKey, n, knownKeys)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Encrypt plaintext for group message.
|
||||
* @return SenderKeyMessage with chain_id, n, ciphertext+tag, nonce
|
||||
*/
|
||||
fun encrypt(plaintext: ByteArray): SenderKeyMessage {
|
||||
val (newChainKey, messageKey) = HkdfUtils.kdfCk(chainKey)
|
||||
chainKey = newChainKey
|
||||
|
||||
val aad = buildAAD(chainId, n)
|
||||
val (nonce, ctWithTag) = AesGcmCrypto.encryptCombined(
|
||||
plaintext = plaintext,
|
||||
key = messageKey,
|
||||
aad = aad,
|
||||
)
|
||||
|
||||
val msg = SenderKeyMessage(
|
||||
chainIdHex = chainId.toHex(),
|
||||
n = n,
|
||||
ciphertext = ctWithTag,
|
||||
nonce = nonce,
|
||||
)
|
||||
n++
|
||||
return msg
|
||||
}
|
||||
|
||||
/**
|
||||
* Decrypt received group message.
|
||||
* @param chainIdHex hex string of chain ID
|
||||
* @param messageN message number
|
||||
* @param ciphertext ciphertext+tag bytes
|
||||
* @param nonce 12-byte nonce
|
||||
* @return decrypted plaintext
|
||||
*/
|
||||
fun decrypt(chainIdHex: String, messageN: Int, ciphertext: ByteArray, nonce: ByteArray): ByteArray {
|
||||
// Verify chain ID
|
||||
if (chainIdHex != chainId.toHex()) {
|
||||
throw CryptoException.ChainIdMismatch("Expected ${chainId.toHex()}, got $chainIdHex")
|
||||
}
|
||||
|
||||
if (messageN - n > MAX_SENDER_KEY_SKIP) {
|
||||
throw CryptoException.MaxSkipExceeded("Cannot skip more than $MAX_SENDER_KEY_SKIP sender key messages")
|
||||
}
|
||||
|
||||
// Snapshot for rollback
|
||||
val snapChainKey = chainKey.copyOf()
|
||||
val snapN = n
|
||||
val snapKnownKeys = knownKeys.toMutableMap()
|
||||
|
||||
try {
|
||||
// Fast-forward: derive keys up to target
|
||||
while (n <= messageN) {
|
||||
val (newCk, mk) = HkdfUtils.kdfCk(chainKey)
|
||||
knownKeys[n] = mk
|
||||
chainKey = newCk
|
||||
n++
|
||||
}
|
||||
|
||||
val messageKey = knownKeys.remove(messageN)
|
||||
?: throw CryptoException.DecryptionFailed("Message key not found for n=$messageN")
|
||||
|
||||
val aad = buildAAD(chainId, messageN)
|
||||
return AesGcmCrypto.decryptCombined(
|
||||
key = messageKey,
|
||||
nonce = nonce,
|
||||
ctWithTag = ciphertext,
|
||||
aad = aad,
|
||||
)
|
||||
} catch (e: Exception) {
|
||||
// Rollback
|
||||
chainKey = snapChainKey
|
||||
n = snapN
|
||||
knownKeys.clear()
|
||||
knownKeys.putAll(snapKnownKeys)
|
||||
throw if (e is CryptoException) e
|
||||
else CryptoException.DecryptionFailed("Sender key decryption failed", e)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Export sender key for distribution to group members.
|
||||
* @return JSON bytes containing just the sender_key
|
||||
*/
|
||||
fun exportKey(): ByteArray {
|
||||
val json = JSONObject()
|
||||
json.put("sender_key", senderKey.toHex())
|
||||
return json.toString().toByteArray()
|
||||
}
|
||||
|
||||
/**
|
||||
* Export full state for persistence.
|
||||
*/
|
||||
fun exportState(): ByteArray {
|
||||
val json = JSONObject()
|
||||
json.put("sender_key", senderKey.toHex())
|
||||
json.put("chain_id", chainId.toHex())
|
||||
json.put("chain_key", chainKey.toHex())
|
||||
json.put("n", n)
|
||||
|
||||
val knownJson = JSONObject()
|
||||
for ((k, v) in knownKeys) {
|
||||
knownJson.put(k.toString(), v.toHex())
|
||||
}
|
||||
json.put("known_keys", knownJson)
|
||||
|
||||
return json.toString().toByteArray()
|
||||
}
|
||||
|
||||
fun getChainIdHex(): String = chainId.toHex()
|
||||
|
||||
private fun buildAAD(chainId: ByteArray, messageN: Int): ByteArray {
|
||||
val nBytes = ByteBuffer.allocate(4).putInt(messageN).array()
|
||||
return chainId + nBytes
|
||||
}
|
||||
}
|
||||
|
||||
data class SenderKeyMessage(
|
||||
val chainIdHex: String,
|
||||
val n: Int,
|
||||
val ciphertext: ByteArray,
|
||||
val nonce: ByteArray,
|
||||
) {
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other !is SenderKeyMessage) return false
|
||||
return chainIdHex == other.chainIdHex && n == other.n &&
|
||||
ciphertext.contentEquals(other.ciphertext) && nonce.contentEquals(other.nonce)
|
||||
}
|
||||
|
||||
override fun hashCode(): Int {
|
||||
var result = chainIdHex.hashCode()
|
||||
result = 31 * result + n
|
||||
result = 31 * result + ciphertext.contentHashCode()
|
||||
result = 31 * result + nonce.contentHashCode()
|
||||
return result
|
||||
}
|
||||
}
|
||||
63
app/src/main/java/com/kecalek/chat/crypto/X25519Crypto.kt
Normal file
63
app/src/main/java/com/kecalek/chat/crypto/X25519Crypto.kt
Normal file
@@ -0,0 +1,63 @@
|
||||
package com.kecalek.chat.crypto
|
||||
|
||||
import org.bouncycastle.crypto.agreement.X25519Agreement
|
||||
import org.bouncycastle.crypto.params.X25519PrivateKeyParameters
|
||||
import org.bouncycastle.crypto.params.X25519PublicKeyParameters
|
||||
import java.security.SecureRandom
|
||||
|
||||
/**
|
||||
* X25519 Diffie-Hellman key agreement using Bouncy Castle.
|
||||
* Compatible with Python's X25519PrivateKey/PublicKey from cryptography library.
|
||||
*/
|
||||
object X25519Crypto {
|
||||
|
||||
/**
|
||||
* Generate X25519 keypair.
|
||||
*/
|
||||
fun generateKeypair(): Pair<X25519PrivateKeyParameters, X25519PublicKeyParameters> {
|
||||
val privateKey = X25519PrivateKeyParameters(SecureRandom())
|
||||
return Pair(privateKey, privateKey.generatePublicKey())
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize X25519 private key to 32 bytes.
|
||||
*/
|
||||
fun serializePrivate(key: X25519PrivateKeyParameters): ByteArray {
|
||||
return key.encoded // 32 bytes
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize X25519 public key to 32 bytes.
|
||||
*/
|
||||
fun serializePublic(key: X25519PublicKeyParameters): ByteArray {
|
||||
return key.encoded // 32 bytes
|
||||
}
|
||||
|
||||
/**
|
||||
* Load X25519 private key from 32 bytes.
|
||||
*/
|
||||
fun loadPrivate(data: ByteArray): X25519PrivateKeyParameters {
|
||||
require(data.size == 32) { "X25519 private key must be 32 bytes" }
|
||||
return X25519PrivateKeyParameters(data, 0)
|
||||
}
|
||||
|
||||
/**
|
||||
* Load X25519 public key from 32 bytes.
|
||||
*/
|
||||
fun loadPublic(data: ByteArray): X25519PublicKeyParameters {
|
||||
require(data.size == 32) { "X25519 public key must be 32 bytes" }
|
||||
return X25519PublicKeyParameters(data, 0)
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform X25519 Diffie-Hellman key agreement.
|
||||
* @return 32-byte shared secret
|
||||
*/
|
||||
fun dh(privateKey: X25519PrivateKeyParameters, publicKey: X25519PublicKeyParameters): ByteArray {
|
||||
val agreement = X25519Agreement()
|
||||
agreement.init(privateKey)
|
||||
val secret = ByteArray(agreement.agreementSize)
|
||||
agreement.calculateAgreement(publicKey, secret, 0)
|
||||
return secret
|
||||
}
|
||||
}
|
||||
206
app/src/main/java/com/kecalek/chat/crypto/X3DH.kt
Normal file
206
app/src/main/java/com/kecalek/chat/crypto/X3DH.kt
Normal file
@@ -0,0 +1,206 @@
|
||||
package com.kecalek.chat.crypto
|
||||
|
||||
import org.bouncycastle.crypto.params.Ed25519PrivateKeyParameters
|
||||
import org.bouncycastle.crypto.params.Ed25519PublicKeyParameters
|
||||
import org.bouncycastle.crypto.params.X25519PrivateKeyParameters
|
||||
import org.bouncycastle.crypto.params.X25519PublicKeyParameters
|
||||
import java.util.UUID
|
||||
|
||||
/**
|
||||
* X3DH (Extended Triple Diffie-Hellman) key agreement protocol.
|
||||
* Used to establish shared secrets for initial Double Ratchet sessions.
|
||||
*
|
||||
* Compatible with Python x3dh_initiate / x3dh_respond.
|
||||
*
|
||||
* DH operations:
|
||||
* dh1 = DH(IK_A_x25519, SPK_B)
|
||||
* dh2 = DH(EK_A, IK_B_x25519)
|
||||
* dh3 = DH(EK_A, SPK_B)
|
||||
* dh4 = DH(EK_A, OPK_B) [optional]
|
||||
*
|
||||
* Shared secret = HKDF(dh1||dh2||dh3[||dh4], salt=0x00*32, info="EncryptedChat_X3DH")
|
||||
*/
|
||||
object X3DH {
|
||||
|
||||
private val ZERO_SALT = ByteArray(32)
|
||||
|
||||
/**
|
||||
* Generate a signed pre-key.
|
||||
* @param identityPrivate Ed25519 identity private key (for signing)
|
||||
* @return SignedPreKey with X25519 keypair, signature, and UUID
|
||||
*/
|
||||
fun generateSignedPreKey(identityPrivate: Ed25519PrivateKeyParameters): SignedPreKey {
|
||||
val (spkPrivate, spkPublic) = X25519Crypto.generateKeypair()
|
||||
val spkPubBytes = X25519Crypto.serializePublic(spkPublic)
|
||||
val signature = Ed25519Crypto.sign(identityPrivate, spkPubBytes)
|
||||
val id = UUID.randomUUID().toString()
|
||||
|
||||
return SignedPreKey(
|
||||
id = id,
|
||||
privateKey = spkPrivate,
|
||||
publicKey = spkPublic,
|
||||
signature = signature,
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate one-time pre-keys.
|
||||
* @param count number of OPKs to generate
|
||||
* @return list of OneTimePreKey with X25519 keypair and UUID
|
||||
*/
|
||||
fun generateOneTimePreKeys(count: Int = 50): List<OneTimePreKey> {
|
||||
return (0 until count).map {
|
||||
val (priv, pub) = X25519Crypto.generateKeypair()
|
||||
OneTimePreKey(
|
||||
id = UUID.randomUUID().toString(),
|
||||
privateKey = priv,
|
||||
publicKey = pub,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Initiator (Alice) side of X3DH.
|
||||
*
|
||||
* @param ikPrivateEd Alice's Ed25519 identity private key
|
||||
* @param ikPublicRemoteEd Bob's Ed25519 identity public key
|
||||
* @param spkRemote Bob's signed pre-key (X25519 public)
|
||||
* @param spkSignature Bob's signature over his SPK public bytes
|
||||
* @param opkRemote Bob's one-time pre-key (X25519 public, optional)
|
||||
* @return X3DHResult with shared secret, ephemeral keypair
|
||||
* @throws CryptoException.InvalidSignature if SPK signature verification fails
|
||||
* @throws CryptoException.X3DHFailed if key agreement fails
|
||||
*/
|
||||
fun initiate(
|
||||
ikPrivateEd: Ed25519PrivateKeyParameters,
|
||||
ikPublicRemoteEd: Ed25519PublicKeyParameters,
|
||||
spkRemote: X25519PublicKeyParameters,
|
||||
spkSignature: ByteArray,
|
||||
opkRemote: X25519PublicKeyParameters? = null,
|
||||
): X3DHResult {
|
||||
// Verify SPK signature
|
||||
val spkRemoteBytes = X25519Crypto.serializePublic(spkRemote)
|
||||
if (!Ed25519Crypto.verify(ikPublicRemoteEd, spkSignature, spkRemoteBytes)) {
|
||||
throw CryptoException.InvalidSignature("SPK signature verification failed")
|
||||
}
|
||||
|
||||
try {
|
||||
// Convert identity keys Ed25519 -> X25519
|
||||
val ikPrivateX = X25519Crypto.loadPrivate(Ed25519Crypto.privateToX25519(ikPrivateEd))
|
||||
val ikPublicRemoteX = X25519Crypto.loadPublic(
|
||||
Ed25519Crypto.publicToX25519(ikPublicRemoteEd)
|
||||
)
|
||||
|
||||
// Generate ephemeral keypair
|
||||
val (ekPrivate, ekPublic) = X25519Crypto.generateKeypair()
|
||||
|
||||
// Four DH computations
|
||||
val dh1 = X25519Crypto.dh(ikPrivateX, spkRemote)
|
||||
val dh2 = X25519Crypto.dh(ekPrivate, ikPublicRemoteX)
|
||||
val dh3 = X25519Crypto.dh(ekPrivate, spkRemote)
|
||||
|
||||
var dhConcat = dh1 + dh2 + dh3
|
||||
if (opkRemote != null) {
|
||||
val dh4 = X25519Crypto.dh(ekPrivate, opkRemote)
|
||||
dhConcat += dh4
|
||||
}
|
||||
|
||||
// Derive shared secret
|
||||
val sharedSecret = HkdfUtils.derive(
|
||||
inputKey = dhConcat,
|
||||
salt = ZERO_SALT,
|
||||
info = HkdfUtils.X3DH_INFO.toByteArray(),
|
||||
length = 32,
|
||||
)
|
||||
|
||||
return X3DHResult(
|
||||
sharedSecret = sharedSecret,
|
||||
ephemeralPrivate = ekPrivate,
|
||||
ephemeralPublic = ekPublic,
|
||||
)
|
||||
} catch (e: CryptoException) {
|
||||
throw e
|
||||
} catch (e: Exception) {
|
||||
throw CryptoException.X3DHFailed("X3DH initiate failed", e)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Responder (Bob) side of X3DH.
|
||||
*
|
||||
* @param ikPrivateEd Bob's Ed25519 identity private key
|
||||
* @param spkPrivate Bob's signed pre-key private (X25519)
|
||||
* @param ikRemoteEd Alice's Ed25519 identity public key
|
||||
* @param ekRemote Alice's ephemeral public key (X25519)
|
||||
* @param opkPrivate Bob's one-time pre-key private (X25519, optional)
|
||||
* @return 32-byte shared secret
|
||||
*/
|
||||
fun respond(
|
||||
ikPrivateEd: Ed25519PrivateKeyParameters,
|
||||
spkPrivate: X25519PrivateKeyParameters,
|
||||
ikRemoteEd: Ed25519PublicKeyParameters,
|
||||
ekRemote: X25519PublicKeyParameters,
|
||||
opkPrivate: X25519PrivateKeyParameters? = null,
|
||||
): ByteArray {
|
||||
try {
|
||||
// Convert identity keys Ed25519 -> X25519
|
||||
val ikPrivateX = X25519Crypto.loadPrivate(Ed25519Crypto.privateToX25519(ikPrivateEd))
|
||||
val ikRemoteX = X25519Crypto.loadPublic(Ed25519Crypto.publicToX25519(ikRemoteEd))
|
||||
|
||||
// Mirror DH computations
|
||||
val dh1 = X25519Crypto.dh(spkPrivate, ikRemoteX)
|
||||
val dh2 = X25519Crypto.dh(ikPrivateX, ekRemote)
|
||||
val dh3 = X25519Crypto.dh(spkPrivate, ekRemote)
|
||||
|
||||
var dhConcat = dh1 + dh2 + dh3
|
||||
if (opkPrivate != null) {
|
||||
val dh4 = X25519Crypto.dh(opkPrivate, ekRemote)
|
||||
dhConcat += dh4
|
||||
}
|
||||
|
||||
return HkdfUtils.derive(
|
||||
inputKey = dhConcat,
|
||||
salt = ZERO_SALT,
|
||||
info = HkdfUtils.X3DH_INFO.toByteArray(),
|
||||
length = 32,
|
||||
)
|
||||
} catch (e: Exception) {
|
||||
throw CryptoException.X3DHFailed("X3DH respond failed", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
data class SignedPreKey(
|
||||
val id: String,
|
||||
val privateKey: X25519PrivateKeyParameters,
|
||||
val publicKey: X25519PublicKeyParameters,
|
||||
val signature: ByteArray,
|
||||
) {
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other !is SignedPreKey) return false
|
||||
return id == other.id
|
||||
}
|
||||
|
||||
override fun hashCode(): Int = id.hashCode()
|
||||
}
|
||||
|
||||
data class OneTimePreKey(
|
||||
val id: String,
|
||||
val privateKey: X25519PrivateKeyParameters,
|
||||
val publicKey: X25519PublicKeyParameters,
|
||||
)
|
||||
|
||||
data class X3DHResult(
|
||||
val sharedSecret: ByteArray,
|
||||
val ephemeralPrivate: X25519PrivateKeyParameters,
|
||||
val ephemeralPublic: X25519PublicKeyParameters,
|
||||
) {
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other !is X3DHResult) return false
|
||||
return sharedSecret.contentEquals(other.sharedSecret)
|
||||
}
|
||||
|
||||
override fun hashCode(): Int = sharedSecret.contentHashCode()
|
||||
}
|
||||
Reference in New Issue
Block a user