Skip to content

Commit 1c05f55

Browse files
committed
figured we cannot fix this without onInterrupt {}, reverted trials.
Found broken jupyter streaming test
1 parent a841611 commit 1c05f55

File tree

5 files changed

+155
-162
lines changed

5 files changed

+155
-162
lines changed

jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/Integration.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ abstract class Integration : JupyterIntegration() {
3939
*/
4040
abstract fun KotlinKernelHost.onLoaded()
4141

42+
abstract fun KotlinKernelHost.onShutdown()
43+
4244
abstract fun KotlinKernelHost.afterCellExecution(snippetInstance: Any, result: FieldValue)
4345

4446
open val dependencies: Array<String> = arrayOf(
@@ -93,6 +95,10 @@ abstract class Integration : JupyterIntegration() {
9395
afterCellExecution(snippetInstance, result)
9496
}
9597

98+
onShutdown {
99+
onShutdown()
100+
}
101+
96102
// Render Dataset
97103
render<Dataset<*>> {
98104
HTML(it.toHtml())

jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkIntegration.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,5 +70,9 @@ internal class SparkIntegration : Integration() {
7070
).map(::execute)
7171
}
7272

73+
override fun KotlinKernelHost.onShutdown() {
74+
execute("""spark.stop()""")
75+
}
76+
7377
override fun KotlinKernelHost.afterCellExecution(snippetInstance: Any, result: FieldValue) = Unit
7478
}

jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkStreamingIntegration.kt

Lines changed: 79 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -19,29 +19,10 @@
1919
*/
2020
package org.jetbrains.kotlinx.spark.api.jupyter
2121

22-
import kotlinx.html.*
23-
import kotlinx.html.stream.appendHTML
24-
import org.apache.spark.api.java.JavaRDDLike
25-
import org.apache.spark.rdd.RDD
26-
import org.apache.spark.sql.Dataset
27-
import org.apache.spark.unsafe.array.ByteArrayMethods
28-
import org.intellij.lang.annotations.Language
29-
import org.jetbrains.kotlinx.jupyter.api.HTML
30-
import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterIntegration
31-
import org.jetbrains.kotlinx.spark.api.*
32-
import java.io.InputStreamReader
33-
3422

35-
import org.apache.spark.*
36-
import org.apache.spark.streaming.api.java.JavaStreamingContext
23+
import org.intellij.lang.annotations.Language
3724
import org.jetbrains.kotlinx.jupyter.api.FieldValue
3825
import org.jetbrains.kotlinx.jupyter.api.KotlinKernelHost
39-
import scala.collection.*
40-
import org.jetbrains.kotlinx.spark.api.SparkSession
41-
import scala.Product
42-
import java.io.Serializable
43-
import scala.collection.Iterable as ScalaIterable
44-
import scala.collection.Iterator as ScalaIterator
4526

4627
/**
4728
* %use spark-streaming
@@ -60,90 +41,87 @@ internal class SparkStreamingIntegration : Integration() {
6041

6142
@Language("kts")
6243
val _1 = listOf(
63-
"""
64-
val sscCollection = mutableSetOf<JavaStreamingContext>()
65-
""".trimIndent(),
66-
"""
67-
@JvmOverloads
68-
fun withSparkStreaming(
69-
batchDuration: Duration = Durations.seconds(1L),
70-
checkpointPath: String? = null,
71-
hadoopConf: Configuration = SparkHadoopUtil.get().conf(),
72-
createOnError: Boolean = false,
73-
props: Map<String, Any> = emptyMap(),
74-
master: String = SparkConf().get("spark.master", "local[*]"),
75-
appName: String = "Kotlin Spark Sample",
76-
timeout: Long = -1L,
77-
startStreamingContext: Boolean = true,
78-
func: KSparkStreamingSession.() -> Unit,
79-
) {
80-
var ssc: JavaStreamingContext? = null
81-
try {
82-
83-
// will only be set when a new context is created
84-
var kSparkStreamingSession: KSparkStreamingSession? = null
85-
86-
val creatingFunc = {
87-
val sc = SparkConf()
88-
.setAppName(appName)
89-
.setMaster(master)
90-
.setAll(
91-
props
92-
.map { (key, value) -> key X value.toString() }
93-
.asScalaIterable()
94-
)
95-
96-
val ssc1 = JavaStreamingContext(sc, batchDuration)
97-
ssc1.checkpoint(checkpointPath)
98-
99-
kSparkStreamingSession = KSparkStreamingSession(ssc1)
100-
func(kSparkStreamingSession!!)
101-
102-
ssc1
103-
}
104-
105-
ssc = when {
106-
checkpointPath != null ->
107-
JavaStreamingContext.getOrCreate(checkpointPath, creatingFunc, hadoopConf, createOnError)
108-
109-
else -> creatingFunc()
110-
}
111-
112-
sscCollection += ssc!!
113-
114-
if (startStreamingContext) {
115-
ssc!!.start()
116-
kSparkStreamingSession?.invokeRunAfterStart()
117-
}
118-
ssc!!.awaitTerminationOrTimeout(timeout)
119-
} finally {
120-
ssc?.stop()
121-
println("stopping ssc")
122-
ssc?.awaitTermination()
123-
println("ssc stopped")
124-
ssc?.let(sscCollection::remove)
125-
}
126-
}
127-
""".trimIndent(),
44+
// For when onInterrupt is implemented in the Jupyter kernel
45+
// """
46+
// val sscCollection = mutableSetOf<JavaStreamingContext>()
47+
// """.trimIndent(),
48+
// """
49+
// @JvmOverloads
50+
// fun withSparkStreaming(
51+
// batchDuration: Duration = Durations.seconds(1L),
52+
// checkpointPath: String? = null,
53+
// hadoopConf: Configuration = SparkHadoopUtil.get().conf(),
54+
// createOnError: Boolean = false,
55+
// props: Map<String, Any> = emptyMap(),
56+
// master: String = SparkConf().get("spark.master", "local[*]"),
57+
// appName: String = "Kotlin Spark Sample",
58+
// timeout: Long = -1L,
59+
// startStreamingContext: Boolean = true,
60+
// func: KSparkStreamingSession.() -> Unit,
61+
// ) {
62+
//
63+
// // will only be set when a new context is created
64+
// var kSparkStreamingSession: KSparkStreamingSession? = null
65+
//
66+
// val creatingFunc = {
67+
// val sc = SparkConf()
68+
// .setAppName(appName)
69+
// .setMaster(master)
70+
// .setAll(
71+
// props
72+
// .map { (key, value) -> key X value.toString() }
73+
// .asScalaIterable()
74+
// )
75+
//
76+
// val ssc = JavaStreamingContext(sc, batchDuration)
77+
// ssc.checkpoint(checkpointPath)
78+
//
79+
// kSparkStreamingSession = KSparkStreamingSession(ssc)
80+
// func(kSparkStreamingSession!!)
81+
//
82+
// ssc
83+
// }
84+
//
85+
// val ssc = when {
86+
// checkpointPath != null ->
87+
// JavaStreamingContext.getOrCreate(checkpointPath, creatingFunc, hadoopConf, createOnError)
88+
//
89+
// else -> creatingFunc()
90+
// }
91+
// sscCollection += ssc
92+
//
93+
// if (startStreamingContext) {
94+
// ssc.start()
95+
// kSparkStreamingSession?.invokeRunAfterStart()
96+
// }
97+
// ssc.awaitTerminationOrTimeout(timeout)
98+
// ssc.stop()
99+
// }
100+
// """.trimIndent(),
128101
"""
129102
println("To start a spark streaming session, simply use `withSparkStreaming { }` inside a cell. To use Spark normally, use `withSpark { }` in a cell, or use `%use spark` to start a Spark session for the whole notebook.")""".trimIndent(),
130103
).map(::execute)
131104
}
132105

133-
override fun KotlinKernelHost.afterCellExecution(snippetInstance: Any, result: FieldValue) {
134-
135-
@Language("kts")
136-
val _1 = listOf(
137-
"""
138-
while (sscCollection.isNotEmpty())
139-
sscCollection.first().let {
140-
it.stop()
141-
sscCollection.remove(it)
142-
}
143-
""".trimIndent(),
144-
"""
145-
println("afterCellExecution cleanup!")
146-
""".trimIndent()
147-
).map(::execute)
148-
}
106+
override fun KotlinKernelHost.onShutdown() = Unit
107+
108+
override fun KotlinKernelHost.afterCellExecution(snippetInstance: Any, result: FieldValue) = Unit
109+
110+
// For when this feature is implemented in the Jupyter kernel
111+
// override fun KotlinKernelHost.onInterrupt() {
112+
//
113+
// @Language("kts")
114+
// val _1 = listOf(
115+
// """
116+
// while (sscCollection.isNotEmpty())
117+
// sscCollection.first().let {
118+
// it.stop()
119+
// sscCollection.remove(it)
120+
// }
121+
// """.trimIndent(),
122+
// """
123+
// println("onInterrupt cleanup!")
124+
// """.trimIndent()
125+
// ).map(::execute)
126+
// }
149127
}

jupyter/src/test/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/JupyterTests.kt

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,13 @@ package org.jetbrains.kotlinx.spark.api.jupyter
2121

2222
import io.kotest.assertions.throwables.shouldThrowAny
2323
import io.kotest.core.spec.style.ShouldSpec
24-
import io.kotest.matchers.collections.shouldBeIn
2524
import io.kotest.matchers.nulls.shouldNotBeNull
2625
import io.kotest.matchers.shouldBe
2726
import io.kotest.matchers.shouldNotBe
2827
import io.kotest.matchers.string.shouldContain
2928
import io.kotest.matchers.types.shouldBeInstanceOf
3029
import jupyter.kotlin.DependsOn
3130
import org.apache.spark.api.java.JavaSparkContext
32-
import org.apache.spark.streaming.Duration
3331
import org.apache.spark.streaming.api.java.JavaStreamingContext
3432
import org.intellij.lang.annotations.Language
3533
import org.jetbrains.kotlinx.jupyter.EvalRequestData
@@ -41,11 +39,8 @@ import org.jetbrains.kotlinx.jupyter.libraries.EmptyResolutionInfoProvider
4139
import org.jetbrains.kotlinx.jupyter.repl.EvalResultEx
4240
import org.jetbrains.kotlinx.jupyter.testkit.ReplProvider
4341
import org.jetbrains.kotlinx.jupyter.util.PatternNameAcceptanceRule
44-
import org.jetbrains.kotlinx.spark.api.tuples.*
45-
import org.jetbrains.kotlinx.spark.api.*
46-
import scala.Tuple2
42+
import org.jetbrains.kotlinx.spark.api.SparkSession
4743
import java.io.Serializable
48-
import java.util.*
4944
import kotlin.script.experimental.jvm.util.classpathFromClassloader
5045

5146
class JupyterTests : ShouldSpec({
@@ -269,7 +264,8 @@ class JupyterStreamingTests : ShouldSpec({
269264
context("Jupyter") {
270265
withRepl {
271266

272-
should("Have sscCollection instance") {
267+
// For when onInterrupt is implemented in the Jupyter kernel
268+
xshould("Have sscCollection instance") {
273269

274270
@Language("kts")
275271
val sscCollection = exec("""sscCollection""")
@@ -292,29 +288,46 @@ class JupyterStreamingTests : ShouldSpec({
292288
}
293289
}
294290

295-
should("stream") {
296-
val input = listOf("aaa", "bbb", "aaa", "ccc")
297-
val counter = Counter(0)
291+
xshould("stream") {
298292

299-
withSparkStreaming(Duration(10), timeout = 1000) {
300-
301-
val (counterBroadcast, queue) = withSpark(ssc) {
302-
spark.broadcast(counter) X LinkedList(listOf(sc.parallelize(input)))
303-
}
304-
305-
val inputStream = ssc.queueStream(queue)
306-
307-
inputStream.foreachRDD { rdd, _ ->
308-
withSpark(rdd) {
309-
rdd.toDS().forEach {
310-
it shouldBeIn input
311-
counterBroadcast.value.value++
293+
@Language("kts")
294+
val value = exec(
295+
"""
296+
import java.util.LinkedList
297+
import org.apache.spark.api.java.function.ForeachFunction
298+
import org.apache.spark.util.LongAccumulator
299+
300+
301+
val input = arrayListOf("aaa", "bbb", "aaa", "ccc")
302+
303+
@Volatile
304+
var counter: LongAccumulator? = null
305+
306+
withSparkStreaming(Duration(10), timeout = 1_000) {
307+
308+
val queue = withSpark(ssc) {
309+
LinkedList(listOf(sc.parallelize(input)))
310+
}
311+
312+
val inputStream = ssc.queueStream(queue)
313+
314+
inputStream.foreachRDD { rdd, _ ->
315+
withSpark(rdd) {
316+
if (counter == null)
317+
counter = sc.sc().longAccumulator()
318+
319+
rdd.toDS().showDS().forEach {
320+
if (it !in input) error(it + " should be in input")
321+
counter!!.add(1L)
322+
}
312323
}
313324
}
314325
}
315-
}
326+
counter!!.sum()
327+
""".trimIndent()
328+
) as Long
316329

317-
counter.value shouldBe input.size
330+
value shouldBe 4L
318331
}
319332

320333
}

0 commit comments

Comments
 (0)