Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 61 additions & 7 deletions scalasql/core/src/DbApi.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,45 @@ object DbApi {
flattened.renderSql(castParams)
}

/**
* A listener that can be added to a [[DbApi.Txn]] to be notified of commit and rollback events.
*
* The default implementations of these methods do nothing, but you can override them to
* implement your own behavior.
*/
trait TransactionListener {

/**
* Called before the transaction is committed.
*
* If this method throws an exception, the transaction will be rolled back and the exception
* will be propagated.
*/
def beforeCommit(): Unit = ()

/**
* Called after the transaction is committed.
*
* If this method throws an exception, it will be propagated.
*/
def afterCommit(): Unit = ()

/**
* Called before the transaction is rolled back.
*
* If this method throws an exception, the transaction will be rolled back and the exception
* will be propagated to the caller of rollback().
*/
def beforeRollback(): Unit = ()

/**
* Called after the transaction is rolled back.
*
* If this method throws an exception, it will be propagated to the caller of rollback().
*/
def afterRollback(): Unit = ()
}

/**
* An interface to a SQL database *transaction*, allowing you to run queries,
* create savepoints, or roll back the transaction.
Expand All @@ -151,9 +190,11 @@ object DbApi {
def savepoint[T](block: DbApi.Savepoint => T): T

/**
* Tolls back any active Savepoints and then rolls back this Transaction
* Rolls back any active Savepoints and then rolls back this Transaction
*/
def rollback(): Unit

def addTransactionListener(listener: TransactionListener): Unit
}

