@@ -300,4 +300,54 @@ default boolean detectCycle(){
300300 }
301301 return false ;
302302 }
303+
304+ /**
305+ * 查找最小生成树的辅助类
306+ */
307+ class CostEdge extends Edge {
308+ Integer cost ;
309+
310+ public CostEdge (int source , int target , int cost ) {
311+ super (source , target );
312+ this .cost = cost ;
313+ }
314+
315+ @ Override
316+ public int compareTo (Edge o ) {
317+ return cost - ((CostEdge )o ).cost ;
318+ }
319+ }
320+
321+ /**
322+ * 最小生成树。参考graph.md
323+ * @param costFinder
324+ * @return
325+ */
326+ default LinkedList <Edge > minimumSpanningTree (OneArgumentExpression <E ,Integer > costFinder ){
327+ if (!isUndirected ()){
328+ throw new IllegalStateException ("Spanning tree only applicable to undirected trees" );
329+ }
330+ LinkedList <Edge > subGraph = new LinkedList <>();
331+ PriorityQueue <CostEdge > edgeQueue = new PriorityQueue <>((x , y )->x .compareTo (y ));
332+ UnionFind <Integer > unionFind = new UnionFind <>();
333+
334+ //将边按照权重加入到队列中。PriorityQueue优先返回最小值。
335+ this .visitAllConnectedEdges (getAllVertices ().getRoot ().getValue (),
336+ (s ,t ,v )-> edgeQueue .add (new CostEdge (s ,t ,costFinder .compute (v ))), TraversalType .DFT );
337+
338+ //将所有节点加入到unionFind中。
339+ this .getAllVertices ().traversePreOrderNonRecursive ((x )-> unionFind .add (x ));
340+
341+ //优先取权重最小的边,查看节点是否相连,未连接则进行连接。
342+ while (unionFind .getPartitionCount ()>1 && edgeQueue .peek ()!=null ){
343+ Edge e = edgeQueue .poll ();
344+ int sGroup = unionFind .find (e .source );
345+ int tGroup = unionFind .find (e .target );
346+ if (sGroup !=tGroup ){
347+ subGraph .appendLast (e );
348+ unionFind .union (e .source , e .target );
349+ }
350+ }
351+ return subGraph ;
352+ }
303353}
0 commit comments