@@ -8,7 +8,7 @@ import io.github.setl.storage.SparkRepositoryBuilder
88import io .github .setl .storage .connector .FileConnector
99import io .github .setl .transformation .{Deliverable , Factory }
1010import io .github .setl .workflow .DeliverableDispatcherSuite .FactoryWithMultipleAutoLoad
11- import org .apache .spark .sql .{Dataset , SparkSession , functions }
11+ import org .apache .spark .sql .{DataFrame , Dataset , SparkSession , functions }
1212import org .scalatest .funsuite .AnyFunSuite
1313import org .scalatest .matchers .should .Matchers
1414
@@ -595,6 +595,20 @@ class PipelineSuite extends AnyFunSuite with Matchers {
595595 fac.floatArray should contain theSameElementsAs fltAry
596596 }
597597
598+ test(" SETL-197: Mermaid diagram should be shown even when the factory output is a DataFrame" ) {
599+ val spark = new SparkSessionBuilder (" test" ).setEnv(" local" ).setSparkMaster(" local" ).getOrCreate()
600+
601+ new Pipeline ()
602+ .setInput[String ](" id_of_product1" , classOf [ProductFactory ])
603+ .setInput[String ](" dataframe" , classOf [DataFrameFactory ])
604+ .addStage[ProductFactory ]()
605+ .addStage[DatasetFactory ](Array (spark))
606+ .addStage[DataFrameFactory ](Array (spark))
607+ .addStage[DatasetFactory4 ](Array (spark))
608+ .run()
609+ .showDiagram()
610+ }
611+
598612}
599613
600614object PipelineSuite {
@@ -857,15 +871,28 @@ object PipelineSuite {
857871
858872 override def read (): DatasetFactory4 .this .type = this
859873
860- override def process (): DatasetFactory4 .this .type = {
861- this
862- }
874+ override def process (): DatasetFactory4 .this .type = this
863875
864876 override def write (): DatasetFactory4 .this .type = this
865877
866878 override def get (): Long = ds1.count()
867879 }
868880
881+ class DataFrameFactory (spark : SparkSession ) extends Factory [DataFrame ] {
882+ import spark .implicits ._
883+
884+ @ Delivery
885+ var input : String = null
886+
887+ override def read (): DataFrameFactory .this .type = this
888+
889+ override def process (): DataFrameFactory .this .type = this
890+
891+ override def write (): DataFrameFactory .this .type = this
892+
893+ override def get (): DataFrame = Seq (input).toDF(" column1" )
894+ }
895+
869896 class PrimaryDeliveryFactory extends Factory [String ] {
870897
871898 @ Delivery (id = " byte" )
0 commit comments