/**
Expand Down Expand Up @@ -187,9 +228,16 @@ object DbApi {
connection: java.sql.Connection,
config: Config,
dialect: DialectConfig,
autoCommit: Boolean,
rollBack0: () => Unit
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed the rollBack0 argument since it was only used to handle the difference between autoCommit true/false behavior.

With the new interface, client code can now hook into the rollback process if needed.

autoCommit: Boolean
) extends DbApi.Txn {
val listeners = collection.mutable.ArrayDeque.empty[TransactionListener]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've considered making listeners thread-safe, but the safepoints member isn't so it seems more consistent this way. Client code that requires thread-safety bears the responsibility of synchronizing around mutations to listeners (and safepoints, as before)


override def addTransactionListener(listener: TransactionListener): Unit = {
if (autoCommit)
throw new IllegalStateException("Cannot add listener to auto-commit transaction")
listeners.append(listener)
}

def run[Q, R](query: Q, fetchSize: Int = -1, queryTimeoutSeconds: Int = -1)(
implicit qr: Queryable[Q, R],
fileName: sourcecode.FileName,
Expand Down Expand Up @@ -218,6 +266,7 @@ object DbApi {
res.toVector.asInstanceOf[R]
}
}

}

def stream[Q, R](query: Q, fetchSize: Int = -1, queryTimeoutSeconds: Int = -1)(
Expand All @@ -229,8 +278,8 @@ object DbApi {
streamFlattened0(
r => {
qr.asInstanceOf[Queryable[Q, R]].construct(query, r) match {
case s: Seq[R] => s.head
case r: R => r
case s: Seq[R] @unchecked => s.head
case r: R @unchecked => r
}
},
flattened,
Expand Down Expand Up @@ -545,8 +594,13 @@ object DbApi {
}

def rollback() = {
savepointStack.clear()
rollBack0()
try {
listeners.foreach(_.beforeRollback())
} finally {
savepointStack.clear()
connection.rollback()
listeners.foreach(_.afterRollback())
}
}

private def cast[T](t: Any): T = t.asInstanceOf[T]
Expand Down
39 changes: 31 additions & 8 deletions scalasql/core/src/DbClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,42 @@ object DbClient {

def transaction[T](block: DbApi.Txn => T): T = {
connection.setAutoCommit(false)
val txn =
new DbApi.Impl(connection, config, dialect, false, () => connection.rollback())
try block(txn)
catch {
val txn = new DbApi.Impl(connection, config, dialect, autoCommit = false)
var rolledBack = false
try {
val result = block(txn)
txn.listeners.foreach(_.beforeCommit())
result
} catch {
case e: Throwable =>
connection.rollback()
rolledBack = true
try {
txn.listeners.foreach(_.beforeRollback())
} catch {
case e2: Throwable =>
e.addSuppressed(e2)
} finally {
connection.rollback()
try {
txn.listeners.foreach(_.afterRollback())
} catch {
case e3: Throwable =>
e.addSuppressed(e3)
}
}
throw e
} finally connection.setAutoCommit(true)
} finally {
// this commits uncommitted operations, if any
connection.setAutoCommit(true)
if (!rolledBack) {
txn.listeners.foreach(_.afterCommit())
}
}
}

def getAutoCommitClientConnection: DbApi = {
connection.setAutoCommit(true)
new DbApi.Impl(connection, config, dialect, autoCommit = true, () => ())
new DbApi.Impl(connection, config, dialect, autoCommit = true)
}
}

Expand All @@ -88,7 +111,7 @@ object DbClient {
def getAutoCommitClientConnection: DbApi = {
val connection = dataSource.getConnection
connection.setAutoCommit(true)
new DbApi.Impl(connection, config, dialect, autoCommit = true, () => ())
new DbApi.Impl(connection, config, dialect, autoCommit = true)
}
}
}
103 changes: 103 additions & 0 deletions scalasql/test/src/api/TransactionTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package scalasql.api

import scalasql.Purchase
import scalasql.utils.{ScalaSqlSuite, SqliteSuite}
import scalasql.DbApi
import sourcecode.Text
import utest._

Expand All @@ -12,6 +13,37 @@ trait TransactionTests extends ScalaSqlSuite {
override def utestBeforeEach(path: Seq[String]): Unit = checker.reset()
class FooException extends Exception

class ListenerException(message: String) extends Exception(message)

class StubTransactionListener(
throwOnBeforeCommit: Boolean = false,
throwOnAfterCommit: Boolean = false,
throwOnBeforeRollback: Boolean = false,
throwOnAfterRollback: Boolean = false
) extends DbApi.TransactionListener {
var beforeCommitCalled = false
var afterCommitCalled = false
var beforeRollbackCalled = false
var afterRollbackCalled = false

override def beforeCommit(): Unit = {
beforeCommitCalled = true
if (throwOnBeforeCommit) throw new ListenerException("beforeCommit")
}
override def afterCommit(): Unit = {
afterCommitCalled = true
if (throwOnAfterCommit) throw new ListenerException("afterCommit")
}
override def beforeRollback(): Unit = {
beforeRollbackCalled = true
if (throwOnBeforeRollback) throw new ListenerException("beforeRollback")
}
override def afterRollback(): Unit = {
afterRollbackCalled = true
if (throwOnAfterRollback) throw new ListenerException("afterRollback")
}
}

def tests = Tests {
test("simple") {
test("commit") - checker.recorded(
Expand Down Expand Up @@ -537,5 +569,76 @@ trait TransactionTests extends ScalaSqlSuite {
}
}
}

test("listener") {
test("beforeCommit and afterCommit are called under normal circumstances") {
val listener = new StubTransactionListener()
dbClient.transaction { implicit txn =>
txn.addTransactionListener(listener)
}
listener.beforeCommitCalled ==> true
listener.afterCommitCalled ==> true
listener.beforeRollbackCalled ==> false
listener.afterRollbackCalled ==> false
}

test("if beforeCommit causes an exception, {before,after}Rollback are called") {
val listener = new StubTransactionListener(throwOnBeforeCommit = true)
val e = intercept[ListenerException] {
dbClient.transaction { implicit txn =>
txn.addTransactionListener(listener)
}
}
e.getMessage ==> "beforeCommit"
listener.beforeCommitCalled ==> true
listener.afterCommitCalled ==> false
listener.beforeRollbackCalled ==> true
listener.afterRollbackCalled ==> true
}

test("if afterCommit causes an exception, the exception is propagated") {
val listener = new StubTransactionListener(throwOnAfterCommit = true)
val e = intercept[ListenerException] {
dbClient.transaction { implicit txn =>
txn.addTransactionListener(listener)
}
}
e.getMessage ==> "afterCommit"
listener.beforeCommitCalled ==> true
listener.afterCommitCalled ==> true
listener.beforeRollbackCalled ==> false
listener.afterRollbackCalled ==> false
}

test("if beforeRollback causes an exception, afterRollback is still called") {
val listener = new StubTransactionListener(throwOnBeforeRollback = true)
val e = intercept[FooException] {
dbClient.transaction { implicit txn =>
txn.addTransactionListener(listener)
throw new FooException()
}
}
e.getSuppressed.head.getMessage ==> "beforeRollback"
listener.beforeCommitCalled ==> false
listener.afterCommitCalled ==> false
listener.beforeRollbackCalled ==> true
listener.afterRollbackCalled ==> true
}

test("if afterRollback causes an exception, the exception is propagated") {
val listener = new StubTransactionListener(throwOnAfterRollback = true)
val e = intercept[FooException] {
dbClient.transaction { implicit txn =>
txn.addTransactionListener(listener)
throw new FooException()
}
}
e.getSuppressed.head.getMessage ==> "afterRollback"
listener.beforeCommitCalled ==> false
listener.afterCommitCalled ==> false
listener.beforeRollbackCalled ==> true
listener.afterRollbackCalled ==> true
}
}
}
}