Skip to content
Prev Previous commit
Next Next commit
add: Radix sort
  • Loading branch information
IgorErin committed Mar 20, 2023
commit 5c6143e516f7534c88f0ab0689b66f6de9b8022b
189 changes: 146 additions & 43 deletions src/GraphBLAS-sharp.Backend/Common/Sort/Radix.fs
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,26 @@ open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
type Indices = ClArray<int>

module Radix =
let defaultBitCount = 4

let localPrefixSum =
<@ fun (lid: int) (workGroupSize: int) (array: int []) ->
let mutable offset = 1

while offset < workGroupSize do
barrierLocal ()
let mutable value = array.[lid]

if lid >= offset then value <- value + array.[lid - offset]

offset <- offset * 2

barrierLocal ()
array.[lid] <- value
barrierLocal () @>
array.[lid] <- value @>

let count (clContext: ClContext) workGroupSize mask bitCount =
let count (clContext: ClContext) workGroupSize mask =

let bitCount = mask + 1

let kernel =
<@ fun (ndRange: Range1D) length (indices: Indices) (workGroupCount: ClCell<int>) (shift: ClCell<int>) (globalOffsets: Indices) (localOffsets: Indices) ->
Expand All @@ -34,52 +38,48 @@ module Radix =

let position = (indices.[gid] >>> shift.Value) &&& mask

if gid < length then printf "position %i for lid = %i" position lid

let localMask = localArray<int> workGroupSize

if gid < length then localMask.[lid] <- position else localMask.[lid] <- 0

if gid < length then
printf "local mask value = %i for lid = %i" localMask.[lid] lid
if gid < length
then localMask.[lid] <- position
else localMask.[lid] <- 0

let localPositions = localArray<int> workGroupSize

for currentBit in 0 .. bitCount - 1 do
let isCurrentPosition = if localMask.[lid] = currentBit then 1 else 0
if gid < length then printf "is current position %i for lid = %i, localMask of i = %i, currentBit = %i" isCurrentPosition lid localMask.[lid] currentBit
let isCurrentPosition = localMask.[lid] = currentBit

localPositions.[lid] <- if isCurrentPosition = 1 && gid < length then 1 else 0
if isCurrentPosition && gid < length
then localPositions.[lid] <- 1
else localPositions.[lid] <- 0

barrierLocal ()

(%localPrefixSum) lid workGroupSize localPositions

if gid < length && isCurrentPosition = 1 then
barrierLocal ()

if gid < length && isCurrentPosition then
localOffsets.[gid] <- localPositions.[lid] - 1

if lid = 0 then
let processedItemsCount = localPositions.[workGroupSize - 1]
printf "%i processed items count" processedItemsCount
let workGroupNumber = gid / workGroupSize
let wgId = gid / workGroupSize

globalOffsets.[position * workGroupCount.Value + workGroupNumber] <- processedItemsCount @>
globalOffsets.[workGroupCount.Value * currentBit + wgId] <- processedItemsCount @>

let kernel = clContext.Compile kernel
printfn $"code: {kernel.Code}"

fun (processor: MailboxProcessor<_>) (indices: Indices) (clWorkGroupCount: ClCell<int>) (shift: ClCell<int>) ->
let ndRange = Range1D.CreateValid(indices.Length, workGroupSize)

let workGroupCount = (indices.Length - 1) / workGroupSize + 1

let globalOffsetsLength = (pown 2 bitCount) * workGroupCount
let globalOffsetsLength = bitCount * workGroupCount

let globalOffsets =
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, globalOffsetsLength)

printfn "local offset length = %d" indices.Length

let localOffsets =
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, indices.Length)

Expand All @@ -92,7 +92,7 @@ module Radix =

globalOffsets, localOffsets

let scatter (clContext: ClContext) workGroupSize mask bitCount =
let scatter (clContext: ClContext) workGroupSize mask =

let kernel =
<@ fun (ndRange: Range1D) length (keys: Indices) (shift: ClCell<int>) (workGroupCount: ClCell<int>) (globalOffsets: Indices) (localOffsets: Indices) (result: ClArray<int>) ->
Expand All @@ -110,8 +110,7 @@ module Radix =

let offset = globalOffset + localOffset

result.[offset] <- keys.[gid]
shift.Value <- shift.Value <<< bitCount @>
result.[offset] <- keys.[gid] @>

let kernel = clContext.Compile kernel

