Skip to content

Commit 6f2adb3

Browse files
committed
wip: SpGeMM expand phase
1 parent 089e11e commit 6f2adb3

File tree

12 files changed

+437
-367
lines changed

12 files changed

+437
-367
lines changed

src/GraphBLAS-sharp.Backend/Matrix/CSR/Matrix.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ module Matrix =
167167
let pairwise = ClArray.pairwise clContext workGroupSize
168168

169169
let subtract =
170-
ClArray.map clContext workGroupSize Map.pairSubtraction
170+
ClArray.map clContext workGroupSize <@ fun (fst, snd) -> snd - fst @>
171171

172172
fun (processor: MailboxProcessor<_>) (matrix: ClMatrix.CSR<'b>) ->
173173
let pointerPairs =

src/GraphBLAS-sharp.Backend/Matrix/Matrix.fs

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -424,19 +424,18 @@ module Matrix =
424424
| ClMatrix.CSR m1, ClMatrix.CSC m2, ClMatrix.COO mask -> runCSRnCSC queue m1 m2 mask |> ClMatrix.COO
425425
| _ -> failwith "Matrix formats are not matching"
426426

427-
// let expand // TODO()
428-
// (clContext: ClContext)
429-
// workGroupSize
430-
// (opAdd: Expr<'c -> 'c -> 'c option>)
431-
// (opMul: Expr<'a -> 'b -> 'c option>)
432-
// =
433-
//
434-
// let run =
435-
// SpGeMM.Expand.run clContext workGroupSize opAdd opMul
436-
//
437-
// fun (processor: MailboxProcessor<_>) allocationMode (leftMatrix: ClMatrix<'a>) (rightMatrix: ClMatrix<'b>) ->
438-
// match leftMatrix, rightMatrix with
439-
// | ClMatrix.CSR leftMatrix, ClMatrix.CSR rightMatrix ->
440-
// run processor allocationMode leftMatrix rightMatrix
441-
// |> ClMatrix.COO
442-
// | _ -> failwith "Matrix formats are not matching"
427+
let expand
428+
(clContext: ClContext)
429+
workGroupSize
430+
(opAdd: Expr<'c -> 'c -> 'c option>)
431+
(opMul: Expr<'a -> 'b -> 'c option>)
432+
=
433+
434+
let run =
435+
SpGeMM.Expand.run clContext workGroupSize opAdd opMul
436+
437+
fun (processor: MailboxProcessor<_>) allocationMode (leftMatrix: ClMatrix<'a>) (rightMatrix: ClMatrix<'b>) ->
438+
match leftMatrix, rightMatrix with
439+
| ClMatrix.CSR leftMatrix, ClMatrix.CSR rightMatrix ->
440+
ClMatrix.Rows <| run processor allocationMode leftMatrix rightMatrix
441+
| _ -> failwith "Matrix formats are not matching"

src/GraphBLAS-sharp.Backend/Matrix/SpGeMM/Expand.fs

Lines changed: 77 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ open GraphBLAS.FSharp.Backend.Objects.ClCell
1212
open FSharp.Quotations
1313
open GraphBLAS.FSharp.Backend.Vector.Sparse
1414
open GraphBLAS.FSharp.Backend.Objects.ClVector
15+
open GraphBLAS.FSharp.Backend.Objects.ClMatrix
1516

1617
type Indices = ClArray<int>
1718

@@ -70,60 +71,62 @@ module Expand =
7071

7172
let rightMatrixGather = Gather.run clContext workGroupSize
7273

73-
fun (processor: MailboxProcessor<_>) lengths (segmentsPointers: Indices) (leftMatrixRow: ClVector.Sparse<'a>) (rightMatrix: ClMatrix.CSR<'b>) ->
74-
75-
// Compute left matrix positions
76-
let leftMatrixPositions = zeroCreate processor DeviceOnly lengths
74+
fun (processor: MailboxProcessor<_>) length (segmentsPointers: Indices) (leftMatrixRow: ClVector.Sparse<'a>) (rightMatrix: ClMatrix.CSR<'b>) ->
75+
if length = 0 then None
76+
else
77+
printfn "expand length: %A" length
78+
// Compute left matrix positions
79+
let leftMatrixPositions = zeroCreate processor DeviceOnly length
7780

78-
idScatter processor segmentsPointers leftMatrixPositions
81+
idScatter processor segmentsPointers leftMatrixPositions
7982

80-
(maxPrefixSum processor leftMatrixPositions 0)
81-
.Free processor
83+
(maxPrefixSum processor leftMatrixPositions 0)
84+
.Free processor
8285

83-
// Compute right matrix positions
84-
let rightMatrixPositions = create processor DeviceOnly lengths 1
86+
// Compute right matrix positions
87+
let rightMatrixPositions = create processor DeviceOnly length 1
8588

86-
let requiredRightMatrixPointers =
87-
zeroCreate processor DeviceOnly leftMatrixRow.Indices.Length
89+
let requiredRightMatrixPointers =
90+
zeroCreate processor DeviceOnly leftMatrixRow.Indices.Length
8891

89-
gather processor leftMatrixRow.Indices rightMatrix.RowPointers requiredRightMatrixPointers
92+
gather processor leftMatrixRow.Indices rightMatrix.RowPointers requiredRightMatrixPointers
9093

91-
scatter processor segmentsPointers requiredRightMatrixPointers rightMatrixPositions
94+
scatter processor segmentsPointers requiredRightMatrixPointers rightMatrixPositions
9295

93-
requiredRightMatrixPointers.Free processor
96+
requiredRightMatrixPointers.Free processor
9497

95-
// another way to get offsets ???
96-
let offsets =
97-
removeDuplicates processor segmentsPointers
98+
// another way to get offsets ???
99+
let offsets =
100+
removeDuplicates processor segmentsPointers
98101

99-
segmentPrefixSum processor offsets.Length rightMatrixPositions leftMatrixPositions offsets
102+
segmentPrefixSum processor offsets.Length rightMatrixPositions leftMatrixPositions offsets
100103

101-
offsets.Free processor
104+
offsets.Free processor
102105

103-
// compute columns
104-
let columns =
105-
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, lengths)
106+
// compute columns
107+
let columns =
108+
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, length)
106109

107-
gather processor rightMatrixPositions rightMatrix.Columns columns
110+
gather processor rightMatrixPositions rightMatrix.Columns columns
108111

109-
// compute left matrix values
110-
let leftMatrixValues =
111-
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, lengths)
112+
// compute left matrix values
113+
let leftMatrixValues =
114+
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, length)
112115

