@@ -12,6 +12,7 @@ open GraphBLAS.FSharp.Backend.Objects.ClCell
1212open FSharp.Quotations
1313open GraphBLAS.FSharp .Backend .Vector .Sparse
1414open GraphBLAS.FSharp .Backend .Objects .ClVector
15+ open GraphBLAS.FSharp .Backend .Objects .ClMatrix
1516
1617type Indices = ClArray< int>
1718
@@ -70,60 +71,62 @@ module Expand =
7071
7172 let rightMatrixGather = Gather.run clContext workGroupSize
7273
73- fun ( processor : MailboxProcessor < _ >) lengths ( segmentsPointers : Indices ) ( leftMatrixRow : ClVector.Sparse < 'a >) ( rightMatrix : ClMatrix.CSR < 'b >) ->
74-
75- // Compute left matrix positions
76- let leftMatrixPositions = zeroCreate processor DeviceOnly lengths
74+ fun ( processor : MailboxProcessor < _ >) length ( segmentsPointers : Indices ) ( leftMatrixRow : ClVector.Sparse < 'a >) ( rightMatrix : ClMatrix.CSR < 'b >) ->
75+ if length = 0 then None
76+ else
77+ printfn " expand length: %A " length
78+ // Compute left matrix positions
79+ let leftMatrixPositions = zeroCreate processor DeviceOnly length
7780
78- idScatter processor segmentsPointers leftMatrixPositions
81+ idScatter processor segmentsPointers leftMatrixPositions
7982
80- ( maxPrefixSum processor leftMatrixPositions 0 )
81- .Free processor
83+ ( maxPrefixSum processor leftMatrixPositions 0 )
84+ .Free processor
8285
83- // Compute right matrix positions
84- let rightMatrixPositions = create processor DeviceOnly lengths 1
86+ // Compute right matrix positions
87+ let rightMatrixPositions = create processor DeviceOnly length 1
8588
86- let requiredRightMatrixPointers =
87- zeroCreate processor DeviceOnly leftMatrixRow.Indices.Length
89+ let requiredRightMatrixPointers =
90+ zeroCreate processor DeviceOnly leftMatrixRow.Indices.Length
8891
89- gather processor leftMatrixRow.Indices rightMatrix.RowPointers requiredRightMatrixPointers
92+ gather processor leftMatrixRow.Indices rightMatrix.RowPointers requiredRightMatrixPointers
9093
91- scatter processor segmentsPointers requiredRightMatrixPointers rightMatrixPositions
94+ scatter processor segmentsPointers requiredRightMatrixPointers rightMatrixPositions
9295
93- requiredRightMatrixPointers.Free processor
96+ requiredRightMatrixPointers.Free processor
9497
95- // another way to get offsets ???
96- let offsets =
97- removeDuplicates processor segmentsPointers
98+ // another way to get offsets ???
99+ let offsets =
100+ removeDuplicates processor segmentsPointers
98101
99- segmentPrefixSum processor offsets.Length rightMatrixPositions leftMatrixPositions offsets
102+ segmentPrefixSum processor offsets.Length rightMatrixPositions leftMatrixPositions offsets
100103
101- offsets.Free processor
104+ offsets.Free processor
102105
103- // compute columns
104- let columns =
105- clContext.CreateClArrayWithSpecificAllocationMode( DeviceOnly, lengths )
106+ // compute columns
107+ let columns =
108+ clContext.CreateClArrayWithSpecificAllocationMode( DeviceOnly, length )
106109
107- gather processor rightMatrixPositions rightMatrix.Columns columns
110+ gather processor rightMatrixPositions rightMatrix.Columns columns
108111
109- // compute left matrix values
110- let leftMatrixValues =
111- clContext.CreateClArrayWithSpecificAllocationMode( DeviceOnly, lengths )
112+ // compute left matrix values
113+ let leftMatrixValues =
114+ clContext.CreateClArrayWithSpecificAllocationMode( DeviceOnly, length )
112115
113- leftMatrixGather processor leftMatrixPositions leftMatrixRow.Values leftMatrixValues
116+ leftMatrixGather processor leftMatrixPositions leftMatrixRow.Values leftMatrixValues
114117
115- leftMatrixPositions.Free processor
118+ leftMatrixPositions.Free processor
116119
117- // compute right matrix values
118- let rightMatrixValues =
119- clContext.CreateClArrayWithSpecificAllocationMode( DeviceOnly, lengths )
120+ // compute right matrix values
121+ let rightMatrixValues =
122+ clContext.CreateClArrayWithSpecificAllocationMode( DeviceOnly, length )
120123
121- rightMatrixGather processor rightMatrixPositions rightMatrix.Values rightMatrixValues
124+ rightMatrixGather processor rightMatrixPositions rightMatrix.Values rightMatrixValues
122125
123- rightMatrixPositions.Free processor
126+ rightMatrixPositions.Free processor
124127
125- // left, right matrix values, columns and rows indices
126- leftMatrixValues, rightMatrixValues, columns
128+ // left, right matrix values, columns indices
129+ Some ( leftMatrixValues, rightMatrixValues, columns)
127130
128131 let multiply ( clContext : ClContext ) workGroupSize ( predicate : Expr < 'a -> 'b -> 'c option >) =
129132 let getBitmap =
@@ -235,42 +238,48 @@ module Expand =
235238 let length , segmentPointers =
236239 getSegmentPointers processor leftMatrixRow leftMatrixRowsLengths
237240
241+ if length < 0 then failwith " length < 0"
242+
238243 // expand
239- let leftMatrixValues , rightMatrixValues , columns =
244+ let expandResult =
240245 expand processor length segmentPointers leftMatrixRow rightMatrix
241246
242- // multiplication
243- let mulResult =
244- multiply processor leftMatrixValues rightMatrixValues columns
247+ segmentPointers.Free processor
248+
249+ expandResult
250+ |> Option.bind ( fun ( leftMatrixValues , rightMatrixValues , columns ) ->
251+ // multiplication
252+ let mulResult =
253+ multiply processor leftMatrixValues rightMatrixValues columns
245254
246- leftMatrixValues.Free processor
247- rightMatrixValues.Free processor
248- columns.Free processor
255+ leftMatrixValues.Free processor
256+ rightMatrixValues.Free processor
257+ columns.Free processor
249258
250- // check multiplication result
251- mulResult
252- |> Option.bind ( fun ( resultValues , resultColumns ) ->
253- // sort
254- let sortedValues , sortedColumns =
255- sort processor resultValues resultColumns
259+ // check multiplication result
260+ mulResult
261+ |> Option.bind ( fun ( resultValues , resultColumns ) ->
262+ // sort
263+ let sortedValues , sortedColumns =
264+ sort processor resultValues resultColumns
256265
257- resultValues.Free processor
258- resultColumns.Free processor
266+ resultValues.Free processor
267+ resultColumns.Free processor
259268
260- let reduceResult =
261- reduce processor allocationMode sortedValues sortedColumns
269+ let reduceResult =
270+ reduce processor allocationMode sortedValues sortedColumns
262271
263- sortedValues.Free processor
264- sortedColumns.Free processor
272+ sortedValues.Free processor
273+ sortedColumns.Free processor
265274
266- // create sparse vector (TODO(empty vector))
267- reduceResult
268- |> Option.bind ( fun ( values , columns ) ->
269- { Context = clContext
270- Indices = columns
271- Values = values
272- Size = rightMatrix.ColumnCount }
273- |> Some))
275+ // create sparse vector (TODO(empty vector))
276+ reduceResult
277+ |> Option.bind ( fun ( values , columns ) ->
278+ { Context = clContext
279+ Indices = columns
280+ Values = values
281+ Size = rightMatrix.ColumnCount }
282+ |> Some) ))
274283
275284 let run < 'a , 'b , 'c when 'a : struct and 'b : struct and 'c : struct >
276285 ( clContext : ClContext )
@@ -296,4 +305,10 @@ module Expand =
296305 split processor allocationMode leftMatrix
297306 |> Seq.map ( fun lazyRow -> Option.bind runRow lazyRow.Value)
298307 |> Seq.toArray
308+ |> fun rows ->
309+ { Rows.Context = clContext
310+ RowCount = leftMatrix.RowCount
311+ ColumnCount = rightMatrix.ColumnCount
312+ Rows = rows
313+ NNZ = - 1 } // TODO(nnz count)
299314
0 commit comments