Expand All @@ -126,42 +125,146 @@ module Radix =

processor.Post(Msg.CreateRunMsg<_, _>(kernel))

let run (clContext: ClContext) workGroupSize =
let run (clContext: ClContext) workGroupSize bitCount =
let copy = ClArray.copy clContext workGroupSize

let bitCount = 2
let mask = (pown 2 bitCount) - 1 // TODO()
let mask = (pown 2 bitCount) - 1

let count = count clContext workGroupSize mask bitCount
let count = count clContext workGroupSize mask

let prefixSum = PrefixSum.standardExcludeInplace clContext workGroupSize

let scatter = scatter clContext workGroupSize mask bitCount
let scatter = scatter clContext workGroupSize mask

fun (processor: MailboxProcessor<_>) (keys: Indices) ->
let firstKeys = copy processor DeviceOnly keys

let secondKeys =
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, keys.Length)

let workGroupCount = clContext.CreateClCell((keys.Length - 1) / workGroupSize + 1)

let mutable pair = (firstKeys, secondKeys)
let swap (x, y) = y, x

if keys.Length <= 1 then
keys
else
for i in 0 .. 15 do // TODO()
let shift = clContext.CreateClCell(bitCount * i)

// printfn "keys: %A" <| (fst pair).ToHost processor
// printfn "shift: %i" <| shift.ToHost processor

let globalOffset, localOffset = count processor (fst pair) workGroupCount shift

// printfn "globalOffset: %A" <| globalOffset.ToHost processor
// printfn "localOffset: %A" <| localOffset.ToHost processor

(prefixSum processor globalOffset).Free processor
// printfn "globalOffset after prefix sum: %A" <| globalOffset.ToHost processor

scatter processor (fst pair) shift workGroupCount globalOffset localOffset (snd pair)

pair <- swap pair

// printfn "secondKeys: %A" <| secondKeys.ToHost processor

globalOffset.Free processor
localOffset.Free processor
shift.Free processor

//printfn "result keys: %A" <| (snd pair).ToHost processor
fst pair

let standardRun clContext workGroupSize = run clContext workGroupSize defaultBitCount

let scatter1D (clContext: ClContext) workGroupSize mask =

let kernel =
<@ fun (ndRange: Range1D) length (keys: Indices) (values: ClArray<'a>) (shift: ClCell<int>) (workGroupCount: ClCell<int>) (globalOffsets: Indices) (localOffsets: Indices) (resultKeys: ClArray<int>) (resultValues: ClArray<'a>) ->

let gid = ndRange.GlobalID0
let wgId = gid / workGroupSize

let workGroupCount = workGroupCount.Value

let secondKeys = clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, keys.Length)
let workGroupCount = clContext.CreateClCell((keys.Length - 1) / workGroupSize - 1)
let shift = clContext.CreateClCell 0
if gid < length then
let slot = (keys.[gid] >>> shift.Value) &&& mask

let localOffset = localOffsets.[gid]
let globalOffset = globalOffsets.[workGroupCount * slot + wgId]

let offset = globalOffset + localOffset

resultKeys.[offset] <- keys.[gid]
resultValues.[offset] <- values.[gid] @>

let kernel = clContext.Compile kernel

fun (processor: MailboxProcessor<_>) (keys: Indices) (values: ClArray<'a>) (shift: ClCell<int>) (workGroupCount: ClCell<int>) (globalOffset: Indices) (localOffsets: Indices) (resultKeys: ClArray<int>) (resultValues: ClArray<'a>) ->

let ndRange =
Range1D.CreateValid(keys.Length, workGroupSize)

let kernel = kernel.GetKernel()

processor.Post(Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange keys.Length keys values shift workGroupCount globalOffset localOffsets resultKeys resultValues))

processor.Post(Msg.CreateRunMsg<_, _>(kernel))

let run1DInplace (clContext: ClContext) workGroupSize bitCount =
let copy = ClArray.copy clContext workGroupSize

let dataCopy = ClArray.copy clContext workGroupSize

let mask = (pown 2 bitCount) - 1

let count = count clContext workGroupSize mask

let prefixSum = PrefixSum.standardExcludeInplace clContext workGroupSize

let scatter1D = scatter1D clContext workGroupSize mask

fun (processor: MailboxProcessor<_>) (keys: Indices) (values: ClArray<'a>) ->
let firstKeys = copy processor DeviceOnly keys

let secondKeys =
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, keys.Length)