113-
leftMatrixGather processor leftMatrixPositions leftMatrixRow.Values leftMatrixValues
116+
leftMatrixGather processor leftMatrixPositions leftMatrixRow.Values leftMatrixValues
114117

115-
leftMatrixPositions.Free processor
118+
leftMatrixPositions.Free processor
116119

117-
// compute right matrix values
118-
let rightMatrixValues =
119-
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, lengths)
120+
// compute right matrix values
121+
let rightMatrixValues =
122+
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, length)
120123

121-
rightMatrixGather processor rightMatrixPositions rightMatrix.Values rightMatrixValues
124+
rightMatrixGather processor rightMatrixPositions rightMatrix.Values rightMatrixValues
122125

123-
rightMatrixPositions.Free processor
126+
rightMatrixPositions.Free processor
124127

125-
// left, right matrix values, columns and rows indices
126-
leftMatrixValues, rightMatrixValues, columns
128+
// left, right matrix values, columns indices
129+
Some (leftMatrixValues, rightMatrixValues, columns)
127130

128131
let multiply (clContext: ClContext) workGroupSize (predicate: Expr<'a -> 'b -> 'c option>) =
129132
let getBitmap =
@@ -235,42 +238,48 @@ module Expand =
235238
let length, segmentPointers =
236239
getSegmentPointers processor leftMatrixRow leftMatrixRowsLengths
237240

241+
if length < 0 then failwith "length < 0"
242+
238243
// expand
239-
let leftMatrixValues, rightMatrixValues, columns =
244+
let expandResult =
240245
expand processor length segmentPointers leftMatrixRow rightMatrix
241246

242-
// multiplication
243-
let mulResult =
244-
multiply processor leftMatrixValues rightMatrixValues columns
247+
segmentPointers.Free processor
248+
249+
expandResult
250+
|> Option.bind (fun (leftMatrixValues, rightMatrixValues, columns) ->
251+
// multiplication
252+
let mulResult =
253+
multiply processor leftMatrixValues rightMatrixValues columns
245254

246-
leftMatrixValues.Free processor
247-
rightMatrixValues.Free processor
248-
columns.Free processor
255+
leftMatrixValues.Free processor
256+
rightMatrixValues.Free processor
257+
columns.Free processor
249258

250-
// check multiplication result
251-
mulResult
252-
|> Option.bind (fun (resultValues, resultColumns) ->
253-
// sort
254-
let sortedValues, sortedColumns =
255-
sort processor resultValues resultColumns
259+
// check multiplication result
260+
mulResult
261+
|> Option.bind (fun (resultValues, resultColumns) ->
262+
// sort
263+
let sortedValues, sortedColumns =
264+
sort processor resultValues resultColumns
256265

257-
resultValues.Free processor
258-
resultColumns.Free processor
266+
resultValues.Free processor
267+
resultColumns.Free processor
259268

260-
let reduceResult =
261-
reduce processor allocationMode sortedValues sortedColumns
269+
let reduceResult =
270+
reduce processor allocationMode sortedValues sortedColumns
262271

263-
sortedValues.Free processor
264-
sortedColumns.Free processor
272+
sortedValues.Free processor
273+
sortedColumns.Free processor
265274

266-
// create sparse vector (TODO(empty vector))
267-
reduceResult
268-
|> Option.bind (fun (values, columns) ->
269-
{ Context = clContext
270-
Indices = columns
271-
Values = values
272-
Size = rightMatrix.ColumnCount }
273-
|> Some))
275+
// create sparse vector (TODO(empty vector))
276+
reduceResult
277+
|> Option.bind (fun (values, columns) ->
278+
{ Context = clContext
279+
Indices = columns
280+
Values = values
281+
Size = rightMatrix.ColumnCount }
282+
|> Some)))
274283

275284
let run<'a, 'b, 'c when 'a : struct and 'b : struct and 'c : struct>
276285
(clContext: ClContext)
@@ -296,4 +305,10 @@ module Expand =
296305
split processor allocationMode leftMatrix
297306
|> Seq.map (fun lazyRow -> Option.bind runRow lazyRow.Value)
298307
|> Seq.toArray
308+
|> fun rows ->
309+
{ Rows.Context = clContext
310+
RowCount = leftMatrix.RowCount
311+
ColumnCount = rightMatrix.ColumnCount
312+
Rows = rows
313+
NNZ = -1 } // TODO(nnz count)
299314

src/GraphBLAS-sharp.Backend/Objects/Vector.fs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ module ClVector =
2222

2323
member this.Dispose(q) = (this :> IDeviceMemObject).Dispose(q)
2424

25+
member this.NNZ = this.Values.Length
26+
2527
[<RequireQualifiedAccess>]
2628
type ClVector<'a when 'a: struct> =
2729
| Sparse of ClVector.Sparse<'a>

src/GraphBLAS-sharp.Backend/Quotes/Map.fs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ module Map =
3030

3131
let inc = <@ fun item -> item + 1 @>
3232

33-
let pairSubtraction = <@ fun (first, second) -> first - second @>
34-
3533
let subtraction = <@ fun first second -> first - second @>
3634

3735
let fst () = <@ fun fst _ -> fst @>

src/GraphBLAS-sharp/Objects/MatrixExtensions.fs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,10 @@ module MatrixExtensions =
3939
|> Array.map (Option.bind (fun row -> Some <| row.ToHost q))
4040
NNZ = m.NNZ }
4141
|> Matrix.Rows
42+
43+
member this.ToHostAndDispose(processor: MailboxProcessor<_>) =
44+
let result = this.ToHost processor
45+
46+
this.Dispose processor
47+
48+
result

