Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,10 @@ public abstract class DbType(public val dbTypeInJdbcUrl: String) {
tableColumnMetadata.jdbcType == Types.NUMERIC &&
tableColumnMetadata.javaClassName == "java.lang.Double" -> Double::class

// Force BIGINT to always be Long, regardless of javaClassName
// Some JDBC drivers (e.g., MariaDB) may report Integer for small BIGINT values
// TODO: tableColumnMetadata.jdbcType == Types.BIGINT -> Long::class

else -> jdbcTypeToKTypeMapping[tableColumnMetadata.jdbcType] ?: String::class
}

Expand All @@ -402,14 +406,22 @@ public abstract class DbType(public val dbTypeInJdbcUrl: String) {
/**
* Retrieves column metadata from a JDBC ResultSet.
*
* By default, this method reads column metadata from [ResultSetMetaData],
* which is fast and supported by most JDBC drivers.
* If the driver does not provide sufficient information (e.g., `isNullable` unknown),
* it falls back to using [DatabaseMetaData.getColumns] for affected columns.
* This method reads column metadata from [ResultSetMetaData] with graceful fallbacks
* for JDBC drivers that throw [java.sql.SQLFeatureNotSupportedException] for certain methods
* (e.g., Apache Hive).
*
* Fallback behavior for unsupported methods:
* - `getColumnName()` → `getColumnLabel()` → `"column_N"`
* - `getTableName()` → extract from column name if contains '.' → `null`
* - `isNullable()` → [DatabaseMetaData.getColumns] → `true` (assume nullable)
* - `getColumnTypeName()` → `"OTHER"`
* - `getColumnType()` → [java.sql.Types.OTHER]
* - `getColumnDisplaySize()` → `0`
* - `getColumnClassName()` → `"java.lang.Object"`
*
* Override this method in subclasses to provide database-specific behavior
* (for example, to disable fallback for databases like Teradata or Oracle
* where `DatabaseMetaData.getColumns` is known to be slow).
* where [DatabaseMetaData.getColumns] is known to be slow).
*
* @param resultSet The [ResultSet] containing query results.
* @return A list of [TableColumnMetadata] objects.
Expand All @@ -418,16 +430,44 @@ public abstract class DbType(public val dbTypeInJdbcUrl: String) {
val rsMetaData = resultSet.metaData
val connection = resultSet.statement.connection
val dbMetaData = connection.metaData
val catalog = connection.catalog.takeUnless { it.isNullOrBlank() }
val schema = connection.schema.takeUnless { it.isNullOrBlank() }

// Some JDBC drivers (e.g., Hive) throw SQLFeatureNotSupportedException
val catalog = try {
connection.catalog.takeUnless { it.isNullOrBlank() }
} catch (_: Exception) {
null
}

val schema = try {
connection.schema.takeUnless { it.isNullOrBlank() }
} catch (_: Exception) {
null
}

val columnCount = rsMetaData.columnCount
val columns = mutableListOf<TableColumnMetadata>()
val nameCounter = mutableMapOf<String, Int>()

for (index in 1..columnCount) {
val columnName = rsMetaData.getColumnName(index)
val tableName = rsMetaData.getTableName(index)
// Try to getColumnName, fallback to getColumnLabel, then generate name
val columnName = try {
rsMetaData.getColumnName(index)
} catch (_: Exception) {
try {
rsMetaData.getColumnLabel(index)
} catch (_: Exception) {
"column$index"
}
}

// Some JDBC drivers (e.g., Apache Hive) throw SQLFeatureNotSupportedException
val tableName = try {
rsMetaData.getTableName(index).takeUnless { it.isBlank() }
} catch (_: Exception) {
// Fallback: try to extract table name from column name if it contains '.'
val dotIndex = columnName.lastIndexOf('.')
if (dotIndex > 0) columnName.take(dotIndex) else null
}

// Try to detect nullability from ResultSetMetaData
val isNullable = try {
Expand All @@ -436,25 +476,48 @@ public abstract class DbType(public val dbTypeInJdbcUrl: String) {

ResultSetMetaData.columnNullable -> true

ResultSetMetaData.columnNullableUnknown -> {
// Unknown nullability: assume it nullable, may trigger fallback
true
}
// Unknown nullability: assume it nullable, may trigger fallback
ResultSetMetaData.columnNullableUnknown -> true

else -> true
}
} catch (_: Exception) {
// Some drivers may throw for unsupported features
// In that case, fallback to DatabaseMetaData
dbMetaData.getColumns(catalog, schema, tableName, columnName).use { cols ->
if (cols.next()) !cols.getString("IS_NULLABLE").equals("NO", ignoreCase = true) else true
// Try fallback to DatabaseMetaData, with additional safety
try {
dbMetaData.getColumns(catalog, schema, tableName, columnName).use { cols ->
if (cols.next()) !cols.getString("IS_NULLABLE").equals("NO", ignoreCase = true) else true
}
} catch (_: Exception) {
// Fallback failed, assume nullable as the safest default
true
}
}

val columnType = rsMetaData.getColumnTypeName(index)
val jdbcType = rsMetaData.getColumnType(index)
val displaySize = rsMetaData.getColumnDisplaySize(index)
val javaClassName = rsMetaData.getColumnClassName(index)
// adding fallbacks to avoid SQLException
val columnType = try {
rsMetaData.getColumnTypeName(index)
} catch (_: Exception) {
"OTHER"
}

val jdbcType = try {
rsMetaData.getColumnType(index)
} catch (_: Exception) {
Types.OTHER
}

val displaySize = try {
rsMetaData.getColumnDisplaySize(index)
} catch (_: Exception) {
0
}

val javaClassName = try {
rsMetaData.getColumnClassName(index)
} catch (_: Exception) {
"java.lang.Object"
}

val uniqueName = manageColumnNameDuplication(nameCounter, columnName)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,22 @@ public object MariaDb : DbType("mariadb") {
get() = "org.mariadb.jdbc.Driver"

override fun convertSqlTypeToColumnSchemaValue(tableColumnMetadata: TableColumnMetadata): ColumnSchema? {
// Force BIGINT to always be Long, regardless of javaClassName
// MariaDB JDBC driver may report Integer for small BIGINT values
// TODO: investigate the corner case

// if (tableColumnMetadata.jdbcType == java.sql.Types.BIGINT) {
// val kType = Long::class.createType(nullable = tableColumnMetadata.isNullable)
// return ColumnSchema.Value(kType)
// }

if (tableColumnMetadata.sqlTypeName == "INTEGER UNSIGNED" ||
tableColumnMetadata.sqlTypeName == "INT UNSIGNED"
) {
val kType = Long::class.createType(nullable = tableColumnMetadata.isNullable)
return ColumnSchema.Value(kType)
}

if (tableColumnMetadata.sqlTypeName == "SMALLINT" && tableColumnMetadata.javaClassName == "java.lang.Short") {
val kType = Short::class.createType(nullable = tableColumnMetadata.isNullable)
return ColumnSchema.Value(kType)
Expand All @@ -35,6 +51,19 @@ public object MariaDb : DbType("mariadb") {
)

override fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? {
// Force BIGINT to always be Long, regardless of javaClassName
// MariaDB JDBC driver may report Integer for small BIGINT values
// TODO: investigate the corner case
// if (tableColumnMetadata.jdbcType == java.sql.Types.BIGINT) {
// return Long::class.createType(nullable = tableColumnMetadata.isNullable)
// }

if (tableColumnMetadata.sqlTypeName == "INTEGER UNSIGNED" ||
tableColumnMetadata.sqlTypeName == "INT UNSIGNED"
) {
return Long::class.createType(nullable = tableColumnMetadata.isNullable)
}

if (tableColumnMetadata.sqlTypeName == "SMALLINT" && tableColumnMetadata.javaClassName == "java.lang.Short") {
return Short::class.createType(nullable = tableColumnMetadata.isNullable)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,48 +13,51 @@ import java.sql.Connection
import java.sql.ResultSet
import kotlin.reflect.typeOf

private const val TEST_TABLE_NAME = "testtable123"

internal fun inferNullability(connection: Connection) {
connection.createStatement().use { st -> st.execute("DROP TABLE IF EXISTS $TEST_TABLE_NAME") }

// prepare tables and data
@Language("SQL")
val createTestTable1Query = """
CREATE TABLE TestTable1 (
CREATE TABLE $TEST_TABLE_NAME (
id INT PRIMARY KEY,
name VARCHAR(50),
surname VARCHAR(50),
age INT NOT NULL
)
"""

connection.createStatement().execute(createTestTable1Query)
connection.createStatement().use { st -> st.execute(createTestTable1Query) }

connection.createStatement()
.execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (1, 'John', 'Crawford', 40)")
.execute("INSERT INTO $TEST_TABLE_NAME (id, name, surname, age) VALUES (1, 'John', 'Crawford', 40)")
connection.createStatement()
.execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (2, 'Alice', 'Smith', 25)")
.execute("INSERT INTO $TEST_TABLE_NAME (id, name, surname, age) VALUES (2, 'Alice', 'Smith', 25)")
connection.createStatement()
.execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (3, 'Bob', 'Johnson', 47)")
.execute("INSERT INTO $TEST_TABLE_NAME (id, name, surname, age) VALUES (3, 'Bob', 'Johnson', 47)")
connection.createStatement()
.execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (4, 'Sam', NULL, 15)")
.execute("INSERT INTO $TEST_TABLE_NAME (id, name, surname, age) VALUES (4, 'Sam', NULL, 15)")

// start testing `readSqlTable` method

// with default inferNullability: Boolean = true
val tableName = "TestTable1"
val df = DataFrame.readSqlTable(connection, tableName)
val df = DataFrame.readSqlTable(connection, TEST_TABLE_NAME)
df.schema().columns["id"]!!.type shouldBe typeOf<Int>()
df.schema().columns["name"]!!.type shouldBe typeOf<String>()
df.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df.schema().columns["age"]!!.type shouldBe typeOf<Int>()

val dataSchema = DataFrameSchema.readSqlTable(connection, tableName)
val dataSchema = DataFrameSchema.readSqlTable(connection, TEST_TABLE_NAME)
dataSchema.columns.size shouldBe 4
dataSchema.columns["id"]!!.type shouldBe typeOf<Int>()
dataSchema.columns["name"]!!.type shouldBe typeOf<String?>()
dataSchema.columns["surname"]!!.type shouldBe typeOf<String?>()
dataSchema.columns["age"]!!.type shouldBe typeOf<Int>()

// with inferNullability: Boolean = false
val df1 = DataFrame.readSqlTable(connection, tableName, inferNullability = false)
val df1 = DataFrame.readSqlTable(connection, TEST_TABLE_NAME, inferNullability = false)
df1.schema().columns["id"]!!.type shouldBe typeOf<Int>()

// this column changed a type because it doesn't contain nulls
Expand All @@ -70,7 +73,7 @@ internal fun inferNullability(connection: Connection) {
@Language("SQL")
val sqlQuery =
"""
SELECT name, surname, age FROM TestTable1
SELECT name, surname, age FROM $TEST_TABLE_NAME
""".trimIndent()

val df2 = DataFrame.readSqlQuery(connection, sqlQuery)
Expand All @@ -97,7 +100,7 @@ internal fun inferNullability(connection: Connection) {

connection.createStatement(ResultSet.TYPE_SCROLL_SENSITIVE, ResultSet.CONCUR_UPDATABLE).use { st ->
@Language("SQL")
val selectStatement = "SELECT * FROM TestTable1"
val selectStatement = "SELECT * FROM $TEST_TABLE_NAME"

st.executeQuery(selectStatement).use { rs ->
// ith default inferNullability: Boolean = true
Expand Down Expand Up @@ -130,7 +133,7 @@ internal fun inferNullability(connection: Connection) {
}
// end testing `readResultSet` method

connection.createStatement().execute("DROP TABLE TestTable1")
connection.createStatement().use { st -> st.execute("DROP TABLE IF EXISTS $TEST_TABLE_NAME") }
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ class MariadbTest {
val result = df1.filter { it[Table1MariaDb::id] == 1 }
result[0][26] shouldBe "textValue1"
val byteArray = "tinyblobValue".toByteArray()
(result[0][22] as Blob).getBytes(1, byteArray.size) contentEquals byteArray
result[0][22] shouldBe byteArray

val schema = DataFrameSchema.readSqlTable(connection, "table1")
schema.columns["id"]!!.type shouldBe typeOf<Int>()
Expand Down