diff --git a/backend/build.gradle.kts b/backend/build.gradle.kts index 66284b8..0b527e0 100644 --- a/backend/build.gradle.kts +++ b/backend/build.gradle.kts @@ -49,6 +49,9 @@ dependencies { implementation("net.datafaker:datafaker:2.1.0") implementation("org.springframework.boot:spring-boot-starter-actuator") testImplementation("org.mockito.kotlin:mockito-kotlin:5.2.1") + testImplementation("org.testcontainers:testcontainers") + testImplementation("org.testcontainers:mongodb") + testImplementation("org.testcontainers:junit-jupiter") } tasks.withType { diff --git a/backend/src/main/kotlin/com/opendatamask/adapter/output/connector/MongoDBConnector.kt b/backend/src/main/kotlin/com/opendatamask/adapter/output/connector/MongoDBConnector.kt index fe24927..9038c8a 100644 --- a/backend/src/main/kotlin/com/opendatamask/adapter/output/connector/MongoDBConnector.kt +++ b/backend/src/main/kotlin/com/opendatamask/adapter/output/connector/MongoDBConnector.kt @@ -4,6 +4,8 @@ import com.mongodb.ConnectionString import com.mongodb.MongoClientSettings import com.mongodb.client.MongoClient import com.mongodb.client.MongoClients +import com.mongodb.client.model.ReplaceOneModel +import com.mongodb.client.model.ReplaceOptions import org.bson.Document open class MongoDBConnector( @@ -81,7 +83,19 @@ open class MongoDBConnector( if (rows.isEmpty()) return 0 createMongoClient().use { client -> val collection = client.getDatabase(getDatabaseName()).getCollection(tableName) - collection.insertMany(rows.map { Document(it) }) + val (withId, withoutId) = rows.partition { it["_id"] != null } + if (withId.isNotEmpty()) { + val upserts = withId.map { row -> + val doc = Document(row) + val id = doc["_id"] + val filter = Document("_id", id) + ReplaceOneModel(filter, doc, ReplaceOptions().upsert(true)) + } + collection.bulkWrite(upserts) + } + if (withoutId.isNotEmpty()) { + collection.insertMany(withoutId.map { Document(it) }) + } } return rows.size } diff --git a/backend/src/main/kotlin/com/opendatamask/application/service/DataConnectionService.kt b/backend/src/main/kotlin/com/opendatamask/application/service/DataConnectionService.kt index da29a22..cc6f1ad 100644 --- a/backend/src/main/kotlin/com/opendatamask/application/service/DataConnectionService.kt +++ b/backend/src/main/kotlin/com/opendatamask/application/service/DataConnectionService.kt @@ -8,6 +8,7 @@ import com.opendatamask.domain.port.output.DataConnectionPort import com.opendatamask.domain.port.input.dto.ConnectionTestResult import com.opendatamask.domain.port.input.dto.DataConnectionRequest import com.opendatamask.domain.port.input.dto.DataConnectionResponse +import com.opendatamask.domain.model.ConnectionType import com.opendatamask.domain.model.DataConnection import org.springframework.stereotype.Service import org.springframework.transaction.annotation.Transactional @@ -21,11 +22,16 @@ class DataConnectionService( @Transactional override fun createConnection(workspaceId: Long, request: DataConnectionRequest): DataConnectionResponse { + val connStr = request.connectionString + if (connStr.isNullOrBlank()) { + throw IllegalArgumentException("Connection string is required when creating a connection") + } val connection = DataConnection( workspaceId = workspaceId, name = request.name, type = request.type, - connectionString = encryptionPort.encrypt(request.connectionString), + connectionString = encryptionPort.encrypt(connStr), + host = extractHost(request.type, connStr), username = request.username, password = request.password?.let { encryptionPort.encrypt(it) }, database = request.database, @@ -49,11 +55,23 @@ class DataConnectionService( @Transactional override fun updateConnection(workspaceId: Long, connectionId: Long, request: DataConnectionRequest): DataConnectionResponse { val connection = findConnection(workspaceId, connectionId) + // If the connector type is changing, a new connection string must be provided because the + // existing stored string is for the old type and would be invalid for the new one. + if (request.type != connection.type && request.connectionString.isNullOrBlank()) { + throw IllegalArgumentException( + "A new connection string is required when changing the connection type" + ) + } connection.name = request.name connection.type = request.type - connection.connectionString = encryptionPort.encrypt(request.connectionString) + if (!request.connectionString.isNullOrBlank()) { + connection.connectionString = encryptionPort.encrypt(request.connectionString) + connection.host = extractHost(request.type, request.connectionString) + } connection.username = request.username - connection.password = request.password?.let { encryptionPort.encrypt(it) } + if (!request.password.isNullOrBlank()) { + connection.password = encryptionPort.encrypt(request.password) + } connection.database = request.database connection.isSource = request.isSource connection.isDestination = request.isDestination @@ -99,11 +117,39 @@ class DataConnectionService( return connection } + // Extracts the host (and port) portion from a connection string for display purposes. + private fun extractHost(type: ConnectionType, connectionString: String): String? { + return try { + when (type) { + ConnectionType.POSTGRESQL, ConnectionType.MYSQL, ConnectionType.AZURE_SQL -> { + // JDBC URL formats: + // jdbc:postgresql://host:port/db + // jdbc:mysql://host:port/db + // jdbc:sqlserver://host:port;databaseName=db;... (semicolon-delimited params) + val afterSlashes = connectionString.substringAfter("//", "") + afterSlashes.substringBefore("/").substringBefore(";").substringBefore("?").ifBlank { null } + } + ConnectionType.MONGODB, ConnectionType.MONGODB_COSMOS -> { + // MongoDB URI: mongodb://[user:pass@]host:port[/db] + val afterSlashes = connectionString.substringAfter("//", "") + val hostPart = afterSlashes.substringBefore("/").substringBefore("?") + // Strip credentials (user:pass@) + val withoutCreds = if (hostPart.contains("@")) hostPart.substringAfter("@") else hostPart + withoutCreds.ifBlank { null } + } + else -> null + } + } catch (e: Exception) { + null + } + } + private fun DataConnection.toResponse() = DataConnectionResponse( id = id, workspaceId = workspaceId, name = name, type = type, + host = host, username = username, database = database, isSource = isSource, diff --git a/backend/src/main/kotlin/com/opendatamask/domain/model/DataConnection.kt b/backend/src/main/kotlin/com/opendatamask/domain/model/DataConnection.kt index 3f00d85..d6fc25d 100644 --- a/backend/src/main/kotlin/com/opendatamask/domain/model/DataConnection.kt +++ b/backend/src/main/kotlin/com/opendatamask/domain/model/DataConnection.kt @@ -27,6 +27,9 @@ class DataConnection( @Column(nullable = false, length = 2048) var connectionString: String, + @Column + var host: String? = null, + @Column var username: String? = null, diff --git a/backend/src/main/kotlin/com/opendatamask/domain/port/input/dto/DataConnectionDto.kt b/backend/src/main/kotlin/com/opendatamask/domain/port/input/dto/DataConnectionDto.kt index afac10c..0a383f3 100644 --- a/backend/src/main/kotlin/com/opendatamask/domain/port/input/dto/DataConnectionDto.kt +++ b/backend/src/main/kotlin/com/opendatamask/domain/port/input/dto/DataConnectionDto.kt @@ -12,8 +12,8 @@ data class DataConnectionRequest( @field:NotNull(message = "Connection type is required") val type: ConnectionType, - @field:NotBlank(message = "Connection string is required") - val connectionString: String, + // Null or blank means "keep existing connection string" on update + val connectionString: String? = null, val username: String? = null, val password: String? = null, @@ -27,6 +27,7 @@ data class DataConnectionResponse( val workspaceId: Long, val name: String, val type: ConnectionType, + val host: String?, val username: String?, val database: String?, val isSource: Boolean, diff --git a/backend/src/test/kotlin/com/opendatamask/adapter/input/rest/DataConnectionControllerTest.kt b/backend/src/test/kotlin/com/opendatamask/adapter/input/rest/DataConnectionControllerTest.kt index affccb6..b926d7d 100644 --- a/backend/src/test/kotlin/com/opendatamask/adapter/input/rest/DataConnectionControllerTest.kt +++ b/backend/src/test/kotlin/com/opendatamask/adapter/input/rest/DataConnectionControllerTest.kt @@ -41,7 +41,7 @@ class DataConnectionControllerTest { private fun makeResponse(id: Long = 1L, workspaceId: Long = 1L) = DataConnectionResponse( id = id, workspaceId = workspaceId, name = "My DB", type = ConnectionType.POSTGRESQL, - username = "user", database = null, isSource = true, isDestination = false, + host = "localhost:5432", username = "user", database = null, isSource = true, isDestination = false, createdAt = LocalDateTime.now() ) diff --git a/backend/src/test/kotlin/com/opendatamask/adapter/output/connector/MongoDBConnectorTest.kt b/backend/src/test/kotlin/com/opendatamask/adapter/output/connector/MongoDBConnectorTest.kt index 6a37571..3e4ff6b 100644 --- a/backend/src/test/kotlin/com/opendatamask/adapter/output/connector/MongoDBConnectorTest.kt +++ b/backend/src/test/kotlin/com/opendatamask/adapter/output/connector/MongoDBConnectorTest.kt @@ -4,6 +4,8 @@ import com.mongodb.client.FindIterable import com.mongodb.client.MongoClient import com.mongodb.client.MongoCollection import com.mongodb.client.MongoDatabase +import com.mongodb.client.model.ReplaceOneModel +import com.mongodb.client.model.WriteModel import org.bson.Document import org.junit.jupiter.api.Test import org.junit.jupiter.api.Assertions.* @@ -98,6 +100,64 @@ class MongoDBConnectorTest { val count = connector.writeData("users", rows) assertEquals(2, count) verify(mockCollection).insertMany(any()) + verify(mockCollection, never()).bulkWrite(any>>()) + } + + @Test + fun `writeData uses bulkWrite upsert when rows contain _id field`() { + val mockCollection = mock>() + val mockDb = mock() + val mockClient = mock() + whenever(mockClient.getDatabase("testdb")).thenReturn(mockDb) + whenever(mockDb.getCollection("users")).thenReturn(mockCollection) + + val connector = createConnector(mockClient) + val rows = listOf( + mapOf("_id" to "abc123", "name" to "Alice"), + mapOf("_id" to "def456", "name" to "Bob") + ) + val count = connector.writeData("users", rows) + assertEquals(2, count) + + val captor = argumentCaptor>>() + verify(mockCollection).bulkWrite(captor.capture()) + verify(mockCollection, never()).insertMany(any()) + + // Each ReplaceOneModel must have upsert=true and an _id filter + val models = captor.firstValue + assertEquals(2, models.size) + val expectedIds = setOf("abc123", "def456") + models.forEach { model -> + assertTrue(model is ReplaceOneModel<*>, + "Expected ReplaceOneModel but was ${model::class.simpleName}") + @Suppress("UNCHECKED_CAST") + val replaceModel = model as ReplaceOneModel + assertTrue(replaceModel.replaceOptions.isUpsert, + "ReplaceOneModel must have upsert=true") + val filter = replaceModel.filter as Document + assertTrue(filter.containsKey("_id"), "Filter must contain _id field") + assertTrue(expectedIds.contains(filter["_id"]), + "Filter _id must be one of the input row ids, got ${filter["_id"]}") + } + } + + @Test + fun `writeData handles mixed rows with and without _id`() { + val mockCollection = mock>() + val mockDb = mock() + val mockClient = mock() + whenever(mockClient.getDatabase("testdb")).thenReturn(mockDb) + whenever(mockDb.getCollection("users")).thenReturn(mockCollection) + + val connector = createConnector(mockClient) + val rows = listOf( + mapOf("_id" to "abc123", "name" to "Alice"), + mapOf("name" to "Bob") + ) + val count = connector.writeData("users", rows) + assertEquals(2, count) + verify(mockCollection).bulkWrite(any>>()) + verify(mockCollection).insertMany(any()) } @Test diff --git a/backend/src/test/kotlin/com/opendatamask/adapter/output/connector/MongoDBMaskingPipelineTest.kt b/backend/src/test/kotlin/com/opendatamask/adapter/output/connector/MongoDBMaskingPipelineTest.kt new file mode 100644 index 0000000..153cdb0 --- /dev/null +++ b/backend/src/test/kotlin/com/opendatamask/adapter/output/connector/MongoDBMaskingPipelineTest.kt @@ -0,0 +1,277 @@ +package com.opendatamask.adapter.output.connector + +import com.mongodb.client.MongoClients +import com.opendatamask.application.service.GeneratorService +import com.opendatamask.domain.model.ColumnGenerator +import com.opendatamask.domain.model.ConsistencyMode +import com.opendatamask.domain.model.GeneratorType +import org.junit.jupiter.api.Assertions.* +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.testcontainers.containers.MongoDBContainer +import org.testcontainers.junit.jupiter.Container +import org.testcontainers.junit.jupiter.Testcontainers +import org.testcontainers.utility.DockerImageName + +// --------------------------------------------------------------------------- +// MongoDBMaskingPipelineTest +// +// End-to-end pipeline using a real MongoDB instance (via Testcontainers): +// source collection -> read -> apply generators -> write to target -> verify +// --------------------------------------------------------------------------- +@Testcontainers(disabledWithoutDocker = true) +class MongoDBMaskingPipelineTest { + + companion object { + @Container + @JvmStatic + val mongoContainer: MongoDBContainer = MongoDBContainer(DockerImageName.parse("mongo:7.0")) + } + + private lateinit var source: MongoDBConnector + private lateinit var target: MongoDBConnector + private lateinit var generatorService: GeneratorService + + // Sample source data - a collection of customer documents + private val sourceCustomers = listOf( + mapOf("_id" to "c001", "name" to "Alice Smith", "email" to "alice@example.com", + "phone" to "555-111-2222", "ssn" to "123-45-6789", "account_status" to "active"), + mapOf("_id" to "c002", "name" to "Bob Jones", "email" to "bob@example.com", + "phone" to "555-333-4444", "ssn" to "987-65-4321", "account_status" to "inactive"), + mapOf("_id" to "c003", "name" to "Carol White", "email" to "carol@example.com", + "phone" to "555-555-6666", "ssn" to "111-22-3333", "account_status" to "active") + ) + + @BeforeEach + fun setUp() { + val uri = mongoContainer.connectionString + // Drop entire databases to guarantee a clean state for every test + MongoClients.create(uri).use { client -> + client.getDatabase("source_db").drop() + client.getDatabase("target_db").drop() + } + source = MongoDBConnector(uri, "source_db") + target = MongoDBConnector(uri, "target_db") + generatorService = GeneratorService("0123456789abcdef") + + // Seed source collection + source.writeData("customers", sourceCustomers) + } + + // ── Helper to build a ColumnGenerator without JPA machinery ───────────── + + private fun buildColumnGenerator( + tableConfigId: Long = 1L, + columnName: String, + type: GeneratorType, + params: String? = null, + consistencyMode: ConsistencyMode = ConsistencyMode.RANDOM + ) = ColumnGenerator( + id = 0, + tableConfigurationId = tableConfigId, + columnName = columnName, + generatorType = type, + generatorParams = params, + consistencyMode = consistencyMode + ) + + // ── 1. Basic masking pipeline ───────────────────────────────────────────── + + @Test + fun `mask pipeline replaces sensitive fields and preserves non-sensitive fields`() { + val generators = listOf( + buildColumnGenerator(columnName = "name", type = GeneratorType.FULL_NAME), + buildColumnGenerator(columnName = "email", type = GeneratorType.EMAIL), + buildColumnGenerator(columnName = "phone", type = GeneratorType.PHONE), + buildColumnGenerator(columnName = "ssn", type = GeneratorType.SSN) + ) + + val sourceRows = source.fetchData("customers") + assertEquals(3, sourceRows.size) + + val maskedRows = sourceRows.map { row -> generatorService.applyGenerators(row, generators) } + target.writeData("customers", maskedRows) + + val result = target.fetchData("customers") + assertEquals(3, result.size) + + // Build a lookup by _id so order-independence is maintained + val originalById = sourceCustomers.associateBy { it["_id"] } + result.forEach { row -> + val id = row["_id"] + val original = originalById[id] ?: fail("Unexpected _id in result: $id") + + // _id is preserved (not in generator list, passed through) + assertEquals(original["_id"], row["_id"], "_id should be preserved for $id") + + // account_status is not in generators, passed through unchanged + assertEquals(original["account_status"], row["account_status"], + "account_status should be passed through for $id") + + // Sensitive columns must be replaced with non-blank values + assertNotNull(row["name"], "masked name must not be null for $id") + assertNotNull(row["email"], "masked email must not be null for $id") + assertNotNull(row["phone"], "masked phone must not be null for $id") + assertNotNull(row["ssn"], "masked ssn must not be null for $id") + + assertTrue((row["name"] as String).isNotBlank(), "masked name must not be blank for $id") + assertTrue((row["email"] as String).isNotBlank(), "masked email must not be blank for $id") + assertTrue((row["phone"] as String).isNotBlank(), "masked phone must not be blank for $id") + + // Masked values must differ from originals (false-positive probability negligible) + assertNotEquals(original["name"], row["name"], "name must be masked for $id") + assertNotEquals(original["email"], row["email"], "email must be masked for $id") + } + } + + // ── 2. Passthrough mode - write source data as-is ───────────────────────── + + @Test + fun `passthrough mode copies all rows unchanged`() { + val sourceRows = source.fetchData("customers") + target.writeData("customers_passthrough", sourceRows) + + val result = target.fetchData("customers_passthrough") + assertEquals(sourceCustomers.size, result.size) + + val originalById = sourceCustomers.associateBy { it["_id"] } + result.forEach { row -> + val original = originalById[row["_id"]] ?: fail("Unexpected _id: ${row["_id"]}") + assertEquals(original["name"], row["name"]) + assertEquals(original["email"], row["email"]) + assertEquals(original["_id"], row["_id"]) + } + } + + // ── 3. Row limit is respected when fetching from source ─────────────────── + + @Test + fun `fetchData respects row limit`() { + val rows = source.fetchData("customers", limit = 2) + assertEquals(2, rows.size) + } + + // ── 4. Upsert path: masked rows with _id replace existing documents ──────── + + @Test + fun `upsert updates existing documents in target preserving _id`() { + // First pass - write original data to target + val sourceRows = source.fetchData("customers") + target.writeData("customers", sourceRows) + assertEquals(3, target.fetchData("customers").size) + + // Second pass - mask and upsert (should replace, not add) + val generators = listOf( + buildColumnGenerator(columnName = "name", type = GeneratorType.FULL_NAME), + buildColumnGenerator(columnName = "email", type = GeneratorType.EMAIL) + ) + val maskedRows = sourceRows.map { row -> generatorService.applyGenerators(row, generators) } + target.writeData("customers", maskedRows) + + // Still 3 documents (upsert replaced, did not duplicate) + val result = target.fetchData("customers") + assertEquals(3, result.size) + + // _id values are unchanged + val resultIds = result.map { it["_id"] }.toSet() + val expectedIds = setOf("c001", "c002", "c003") + assertEquals(expectedIds, resultIds) + + // Names were replaced by masked values + val resultNames = result.map { it["name"] }.toSet() + val originalNames = sourceCustomers.map { it["name"] }.toSet() + assertTrue( + resultNames.none { it in originalNames }, + "All names should have been replaced by masked values" + ) + } + + // ── 5. Mixed batch: rows with and without _id ────────────────────────────── + + @Test + fun `writeData handles mixed batch of rows with and without _id`() { + val mixedRows = listOf( + mapOf("_id" to "x001", "field" to "upsert-me"), + mapOf( "field" to "insert-me") + ) + val count = target.writeData("mixed", mixedRows) + assertEquals(2, count) + + val stored = target.fetchData("mixed") + assertEquals(2, stored.size) + assertEquals(1, stored.count { it.containsKey("_id") && it["_id"] == "x001" }) + assertEquals(1, stored.count { it["field"] == "insert-me" }) + } + + // ── 6. Truncate clears target before re-run ───────────────────────────────── + + @Test + fun `truncateTable then writeData yields clean result`() { + // Write all 3 rows once + target.writeData("customers", source.fetchData("customers")) + assertEquals(3, target.fetchData("customers").size) + + // Truncate and re-write a single row + target.truncateTable("customers") + assertEquals(0, target.fetchData("customers").size) + + target.writeData("customers", listOf(sourceCustomers.first())) + assertEquals(1, target.fetchData("customers").size) + } + + // ── 7. Consistent masking produces identical output for same input ────────── + + @Test + fun `CONSISTENT mode produces deterministic output for same original value`() { + val generator = buildColumnGenerator( + columnName = "name", + type = GeneratorType.FULL_NAME, + consistencyMode = ConsistencyMode.CONSISTENT + ) + val workspaceSecret = generatorService.computeWorkspaceSecret(42L) + + val row = mapOf("_id" to "c001", "name" to "Alice Smith") + val first = generatorService.applyGenerators(row, listOf(generator), workspaceSecret) + val second = generatorService.applyGenerators(row, listOf(generator), workspaceSecret) + + assertEquals(first["name"], second["name"], + "CONSISTENT mode must produce the same masked name for the same original value") + assertNotEquals("Alice Smith", first["name"], + "CONSISTENT masked value must differ from original") + } + + // ── 8. NULL generator nullifies the column ────────────────────────────────── + + @Test + fun `NULL generator sets column to null in target`() { + val generators = listOf( + buildColumnGenerator(columnName = "ssn", type = GeneratorType.NULL) + ) + val sourceRows = source.fetchData("customers") + val maskedRows = sourceRows.map { row -> generatorService.applyGenerators(row, generators) } + target.writeData("customers", maskedRows) + + val result = target.fetchData("customers") + assertEquals(3, result.size) + result.forEach { row -> + // The NULL generator sets ssn to null in the document map before writing. + // MongoDB stores a BSON null for the field; when read back, the value is null. + assertNull(row["ssn"], "SSN should be null for ${row["_id"]}") + // Other fields remain + assertNotNull(row["name"]) + assertNotNull(row["email"]) + } + } + + // ── 9. Empty source collection writes zero rows ────────────────────────────── + + @Test + fun `empty source collection results in zero rows written to target`() { + val empty = source.fetchData("nonexistent_collection") + assertEquals(0, empty.size) + + val written = target.writeData("output", empty) + assertEquals(0, written) + assertEquals(0, target.fetchData("output").size) + } +} diff --git a/backend/src/test/kotlin/com/opendatamask/application/service/DataConnectionServiceTest.kt b/backend/src/test/kotlin/com/opendatamask/application/service/DataConnectionServiceTest.kt index 2a033dc..8edc407 100644 --- a/backend/src/test/kotlin/com/opendatamask/application/service/DataConnectionServiceTest.kt +++ b/backend/src/test/kotlin/com/opendatamask/application/service/DataConnectionServiceTest.kt @@ -31,7 +31,7 @@ class DataConnectionServiceTest { private fun makeRequest( name: String = "My DB", type: ConnectionType = ConnectionType.POSTGRESQL, - connectionString: String = "jdbc:postgresql://localhost/db", + connectionString: String? = "jdbc:postgresql://localhost/db", username: String? = "user", password: String? = "pass", database: String? = null, @@ -87,6 +87,18 @@ class DataConnectionServiceTest { verify(EncryptionPort, times(1)).encrypt(any()) // Only conn string encrypted } + @Test + fun `createConnection throws for null connection string`() { + val request = makeRequest(connectionString = null) + assertThrows { service.createConnection(10L, request) } + } + + @Test + fun `createConnection throws for blank connection string`() { + val request = makeRequest(connectionString = " ") + assertThrows { service.createConnection(10L, request) } + } + // ── getConnection ────────────────────────────────────────────────────── @Test @@ -138,6 +150,48 @@ class DataConnectionServiceTest { assertTrue(result.isEmpty()) } + @Test + fun `createConnection extracts host from JDBC URL`() { + val request = makeRequest(connectionString = "jdbc:postgresql://myhost:5432/mydb") + whenever(EncryptionPort.encrypt(any())).thenReturn("encrypted") + val captor = argumentCaptor() + whenever(dataConnectionRepository.save(captor.capture())).thenAnswer { it.arguments[0] as DataConnection } + + service.createConnection(10L, request) + + assertEquals("myhost:5432", captor.firstValue.host) + } + + @Test + fun `createConnection extracts host from Azure SQL JDBC URL ignoring semicolon params`() { + val request = makeRequest( + type = ConnectionType.AZURE_SQL, + connectionString = "jdbc:sqlserver://myserver:1433;databaseName=mydb;encrypt=true" + ) + whenever(EncryptionPort.encrypt(any())).thenReturn("encrypted") + val captor = argumentCaptor() + whenever(dataConnectionRepository.save(captor.capture())).thenAnswer { it.arguments[0] as DataConnection } + + service.createConnection(10L, request) + + assertEquals("myserver:1433", captor.firstValue.host) + } + + @Test + fun `createConnection strips credentials from MongoDB URI when extracting host`() { + val request = makeRequest( + type = ConnectionType.MONGODB, + connectionString = "mongodb://user:pass@mycluster:27017/mydb" + ) + whenever(EncryptionPort.encrypt(any())).thenReturn("encrypted") + val captor = argumentCaptor() + whenever(dataConnectionRepository.save(captor.capture())).thenAnswer { it.arguments[0] as DataConnection } + + service.createConnection(10L, request) + + assertEquals("mycluster:27017", captor.firstValue.host) + } + // ── updateConnection ─────────────────────────────────────────────────── @Test @@ -155,6 +209,35 @@ class DataConnectionServiceTest { verify(dataConnectionRepository).save(any()) } + @Test + fun `updateConnection keeps existing connection string when connectionString is null`() { + val conn = makeConnection(id = 1L, workspaceId = 10L) + val updateRequest = makeRequest(name = "Updated DB", connectionString = null) + whenever(dataConnectionRepository.findById(1L)).thenReturn(Optional.of(conn)) + whenever(dataConnectionRepository.save(any())).thenAnswer { it.arguments[0] as DataConnection } + + service.updateConnection(10L, 1L, updateRequest) + + // encrypt should NOT be called for connection string (no new string provided) + verify(EncryptionPort, never()).encrypt(eq("jdbc:postgresql://localhost/db")) + } + + @Test + fun `updateConnection throws when type changes without new connection string`() { + val conn = makeConnection(id = 1L, workspaceId = 10L) // type=POSTGRESQL + // Attempt to switch to MONGODB without providing a new connection string + val updateRequest = makeRequest( + name = "Updated DB", + type = ConnectionType.MONGODB, + connectionString = null + ) + whenever(dataConnectionRepository.findById(1L)).thenReturn(Optional.of(conn)) + + assertThrows { + service.updateConnection(10L, 1L, updateRequest) + } + } + @Test fun `updateConnection throws when connection not found`() { whenever(dataConnectionRepository.findById(99L)).thenReturn(Optional.empty()) diff --git a/docs/mongodb-masking-example.yml b/docs/mongodb-masking-example.yml new file mode 100644 index 0000000..61afd60 --- /dev/null +++ b/docs/mongodb-masking-example.yml @@ -0,0 +1,116 @@ +# OpenDataMask – MongoDB Selective Attribute Masking Example +# +# This file shows how to export/import a workspace configuration that: +# 1. Reads documents from a MongoDB source collection. +# 2. Masks specific fields while leaving all other fields as-is ("passthrough"). +# 3. Writes the result to a target MongoDB collection (upsert on _id, insert otherwise). +# +# Prerequisites +# ------------- +# Before importing this config create two Data Connections in the workspace: +# +# Source connection +# Type : MONGODB +# URI : mongodb://readonly_user:secret@mongo-source.example.com:27017 +# Database (optional): myapp # can also be embedded in the URI +# Roles : Source ✓ / Destination ✗ +# +# Destination connection +# Type : MONGODB +# URI : mongodb://rw_user:secret@mongo-dest.example.com:27017 +# Database (optional): myapp_masked +# Roles : Source ✗ / Destination ✓ +# +# Connection URI tips +# ------------------- +# Standard standalone: mongodb://user:pass@host:27017/dbname +# Replica set: mongodb://user:pass@host1:27017,host2:27017/dbname?replicaSet=rs0 +# MongoDB Atlas (SRV): mongodb+srv://user:pass@cluster.mongodb.net/dbname +# Azure Cosmos DB: mongodb://account:key@account.mongo.cosmos.azure.com:10255/?ssl=true&replicaSet=globaldb&retrywrites=false +# +# How masking works +# ----------------- +# * Table name = MongoDB collection name. +# * Mode MASK : listed column generators replace field values; all other +# fields are transferred unchanged ("as-is"). +# * Mode PASSTHROUGH : every field is copied without modification. +# * Mode SKIP : the collection is not processed. +# * Nested paths (e.g. "address.street") are supported when the generator +# operates on the top-level key that contains the object. +# For deep-nested masking add one generator per top-level field path. + +version: "1.0" + +tables: + + # ── Collection: customers ────────────────────────────────────────────────── + - tableName: customers + schemaName: null # MongoDB has no schema concept; set to null + mode: MASK + rowLimit: null # null = process all documents + whereClause: null # null = no filter (process every document) + # MongoDB JSON filter example: '{"status": "active"}' + columnGenerators: + # PII fields to replace with realistic fake data + - columnName: firstName + generatorType: FIRST_NAME + generatorParams: null + + - columnName: lastName + generatorType: LAST_NAME + generatorParams: null + + - columnName: email + generatorType: EMAIL + generatorParams: null + + - columnName: phone + generatorType: PHONE + generatorParams: null + + - columnName: dateOfBirth + generatorType: BIRTH_DATE + generatorParams: null + + # Replace SSN with a NULL value (field remains present in masked output) + - columnName: ssn + generatorType: NULL + generatorParams: null + + # Partially mask a credit card number (keep last 4 digits) + - columnName: creditCardNumber + generatorType: PARTIAL_MASK + generatorParams: '{"visibleSuffix":"4"}' + + # ── Collection: orders ──────────────────────────────────────────────────── + # Copy the entire orders collection without any modifications. + - tableName: orders + schemaName: null + mode: PASSTHROUGH + rowLimit: null + whereClause: null + columnGenerators: [] + + # ── Collection: audit_logs ──────────────────────────────────────────────── + # Skip this collection entirely – it is not copied to the destination. + - tableName: audit_logs + schemaName: null + mode: SKIP + rowLimit: null + whereClause: null + columnGenerators: [] + + # ── Collection: products ────────────────────────────────────────────────── + # Subset: copy only documents matching the MongoDB filter. + - tableName: products + schemaName: null + mode: SUBSET + rowLimit: 1000 # cap at 1 000 documents + whereClause: '{"inStock": true}' + columnGenerators: [] + +# Post-job actions (optional) +actions: + - actionType: WEBHOOK + config: '{"url":"https://hooks.example.com/masking-done","method":"POST"}' + enabled: true diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 995aea8..72f429d 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -64,25 +64,24 @@ export interface DataConnection { workspaceId: number name: string type: ConnectionType - host: string - port: number - database: string - username: string - /** password is never returned from the API */ - sslEnabled: boolean + host: string | null + database: string | null + username: string | null + isSource: boolean + isDestination: boolean createdAt: string - updatedAt: string } export interface DataConnectionRequest { name: string type: ConnectionType - host: string - port: number - database: string - username: string - password: string - sslEnabled?: boolean + // Full connection string (MongoDB URI or JDBC URL). Null/omitted on update means "keep existing". + connectionString?: string + username?: string + password?: string + database?: string + isSource: boolean + isDestination: boolean } export interface ConnectionTestResult { diff --git a/frontend/src/views/ConnectionsView.vue b/frontend/src/views/ConnectionsView.vue index 59c30f5..09b0a9f 100644 --- a/frontend/src/views/ConnectionsView.vue +++ b/frontend/src/views/ConnectionsView.vue @@ -19,30 +19,110 @@ const testing = ref(null) const testResults = ref>({}) const formError = ref('') -const defaultForm = (): DataConnectionRequest => ({ +// Shared form fields +const form = ref({ name: '', type: ConnectionType.POSTGRESQL, + connectionString: '', host: 'localhost', port: 5432, database: '', username: '', password: '', - sslEnabled: false + sslEnabled: false, + isSource: false, + isDestination: false }) -const form = ref(defaultForm()) +// Track the type that was saved in the DB (used to detect type changes on edit) +const originalType = ref(null) -const connectionTypes = Object.values(ConnectionType) +// True when the user has changed SQL connection details (host/port/database/ssl) during edit. +// Only in this case do we rebuild and send the connection string on update to avoid overwriting +// custom JDBC params or unintentionally disabling SSL. +const sqlConnectionChanged = ref(false) -const defaultPorts: Record = { +// Types that use a direct connection string instead of host/port +const mongoTypes = new Set([ConnectionType.MONGODB, ConnectionType.MONGODB_COSMOS]) + +const isMongoType = computed(() => mongoTypes.has(form.value.type)) + +const defaultPorts: Partial> = { [ConnectionType.POSTGRESQL]: 5432, [ConnectionType.MONGODB]: 27017, [ConnectionType.AZURE_SQL]: 1433, - [ConnectionType.MONGODB_COSMOS]: 10255 + [ConnectionType.MONGODB_COSMOS]: 10255, + [ConnectionType.MYSQL]: 3306 +} + +// Labels shown in the type selector (exclude FILE from the standard form) +const displayConnectionTypes = Object.values(ConnectionType).filter( + (t) => t !== ConnectionType.FILE +) + +function typeLabel(t: ConnectionType) { + const labels: Record = { + [ConnectionType.POSTGRESQL]: 'PostgreSQL', + [ConnectionType.MONGODB]: 'MongoDB', + [ConnectionType.AZURE_SQL]: 'Azure SQL', + [ConnectionType.MONGODB_COSMOS]: 'MongoDB Cosmos', + [ConnectionType.FILE]: 'File', + [ConnectionType.MYSQL]: 'MySQL' + } + return labels[t] ?? t +} + +function resetForm() { + form.value = { + name: '', + type: ConnectionType.POSTGRESQL, + connectionString: '', + host: 'localhost', + port: defaultPorts[ConnectionType.POSTGRESQL] ?? 5432, + database: '', + username: '', + password: '', + sslEnabled: false, + isSource: false, + isDestination: false + } + sqlConnectionChanged.value = false + originalType.value = null } function onTypeChange() { - form.value.port = defaultPorts[form.value.type] + const port = defaultPorts[form.value.type] + if (port !== undefined) { + form.value.port = port + } + // Clear connection string whenever the type changes + form.value.connectionString = '' + // Changing type always means we need a new connection string + sqlConnectionChanged.value = true +} + +// Called by SQL input fields to mark connection details as changed +function onSqlConnectionChange() { + sqlConnectionChanged.value = true +} + +// Build a JDBC / MongoDB URI from the form fields +function buildConnectionString(): string { + const { type, host, port, database, sslEnabled } = form.value + switch (type) { + case ConnectionType.POSTGRESQL: + return `jdbc:postgresql://${host}:${port}/${database}${sslEnabled ? '?sslmode=require' : ''}` + case ConnectionType.MYSQL: + return `jdbc:mysql://${host}:${port}/${database}${sslEnabled ? '?useSSL=true' : ''}` + case ConnectionType.AZURE_SQL: + return `jdbc:sqlserver://${host}:${port};databaseName=${database};encrypt=true` + case ConnectionType.MONGODB: + case ConnectionType.MONGODB_COSMOS: + // User enters the full URI directly in connectionString field + return form.value.connectionString ?? '' + default: + return form.value.connectionString ?? '' + } } async function fetchConnections() { @@ -61,36 +141,137 @@ onMounted(fetchConnections) function openCreate() { editingConnection.value = null - form.value = defaultForm() + resetForm() formError.value = '' showModal.value = true } +function parseStoredSqlHost(hostValue: string | undefined, type: ConnectionType) { + let host = 'localhost' + let port = defaultPorts[type] ?? 5432 + + if (!hostValue || mongoTypes.has(type)) { + return { host, port } + } + + if (hostValue.startsWith('[')) { + // Bracketed IPv6 literal: [::1]:5432 + const closingBracket = hostValue.indexOf(']') + if (closingBracket !== -1) { + host = hostValue.slice(1, closingBracket) || 'localhost' + const remainder = hostValue.slice(closingBracket + 1) + if (remainder.startsWith(':')) { + const p = parseInt(remainder.slice(1), 10) + if (!isNaN(p)) port = p + } + } + } else { + const lastColon = hostValue.lastIndexOf(':') + if (lastColon !== -1) { + const p = parseInt(hostValue.slice(lastColon + 1), 10) + if (!isNaN(p)) { + host = hostValue.slice(0, lastColon) || 'localhost' + port = p + } else { + host = hostValue + } + } else { + host = hostValue + } + } + + return { host, port } +} + function openEdit(conn: DataConnection) { editingConnection.value = conn + originalType.value = conn.type + sqlConnectionChanged.value = false + // Parse host/port back from stored host string for SQL types (e.g. "localhost:5432") + const { host, port } = parseStoredSqlHost(conn.host, conn.type) form.value = { name: conn.name, type: conn.type, - host: conn.host, - port: conn.port, - database: conn.database, - username: conn.username, + connectionString: '', // Never pre-filled – must be re-entered if changed + host, + port, + database: conn.database ?? '', + username: conn.username ?? '', password: '', - sslEnabled: conn.sslEnabled + // sslEnabled is not separately stored; the user must re-check if they are rebuilding + // the connection string (i.e. when they change host/port/database/ssl) + sslEnabled: false, + isSource: conn.isSource, + isDestination: conn.isDestination } formError.value = '' showModal.value = true } +function validateForm(): boolean { + if (!form.value.name) { + formError.value = 'Connection name is required.' + return false + } + const typeChanged = editingConnection.value !== null && form.value.type !== originalType.value + if (isMongoType.value) { + // Require a URI on create, or when the type has changed to Mongo (old string is for a different type) + const uriRequired = !editingConnection.value || typeChanged + if (uriRequired && form.value.connectionString.trim().length === 0) { + formError.value = 'Connection URI is required.' + return false + } + } else { + // For SQL types, host, database and username are required + if (!form.value.host || !form.value.database || !form.value.username) { + formError.value = 'Host, database, and username are required.' + return false + } + if (!editingConnection.value && !form.value.password) { + formError.value = 'Password is required.' + return false + } + } + if (!form.value.isSource && !form.value.isDestination) { + formError.value = 'Select at least one role: Source or Destination.' + return false + } + return true +} + async function submitForm() { - if (!form.value.name || !form.value.host || !form.value.database || !form.value.username) { - formError.value = 'Please fill in all required fields.' - return + if (!validateForm()) return + + const payload: DataConnectionRequest = { + name: form.value.name, + type: form.value.type, + username: form.value.username || undefined, + database: form.value.database || undefined, + isSource: form.value.isSource, + isDestination: form.value.isDestination } - if (!editingConnection.value && !form.value.password) { - formError.value = 'Password is required.' - return + + const isCreate = !editingConnection.value + const typeChanged = editingConnection.value !== null && form.value.type !== originalType.value + + if (isMongoType.value) { + // For Mongo: include URI if provided (non-empty) + const uri = form.value.connectionString?.trim() + if (uri) { + payload.connectionString = uri + } + } else { + // For SQL: only rebuild and send the connection string when the user explicitly changed + // connection details, or when creating, or when the type changed. + if (isCreate || typeChanged || sqlConnectionChanged.value) { + payload.connectionString = buildConnectionString() + } + } + + if (form.value.password) { + payload.password = form.value.password } + saving.value = true formError.value = '' try { @@ -98,12 +279,12 @@ async function submitForm() { const updated = await connectionsApi.updateConnection( workspaceId.value, editingConnection.value.id, - form.value + payload ) const idx = connections.value.findIndex((c) => c.id === updated.id) if (idx !== -1) connections.value[idx] = updated } else { - const created = await connectionsApi.createConnection(workspaceId.value, form.value) + const created = await connectionsApi.createConnection(workspaceId.value, payload) connections.value.push(created) } showModal.value = false @@ -135,10 +316,6 @@ async function handleTest(conn: DataConnection) { testing.value = null } } - -function typeLabel(t: ConnectionType) { - return t.charAt(0) + t.slice(1).toLowerCase() -} + +