Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import java.sql.ResultSet
import java.util.Locale
import org.jetbrains.kotlinx.dataframe.io.TableMetadata
import kotlin.reflect.KType
import kotlin.reflect.typeOf

/**
* Represents the H2 database type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema
import java.sql.ResultSet
import org.jetbrains.kotlinx.dataframe.io.TableMetadata
import kotlin.reflect.KType
import kotlin.reflect.typeOf
import kotlin.reflect.full.createType

/**
* Represents the MariaDb database type.
Expand All @@ -18,6 +18,10 @@ public object MariaDb : DbType("mariadb") {
get() = "org.mariadb.jdbc.Driver"

override fun convertSqlTypeToColumnSchemaValue(tableColumnMetadata: TableColumnMetadata): ColumnSchema? {
if(tableColumnMetadata.sqlTypeName == "SMALLINT") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't forget to lint

val kType = Short::class.createType(nullable = tableColumnMetadata.isNullable)
return ColumnSchema.Value(kType)
}
return null
}

Expand All @@ -33,6 +37,8 @@ public object MariaDb : DbType("mariadb") {
}

override fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? {
if(tableColumnMetadata.sqlTypeName == "SMALLINT")
return Short::class.createType(nullable = tableColumnMetadata.isNullable)
return null
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@ import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata
import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema
import java.sql.ResultSet
import java.util.Locale
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
import org.jetbrains.kotlinx.dataframe.io.TableMetadata
import kotlin.reflect.KType
import kotlin.reflect.typeOf
import kotlin.reflect.full.createType

/**
* Represents the MySql database type.
Expand All @@ -21,6 +19,10 @@ public object MySql : DbType("mysql") {
get() = "com.mysql.jdbc.Driver"

override fun convertSqlTypeToColumnSchemaValue(tableColumnMetadata: TableColumnMetadata): ColumnSchema? {
if(tableColumnMetadata.sqlTypeName == "INT UNSIGNED") {
val kType = Long::class.createType(nullable = tableColumnMetadata.isNullable)
return ColumnSchema.Value(kType)
}
return null
}

Expand Down Expand Up @@ -48,6 +50,8 @@ public object MySql : DbType("mysql") {
}

override fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? {
if(tableColumnMetadata.sqlTypeName == "INT UNSIGNED")
return Long::class.createType(nullable = tableColumnMetadata.isNullable)
return null
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@ import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata
import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema
import java.sql.ResultSet
import java.util.Locale
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
import org.jetbrains.kotlinx.dataframe.io.TableMetadata
import kotlin.reflect.KType
import kotlin.reflect.typeOf
import kotlin.reflect.full.createType

/**
* Represents the PostgreSql database type.
Expand All @@ -21,6 +19,12 @@ public object PostgreSql : DbType("postgresql") {
get() = "org.postgresql.Driver"

override fun convertSqlTypeToColumnSchemaValue(tableColumnMetadata: TableColumnMetadata): ColumnSchema? {
// TODO: could be a wrapper of convertSqlTypeToKType
if (tableColumnMetadata.sqlTypeName == "money") // because of https://github.com/pgjdbc/pgjdbc/issues/425
{
val kType = String::class.createType(nullable = tableColumnMetadata.isNullable)
return ColumnSchema.Value(kType)
}
return null
}

Expand All @@ -38,7 +42,7 @@ public object PostgreSql : DbType("postgresql") {

override fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? {
if(tableColumnMetadata.sqlTypeName == "money") // because of https://github.com/pgjdbc/pgjdbc/issues/425
return typeOf<String>()
return String::class.createType(nullable = tableColumnMetadata.isNullable)
return null
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema
import java.sql.ResultSet
import org.jetbrains.kotlinx.dataframe.io.TableMetadata
import kotlin.reflect.KType
import kotlin.reflect.typeOf

/**
* Represents the Sqlite database type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,16 @@ private const val MULTIPLE_SQL_QUERY_SEPARATOR = ";"
* @property [sqlTypeName] the SQL data type of the column.
* @property [jdbcType] the JDBC data type of the column produced from [java.sql.Types].
* @property [size] the size of the column.
* @property [javaClassName] the class name in Java.
* @property [isNullable] true if column could contain nulls.
*/
public data class TableColumnMetadata(
val name: String,
val sqlTypeName: String,
val jdbcType: Int,
val size: Int,
val isNullable: Boolean = false
val javaClassName: String,
val isNullable: Boolean = false,
)

