ComputeRadixSort

@three-blocks/coreWebGPU
new ComputeRadixSort(keysBuffer : StorageBufferNode, options : Object)

GPU-accelerated stable radix sort using the Blelloch scan algorithm.

ComputeRadixSort provides an O(n) stable sorting algorithm that runs entirely on the GPU. It uses a configurable radix (1, 2, or 4 bits per pass) with workgroup-local prefix sums for optimal performance.

Algorithm

For each N-bit pass:

  1. Phase 1 - Local Histogram + Prefix Sum: Each workgroup counts digit occurrences and computes local prefix sums using shared memory
  2. Phase 2 - Global Prefix Sum: Compute prefix sum on workgroup totals
  3. Phase 3 - Scatter: Write elements to sorted positions using deterministic offsets

Radix Bits Configuration

Bits Buckets Passes Memory Best For
1 2 32 Lowest Small datasets, low memory
2 4 16 Low General use (default)
4 16 8 Medium Large datasets, fewer passes

Stability

The algorithm is inherently stable because:

  • Local prefix sums preserve element order within each workgroup
  • Global prefix sums preserve workgroup order
  • No atomic operations in scatter phase (deterministic positions)

Usage

Constructor Parameters
keysBufferStorageBufferNode
The storage buffer containing uint32 keys to sort.
optionsoptionalObject
Configuration options.
Default is {}.
  • valuesoptionalStorageBufferNode
    Optional buffer of values to sort alongside keys.
  • workgroupSizeoptionalnumber
    The workgroup size for compute shaders.
    Default is 256.
  • radixBitsoptionalnumber
    Bits per sorting pass (1, 2, or 4). Higher values mean fewer passes but more memory.
    Default is 2.
  • keyBitsoptionalnumber
    Effective key width in bits (16 or 32). Use 16 for faster sorting with reduced precision.
    Default is 32.
Example
import { instancedArray } from 'three/tsl';
import { ComputeRadixSort } from '@three-blocks/core';

// Create a buffer of uint keys to sort
const keysBuffer = instancedArray(new Uint32Array([5, 2, 8, 1, 9, 3]), 'uint');

// Optional: values buffer to sort alongside keys
const valuesBuffer = instancedArray(new Uint32Array([50, 20, 80, 10, 90, 30]), 'uint');

// Initialize the sorter (default 2-bit radix)
const sorter = new ComputeRadixSort(keysBuffer, { values: valuesBuffer });

// Or use 4-bit radix for fewer passes on large datasets
const fastSorter = new ComputeRadixSort(keysBuffer, { radixBits: 4 });

// In your render loop:
sorter.compute(renderer);
// Result: keys = [1, 2, 3, 5, 8, 9], values = [10, 20, 30, 50, 80, 90]

Properties

.keysBuffer : StorageBufferNode

The storage buffer containing keys to sort.

.valuesBuffer : StorageBufferNode|null

Optional values buffer to sort alongside keys.

.count : number

Number of elements to sort.

.radixBits : number

Number of radix bits per pass (1, 2, or 4).

.keyBits : number

Effective key width in bits (16 or 32). Use 16 for faster sorting with reduced precision (halves number of passes).

.bucketCount : number

Number of buckets (2^radixBits).

.bitMask : number

Bit mask for extracting digits (bucketCount - 1).

.passCount : number

Number of passes (keyBits / radixBits). With keyBits=16 and radixBits=2, this is 8 passes instead of 16.

.workgroupSize : number

Threads per workgroup.

.elementsPerWorkgroup : number

Elements processed per workgroup (2 per thread).

.workgroupCount : number

Number of workgroups.

.bitOffsetUniform : UniformNode

Uniform for current bit offset.

.initialized : boolean

Whether initialized.

.readBufferName : string

Current read buffer name ('Keys' or 'Temp').

.tempKeysBuffer : StorageBufferNode

Temporary keys buffer for ping-pong.

.tempValuesBuffer : StorageBufferNode|null

Temporary values buffer for ping-pong.

.localPrefixBuffer : StorageBufferNode

Per-element local prefix sums (one uint per bucket per element). Layout: element[i] bucket[d] -> index i * bucketCount + d

.digitCountsBuffer : StorageBufferNode

Consolidated per-workgroup digit counts for all buckets. Layout: digit[d] workgroup[wg] -> index d * workgroupCount + wg After prefix sum phase, contains cumulative counts.

.digitTotalsBuffer : StorageBufferNode

Global digit totals and base positions (bucketCount * 2 elements). First half [0..bucketCount-1]: total count for each digit Second half [bucketCount..2*bucketCount-1]: base position for each digit

Methods

init#

init(renderer : WebGPURenderer)

Initializes the sorter for the given renderer.

Parameters
rendererWebGPURenderer
The Three.js WebGPU renderer.

computePass#

computePass(renderer : WebGPURenderer, bitOffset : number)

Executes a single sorting pass.

Parameters
rendererWebGPURenderer
The Three.js WebGPU renderer.
bitOffsetnumber
The bit offset for this pass.

compute#

compute(renderer : WebGPURenderer)

Executes a complete radix sort (all passes).

Parameters
rendererWebGPURenderer
The Three.js WebGPU renderer.

computeStep#

computeStep(renderer : WebGPURenderer, passIndex : number)

Executes a single step of the radix sort (one pass).

Call this method passCount times to complete a full 32-bit sort. Useful for amortizing sort cost across multiple frames.

Parameters
rendererWebGPURenderer
The Three.js WebGPU renderer.
passIndexnumber
Pass index (0 to passCount-1).

dispose#

dispose()

Disposes of GPU resources held by this sorter.