tests/GraphBLAS-sharp.Tests/Common/ClArray/Pairwise.fs

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,17 @@ let makeTest<'a> isEqual testFun (array: 'a [] ) =
2121

2222
let clArray = context.CreateClArray array
2323

24-
testFun processor HostInterop clArray
25-
|> Option.bind (fun (actual: ClArray<_>) ->
26-
let firstActual, secondActual =
27-
actual.ToHostAndFree processor
28-
|> Array.unzip
24+
match testFun processor HostInterop clArray with
25+
| Some (actual: ClArray<_>) ->
26+
let actual = actual.ToHostAndFree processor
2927

30-
let firstExpected, secondExpected = Array.pairwise array |> Array.unzip
28+
let expected = Array.pairwise array
3129

3230
"First results must be the same"
33-
|> Utils.compareArrays isEqual firstActual firstExpected
34-
35-
"Second results must be the same"
36-
|> Utils.compareArrays isEqual secondActual secondExpected
37-
None)
38-
|> ignore
31+
|> Utils.compareArrays isEqual actual expected
32+
| None ->
33+
"Result must be empty"
34+
|> Expect.isTrue (array.Size <= 1)
3935

4036
let createTest<'a> isEqual =
4137
ClArray.pairwise context Utils.defaultWorkGroupSize

