Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package com.infomaniak.auth.lib

import com.infomaniak.auth.lib.internal.CryptoObjectsBuilder
import com.infomaniak.auth.lib.internal.KeyPairManagerImpl
import com.infomaniak.auth.lib.internal.KeyPairManager
import com.infomaniak.auth.lib.internal.models.PasskeysOptions
import com.infomaniak.auth.lib.internal.models.PubKeyCredParam
import com.infomaniak.auth.lib.internal.models.RelyingParty
Expand Down Expand Up @@ -59,7 +59,7 @@ class WebAuthnTest {

// Just getting the public key to generate RegisterPasskey object
val cryptoObjectsBuilder = CryptoObjectsBuilder()
val keyPairManager = KeyPairManagerImpl()
val keyPairManager = KeyPairManager()
val userId = 12345L
val keyIdAsByteArray = cryptoObjectsBuilder.getKeyIds().first
val keyIdAsString = cryptoObjectsBuilder.getKeyIds().second
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package com.infomaniak.auth.internal
package com.infomaniak.auth.lib.internal

import com.infomaniak.auth.lib.internal.KeyPairManagerImpl
import com.infomaniak.auth.lib.internal.utils.Xor
import kotlinx.coroutines.test.runTest
import kotlin.test.Test
Expand All @@ -28,10 +27,10 @@ class KeyPairManagerTest {

@Test
fun testKeyPairManager() {
val keyPairManager = KeyPairManagerImpl()
val keyPairManager = KeyPairManager()

runTest {
val userId = 12345
val userId = 12345L
val keyId = "keyId"
val error = keyPairManager.generateNewKey(userId, keyId)
assertNull(error)
Comment thread
LouisCAD marked this conversation as resolved.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,18 @@ import com.infomaniak.auth.lib.internal.utils.Xor
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.invoke
import kotlinx.coroutines.withContext
import kotlinx.io.IOException
import splitties.init.appCtx
import java.io.File
import java.nio.file.Files
import java.nio.file.attribute.BasicFileAttributes

internal actual class KeyPairManagerImpl : KeyPairManager {
internal actual fun createKeyPairManager(): KeyPairManager = KeyPairManagerAndroidImpl()

private class KeyPairManagerAndroidImpl : KeyPairManager() {

@Throws(Exception::class)
actual override suspend fun generateNewKey(userId: Long, keyId: String): Failure.KeyManagement.GenerationFailed? {
override suspend fun generateNewKey(userId: Long, keyId: String): Failure.KeyManagement.GenerationFailed? {
val keyPair = generateEcKeyPair().getOrElse {
return Failure.KeyManagement.GenerationFailed(it.toString())
}
Expand All @@ -37,7 +42,7 @@ internal actual class KeyPairManagerImpl : KeyPairManager {
return null
}

actual override suspend fun retrievePublicKey(
override suspend fun retrievePublicKey(
userId: Long,
keyId: String,
): Xor<ByteArray, Failure.KeyManagement.KeyExtractionFailed> = Dispatchers.IO {
Expand All @@ -47,7 +52,7 @@ internal actual class KeyPairManagerImpl : KeyPairManager {
}.getOrElse { Xor.Second(Failure.KeyManagement.KeyExtractionFailed(it.toString())) }
}

actual override suspend fun retrievePrivateKey(
override suspend fun retrievePrivateKey(
userId: Long,
keyId: String,
): Xor<ByteArray, Failure.KeyManagement.KeyExtractionFailed> = Dispatchers.IO {
Expand All @@ -57,21 +62,46 @@ internal actual class KeyPairManagerImpl : KeyPairManager {
}.getOrElse { Xor.Second(Failure.KeyManagement.KeyExtractionFailed(it.toString())) }
}

actual override suspend fun findKeyIdFor(predicate: (name: String) -> Boolean): String? {
override suspend fun getSortedKeyIds(matchOn: MatchOn): List<String> {
val files = withContext(Dispatchers.IO) {
appCtx.filesDir.listFiles()
} ?: return emptyList()
return buildList {
val predicate = matchOn.asFilterPredicate()
for (file in files) {
val fileName = file.name
if (predicate(file.name)) {
val fileTimestampMillis = Dispatchers.IO {
try {
Files.readAttributes(file.toPath(), BasicFileAttributes::class.java).creationTime().toMillis()
} catch (_: IOException) {
file.lastModified()
}
}
add(extractKeyIdFromFileName(fileName) to fileTimestampMillis)
}
Comment thread
LouisCAD marked this conversation as resolved.
}
}.sortedBy { (_, creationTime) ->
creationTime
}.map { (keyId, _) ->
keyId
}.distinct() // Private/public keys pairs have a common id, so we filter duplicates.
}

override suspend fun findKeyIdFor(matchOn: MatchOn): String? {
val predicate = matchOn.asFilterPredicate()
val userPassKey: File = withContext(Dispatchers.IO) {
appCtx.filesDir.listFiles()
}?.find {
predicate(it.name)
} ?: return null

val keyId = userPassKey.name.substring(
startIndex = userPassKey.name.indexOfFirst { it == '-' } + 1,
endIndex = userPassKey.name.indexOfLast { it == '-' }
)
return keyId
//TODO 2: Put keys into a dedicated dir
return extractKeyIdFromFileName(userPassKey.name)
}

actual override suspend fun deleteKeysMatching(predicate: (name: String) -> Boolean): Xor<Unit, Failure.KeyManagement.KeyNotFound> {
override suspend fun deleteKeysMatching(matchOn: MatchOn): Xor<Unit, Failure.KeyManagement.KeyNotFound> {
val predicate = matchOn.asFilterPredicate()
val keys = withContext(Dispatchers.IO) {
appCtx.filesDir.listFiles()
}?.filter {
Expand All @@ -83,6 +113,13 @@ internal actual class KeyPairManagerImpl : KeyPairManager {
return Xor.First(Unit)
}

override fun MatchOn.PasskeyId.asFilterPredicate() = { name: String -> "-$id-" in name }

private fun extractKeyIdFromFileName(name: String): String = name.substring(
startIndex = name.indexOfFirst { it == '-' } + 1,
endIndex = name.indexOfLast { it == '-' }
)

private suspend fun saveFileToFilesDir(fileName: String, key: ByteArray) = Dispatchers.IO {
val file = File(appCtx.filesDir, fileName)
file.writeBytes(key)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,13 @@
package com.infomaniak.auth.lib.internal

import com.infomaniak.auth.lib.internal.extensions.buildCFDictionary
import com.infomaniak.auth.lib.internal.extensions.get
import com.infomaniak.auth.lib.internal.extensions.isNullOrEmpty
import com.infomaniak.auth.lib.internal.extensions.set
import com.infomaniak.auth.lib.internal.extensions.size
import com.infomaniak.auth.lib.internal.extensions.toByteArray
import com.infomaniak.auth.lib.internal.extensions.toNSData
import com.infomaniak.auth.lib.internal.extensions.toNSDate
import com.infomaniak.auth.lib.internal.extensions.toNsData
import com.infomaniak.auth.lib.internal.extensions.tryIt
import com.infomaniak.auth.lib.internal.extensions.use
Expand All @@ -37,22 +41,21 @@ import kotlinx.cinterop.value
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.IO
import kotlinx.coroutines.invoke
import platform.CoreFoundation.CFArrayGetCount
import platform.CoreFoundation.CFArrayGetValueAtIndex
import platform.CoreFoundation.CFArrayRef
import platform.CoreFoundation.CFDataRef
import platform.CoreFoundation.CFDictionaryGetValue
import platform.CoreFoundation.CFDateRef
import platform.CoreFoundation.CFDictionaryRef
import platform.CoreFoundation.CFRelease
import platform.CoreFoundation.CFTypeRef
import platform.CoreFoundation.CFTypeRefVar
import platform.Foundation.timeIntervalSince1970
import platform.Security.SecItemCopyMatching
import platform.Security.SecItemDelete
import platform.Security.SecKeyCopyExternalRepresentation
import platform.Security.SecKeyCopyPublicKey
import platform.Security.SecKeyRef
import platform.Security.errSecSuccess
import platform.Security.kSecAttrApplicationTag
import platform.Security.kSecAttrCreationDate
import platform.Security.kSecAttrKeyClass
import platform.Security.kSecAttrKeyClassPrivate
import platform.Security.kSecAttrKeyType
Expand All @@ -64,17 +67,19 @@ import platform.Security.kSecMatchLimitAll
import platform.Security.kSecReturnAttributes
import platform.Security.kSecReturnRef

internal actual class KeyPairManagerImpl : KeyPairManager {
internal actual fun createKeyPairManager(): KeyPairManager = KeyPairManagerAppleImpl()

actual override suspend fun generateNewKey(
private class KeyPairManagerAppleImpl : KeyPairManager() {

override suspend fun generateNewKey(
userId: Long,
keyId: String,
): Failure.KeyManagement.GenerationFailed? = Dispatchers.IO {

val result = generateEcPrivateKeyInTheKeychain(
tag = "$userId-$keyId",
privateKeyPurposes = KeyPairManager.privateKeyPurposes,
publicKeyPurposes = KeyPairManager.publicKeyPurposes,
privateKeyPurposes = KeyPurposes.privateKeyDefaults,
publicKeyPurposes = KeyPurposes.publicKeyDefaults,
keyAccessGuard = KeyAccessGuard.Unguarded,
accessibility = KeyAccessibility.AfterFirstUnlock.ThisDeviceOnly,
)
Expand All @@ -85,7 +90,7 @@ internal actual class KeyPairManagerImpl : KeyPairManager {
}

@OptIn(ExperimentalForeignApi::class)
actual override suspend fun retrievePublicKey(
override suspend fun retrievePublicKey(
userId: Long,
keyId: String,
): Xor<ByteArray, Failure.KeyManagement.KeyExtractionFailed> = Dispatchers.IO {
Expand All @@ -107,7 +112,7 @@ internal actual class KeyPairManagerImpl : KeyPairManager {
}
}

actual override suspend fun retrievePrivateKey(
override suspend fun retrievePrivateKey(
userId: Long,
keyId: String
): Xor<ByteArray, Failure.KeyManagement.KeyExtractionFailed> {
Expand All @@ -125,43 +130,60 @@ internal actual class KeyPairManagerImpl : KeyPairManager {
}
}

override suspend fun getSortedKeyIds(matchOn: MatchOn): List<String> = Dispatchers.IO {
memScoped {
val resultsArray = getAllPrivateKeysQuery()
if (resultsArray.isNullOrEmpty()) return@memScoped emptyList()
val predicate = matchOn.asFilterPredicate()

buildList {
for (i in 0 until resultsArray.size) {
val item: CFDictionaryRef = resultsArray[i]
val tag = extractTagFromItem(item) ?: continue
val dateRef: CFDateRef = item[kSecAttrCreationDate]

if (predicate(tag)) {
add(extractKeyIdFromTag(tag) to dateRef.toNSDate())
}
Comment thread
LouisCAD marked this conversation as resolved.
}
}.sortedBy { (_, date) ->
date.timeIntervalSince1970
}.map { (keyId, _) ->
keyId
}.distinct() // Private/public keys pairs have a common id, so we filter duplicates.
}
}

@OptIn(BetaInteropApi::class)
actual override suspend fun findKeyIdFor(predicate: (name: String) -> Boolean): String? = Dispatchers.IO {
override suspend fun findKeyIdFor(matchOn: MatchOn): String? = Dispatchers.IO {
memScoped {
//TODO[ik-auth]: Test this code somehow.
val (resultsArray, count) = getAllPrivateKeysQuery()
val resultsArray = getAllPrivateKeysQuery()

if (resultsArray == null || count == 0) return@memScoped null
if (resultsArray.isNullOrEmpty()) return@memScoped null
val predicate = matchOn.asFilterPredicate()

for (i in 0 until count) {
val tag = extractTagFromItem(CFArrayGetValueAtIndex(resultsArray, i.toLong()))
for (i in 0 until resultsArray.size) {
val tag = extractTagFromItem(resultsArray[i]) ?: continue

if (tag != null && predicate(tag)) {
val keyId = tag.substring(
startIndex = tag.indexOfFirst { it == '-' } + 1,
endIndex = tag.indexOfLast { it == '-' }
)
return@memScoped keyId
}
if (predicate(tag)) return@memScoped extractKeyIdFromTag(tag)
}

return@memScoped null
}
}

actual override suspend fun deleteKeysMatching(
predicate: (name: String) -> Boolean
): Xor<Unit, Failure.KeyManagement.KeyNotFound> = Dispatchers.IO {
override suspend fun deleteKeysMatching(matchOn: MatchOn): Xor<Unit, Failure.KeyManagement.KeyNotFound> = Dispatchers.IO {
memScoped {
val (resultsArray, count) = getAllPrivateKeysQuery()
val resultsArray = getAllPrivateKeysQuery()

if (resultsArray == null || count == 0) {
if (resultsArray.isNullOrEmpty()) {
return@memScoped Xor.Second(Failure.KeyManagement.KeyNotFound("No keys found in Keychain"))
}
val predicate = matchOn.asFilterPredicate()

var hasDeletedAtLeastOneKey = false
for (i in 0 until count) {
val tag = extractTagFromItem(CFArrayGetValueAtIndex(resultsArray, i.toLong())) ?: continue
for (i in 0 until resultsArray.size) {
val tag = extractTagFromItem(resultsArray[i]) ?: continue

if (predicate(tag)) {
deleteKeyByTag(tag)
Expand All @@ -177,8 +199,15 @@ internal actual class KeyPairManagerImpl : KeyPairManager {
}
}

override fun MatchOn.PasskeyId.asFilterPredicate(): (String) -> Boolean {
val publicKeyEnd = "$id.pub"
return { name: String -> name.endsWith(id) || name.endsWith(publicKeyEnd) }
}

private fun extractKeyIdFromTag(tag: String): String = tag.substring(startIndex = tag.indexOfFirst { it == '-' } + 1)

@OptIn(BetaInteropApi::class, ExperimentalForeignApi::class)
private fun MemScope.getAllPrivateKeysQuery(): Pair<CFArrayRef?, Int> {
private fun MemScope.getAllPrivateKeysQuery(): CFArrayRef? {
val query = buildCFDictionary {
this[kSecClass] = kSecClassKey
this[kSecAttrKeyClass] = kSecAttrKeyClassPrivate
Expand All @@ -192,18 +221,17 @@ internal actual class KeyPairManagerImpl : KeyPairManager {
CFRelease(query)

return if (status == errSecSuccess && resultRef.value != null) {
defer { CFRelease(resultRef.value) }
@Suppress("unchecked_cast")
val resultsArray = resultRef.value as CFArrayRef
Pair(resultsArray, CFArrayGetCount(resultsArray).toInt())
resultsArray
} else {
Comment thread
LouisCAD marked this conversation as resolved.
Pair(null, 0)
null
}
}

@OptIn(BetaInteropApi::class)
private fun extractTagFromItem(item: CFTypeRef?): String? {
@Suppress("unchecked_cast")
val tagData = CFDictionaryGetValue(item as CFDictionaryRef, kSecAttrApplicationTag) as? CFDataRef ?: return null
private fun extractTagFromItem(item: CFDictionaryRef?): String? {
val tagData: CFDataRef = item[kSecAttrApplicationTag] ?: return null
return tagData.toNSData().toByteArray().decodeToString()
}

Expand All @@ -213,6 +241,7 @@ internal actual class KeyPairManagerImpl : KeyPairManager {
this[kSecAttrApplicationTag] = tag.toNsData()
}
SecItemDelete(deleteQuery)
CFRelease(deleteQuery)
}

@OptIn(ExperimentalForeignApi::class)
Expand Down
Loading
Loading