let secondValues = dataCopy processor DeviceOnly values

let workGroupCount = clContext.CreateClCell((keys.Length - 1) / workGroupSize + 1)

let mutable keysPair = (firstKeys, secondKeys)
let mutable valuesPair = (values, secondValues)

let mutable pair = (keys, secondKeys)
let swap (x, y) = y, x

//for i in 0 .. 4 do
printfn "keys: %A" <| keys.ToHost processor
if keys.Length <= 1 then
keys, values
else
for i in 0 .. 15 do
let shift = clContext.CreateClCell(bitCount * i)

let currentKeys = fst keysPair
let resultKeysBuffer = snd keysPair

let globalOffset, localOffset = count processor (fst pair) workGroupCount shift
let currentValues = fst valuesPair
let resultValuesBuffer = snd valuesPair

printfn "globalOffset: %A" <| globalOffset.ToHost processor
printfn "localOffset: %A" <| localOffset.ToHost processor
let globalOffset, localOffset = count processor currentKeys workGroupCount shift

(prefixSum processor globalOffset).Free processor
(prefixSum processor globalOffset).Free processor

scatter processor (fst pair) shift workGroupCount globalOffset localOffset (snd pair)
scatter1D processor currentKeys currentValues shift workGroupCount globalOffset localOffset resultKeysBuffer resultValuesBuffer

//pair <- swap pair
keysPair <- swap keysPair
valuesPair <- swap valuesPair

globalOffset.Free processor
localOffset.Free processor
localOffset.Free processor
shift.Free processor

keys
(fst keysPair), (fst valuesPair)

let run1DInplaceStandard clContext workGroupSize = run1DInplace clContext workGroupSize defaultBitCount
3 changes: 1 addition & 2 deletions src/GraphBLAS-sharp.Backend/Objects/ArraysExtentions.fs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ module ArraysExtensions =
q.PostAndReply(fun ch -> Msg.CreateToHostMsg(this, dst, ch))

member this.ToHostAndFree(q: MailboxProcessor<_>) =
let dst = Array.zeroCreate this.Length
let result = q.PostAndReply(fun ch -> Msg.CreateToHostMsg(this, dst, ch))
let result = this.ToHost q
this.Free q

result
Expand Down
14 changes: 7 additions & 7 deletions src/GraphBLAS-sharp.Backend/Objects/ClCell.fs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ open Brahma.FSharp

module ClCell =
type ClCell<'a> with
member this.ToHostAndFree(processor: MailboxProcessor<_>) =
let res =
processor.PostAndReply(fun ch -> Msg.CreateToHostMsg<_>(this, (Array.zeroCreate<'a> 1), ch))

processor.Post(Msg.CreateFreeMsg<_>(this))

res.[0]
member this.ToHost(processor: MailboxProcessor<_>) =
processor.PostAndReply(fun ch -> Msg.CreateToHostMsg<_>(this, (Array.zeroCreate<'a> 1), ch)).[0]

member this.Free(processor: MailboxProcessor<_>) = processor.Post(Msg.CreateFreeMsg<_>(this))

member this.ToHostAndFree(processor: MailboxProcessor<_>) =
let result = this.ToHost processor
this.Free processor

result
13 changes: 4 additions & 9 deletions tests/GraphBLAS-sharp.Tests/Common/BitonicSort.fs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ open GraphBLAS.FSharp.Backend.Common
open Brahma.FSharp
open GraphBLAS.FSharp.Tests
open GraphBLAS.FSharp.Tests.Context
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions

let logger = Log.create "BitonicSort.Tests"

Expand Down Expand Up @@ -38,15 +39,9 @@ let makeTest sort (array: ('n * 'n * 'a) []) =
let actualRows, actualCols, actualValues =
sort q clRows clColumns clValues

let rows = Array.zeroCreate<'n> clRows.Length
let columns = Array.zeroCreate<'n> clColumns.Length
let values = Array.zeroCreate<'a> clValues.Length

q.Post(Msg.CreateToHostMsg(clRows, rows))
q.Post(Msg.CreateToHostMsg(clColumns, columns))

q.PostAndReply(fun ch -> Msg.CreateToHostMsg(clValues, values, ch))
|> ignore
let rows = clRows.ToHostAndFree q
let columns = clColumns.ToHostAndFree q
let values = clValues.ToHostAndFree q

rows, columns, values

Expand Down
Loading