tests/GraphBLAS-sharp.Tests/Generators.fs

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ module Generators =
318318
Arb.generate<CustomDatatypes.WrappedInt>
319319
|> Arb.fromGen
320320

321-
type PairOfSparseVectorAndMatrixOfCompatibleSize() =
321+
type PairOfSparseVectorAndMatrixAndMaskOfCompatibleSize() =
322322
static let pairOfVectorAndMatrixOfCompatibleSizeGenerator (valuesGenerator: Gen<'a>) =
323323
gen {
324324
let! nRows, nColumns = dimension2DGenerator
@@ -376,6 +376,63 @@ module Generators =
376376
|> genericSparseGenerator false Arb.generate<bool>
377377
|> Arb.fromGen
378378

379+
type VectorXMatrix() =
380+
static let pairOfVectorAndMatrixOfCompatibleSizeGenerator (valuesGenerator: Gen<'a>) =
381+
gen {
382+
let! nRows, nColumns = dimension2DGenerator
383+
let! vector = valuesGenerator |> Gen.arrayOfLength nRows
384+
385+
let! matrix =
386+
valuesGenerator
387+
|> Gen.array2DOfDim (nRows, nColumns)
388+
389+
return (vector, matrix)
390+
}
391+
392+
static member IntType() =
393+
pairOfVectorAndMatrixOfCompatibleSizeGenerator
394+
|> genericSparseGenerator 0 Arb.generate<int>
395+
|> Arb.fromGen
396+
397+
static member FloatType() =
398+
pairOfVectorAndMatrixOfCompatibleSizeGenerator
399+
|> genericSparseGenerator
400+
0.
401+
(Arb.Default.NormalFloat()
402+
|> Arb.toGen
403+
|> Gen.map float)
404+
|> Arb.fromGen
405+
406+
static member Float32Type() =
407+
pairOfVectorAndMatrixOfCompatibleSizeGenerator
408+
|> genericSparseGenerator 0.0f (normalFloat32Generator <| System.Random())
409+
|> Arb.fromGen
410+
411+
static member SByteType() =
412+
pairOfVectorAndMatrixOfCompatibleSizeGenerator
413+
|> genericSparseGenerator 0y Arb.generate<sbyte>
414+
|> Arb.fromGen
415+
416+
static member ByteType() =
417+
pairOfVectorAndMatrixOfCompatibleSizeGenerator
418+
|> genericSparseGenerator 0uy Arb.generate<byte>
419+
|> Arb.fromGen
420+
421+
static member Int16Type() =
422+
pairOfVectorAndMatrixOfCompatibleSizeGenerator
423+
|> genericSparseGenerator 0s Arb.generate<int16>
424+
|> Arb.fromGen
425+
426+
static member UInt16Type() =
427+
pairOfVectorAndMatrixOfCompatibleSizeGenerator
428+
|> genericSparseGenerator 0us Arb.generate<uint16>
429+
|> Arb.fromGen
430+
431+
static member BoolType() =
432+
pairOfVectorAndMatrixOfCompatibleSizeGenerator
433+
|> genericSparseGenerator false Arb.generate<bool>
434+
|> Arb.fromGen
435+
379436
type PairOfMatricesOfCompatibleSize() =
380437
static let pairOfMatricesOfCompatibleSizeGenerator (valuesGenerator: Gen<'a>) =
381438
gen {

tests/GraphBLAS-sharp.Tests/Helpers.fs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ module Utils =
2525
typeof<Generators.PairOfSparseMatricesOfEqualSize>
2626
typeof<Generators.PairOfMatricesOfCompatibleSize>
2727
typeof<Generators.PairOfSparseMatrixAndVectorsCompatibleSize>
28-
typeof<Generators.PairOfSparseVectorAndMatrixOfCompatibleSize>
28+
typeof<Generators.PairOfSparseVectorAndMatrixAndMaskOfCompatibleSize>
2929
typeof<Generators.ArrayOfDistinctKeys2D>
3030
typeof<Generators.ArrayOfAscendingKeys>
3131
typeof<Generators.BufferCompatibleArray>
@@ -150,11 +150,6 @@ module Utils =
150150

151151
result
152152

153-
let castMatrixToCSR =
154-
function
155-
| Matrix.CSR matrix -> matrix
156-
| _ -> failwith "matrix format must be CSR"
157-
158153
module HostPrimitives =
159154
let prefixSumInclude zero add array =
160155
Array.scan add zero array

0 commit comments

Comments
 (0)