/**
Expand Down Expand Up @@ -521,8 +523,9 @@ private fun getTableColumnsMetadata(rs: ResultSet): MutableList<TableColumnMetad
val size = metaData.getColumnDisplaySize(i)
val type = metaData.getColumnTypeName(i)
val jdbcType = metaData.getColumnType(i)
val javaClassName = metaData.getColumnClassName(i)

tableColumns += TableColumnMetadata(name, type, jdbcType, size, isNullable)
tableColumns += TableColumnMetadata(name, type, jdbcType, size, javaClassName, isNullable)
}
return tableColumns
}
Expand Down Expand Up @@ -642,8 +645,8 @@ private fun generateKType(dbType: DbType, tableColumnMetadata: TableColumnMetada
private fun makeCommonSqlToKTypeMapping(tableColumnMetadata: TableColumnMetadata): KType {
val jdbcTypeToKTypeMapping = mapOf(
Types.BIT to Boolean::class,
Types.TINYINT to Byte::class,
Types.SMALLINT to Short::class,
Types.TINYINT to Int::class,
Types.SMALLINT to Int::class,
Types.INTEGER to Int::class,
Types.BIGINT to Long::class,
Types.FLOAT to Float::class,
Expand Down Expand Up @@ -681,6 +684,7 @@ private fun makeCommonSqlToKTypeMapping(tableColumnMetadata: TableColumnMetadata
Types.TIME_WITH_TIMEZONE to Time::class,
Types.TIMESTAMP_WITH_TIMEZONE to Timestamp::class
)
// TODO: check mapping of JDBC types and classes correctly
val kClass = jdbcTypeToKTypeMapping[tableColumnMetadata.jdbcType] ?: String::class
return kClass.createType(nullable = tableColumnMetadata.isNullable)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import java.sql.Connection
import java.sql.DriverManager
import java.sql.ResultSet
import java.sql.SQLException
import org.jetbrains.kotlinx.dataframe.api.add
import org.jetbrains.kotlinx.dataframe.api.select
import kotlin.reflect.typeOf

private const val URL = "jdbc:h2:mem:test;DB_CLOSE_DELAY=-1;MODE=MySQL;DATABASE_TO_UPPER=false"
Expand Down Expand Up @@ -52,14 +54,14 @@ interface TestTableData {
val binaryVaryingCol: ByteArray?
val binaryLargeObjectCol: ByteArray?
val booleanCol: Boolean?
val tinyIntCol: Byte?
val smallIntCol: Short?
val tinyIntCol: Int?
val smallIntCol: Int?
val integerCol: Int?
val bigIntCol: Long?
val numericCol: Double?
val numericCol: BigDecimal?
val realCol: Float?
val doublePrecisionCol: Double?
val decFloatCol: Double?
val decFloatCol: BigDecimal?
val dateCol: String?
val timeCol: String?
val timeWithTimeZoneCol: String?
Expand Down Expand Up @@ -235,9 +237,56 @@ class JdbcTest {
""".trimIndent()
).executeUpdate()

val df = DataFrame.readSqlTable(connection, "TestTable").cast<TestTableData>()
val tableName = "TestTable"
val df = DataFrame.readSqlTable(connection, tableName).cast<TestTableData>()
df.rowsCount() shouldBe 3
df.filter { it[TestTableData::integerCol]!! > 1000 }.rowsCount() shouldBe 2

// testing numeric columns
val result = df.select("tinyIntCol")
.add("tinyIntCol2") { it[TestTableData::tinyIntCol] }

result[0][1] shouldBe 1

val result1 = df.select("smallIntCol")
.add("smallIntCol2") { it[TestTableData::smallIntCol] }

result1[0][1] shouldBe 100

val result2 = df.select("bigIntCol")
.add("bigIntCol2") { it[TestTableData::bigIntCol] }

result2[0][1] shouldBe 100000

val result3 = df.select("numericCol")
.add("numericCol2") { it[TestTableData::numericCol] }

BigDecimal("123.45").compareTo(result3[0][1] as BigDecimal) shouldBe 0

val result4 = df.select("realCol")
.add("realCol2") { it[TestTableData::realCol] }

result4[0][1] shouldBe 1.23f

val result5 = df.select("doublePrecisionCol")
.add("doublePrecisionCol2") { it[TestTableData::doublePrecisionCol] }

result5[0][1] shouldBe 3.14

val result6 = df.select("decFloatCol")
.add("decFloatCol2") { it[TestTableData::decFloatCol] }

BigDecimal("2.71").compareTo(result6[0][1] as BigDecimal) shouldBe 0

val schema = DataFrame.getSchemaForSqlTable(connection, tableName)

schema.columns["tinyIntCol"]!!.type shouldBe typeOf<Int?>()
schema.columns["smallIntCol"]!!.type shouldBe typeOf<Int?>()
schema.columns["bigIntCol"]!!.type shouldBe typeOf<Long?>()
schema.columns["numericCol"]!!.type shouldBe typeOf<BigDecimal?>()
schema.columns["realCol"]!!.type shouldBe typeOf<Float?>()
schema.columns["doublePrecisionCol"]!!.type shouldBe typeOf<Double?>()
schema.columns["decFloatCol"]!!.type shouldBe typeOf<BigDecimal?>()
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import io.kotest.matchers.shouldBe
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
import org.jetbrains.kotlinx.dataframe.api.cast
import org.jetbrains.kotlinx.dataframe.api.print
import org.junit.Test
import java.sql.DriverManager
import java.util.Properties
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import java.math.BigDecimal
import java.sql.Connection
import java.sql.DriverManager
import java.sql.SQLException
import org.jetbrains.kotlinx.dataframe.api.add
import org.jetbrains.kotlinx.dataframe.api.select
import org.junit.Ignore
import kotlin.reflect.typeOf

Expand All @@ -26,16 +28,16 @@ interface Table1MariaDb {
val id: Int
val bitCol: Boolean
val tinyintCol: Int
val smallintCol: Int
val smallintCol: Short?
val mediumintCol: Int
val mediumintUnsignedCol: Long
val mediumintUnsignedCol: Int
val integerCol: Int
val intCol: Int
val integerUnsignedCol: Long
val bigintCol: Long
val floatCol: Float
val doubleCol: Double
val decimalCol: Double
val decimalCol: BigDecimal
val dateCol: String
val datetimeCol: String
val timestampCol: String
Expand Down Expand Up @@ -64,7 +66,7 @@ interface Table2MariaDb {
val tinyintCol: Int?
val smallintCol: Int?
val mediumintCol: Int?
val mediumintUnsignedCol: Long?
val mediumintUnsignedCol: Int?
val integerCol: Int?
val intCol: Int?
val integerUnsignedCol: Long?
Expand Down Expand Up @@ -138,7 +140,7 @@ class MariadbTest {
id INT AUTO_INCREMENT PRIMARY KEY,
bitCol BIT NOT NULL,
tinyintCol TINYINT NOT NULL,
smallintCol SMALLINT NOT NULL,
smallintCol SMALLINT,
mediumintCol MEDIUMINT NOT NULL,
mediumintUnsignedCol MEDIUMINT UNSIGNED NOT NULL,
integerCol INTEGER NOT NULL,
Expand Down Expand Up @@ -387,4 +389,66 @@ class MariadbTest {
table2Df[0][11] shouldBe 20.0
table2Df[0][26] shouldBe null
}

@Test
fun `reading numeric types`() {
val df1 = DataFrame.readSqlTable(connection, "table1").cast<Table1MariaDb>()

val result = df1.select("tinyintCol")
.add("tinyintCol2") { it[Table1MariaDb::tinyintCol] }

result[0][1] shouldBe 1

val result1 = df1.select("smallintCol")
.add("smallintCol2") { it[Table1MariaDb::smallintCol] }

result1[0][1] shouldBe 10

val result2 = df1.select("mediumintCol")
.add("mediumintCol2") { it[Table1MariaDb::mediumintCol] }

result2[0][1] shouldBe 100

val result3 = df1.select("mediumintUnsignedCol")
.add("mediumintUnsignedCol2") { it[Table1MariaDb::mediumintUnsignedCol] }

result3[0][1] shouldBe 100

val result4 = df1.select("integerUnsignedCol")
.add("integerUnsignedCol2") { it[Table1MariaDb::integerUnsignedCol] }

result4[0][1] shouldBe 100L

val result5 = df1.select("bigintCol")
.add("bigintCol2") { it[Table1MariaDb::bigintCol] }

result5[0][1] shouldBe 100

val result6 = df1.select("floatCol")
.add("floatCol2") { it[Table1MariaDb::floatCol] }

result6[0][1] shouldBe 10.0f

val result7 = df1.select("doubleCol")
.add("doubleCol2") { it[Table1MariaDb::doubleCol] }

result7[0][1] shouldBe 10.0

val result8 = df1.select("decimalCol")
.add("decimalCol2") { it[Table1MariaDb::decimalCol] }

result8[0][1] shouldBe BigDecimal("10")

val schema = DataFrame.getSchemaForSqlTable(connection, "table1")

schema.columns["tinyintCol"]!!.type shouldBe typeOf<Int>()
schema.columns["smallintCol"]!!.type shouldBe typeOf<Short?>()
schema.columns["mediumintCol"]!!.type shouldBe typeOf<Int>()
schema.columns["mediumintUnsignedCol"]!!.type shouldBe typeOf<Int>()
schema.columns["integerUnsignedCol"]!!.type shouldBe typeOf<Long>()
schema.columns["bigintCol"]!!.type shouldBe typeOf<Long>()
schema.columns["floatCol"]!!.type shouldBe typeOf<Float>()
schema.columns["doubleCol"]!!.type shouldBe typeOf<Double>()
schema.columns["decimalCol"]!!.type shouldBe typeOf<BigDecimal>()
}
}
Loading