Skip to content

Commit 40ef056

Browse files
authored
[SETL-197] Mermaid diagram with DataFrame as FactoryOutput (#202)
* fix: mermaid diagram for DataFrame FactoryOutput * fix: removed unused import * Update FactoryOutput.scala Co-authored-by: Qinx <17144939+qxzzxq@users.noreply.github.com> close #197
1 parent 12bbb5a commit 40ef056

File tree

2 files changed

+40
-7
lines changed

2 files changed

+40
-7
lines changed

src/main/scala/io/github/setl/transformation/FactoryOutput.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,15 @@ private[setl] case class FactoryOutput(override val runtimeType: runtime.univers
3232
}
3333

3434
if (isDataset) {
35-
val datasetTypeArgFields = super.getTypeArgList(this.runtimeType.typeArgs.head)
36-
datasetTypeArgFields.map {
37-
i => s">${i.name}: ${ReflectUtils.getPrettyName(i.typeSignature)}"
35+
if (this.runtimeType.typeArgs.isEmpty) {
36+
// DataFrame
37+
List.empty
38+
} else {
39+
// Dataset
40+
val datasetTypeArgFields = super.getTypeArgList(this.runtimeType.typeArgs.head)
41+
datasetTypeArgFields.map {
42+
i => s">${i.name}: ${ReflectUtils.getPrettyName(i.typeSignature)}"
43+
}
3844
}
3945

4046
} else if (isCaseClass) {

src/test/scala/io/github/setl/workflow/PipelineSuite.scala

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import io.github.setl.storage.SparkRepositoryBuilder
88
import io.github.setl.storage.connector.FileConnector
99
import io.github.setl.transformation.{Deliverable, Factory}
1010
import io.github.setl.workflow.DeliverableDispatcherSuite.FactoryWithMultipleAutoLoad
11-
import org.apache.spark.sql.{Dataset, SparkSession, functions}
11+
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession, functions}
1212
import org.scalatest.funsuite.AnyFunSuite
1313
import org.scalatest.matchers.should.Matchers
1414

@@ -595,6 +595,20 @@ class PipelineSuite extends AnyFunSuite with Matchers {
595595
fac.floatArray should contain theSameElementsAs fltAry
596596
}
597597

598+
test("SETL-197: Mermaid diagram should be shown even when the factory output is a DataFrame") {
599+
val spark = new SparkSessionBuilder("test").setEnv("local").setSparkMaster("local").getOrCreate()
600+
601+
new Pipeline()
602+
.setInput[String]("id_of_product1", classOf[ProductFactory])
603+
.setInput[String]("dataframe", classOf[DataFrameFactory])
604+
.addStage[ProductFactory]()
605+
.addStage[DatasetFactory](Array(spark))
606+
.addStage[DataFrameFactory](Array(spark))
607+
.addStage[DatasetFactory4](Array(spark))
608+
.run()
609+
.showDiagram()
610+
}
611+
598612
}
599613

600614
object PipelineSuite {
@@ -857,15 +871,28 @@ object PipelineSuite {
857871

858872
override def read(): DatasetFactory4.this.type = this
859873

860-
override def process(): DatasetFactory4.this.type = {
861-
this
862-
}
874+
override def process(): DatasetFactory4.this.type = this
863875

864876
override def write(): DatasetFactory4.this.type = this
865877

866878
override def get(): Long = ds1.count()
867879
}
868880

881+
class DataFrameFactory(spark: SparkSession) extends Factory[DataFrame] {
882+
import spark.implicits._
883+
884+
@Delivery
885+
var input: String = null
886+
887+
override def read(): DataFrameFactory.this.type = this
888+
889+
override def process(): DataFrameFactory.this.type = this
890+
891+
override def write(): DataFrameFactory.this.type = this
892+
893+
override def get(): DataFrame = Seq(input).toDF("column1")
894+
}
895+
869896
class PrimaryDeliveryFactory extends Factory[String] {
870897

871898
@Delivery(id = "byte")

0 commit comments

Comments
 (0)