diff --git a/qbit-core/src/commonMain/kotlin/qbit/Conn.kt b/qbit-core/src/commonMain/kotlin/qbit/Conn.kt index 55547c27..ac30c123 100644 --- a/qbit-core/src/commonMain/kotlin/qbit/Conn.kt +++ b/qbit-core/src/commonMain/kotlin/qbit/Conn.kt @@ -24,7 +24,7 @@ import qbit.index.Indexer import qbit.index.InternalDb import qbit.index.RawEntity import qbit.ns.Namespace -import qbit.resolving.lastWriterWinsResolve +import qbit.resolving.crdtResolve import qbit.resolving.logsDiff import qbit.serialization.* import qbit.spi.Storage @@ -122,14 +122,14 @@ class QConn( } } - override suspend fun update(trxLog: TrxLog, newLog: TrxLog, newDb: InternalDb) { + override suspend fun update(trxLog: TrxLog, baseDb: InternalDb, newLog: TrxLog, newDb: InternalDb) { val (log, db) = if (hasConcurrentTrx(trxLog)) { - mergeLogs(trxLog, this.trxLog, newLog, newDb) + mergeLogs(trxLog, this.trxLog, newLog, baseDb, newDb) } else { newLog to newDb } - storage.overwrite(Namespace("refs")["head"], newLog.hash.bytes) + storage.overwrite(Namespace("refs")["head"], log.hash.bytes) this.trxLog = log this.db = db } @@ -141,6 +141,7 @@ class QConn( baseLog: TrxLog, committedLog: TrxLog, committingLog: TrxLog, + baseDb: InternalDb, newDb: InternalDb ): Pair { val logsDifference = logsDiff(baseLog, committedLog, committingLog, resolveNode) @@ -149,7 +150,7 @@ class QConn( .logAEntities() .toEavsList() val reconciliationEavs = logsDifference - .reconciliationEntities(lastWriterWinsResolve { db.attr(it) }) + .reconciliationEntities(crdtResolve(baseDb::pullEntity, db::attr)) .toEavsList() val mergedDb = newDb diff --git a/qbit-core/src/commonMain/kotlin/qbit/api/model/DataTypes.kt b/qbit-core/src/commonMain/kotlin/qbit/api/model/DataTypes.kt index ce35fdea..9527a331 100644 --- a/qbit-core/src/commonMain/kotlin/qbit/api/model/DataTypes.kt +++ b/qbit-core/src/commonMain/kotlin/qbit/api/model/DataTypes.kt @@ -21,6 +21,12 @@ import kotlin.reflect.KClass * - List */ +val scalarRange = 0u..31u +val listRange = 32u..63u +val pnCounterRange = 64u..95u +val registerRange = 96u..127u +val setRange = 128u..159u + @Suppress("UNCHECKED_CAST") sealed class DataType { @@ -31,14 +37,16 @@ sealed class DataType { private val values: Array> get() = arrayOf(QBoolean, QByte, QInt, QLong, QString, QBytes, QGid, QRef) - fun ofCode(code: Byte): DataType<*>? = - if (code <= 19) { - values.firstOrNull { it.code == code } - } else { - values.map { it.list() }.firstOrNull { it.code == code } - } + fun ofCode(code: Byte): DataType<*>? = when(code.toUByte()) { + in scalarRange -> values.firstOrNull { it.code == code } + in listRange -> ofCode((code.toUByte() - listRange.first).toByte())?.list() + in pnCounterRange -> ofCode((code.toUByte() - pnCounterRange.first).toByte())?.counter() + in registerRange -> ofCode((code.toUByte() - registerRange.first).toByte())?.register() + in setRange -> ofCode((code.toUByte() - setRange.first).toByte())?.set() + else -> null + } - fun ofValue(value: T?): DataType? = when (value) { + fun ofValue(value: T?): DataType? = when (value) { // TODO REFACTOR is Boolean -> QBoolean as DataType is Byte -> QByte as DataType is Int -> QInt as DataType @@ -46,20 +54,46 @@ sealed class DataType { is String -> QString as DataType is ByteArray -> QBytes as DataType is Gid -> QGid as DataType - is List<*> -> value.firstOrNull()?.let { ofValue(it)?.list() } as DataType + is List<*> -> value.firstOrNull()?.let { ofValue(it)?.list() } as DataType? else -> QRef as DataType } } + fun isScalar(): Boolean = code.toUByte() in scalarRange + fun list(): QList { // TODO: make types hierarchy: Type -> List | (Scalar -> (Ref | Value)) - require(!isList()) { "Nested lists is not allowed" } + require(this.isScalar()) { "Nested wrappers is not allowed" } return QList(this) } - fun isList(): Boolean = (code.toInt().and(32)) > 0 + fun isList(): Boolean = code.toUByte() in listRange - fun ref(): Boolean = this == QRef || this is QList<*> && this.itemsType == QRef + fun counter(): QCounter { + require(this is QByte || this is QInt || this is QLong) { "Only primitive number values are allowed in counters" } + return QCounter(this) + } + + fun isCounter(): Boolean = code.toUByte() in pnCounterRange + + fun register(): QRegister { + require(this.isScalar()) { "Nested wrappers is not allowed" } + return QRegister(this) + } + + fun isRegister(): Boolean = code.toUByte() in registerRange + + fun set(): QSet { + require(this.isScalar()) { "Nested wrappers is not allowed" } + return QSet(this) + } + + fun isSet(): Boolean = code.toUByte() in setRange + + fun ref(): Boolean = this == QRef || + this is QList<*> && this.itemsType == QRef || + this is QRegister<*> && this.itemsType == QRef || + this is QSet<*> && this.itemsType == QRef fun value(): Boolean = !ref() @@ -73,15 +107,35 @@ sealed class DataType { is QBytes -> ByteArray::class is QGid -> Gid::class is QList<*> -> this.itemsType.typeClass() + is QCounter<*> -> this.primitiveType.typeClass() + is QRegister<*> -> this.itemsType.typeClass() + is QSet<*> -> this.itemsType.typeClass() QRef -> Any::class } } - } data class QList(val itemsType: DataType) : DataType>() { - override val code = (32 + itemsType.code).toByte() + override val code = (listRange.first.toByte() + itemsType.code).toByte() + +} + +data class QCounter(val primitiveType: DataType) : DataType() { + + override val code = (pnCounterRange.first.toByte() + primitiveType.code).toByte() + +} + +data class QRegister(val itemsType: DataType) : DataType() { + + override val code = (registerRange.first.toByte() + itemsType.code).toByte() + +} + +data class QSet(val itemsType: DataType) : DataType>() { + + override val code = (setRange.first.toByte() + itemsType.code).toByte() } @@ -134,4 +188,4 @@ object QGid : DataType() { } fun isListOfVals(list: List?) = - list == null || list.isEmpty() || list.firstOrNull()?.let { DataType.ofValue(it)?.value() } ?: true \ No newline at end of file + list == null || list.isEmpty() || list.firstOrNull()?.let { DataType.ofValue(it)?.value() } ?: true // TODO REFACTOR \ No newline at end of file diff --git a/qbit-core/src/commonMain/kotlin/qbit/api/model/Register.kt b/qbit-core/src/commonMain/kotlin/qbit/api/model/Register.kt new file mode 100644 index 00000000..7677932b --- /dev/null +++ b/qbit-core/src/commonMain/kotlin/qbit/api/model/Register.kt @@ -0,0 +1,14 @@ +package qbit.api.model + +import kotlinx.serialization.Serializable + +@Serializable +class Register( + private var entries: List +) { + fun getValues(): List = entries + + fun setValue(t: T) { + entries = listOf(t) + } +} \ No newline at end of file diff --git a/qbit-core/src/commonMain/kotlin/qbit/factoring/EntityGraph.kt b/qbit-core/src/commonMain/kotlin/qbit/factoring/EntityGraph.kt index dabf60ec..29a91b63 100644 --- a/qbit-core/src/commonMain/kotlin/qbit/factoring/EntityGraph.kt +++ b/qbit-core/src/commonMain/kotlin/qbit/factoring/EntityGraph.kt @@ -32,11 +32,11 @@ internal data class EntityBuilder( return DetachedEntity(gid!!, attrValues) } - private fun resolveRefs(attrVallue: Any, resolve: (Any) -> Gid): Any { + private fun resolveRefs(attrValue: Any, resolve: (Any) -> Gid): Any { return when { - attrVallue is Ref -> resolve(attrVallue.obj) - attrVallue is List<*> && attrVallue.firstOrNull() is Ref -> (attrVallue as List).map { resolve(it.obj) } - else -> attrVallue + attrValue is Ref -> resolve(attrValue.obj) + attrValue is List<*> && attrValue.firstOrNull() is Ref -> (attrValue as List).map { resolve(it.obj) } + else -> attrValue } } diff --git a/qbit-core/src/commonMain/kotlin/qbit/factoring/serializatoin/SerializationFactorizer.kt b/qbit-core/src/commonMain/kotlin/qbit/factoring/serializatoin/SerializationFactorizer.kt index e1e4ba93..ab0b522f 100644 --- a/qbit-core/src/commonMain/kotlin/qbit/factoring/serializatoin/SerializationFactorizer.kt +++ b/qbit-core/src/commonMain/kotlin/qbit/factoring/serializatoin/SerializationFactorizer.kt @@ -14,10 +14,7 @@ import kotlinx.serialization.modules.SerializersModule import kotlinx.serialization.modules.SerializersModuleCollector import qbit.api.QBitException import qbit.api.gid.Gid -import qbit.api.model.Attr -import qbit.api.model.Eav -import qbit.api.model.Entity -import qbit.api.model.Tombstone +import qbit.api.model.* import qbit.api.tombstone import qbit.collections.IdentityMap import qbit.factoring.* @@ -99,6 +96,18 @@ internal class EntityEncoder( ValueKind.REF_LIST -> { serializeRefList(value as Iterable) } + ValueKind.VALUE_REGISTER -> { + (value as Register).getValues() + } + ValueKind.REF_REGISTER -> { + serializeRefList((value as Register).getValues()) + } + ValueKind.VALUE_SET -> { + (value as Set).toList() + } + ValueKind.REF_SET -> { + serializeRefList(value as Set) + } } val fieldPointer = Pointer( @@ -185,7 +194,7 @@ internal class EntityEncoder( enum class ValueKind { - SCALAR_VALUE, SCALAR_REF, VALUE_LIST, REF_LIST; + SCALAR_VALUE, SCALAR_REF, VALUE_LIST, REF_LIST, VALUE_REGISTER, REF_REGISTER, VALUE_SET, REF_SET; companion object { fun of(descriptor: SerialDescriptor, index: Int, value: Any): ValueKind { @@ -194,7 +203,10 @@ enum class ValueKind { isScalarValue(value) -> { SCALAR_VALUE } - isScalarRef(elementDescriptor) -> { + isScalarRef( + elementDescriptor, + value + ) -> { SCALAR_REF } isValueList( @@ -209,6 +221,30 @@ enum class ValueKind { ) -> { REF_LIST } + isValueRegister( + elementDescriptor, + value + ) -> { + VALUE_REGISTER + } + isRefRegister( + elementDescriptor, + value + ) -> { + REF_REGISTER + } + isValueSet( + elementDescriptor, + value + ) -> { + VALUE_SET + } + isRefSet( + elementDescriptor, + value + ) -> { + REF_SET + } else -> { throw AssertionError("Writing primitive via encodeSerializableElement") } @@ -219,8 +255,8 @@ enum class ValueKind { // other primitive values are encoded directly via encodeXxxElement value is Gid || value is ByteArray - private fun isScalarRef(elementDescriptor: SerialDescriptor) = - elementDescriptor.kind == StructureKind.CLASS + private fun isScalarRef(elementDescriptor: SerialDescriptor, value: Any) = + elementDescriptor.kind == StructureKind.CLASS && value !is Register<*> private fun isValueList(elementDescriptor: SerialDescriptor, value: Any) = elementDescriptor.kind == StructureKind.LIST && @@ -231,6 +267,21 @@ enum class ValueKind { private fun isRefList(elementDescriptor: SerialDescriptor, value: Any) = elementDescriptor.kind == StructureKind.LIST && value is List<*> + private fun isValueRegister(elementDescriptor: SerialDescriptor, value: Any) = + value is Register<*> && //TODO DEDUPLICATE + (elementDescriptor.getElementDescriptor(0).getElementDescriptor(0).kind is PrimitiveKind || + elementDescriptor.getElementDescriptor(0).getElementDescriptor(0).kind == StructureKind.LIST) // ByteArray + + private fun isRefRegister(elementDescriptor: SerialDescriptor, value: Any) = // TODO REFACTOR + value is Register<*> && elementDescriptor.getElementDescriptor(0).getElementDescriptor(0).kind is StructureKind.CLASS + + private fun isValueSet(elementDescriptor: SerialDescriptor, value: Any) = + value is Set<*> && + (elementDescriptor.getElementDescriptor(0).kind is PrimitiveKind || + elementDescriptor.getElementDescriptor(0).kind == StructureKind.LIST) // ByteArray + + private fun isRefSet(elementDescriptor: SerialDescriptor, value: Any) = + value is Set<*> && elementDescriptor.getElementDescriptor(0).kind is StructureKind.CLASS } } diff --git a/qbit-core/src/commonMain/kotlin/qbit/index/IndexDb.kt b/qbit-core/src/commonMain/kotlin/qbit/index/IndexDb.kt index 96e79231..3ce8b9d5 100644 --- a/qbit-core/src/commonMain/kotlin/qbit/index/IndexDb.kt +++ b/qbit-core/src/commonMain/kotlin/qbit/index/IndexDb.kt @@ -50,9 +50,9 @@ class IndexDb( val attrValues = rawEntity.entries.map { val attr = schema[it.key] require(attr != null) { "There is no attribute with name ${it.key}" } - require(attr.list || it.value.size == 1) { "Corrupted ${attr.name} of $gid - it is scalar, but multiple values has been found: ${it.value}" } + require(attr.list || DataType.ofCode(attr.type)!!.isRegister() || DataType.ofCode(attr.type)!!.isSet() || it.value.size == 1) { "Corrupted ${attr.name} of $gid - it is scalar, but multiple values has been found: ${it.value}" } val value = - if (attr.list) it.value.map { e -> fixNumberType(attr, e) } + if (attr.list || DataType.ofCode(attr.type)!!.isRegister() || DataType.ofCode(attr.type)!!.isSet()) it.value.map { e -> fixNumberType(attr, e) } else fixNumberType(attr, it.value[0]) attr to value } diff --git a/qbit-core/src/commonMain/kotlin/qbit/resolving/ConflictResolving.kt b/qbit-core/src/commonMain/kotlin/qbit/resolving/ConflictResolving.kt index 4d3bdbc8..b296f7b7 100644 --- a/qbit-core/src/commonMain/kotlin/qbit/resolving/ConflictResolving.kt +++ b/qbit-core/src/commonMain/kotlin/qbit/resolving/ConflictResolving.kt @@ -2,10 +2,9 @@ package qbit.resolving import kotlinx.coroutines.flow.toList import qbit.api.Instances +import qbit.api.QBitException import qbit.api.gid.Gid -import qbit.api.model.Attr -import qbit.api.model.Eav -import qbit.api.model.Hash +import qbit.api.model.* import qbit.index.RawEntity import qbit.serialization.* import qbit.trx.TrxLog @@ -72,6 +71,67 @@ internal fun lastWriterWinsResolve(resolveAttrName: (String) -> Attr?): (Li } } +internal fun crdtResolve( + resolveEntity: (Gid) -> StoredEntity?, + resolveAttrName: (String) -> Attr? +): (List, List) -> List = { eavsFromA, eavsFromB -> + require(eavsFromA.isNotEmpty()) { "eavsFromA should be not empty" } + require(eavsFromB.isNotEmpty()) { "eavsFromB should be not empty" } + + val gid = eavsFromA[0].eav.gid + val attr = resolveAttrName(eavsFromA[0].eav.attr) + ?: throw IllegalArgumentException("Cannot resolve ${eavsFromA[0].eav.attr}") + + when { + // temporary dirty hack until crdt counter or custom resolution strategy support is implemented + attr == Instances.nextEid -> listOf((eavsFromA + eavsFromB).maxByOrNull { it.eav.value as Int }!!.eav) + attr.list -> (eavsFromA + eavsFromB).map { it.eav }.distinct() + DataType.ofCode(attr.type)!!.isCounter() -> { + val latestFromA = eavsFromA.maxByOrNull { it.timestamp }!!.eav.value + val latestFromB = eavsFromB.maxByOrNull { it.timestamp }!!.eav.value + val previous = resolveEntity(gid)?.tryGet(attr) + + listOf( + if (previous != null) + Eav( + eavsFromA[0].eav.gid, + eavsFromA[0].eav.attr, + if (previous is Byte && latestFromA is Byte && latestFromB is Byte) latestFromA + latestFromB - previous + else if (previous is Int && latestFromA is Int && latestFromB is Int) latestFromA + latestFromB - previous + else if (previous is Long && latestFromA is Long && latestFromB is Long) latestFromA + latestFromB - previous + else throw QBitException("Unexpected counter value type for eav with gid=$gid, attr=$attr") + ) + else + Eav( + eavsFromA[0].eav.gid, + eavsFromA[0].eav.attr, + if (latestFromA is Byte && latestFromB is Byte) latestFromA + latestFromB + else if (latestFromA is Int && latestFromB is Int) latestFromA + latestFromB + else if (latestFromA is Long && latestFromB is Long) latestFromA + latestFromB + else throw QBitException("Unexpected counter value type for eav with gid=$gid, attr=$attr") + ) + ) + } + DataType.ofCode(attr.type)!!.isRegister() -> { + val latestFromA = + eavsFromA.maxOf { it.timestamp }.let { timestamp -> eavsFromA.filter { it.timestamp == timestamp } } + val latestFromB = + eavsFromB.maxOf { it.timestamp }.let { timestamp -> eavsFromB.filter { it.timestamp == timestamp } } + + latestFromA.map { it.eav } + latestFromB.map { it.eav } + } + DataType.ofCode(attr.type)!!.isSet() -> { + val latestFromA = + eavsFromA.maxOf { it.timestamp }.let { timestamp -> eavsFromA.filter { it.timestamp == timestamp } } + val latestFromB = + eavsFromB.maxOf { it.timestamp }.let { timestamp -> eavsFromB.filter { it.timestamp == timestamp } } + + (latestFromA.map { it.eav } + latestFromB.map { it.eav }).distinctBy { it.value } + } + else -> listOf((eavsFromA + eavsFromB).maxByOrNull { it.timestamp }!!.eav) + } +} + internal fun findBaseNode(node1: Node, node2: Node, nodesDepth: Map): Node { return when { node1 == node2 -> node1 diff --git a/qbit-core/src/commonMain/kotlin/qbit/schema/SchemaDsl.kt b/qbit-core/src/commonMain/kotlin/qbit/schema/SchemaDsl.kt index 384abb6e..505d439b 100644 --- a/qbit-core/src/commonMain/kotlin/qbit/schema/SchemaDsl.kt +++ b/qbit-core/src/commonMain/kotlin/qbit/schema/SchemaDsl.kt @@ -24,7 +24,7 @@ class SchemaBuilder(private val serialModule: SerializersModule) { ?: throw QBitException("Cannot find descriptor for $type") val eb = EntityBuilder(descr) eb.body() - attrs.addAll(schemaFor(descr, eb.uniqueProps)) + attrs.addAll(schemaFor(descr, eb.uniqueProps, eb.counters)) } } @@ -33,6 +33,8 @@ class EntityBuilder(private val descr: SerialDescriptor) { internal val uniqueProps = HashSet() + internal val counters = HashSet() + fun uniqueInt(prop: KProperty1) { uniqueAttr(prop) } @@ -42,40 +44,69 @@ class EntityBuilder(private val descr: SerialDescriptor) { } private fun uniqueAttr(prop: KProperty1) { + uniqueProps.add(getAttrName(prop)) + } + + fun byteCounter(prop: KProperty1) { + counter(prop) + } + + fun intCounter(prop: KProperty1) { + counter(prop) + } + + fun longCounter(prop: KProperty1) { + counter(prop) + } + + private fun counter(prop: KProperty1) { + counters.add(getAttrName(prop)) + } + + private fun getAttrName(prop: KProperty1): String { val (idx, _) = descr.elementNames .withIndex().firstOrNull { (_, name) -> name == prop.name } ?: throw QBitException("Cannot find attr for ${prop.name} in $descr") - uniqueProps.add(AttrName(descr, idx).asString()) + return AttrName(descr, idx).asString() } } -fun schemaFor(rootDesc: SerialDescriptor, unique: Set = emptySet()): List> { +fun schemaFor(rootDesc: SerialDescriptor, unique: Set = emptySet(), counters: Set = emptySet()): List> { return rootDesc.elementDescriptors .withIndex() .filter { rootDesc.getElementName(it.index) !in setOf("id", "gid") } .map { (idx, desc) -> - val dataType = DataType.of(desc) val attr = AttrName(rootDesc, idx).asString() + val dataType = if (attr in counters) DataType.of(desc).counter() else DataType.of(desc) Attr(null, attr, dataType.code, attr in unique, dataType.isList()) } } private fun DataType.Companion.of(desc: SerialDescriptor): DataType<*> = when (desc.kind) { - StructureKind.CLASS -> QRef + StructureKind.CLASS -> { + when (desc.serialName) { + "qbit.api.model.Register" -> DataType.of(desc.getElementDescriptor(0).getElementDescriptor(0)).register() + else -> QRef + } + } StructureKind.LIST -> { val listElementDesc = desc.getElementDescriptor(0) - when (listElementDesc.kind) { - PrimitiveKind.BYTE -> { - when (desc.serialName) { - "kotlin.ByteArray", "kotlin.ByteArray?" -> QBytes - "kotlin.collections.ArrayList", "kotlin.collections.ArrayList?" -> QByte.list() - else -> throw AssertionError("Unexpected descriptor: ${desc.serialName}") + if(desc.serialName == "kotlin.collections.LinkedHashSet" || desc.serialName == "kotlin.collections.LinkedHashSet?") { + DataType.of(listElementDesc).set() + } else { + when (listElementDesc.kind) { + PrimitiveKind.BYTE -> { + when (desc.serialName) { + "kotlin.ByteArray", "kotlin.ByteArray?" -> QBytes + "kotlin.collections.ArrayList", "kotlin.collections.ArrayList?" -> QByte.list() + else -> throw AssertionError("Unexpected descriptor: ${desc.serialName}") + } } + StructureKind.LIST -> QBytes.list() + else -> DataType.of(listElementDesc).list() } - StructureKind.LIST -> QBytes.list() - else -> DataType.of(listElementDesc).list() } } PrimitiveKind.STRING -> QString diff --git a/qbit-core/src/commonMain/kotlin/qbit/serialization/Simple.kt b/qbit-core/src/commonMain/kotlin/qbit/serialization/Simple.kt index f198ee6f..622560d8 100644 --- a/qbit-core/src/commonMain/kotlin/qbit/serialization/Simple.kt +++ b/qbit-core/src/commonMain/kotlin/qbit/serialization/Simple.kt @@ -176,7 +176,7 @@ internal fun deserialize(ins: Input): Any { private fun readMark(ins: Input, expectedMark: DataType): Any { return when (expectedMark) { QBoolean -> (ins.readByte() == 1.toByte()) as T - QByte, QInt, QLong -> readLong(ins) as T + QByte, QInt, QLong, is QCounter<*> -> readLong(ins) as T QBytes -> readLong(ins).let { count -> readBytes(ins, count.toInt()) as T @@ -187,7 +187,7 @@ private fun readMark(ins: Input, expectedMark: DataType): Any { } QGid -> Gid(readLong(ins)) as T QRef -> throw AssertionError("Should never happen") - is QList<*> -> throw AssertionError("Should never happen") + is QList<*>, is QRegister<*>, is QSet<*> -> throw AssertionError("Should never happen") } } diff --git a/qbit-core/src/commonMain/kotlin/qbit/trx/CommitHandler.kt b/qbit-core/src/commonMain/kotlin/qbit/trx/CommitHandler.kt index 830fc897..2f8264c2 100644 --- a/qbit-core/src/commonMain/kotlin/qbit/trx/CommitHandler.kt +++ b/qbit-core/src/commonMain/kotlin/qbit/trx/CommitHandler.kt @@ -5,6 +5,6 @@ import qbit.index.InternalDb internal interface CommitHandler { - suspend fun update(trxLog: TrxLog, newLog: TrxLog, newDb: InternalDb) + suspend fun update(trxLog: TrxLog, baseDb: InternalDb, newLog: TrxLog, newDb: InternalDb) } \ No newline at end of file diff --git a/qbit-core/src/commonMain/kotlin/qbit/trx/Trx.kt b/qbit-core/src/commonMain/kotlin/qbit/trx/Trx.kt index f598b549..3c70cbed 100644 --- a/qbit-core/src/commonMain/kotlin/qbit/trx/Trx.kt +++ b/qbit-core/src/commonMain/kotlin/qbit/trx/Trx.kt @@ -64,8 +64,7 @@ internal class QTrx( val instance = factor(inst.copy(nextEid = gids.next().eid), curDb::attr, EmptyIterator) val newLog = trxLog.append(factsBuffer + instance) try { - base = curDb.with(instance) - commitHandler.update(trxLog, newLog, base) + commitHandler.update(trxLog, base, newLog, curDb.with(instance)) factsBuffer.clear() } catch (e: Throwable) { // todo clean up @@ -92,10 +91,10 @@ fun Entity.toFacts(): Collection = val type = DataType.ofCode(attr.type)!! @Suppress("UNCHECKED_CAST") when { - type.value() && !attr.list -> listOf(valToFacts(gid, attr, value)) - type.value() && attr.list -> listToFacts(gid, attr, value as List) - type.ref() && !attr.list -> listOf(refToFacts(gid, attr, value)) - type.ref() && attr.list -> refListToFacts(gid, attr, value as List) + type.value() && !(attr.list || DataType.ofCode(attr.type)!!.isRegister() || DataType.ofCode(attr.type)!!.isSet()) -> listOf(valToFacts(gid, attr, value)) + type.value() && (attr.list || DataType.ofCode(attr.type)!!.isRegister() || DataType.ofCode(attr.type)!!.isSet()) -> listToFacts(gid, attr, value as List) + type.ref() && !(attr.list || DataType.ofCode(attr.type)!!.isRegister() || DataType.ofCode(attr.type)!!.isSet()) -> listOf(refToFacts(gid, attr, value)) + type.ref() && (attr.list || DataType.ofCode(attr.type)!!.isRegister() || DataType.ofCode(attr.type)!!.isSet()) -> refListToFacts(gid, attr, value as List) else -> throw AssertionError("Unexpected attr kind: $attr") } } diff --git a/qbit-core/src/commonMain/kotlin/qbit/trx/Validation.kt b/qbit-core/src/commonMain/kotlin/qbit/trx/Validation.kt index 5b6d4bc3..f6a4d737 100644 --- a/qbit-core/src/commonMain/kotlin/qbit/trx/Validation.kt +++ b/qbit-core/src/commonMain/kotlin/qbit/trx/Validation.kt @@ -3,6 +3,7 @@ package qbit.trx import qbit.api.QBitException import qbit.api.db.attrIs import qbit.api.model.Attr +import qbit.api.model.DataType import qbit.api.model.Eav import qbit.index.InternalDb @@ -39,7 +40,7 @@ fun validate(db: InternalDb, facts: List, newAttrs: List> = emptyLi // check that scalar attrs has single fact facts.groupBy { it.gid to it.attr } - .filter { !factAttrs.getValue(it.key.second)!!.list } + .filter { factAttrs.getValue(it.key.second)!!.let { attr -> !attr.list && !DataType.ofCode(attr.type)!!.isRegister() && !DataType.ofCode(attr.type)!!.isSet() } } .forEach { if (it.value.size > 1) { throw QBitException("Duplicate facts $it for scalar attribute: ${it.value}") diff --git a/qbit-core/src/commonMain/kotlin/qbit/typing/ListDecoder.kt b/qbit-core/src/commonMain/kotlin/qbit/typing/ListDecoder.kt new file mode 100644 index 00000000..4a5fb22e --- /dev/null +++ b/qbit-core/src/commonMain/kotlin/qbit/typing/ListDecoder.kt @@ -0,0 +1,132 @@ +package qbit.typing + +import kotlinx.serialization.DeserializationStrategy +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.descriptors.PrimitiveKind +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.StructureKind +import kotlinx.serialization.encoding.CompositeDecoder +import kotlinx.serialization.encoding.Decoder +import kotlinx.serialization.modules.SerializersModule +import qbit.api.QBitException +import qbit.api.gid.Gid +import qbit.api.model.Attr +import qbit.api.model.StoredEntity +import qbit.factoring.serializatoin.AttrName + +@Suppress("UNCHECKED_CAST") +class ListDecoder( + val schema: (String) -> Attr<*>?, + val entity: StoredEntity, + private val elements: List, + override val serializersModule: SerializersModule, + private val cache: HashMap +) : StubDecoder() { + + private var isRefList = false + + private var indexCounter = 0 + + override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder { + val elementDescriptor = descriptor.getElementDescriptor(0) + val elementKind = elementDescriptor.kind + + isRefList = when { + isValueAttr(elementDescriptor) -> false + isRefAttr(elementDescriptor) -> true + else -> throw QBitException("$elementKind not yet supported") + } + + return this + } + + override fun decodeNullableSerializableElement( + descriptor: SerialDescriptor, + index: Int, + deserializer: DeserializationStrategy, + previousValue: T? + ): T? { + val element = elements[index] + return if(!isRefList) { + element as T? + } else { + decodeReferred(element as Gid, deserializer) as T? + } + } + + override fun decodeSerializableElement( + descriptor: SerialDescriptor, + index: Int, + deserializer: DeserializationStrategy, + previousValue: T?, + ): T { + return decodeNullableSerializableElement(descriptor, index, deserializer as DeserializationStrategy) as T + } + + private fun isValueAttr(elementDescriptor: SerialDescriptor): Boolean { + val listElementsKind = elementDescriptor.kind + return listElementsKind is PrimitiveKind || + listElementsKind is StructureKind.LIST // ByteArrays + } + + private fun isRefAttr(elementDescriptor: SerialDescriptor): Boolean { + return elementDescriptor.kind is StructureKind.CLASS + } + + private fun decodeReferred(gid: Gid, deserializer: DeserializationStrategy<*>): Any? { + val referee = entity.pull(gid) ?: throw QBitException("Dangling ref: $gid") + val decoder = EntityDecoder(schema, referee, serializersModule) + return cache.getOrPut(gid) { deserializer.deserialize(decoder) } + } + + private fun decodeElement(descriptor: SerialDescriptor, index: Int): T { + val attrName = AttrName(descriptor, index).asString() + return entity[schema(attrName) as Attr] + } + + override fun decodeElementIndex(descriptor: SerialDescriptor): Int { + return if (indexCounter < elements.size) indexCounter++ else CompositeDecoder.DECODE_DONE + } + + @ExperimentalSerializationApi + override fun decodeInline(inlineDescriptor: SerialDescriptor): Decoder { + return this + } + + @ExperimentalSerializationApi + override fun decodeInlineElement(descriptor: SerialDescriptor, index: Int): Decoder { + return this + } + + override fun decodeBooleanElement(descriptor: SerialDescriptor, index: Int): Boolean { + return decodeElement(descriptor, index) + } + + override fun decodeByteElement(descriptor: SerialDescriptor, index: Int): Byte { + return decodeElement(descriptor, index) + } + + override fun decodeCharElement(descriptor: SerialDescriptor, index: Int): Char { + return decodeElement(descriptor, index) + } + + override fun decodeDoubleElement(descriptor: SerialDescriptor, index: Int): Double { + return decodeElement(descriptor, index) + } + + override fun decodeFloatElement(descriptor: SerialDescriptor, index: Int): Float { + return decodeElement(descriptor, index) + } + + override fun decodeIntElement(descriptor: SerialDescriptor, index: Int): Int { + return decodeElement(descriptor, index) + } + + override fun decodeLongElement(descriptor: SerialDescriptor, index: Int): Long { + return decodeElement(descriptor, index) + } + + override fun decodeStringElement(descriptor: SerialDescriptor, index: Int): String { + return decodeElement(descriptor, index) + } +} \ No newline at end of file diff --git a/qbit-core/src/commonMain/kotlin/qbit/typing/RegisterDecoder.kt b/qbit-core/src/commonMain/kotlin/qbit/typing/RegisterDecoder.kt new file mode 100644 index 00000000..f62b4790 --- /dev/null +++ b/qbit-core/src/commonMain/kotlin/qbit/typing/RegisterDecoder.kt @@ -0,0 +1,102 @@ +package qbit.typing + +import kotlinx.serialization.DeserializationStrategy +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.encoding.CompositeDecoder +import kotlinx.serialization.encoding.Decoder +import kotlinx.serialization.modules.SerializersModule +import qbit.api.gid.Gid +import qbit.api.model.Attr +import qbit.api.model.StoredEntity +import qbit.factoring.serializatoin.AttrName + +@Suppress("UNCHECKED_CAST") +class RegisterDecoder( + val schema: (String) -> Attr<*>?, + val entity: StoredEntity, + private val elements: List, + override val serializersModule: SerializersModule, + private val cache: HashMap +) : StubDecoder() { + private var listDecoded = false + + override fun decodeNullableSerializableElement( + descriptor: SerialDescriptor, + index: Int, + deserializer: DeserializationStrategy, + previousValue: T? + ): T? { + val decoder = ListDecoder(schema, entity, elements, serializersModule, cache) + return deserializer.deserialize(decoder) + } + + override fun decodeSerializableElement( + descriptor: SerialDescriptor, + index: Int, + deserializer: DeserializationStrategy, + previousValue: T?, + ): T { + return decodeNullableSerializableElement(descriptor, index, deserializer as DeserializationStrategy) as T + } + + private fun decodeElement(descriptor: SerialDescriptor, index: Int): T { + val attrName = AttrName(descriptor, index).asString() + return entity[schema(attrName) as Attr] + } + + override fun decodeElementIndex(descriptor: SerialDescriptor): Int { + if(listDecoded) { + return CompositeDecoder.DECODE_DONE + } else { + listDecoded = true + return 0 + } + } + + override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder { + return this + } + + @ExperimentalSerializationApi + override fun decodeInline(inlineDescriptor: SerialDescriptor): Decoder { + return this + } + + @ExperimentalSerializationApi + override fun decodeInlineElement(descriptor: SerialDescriptor, index: Int): Decoder { + return this + } + + override fun decodeBooleanElement(descriptor: SerialDescriptor, index: Int): Boolean { + return decodeElement(descriptor, index) + } + + override fun decodeByteElement(descriptor: SerialDescriptor, index: Int): Byte { + return decodeElement(descriptor, index) + } + + override fun decodeCharElement(descriptor: SerialDescriptor, index: Int): Char { + return decodeElement(descriptor, index) + } + + override fun decodeDoubleElement(descriptor: SerialDescriptor, index: Int): Double { + return decodeElement(descriptor, index) + } + + override fun decodeFloatElement(descriptor: SerialDescriptor, index: Int): Float { + return decodeElement(descriptor, index) + } + + override fun decodeIntElement(descriptor: SerialDescriptor, index: Int): Int { + return decodeElement(descriptor, index) + } + + override fun decodeLongElement(descriptor: SerialDescriptor, index: Int): Long { + return decodeElement(descriptor, index) + } + + override fun decodeStringElement(descriptor: SerialDescriptor, index: Int): String { + return decodeElement(descriptor, index) + } +} \ No newline at end of file diff --git a/qbit-core/src/commonMain/kotlin/qbit/typing/SerializationTyping.kt b/qbit-core/src/commonMain/kotlin/qbit/typing/SerializationTyping.kt index a7d22cbc..6b0aaf30 100644 --- a/qbit-core/src/commonMain/kotlin/qbit/typing/SerializationTyping.kt +++ b/qbit-core/src/commonMain/kotlin/qbit/typing/SerializationTyping.kt @@ -2,9 +2,7 @@ package qbit.typing import kotlinx.serialization.DeserializationStrategy import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.descriptors.PrimitiveKind -import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.descriptors.StructureKind +import kotlinx.serialization.descriptors.* import kotlinx.serialization.encoding.CompositeDecoder import kotlinx.serialization.encoding.CompositeDecoder.Companion.DECODE_DONE import kotlinx.serialization.encoding.Decoder @@ -12,6 +10,7 @@ import kotlinx.serialization.modules.SerializersModule import qbit.api.QBitException import qbit.api.gid.Gid import qbit.api.model.Attr +import qbit.api.model.DataType import qbit.api.model.StoredEntity import qbit.factoring.serializatoin.AttrName import kotlin.reflect.KClass @@ -71,14 +70,22 @@ class EntityDecoder( } } - if (descriptor.kind == StructureKind.LIST && elementKind == StructureKind.CLASS) { - val decoder = EntityDecoder(schema, entity, serializersModule) - return deserializer.deserialize(decoder) - } - val attrName = AttrName(descriptor, index).asString() val attr: Attr = schema(attrName) as Attr? ?: throw QBitException("Corrupted entity $entity, there is no attr $attrName in schema") + val dataType = DataType.ofCode(attr.type)!! + + if(dataType.isList() || dataType.isSet()) { + val elements = entity.tryGet(attr) ?: return null // TODO CHECK NULLABILITY + val decoder = ListDecoder(schema, entity, elements as List, serializersModule, cache) + return deserializer.deserialize(decoder) + } + + if(dataType.isRegister()) { + val elements = entity.tryGet(attr) ?: return null + val decoder = RegisterDecoder(schema, entity, elements as List, serializersModule, cache) + return deserializer.deserialize(decoder) + } return when { isValueAttr(elementDescriptor) -> entity.tryGet(attr) @@ -95,50 +102,28 @@ class EntityDecoder( private fun decodeReferred( elementDescriptor: SerialDescriptor, attrName: String, - gids: Any?, + gid: Gid?, deserializer: DeserializationStrategy, ): Any? { when { - gids == null && elementDescriptor.isNullable -> return null - gids == null && !elementDescriptor.isNullable -> throw QBitException("Corrupted entity: $entity, no value for $attrName") + gid == null && elementDescriptor.isNullable -> return null + gid == null && !elementDescriptor.isNullable -> throw QBitException("Corrupted entity: $entity, no value for $attrName") } - check(gids != null) + check(gid != null) - val sureGids = when (gids) { - is Gid -> listOf(gids) - is List<*> -> gids as List - else -> throw AssertionError("Unexpected gids: $gids") - } - - val referreds = sureGids.map { - val referee = entity.pull(it) ?: throw QBitException("Dangling ref: $it") - val decoder = EntityDecoder(schema, referee, serializersModule) - val res = cache.getOrPut(it, { deserializer.deserialize(decoder) }) - if (res is List<*>) { - res[0] as T - } else { - res as T - } - } - - return when (elementDescriptor.kind) { - is StructureKind.CLASS -> referreds[0] - is StructureKind.LIST -> referreds - else -> throw AssertionError("Unexpected kind: ${elementDescriptor.kind}") - } + val referee = entity.pull(gid) ?: throw QBitException("Dangling ref: $gid") + val decoder = EntityDecoder(schema, referee, serializersModule) + return cache.getOrPut(gid) { deserializer.deserialize(decoder) } } private fun isValueAttr(elementDescriptor: SerialDescriptor): Boolean { val elementKind = elementDescriptor.kind val listElementsKind = elementDescriptor.takeIf { it.kind is StructureKind.LIST }?.getElementDescriptor(0)?.kind - return elementKind is PrimitiveKind || listElementsKind is PrimitiveKind || - listElementsKind is StructureKind.LIST // List of ByteArrays + return elementKind is PrimitiveKind || listElementsKind is PrimitiveKind } private fun isRefAttr(elementDescriptor: SerialDescriptor): Boolean { - val elementKind = elementDescriptor.kind - val listElementsKind = elementDescriptor.takeIf { it.kind is StructureKind.LIST }?.getElementDescriptor(0)?.kind - return elementKind is StructureKind.CLASS || listElementsKind is StructureKind.CLASS + return elementDescriptor.kind is StructureKind.CLASS } override fun decodeSerializableElement( diff --git a/qbit-core/src/commonTest/kotlin/qbit/ConnTest.kt b/qbit-core/src/commonTest/kotlin/qbit/ConnTest.kt index 29fb3db2..d945bb23 100644 --- a/qbit-core/src/commonTest/kotlin/qbit/ConnTest.kt +++ b/qbit-core/src/commonTest/kotlin/qbit/ConnTest.kt @@ -63,7 +63,7 @@ class ConnTest { ) val newLog = FakeTrxLog(storedLeaf.hash) - conn.update(conn.trxLog, newLog, EmptyDb) + conn.update(conn.trxLog, EmptyDb, newLog, EmptyDb) assertArrayEquals(newLog.hash.bytes, storage.load(Namespace("refs")["head"])) } diff --git a/qbit-core/src/commonTest/kotlin/qbit/FakeConn.kt b/qbit-core/src/commonTest/kotlin/qbit/FakeConn.kt index 1636e341..9b9d5adc 100644 --- a/qbit-core/src/commonTest/kotlin/qbit/FakeConn.kt +++ b/qbit-core/src/commonTest/kotlin/qbit/FakeConn.kt @@ -40,7 +40,7 @@ internal class FakeConn : Conn(), CommitHandler { override val head: Hash get() = TODO("not implemented") - override suspend fun update(trxLog: TrxLog, newLog: TrxLog, newDb: InternalDb) { + override suspend fun update(trxLog: TrxLog, baseDb: InternalDb, newLog: TrxLog, newDb: InternalDb) { updatesCalls++ } diff --git a/qbit-core/src/commonTest/kotlin/qbit/FunTest.kt b/qbit-core/src/commonTest/kotlin/qbit/FunTest.kt index c53945e2..22ae8285 100644 --- a/qbit-core/src/commonTest/kotlin/qbit/FunTest.kt +++ b/qbit-core/src/commonTest/kotlin/qbit/FunTest.kt @@ -4,6 +4,7 @@ import kotlinx.coroutines.delay import qbit.api.QBitException import qbit.api.db.* import qbit.api.gid.Gid +import qbit.api.model.Register import qbit.index.InternalDb import qbit.platform.runBlocking import qbit.serialization.CommonNodesStorage @@ -400,7 +401,7 @@ class FunTest { assertEquals(bomb.country, storedBomb.country) assertEquals(bomb.optCountry, storedBomb.optCountry) assertEquals( - listOf(Country(12884901889, "Country1", 0), Country(4294967383, "Country3", 2)), + listOf(Country(12884901889, "Country1", 0), Country(4294967388, "Country3", 2)), storedBomb.countiesList ) // todo: assertEquals(bomb.countriesListOpt, storedBomb.countriesListOpt) @@ -459,9 +460,9 @@ class FunTest { trx1.persist(eBrewer.copy(name = "Im different change")) val trx2 = conn.trx() trx2.persist(eCodd.copy(name = "Im change 2")) - delay(100) trx2.persist(pChen.copy(name = "Im different change")) trx1.commit() + delay(1) trx2.commit() conn.db { assertEquals("Im change 2", it.pull(eCodd.id!!)!!.name) @@ -540,6 +541,7 @@ class FunTest { ) ) trx1.commit() + delay(1) trx2.commit() conn.db { assertEquals("Im change 2", it.pull(eCodd.id!!)!!.name) @@ -574,4 +576,125 @@ class FunTest { assertEquals(Gid(nsk.id!!), trx2EntityAttrValues.first { it.attr.name == "City/region" }.value) } } + + @JsName("qbit_should_accumulate_concurrent_increments_of_counter") + @Test + fun `qbit should accumulate concurrent increments of counter`() { + runBlocking { + val conn = setupTestSchema() + val counter = IntCounterEntity(1, 10) + val trx = conn.trx() + trx.persist(counter) + trx.commit() + + val trx1 = conn.trx() + val trx2 = conn.trx() + trx1.persist(counter.copy(counter = 40)) + trx2.persist(counter.copy(counter = 70)) + trx1.commit() + trx2.commit() + + assertEquals(conn.db().pull(1)?.counter, 100) + } + } + + @JsName("qbit_should_keep_both_concurrent_writes_to_a_value_register") + @Test + fun `qbit should keep both concurrent writes to a value register`() { + runBlocking { + val conn = setupTestSchema() + conn.trx { + persist(IntRegisterEntity(1, Register(listOf(1)))) + } + assertEquals(conn.db().pull(1)?.register?.getValues(), listOf(1)) + + val trx1 = conn.trx() + val trx2 = conn.trx() + trx1.persist(IntRegisterEntity(1, Register(listOf(2)))) + trx2.persist(IntRegisterEntity(1, Register(listOf(3)))) + trx1.commit() + trx2.commit() + assertEquals(conn.db().pull(1)?.register?.getValues()?.sorted(), listOf(2, 3)) + + conn.trx { + persist(IntRegisterEntity(1, Register(listOf(4)))) + } + assertEquals(conn.db().pull(1)?.register?.getValues()?.sorted(), listOf(4)) + } + } + + @JsName("qbit_should_keep_both_concurrent_writes_to_a_ref_register") + @Test + fun `qbit should keep both concurrent writes to a ref register`() { + runBlocking { + val conn = setupTestSchema() + val sweden = Country(null, "Sweden", 10350000) + val norway = Country(null, "Norway", 5379000) + val denmark = Country(null, "Denmark", 5831000) + val finland = Country(null, "Finland", 5531000) + + conn.trx { + persist(CountryRegisterEntity(1, Register(listOf(sweden)))) + } + assertEquals(conn.db().pull(1)?.register?.getValues()?.map { it.copy(id = null) }, listOf(sweden)) + + val trx1 = conn.trx() + val trx2 = conn.trx() + trx1.persist(CountryRegisterEntity(1, Register(listOf(norway)))) + trx2.persist(CountryRegisterEntity(1, Register(listOf(denmark)))) + trx1.commit() + trx2.commit() + assertEquals(conn.db().pull(1)?.register?.getValues()?.map { it.copy(id = null) }?.sortedBy { it.name }, listOf(denmark, norway)) + + conn.trx { + persist(CountryRegisterEntity(1, Register(listOf(finland)))) + } + assertEquals(conn.db().pull(1)?.register?.getValues()?.map { it.copy(id = null) }, listOf(finland)) + } + } + + @JsName("qbit_should_merge_concurrent_writes_to_value_set") + @Test + fun `qbit should merge concurrent writes to value set`() { + runBlocking { + val conn = setupTestSchema() + conn.trx { + persist(IntSetEntity(1, setOf(1))) + } + assertEquals(conn.db().pull(1)?.set, setOf(1)) + + val trx1 = conn.trx() + val trx2 = conn.trx() + trx1.persist(IntSetEntity(1, setOf(1, 2))) + trx2.persist(IntSetEntity(1, setOf(1, 3))) + trx1.commit() + trx2.commit() + assertEquals(conn.db().pull(1)?.set, setOf(1, 2, 3)) + } + } + + @JsName("qbit_should_merge_concurrent_writes_to_ref_set") + @Test + fun `qbit should merge concurrent writes to ref set`() { + runBlocking { + val conn = setupTestSchema() + val sweden = Country(null, "Sweden", 10350000) + val norway = Country(null, "Norway", 5379000) + val denmark = Country(null, "Denmark", 5831000) + + conn.trx { + persist(CountrySetEntity(1, setOf(sweden))) + } + val persistedEntity = conn.db().pull(1) + assertEquals(persistedEntity?.set?.map { it.copy(id = null) }, listOf(sweden)) + + val trx1 = conn.trx() + val trx2 = conn.trx() + trx1.persist(CountrySetEntity(1, persistedEntity!!.set + norway)) + trx2.persist(CountrySetEntity(1, persistedEntity.set + denmark)) + trx1.commit() + trx2.commit() + assertEquals(conn.db().pull(1)?.set?.map { it.copy(id = null) }?.sortedBy { it.name }, listOf(denmark, norway, sweden)) + } + } } \ No newline at end of file diff --git a/qbit-core/src/commonTest/kotlin/qbit/TestSchema.kt b/qbit-core/src/commonTest/kotlin/qbit/TestSchema.kt index fc7d77a9..7978b685 100644 --- a/qbit-core/src/commonTest/kotlin/qbit/TestSchema.kt +++ b/qbit-core/src/commonTest/kotlin/qbit/TestSchema.kt @@ -6,6 +6,7 @@ import kotlinx.serialization.modules.plus import qbit.api.gid.Gid import qbit.api.gid.nextGids import qbit.api.model.Attr +import qbit.api.model.Register import qbit.factoring.AttrName import qbit.platform.collections.EmptyIterator import qbit.schema.schema @@ -19,8 +20,16 @@ import kotlin.reflect.KProperty1 @Serializable data class GidEntity(val id: Gid?, val bool: Boolean) +@Serializable +data class IntRegisterEntity(val id: Long?, val register: Register) + +@Serializable +data class CountryRegisterEntity(val id: Long?, val register: Register) + val internalTestsSerialModule = testsSerialModule + SerializersModule { contextual(GidEntity::class, GidEntity.serializer()) + contextual(IntRegisterEntity::class, IntRegisterEntity.serializer()) + contextual(CountryRegisterEntity::class, CountryRegisterEntity.serializer()) } val testSchema = schema(internalTestsSerialModule) { @@ -36,6 +45,13 @@ val testSchema = schema(internalTestsSerialModule) { entity(NullableList::class) entity(NullableRef::class) entity(IntEntity::class) + entity(IntCounterEntity::class) { + intCounter(IntCounterEntity::counter) + } + entity(IntSetEntity::class) + entity(CountrySetEntity::class) + entity(IntRegisterEntity::class) + entity(CountryRegisterEntity::class) entity(ResearchGroup::class) entity(EntityWithByteArray::class) entity(EntityWithListOfBytes::class) diff --git a/qbit-core/src/commonTest/kotlin/qbit/TestUtilsCore.kt b/qbit-core/src/commonTest/kotlin/qbit/TestUtilsCore.kt index 7df25f1b..89b10c40 100644 --- a/qbit-core/src/commonTest/kotlin/qbit/TestUtilsCore.kt +++ b/qbit-core/src/commonTest/kotlin/qbit/TestUtilsCore.kt @@ -142,7 +142,7 @@ inline fun > ListAttr(id: Gid?, name: Strin } suspend fun setupTestSchema(storage: Storage = MemStorage()): Conn { - val conn = qbit(storage, testsSerialModule) + val conn = qbit(storage, internalTestsSerialModule) conn.trx { testSchema.forEach { persist(it) diff --git a/qbit-core/src/commonTest/kotlin/qbit/TrxTest.kt b/qbit-core/src/commonTest/kotlin/qbit/TrxTest.kt index 840ccd9e..92d65d54 100644 --- a/qbit-core/src/commonTest/kotlin/qbit/TrxTest.kt +++ b/qbit-core/src/commonTest/kotlin/qbit/TrxTest.kt @@ -6,18 +6,16 @@ import qbit.api.Attrs import qbit.api.Instances import qbit.api.QBitException import qbit.api.db.Conn -import qbit.api.db.attrIs import qbit.api.db.pull -import qbit.api.db.query import qbit.api.gid.Gid import qbit.api.gid.nextGids -import qbit.api.model.Attr import qbit.api.system.Instance import qbit.ns.Key import qbit.ns.ns import qbit.platform.runBlocking import qbit.spi.Storage import qbit.storage.MemStorage +import qbit.test.model.IntCounterEntity import qbit.test.model.Region import qbit.test.model.Scientist import qbit.test.model.testsSerialModule @@ -176,6 +174,25 @@ class TrxTest { } } + @JsName("Counter_test") + @Test + fun `Counter test`() { // TODO: find an appropriate place for this test + runBlocking { + val conn = setupTestData() + val counterEntity = IntCounterEntity(1, 10) + + conn.trx { + persist(counterEntity) + } + assertEquals(conn.db().pull(1)?.counter, 10) + + conn.trx { + persist(counterEntity.copy(counter = 90)) + } + assertEquals(conn.db().pull(1)?.counter, 90) + } + } + private suspend fun openEmptyConn(): Pair { val storage = MemStorage() val conn = qbit(storage, testsSerialModule) diff --git a/qbit-test-fixtures/src/commonMain/kotlin/qbit/test/model/TestModels.kt b/qbit-test-fixtures/src/commonMain/kotlin/qbit/test/model/TestModels.kt index d88500fc..27b0dccb 100644 --- a/qbit-test-fixtures/src/commonMain/kotlin/qbit/test/model/TestModels.kt +++ b/qbit-test-fixtures/src/commonMain/kotlin/qbit/test/model/TestModels.kt @@ -9,6 +9,15 @@ data class TheSimplestEntity(val id: Long?, val scalar: String) @Serializable data class IntEntity(val id: Long?, val int: Int) +@Serializable +data class IntCounterEntity(val id: Long?, val counter: Int) + +@Serializable +data class IntSetEntity(val id: Long?, val set: Set) + +@Serializable +data class CountrySetEntity(val id: Long?, val set: Set) + @Serializable data class NullableIntEntity(val id: Long?, val int: Int?) @@ -307,6 +316,9 @@ val testsSerialModule = SerializersModule { contextual(ByteArrayEntity::class, ByteArrayEntity.serializer()) contextual(ListOfByteArraysEntity::class, ListOfByteArraysEntity.serializer()) contextual(IntEntity::class, IntEntity.serializer()) + contextual(IntCounterEntity::class, IntCounterEntity.serializer()) + contextual(IntSetEntity::class, IntSetEntity.serializer()) + contextual(CountrySetEntity::class, CountrySetEntity.serializer()) contextual(Region::class, Region.serializer()) contextual(ParentToChildrenTreeEntity::class, ParentToChildrenTreeEntity.serializer()) contextual(EntityWithRefsToSameType::class, EntityWithRefsToSameType.serializer())