warp.prefix_sum() ํ•˜๋“œ์›จ์–ด ์ตœ์ ํ™” ๋ณ‘๋ ฌ ์Šค์บ”

์›Œํ”„ ๋ ˆ๋ฒจ ๋ณ‘๋ ฌ ์Šค์บ” ์—ฐ์‚ฐ์—์„œ๋Š” prefix_sum()์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ณต์žกํ•œ ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ํ•˜๋“œ์›จ์–ด ์ตœ์ ํ™” ๊ธฐ๋ณธ ์š”์†Œ๋กœ ๋Œ€์ฒดํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ๊ฐ•๋ ฅํ•œ ์—ฐ์‚ฐ์„ ํ†ตํ•ด ์ˆ˜์‹ญ ์ค„์˜ ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ๋ฐ ๋™๊ธฐํ™” ์ฝ”๋“œ๊ฐ€ ํ•„์š”ํ–ˆ์„ ํšจ์œจ์ ์ธ ๋ˆ„์  ๊ณ„์‚ฐ, ๋ณ‘๋ ฌ ํŒŒํ‹ฐ์…”๋‹, ๊ณ ๊ธ‰ ์กฐ์ • ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ๊ตฌํ˜„ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

ํ•ต์‹ฌ ํ†ต์ฐฐ: prefix_sum() ์—ฐ์‚ฐ์€ ํ•˜๋“œ์›จ์–ด ๊ฐ€์† ๋ณ‘๋ ฌ ์Šค์บ”์„ ํ™œ์šฉํ•˜์—ฌ ์›Œํ”„ ๋ ˆ์ธ์— ๊ฑธ์ณ \(O(\log n)\) ๋ณต์žก๋„๋กœ ๋ˆ„์  ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•˜๋ฉฐ, ๋ณต์žกํ•œ ๋‹ค๋‹จ๊ณ„ ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ๋‹จ์ผ ํ•จ์ˆ˜ ํ˜ธ์ถœ๋กœ ๋Œ€์ฒดํ•ฉ๋‹ˆ๋‹ค.

๋ณ‘๋ ฌ ์Šค์บ”์ด๋ž€? ๋ณ‘๋ ฌ ์Šค์บ” (๋ˆ„์  ํ•ฉ)์€ ๋ฐ์ดํ„ฐ ์š”์†Œ์— ๊ฑธ์ณ ๋ˆ„์  ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๊ธฐ๋ณธ์ ์ธ ๋ณ‘๋ ฌ ๊ธฐ๋ณธ ์š”์†Œ์ž…๋‹ˆ๋‹ค. ๋ง์…ˆ์˜ ๊ฒฝ์šฐ [a, b, c, d]๋ฅผ [a, a+b, a+b+c, a+b+c+d]๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค. ์ด ์—ฐ์‚ฐ์€ ์ŠคํŠธ๋ฆผ ์ปดํŒฉ์…˜, quicksort ํŒŒํ‹ฐ์…”๋‹, ๋ณ‘๋ ฌ ์ •๋ ฌ ๊ฐ™์€ ๋ณ‘๋ ฌ ์•Œ๊ณ ๋ฆฌ์ฆ˜์— ํ•„์ˆ˜์ ์ž…๋‹ˆ๋‹ค.

ํ•ต์‹ฌ ๊ฐœ๋…

์ด ํผ์ฆ์—์„œ ๋ฐฐ์šธ ๋‚ด์šฉ:

  • prefix_sum()์„ ํ™œ์šฉํ•œ ํ•˜๋“œ์›จ์–ด ์ตœ์ ํ™” ๋ณ‘๋ ฌ ์Šค์บ”
  • ํฌํ•จ(inclusive) vs ๋น„ํฌํ•จ(exclusive) ๋ˆ„์  ํ•ฉ ํŒจํ„ด
  • ๋ฐ์ดํ„ฐ ์žฌ๋ฐฐ์น˜๋ฅผ ์œ„ํ•œ ์›Œํ”„ ๋ ˆ๋ฒจ ์ŠคํŠธ๋ฆผ ์ปดํŒฉ์…˜
  • ์—ฌ๋Ÿฌ ์›Œํ”„ ๊ธฐ๋ณธ ์š”์†Œ๋ฅผ ๊ฒฐํ•ฉํ•œ ๊ณ ๊ธ‰ ๋ณ‘๋ ฌ ํŒŒํ‹ฐ์…”๋‹
  • ๋ณต์žกํ•œ ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๋Œ€์ฒดํ•˜๋Š” ๋‹จ์ผ ์›Œํ”„ ์•Œ๊ณ ๋ฆฌ์ฆ˜ ์ตœ์ ํ™”

์ด๋ฅผ ํ†ตํ•ด ๋‹ค๋‹จ๊ณ„ ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ์•Œ๊ณ ๋ฆฌ์ฆ˜์ด ์šฐ์•„ํ•œ ๋‹จ์ผ ํ•จ์ˆ˜ ํ˜ธ์ถœ๋กœ ๋ณ€ํ™˜๋˜์–ด, ๋ช…์‹œ์  ๋™๊ธฐํ™” ์—†์ด ํšจ์œจ์ ์ธ ๋ณ‘๋ ฌ ์Šค์บ” ์—ฐ์‚ฐ์ด ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.

1. ์›Œํ”„ ํฌํ•จ ๋ˆ„์  ํ•ฉ

๊ตฌ์„ฑ

  • ๋ฒกํ„ฐ ํฌ๊ธฐ: SIZE = WARP_SIZE (GPU์— ๋”ฐ๋ผ 32 ๋˜๋Š” 64)
  • ๊ทธ๋ฆฌ๋“œ ๊ตฌ์„ฑ: (1, 1) ๊ทธ๋ฆฌ๋“œ๋‹น ๋ธ”๋ก ์ˆ˜
  • ๋ธ”๋ก ๊ตฌ์„ฑ: (WARP_SIZE, 1) ๋ธ”๋ก๋‹น ์Šค๋ ˆ๋“œ ์ˆ˜
  • ๋ฐ์ดํ„ฐ ํƒ€์ž…: DType.float32
  • ๋ ˆ์ด์•„์›ƒ: Layout.row_major(SIZE) (1D row-major)

prefix_sum์˜ ์ด์ 

๊ธฐ์กด ๋ˆ„์  ํ•ฉ์€ ๋ณต์žกํ•œ ๋‹ค๋‹จ๊ณ„ ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ์•Œ๊ณ ๋ฆฌ์ฆ˜์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. Puzzle 14: ๋ˆ„์  ํ•ฉ์—์„œ๋Š” ๋ช…์‹œ์  ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ๋กœ ์ด๋ฅผ ํž˜๋“ค๊ฒŒ ๊ตฌํ˜„ํ–ˆ์Šต๋‹ˆ๋‹ค:

fn prefix_sum_simple[
    layout: Layout
](
    output: LayoutTensor[dtype, layout, MutAnyOrigin],
    a: LayoutTensor[dtype, layout, ImmutAnyOrigin],
    size: UInt,
):
    global_i = block_dim.x * block_idx.x + thread_idx.x
    local_i = thread_idx.x
    shared = LayoutTensor[
        dtype,
        Layout.row_major(TPB),
        MutAnyOrigin,
        address_space = AddressSpace.SHARED,
    ].stack_allocation()
    if global_i < size:
        shared[local_i] = a[global_i]

    barrier()

    offset = UInt(1)
    for i in range(Int(log2(Scalar[dtype](TPB)))):
        var current_val: output.element_type = 0
        if local_i >= offset and local_i < size:
            current_val = shared[local_i - offset]  # read

        barrier()
        if local_i >= offset and local_i < size:
            shared[local_i] += current_val

        barrier()
        offset *= 2

    if global_i < size:
        output[global_i] = shared[local_i]


๊ธฐ์กด ๋ฐฉ์‹์˜ ๋ฌธ์ œ์ :

  • ๋ฉ”๋ชจ๋ฆฌ ์˜ค๋ฒ„ํ—ค๋“œ: ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ํ• ๋‹น์ด ํ•„์š”
  • ๋‹ค์ค‘ ๋ฐฐ๋ฆฌ์–ด: ๋ณต์žกํ•œ ๋‹ค๋‹จ๊ณ„ ๋™๊ธฐํ™”
  • ๋ณต์žกํ•œ ์ธ๋ฑ์‹ฑ: ์ˆ˜๋™ ์ŠคํŠธ๋ผ์ด๋“œ ๊ณ„์‚ฐ๊ณผ ๊ฒฝ๊ณ„ ๊ฒ€์‚ฌ
  • ๋‚ฎ์€ ํ™•์žฅ์„ฑ: ๊ฐ ๋‹จ๊ณ„ ์‚ฌ์ด์— ๋ฐฐ๋ฆฌ์–ด๊ฐ€ ํ•„์š”ํ•œ \(O(\log n)\) ๋‹จ๊ณ„

prefix_sum()์„ ์‚ฌ์šฉํ•˜๋ฉด ๋ณ‘๋ ฌ ์Šค์บ”์ด ๊ฐ„๋‹จํ•ด์ง‘๋‹ˆ๋‹ค:

# ํ•˜๋“œ์›จ์–ด ์ตœ์ ํ™” ๋ฐฉ์‹ - ๋‹จ์ผ ํ•จ์ˆ˜ ํ˜ธ์ถœ!
current_val = input[global_i]
scan_result = prefix_sum[exclusive=False](current_val)
output[global_i] = scan_result

prefix_sum์˜ ์žฅ์ :

  • ๋ฉ”๋ชจ๋ฆฌ ์˜ค๋ฒ„ํ—ค๋“œ ์ œ๋กœ: ํ•˜๋“œ์›จ์–ด ๊ฐ€์† ์—ฐ์‚ฐ
  • ๋™๊ธฐํ™” ๋ถˆํ•„์š”: ๋‹จ์ผ ์•„ํ† ๋ฏน ์—ฐ์‚ฐ
  • ํ•˜๋“œ์›จ์–ด ์ตœ์ ํ™”: ์ „์šฉ ์Šค์บ” ์œ ๋‹› ํ™œ์šฉ
  • ์™„๋ฒฝํ•œ ํ™•์žฅ์„ฑ: ๋ชจ๋“  WARP_SIZE (32, 64 ๋“ฑ)์—์„œ ๋™์ž‘

์™„์„ฑํ•  ์ฝ”๋“œ

ํ•˜๋“œ์›จ์–ด ์ตœ์ ํ™” prefix_sum() ๊ธฐ๋ณธ ์š”์†Œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํฌํ•จ ๋ˆ„์  ํ•ฉ์„ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค.

์ˆ˜ํ•™์  ์—ฐ์‚ฐ: ๊ฐ ๋ ˆ์ธ์ด ์ž์‹ ์˜ ์œ„์น˜๊นŒ์ง€ ๋ชจ๋“  ์š”์†Œ์˜ ํ•ฉ์„ ํฌํ•จํ•˜๋Š” ๋ˆ„์  ํ•ฉ์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค: \[\Large \text{output}[i] = \sum_{j=0}^{i} \text{input}[j]\]

์ž…๋ ฅ ๋ฐ์ดํ„ฐ [1, 2, 3, 4, 5, ...]๋ฅผ ๋ˆ„์  ํ•ฉ [1, 3, 6, 10, 15, ...]์œผ๋กœ ๋ณ€ํ™˜ํ•˜๋ฉฐ, ๊ฐ ์œ„์น˜์— ์ด์ „ ๋ชจ๋“  ์š”์†Œ์™€ ์ž๊ธฐ ์ž์‹ ์˜ ํ•ฉ์ด ๋‹ด๊น๋‹ˆ๋‹ค.

fn warp_inclusive_prefix_sum[
    layout: Layout, size: Int
](
    output: LayoutTensor[dtype, layout, MutAnyOrigin],
    input: LayoutTensor[dtype, layout, ImmutAnyOrigin],
):
    """
    Inclusive prefix sum using warp primitive:
    Each thread gets sum of all elements up to and including its position.
    Compare this to Puzzle 12's complex shared memory + barrier approach.

    Puzzle 12 approach:
    - Shared memory allocation
    - Multiple barrier synchronizations
    - Log(n) iterations with manual tree reduction
    - Complex multi-phase algorithm

    Warp prefix_sum approach:
    - Single function call!
    - Hardware-optimized parallel scan
    - Automatic synchronization
    - O(log n) complexity, but implemented in hardware.

    NOTE: This implementation only works correctly within a single warp (WARP_SIZE threads).
    For multi-warp scenarios, additional coordination would be needed.
    """
    global_i = Int(block_dim.x * block_idx.x + thread_idx.x)

    # FILL ME IN (roughly 4 lines)


์ „์ฒด ํŒŒ์ผ ๋ณด๊ธฐ: problems/p26/p26.mojo

ํŒ

1. prefix_sum ๋งค๊ฐœ๋ณ€์ˆ˜ ์ดํ•ดํ•˜๊ธฐ

prefix_sum() ํ•จ์ˆ˜์—๋Š” ์Šค์บ” ์œ ํ˜•์„ ์ œ์–ดํ•˜๋Š” ์ค‘์š”ํ•œ ํ…œํ”Œ๋ฆฟ ๋งค๊ฐœ๋ณ€์ˆ˜๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.

ํ•ต์‹ฌ ์งˆ๋ฌธ:

  • ํฌํ•จ ๋ˆ„์  ํ•ฉ๊ณผ ๋น„ํฌํ•จ ๋ˆ„์  ํ•ฉ์˜ ์ฐจ์ด๋Š” ๋ฌด์—‡์ธ๊ฐ€์š”?
  • ์–ด๋–ค ๋งค๊ฐœ๋ณ€์ˆ˜๊ฐ€ ์ด ๋™์ž‘์„ ์ œ์–ดํ•˜๋‚˜์š”?
  • ํฌํ•จ ์Šค์บ”์—์„œ ๊ฐ ๋ ˆ์ธ์€ ๋ฌด์—‡์„ ์ถœ๋ ฅํ•ด์•ผ ํ•˜๋‚˜์š”?

ํžŒํŠธ: ํ•จ์ˆ˜ ์‹œ๊ทธ๋‹ˆ์ฒ˜๋ฅผ ๋ณด๊ณ  ๋ˆ„์  ์—ฐ์‚ฐ์—์„œ โ€œํฌํ•จ(inclusive)โ€œ์ด ๋ฌด์—‡์„ ์˜๋ฏธํ•˜๋Š”์ง€ ์ƒ๊ฐํ•ด ๋ณด์„ธ์š”.

2. ๋‹จ์ผ ์›Œํ”„ ์ œํ•œ

์ด ํ•˜๋“œ์›จ์–ด ๊ธฐ๋ณธ ์š”์†Œ๋Š” ๋‹จ์ผ ์›Œํ”„ ๋‚ด์—์„œ๋งŒ ๋™์ž‘ํ•ฉ๋‹ˆ๋‹ค. ์ด ์ œํ•œ์˜ ์˜๋ฏธ๋ฅผ ์ƒ๊ฐํ•ด ๋ณด์„ธ์š”.

์ƒ๊ฐํ•ด ๋ณด์„ธ์š”:

  • ์—ฌ๋Ÿฌ ์›Œํ”„๊ฐ€ ์žˆ์œผ๋ฉด ์–ด๋–ป๊ฒŒ ๋˜๋‚˜์š”?
  • ์ด ์ œํ•œ์„ ์ดํ•ดํ•˜๋Š” ๊ฒƒ์ด ์™œ ์ค‘์š”ํ•œ๊ฐ€์š”?
  • ๋ฉ€ํ‹ฐ ์›Œํ”„ ์‹œ๋‚˜๋ฆฌ์˜ค๋กœ ํ™•์žฅํ•˜๋ ค๋ฉด ์–ด๋–ป๊ฒŒ ํ•ด์•ผ ํ•˜๋‚˜์š”?

3. ๋ฐ์ดํ„ฐ ํƒ€์ž… ๊ณ ๋ ค์‚ฌํ•ญ

prefix_sum ํ•จ์ˆ˜๋Š” ์ตœ์  ์„ฑ๋Šฅ์„ ์œ„ํ•ด ํŠน์ • ๋ฐ์ดํ„ฐ ํƒ€์ž…์„ ์š”๊ตฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๊ณ ๋ คํ•  ์ :

  • ์ž…๋ ฅ์ด ์–ด๋–ค ๋ฐ์ดํ„ฐ ํƒ€์ž…์„ ์‚ฌ์šฉํ•˜๋‚˜์š”?
  • prefix_sum์ด ํŠน์ • ์Šค์นผ๋ผ ํƒ€์ž…์„ ๊ธฐ๋Œ€ํ•˜๋‚˜์š”?
  • ํ•„์š”ํ•œ ๊ฒฝ์šฐ ํƒ€์ž… ๋ณ€ํ™˜์„ ์–ด๋–ป๊ฒŒ ์ฒ˜๋ฆฌํ•˜๋‚˜์š”?

์›Œํ”„ ํฌํ•จ ๋ˆ„์  ํ•ฉ ํ…Œ์ŠคํŠธ:

pixi run p26 --prefix-sum
pixi run -e amd p26 --prefix-sum
pixi run -e apple p26 --prefix-sum
uv run poe p26 --prefix-sum

ํ’€์—ˆ์„ ๋•Œ์˜ ์˜ˆ์ƒ ์ถœ๋ ฅ:

WARP_SIZE:  32
SIZE:  32
output: [1.0, 3.0, 6.0, 10.0, 15.0, 21.0, 28.0, 36.0, 45.0, 55.0, 66.0, 78.0, 91.0, 105.0, 120.0, 136.0, 153.0, 171.0, 190.0, 210.0, 231.0, 253.0, 276.0, 300.0, 325.0, 351.0, 378.0, 406.0, 435.0, 465.0, 496.0, 528.0]
expected: [1.0, 3.0, 6.0, 10.0, 15.0, 21.0, 28.0, 36.0, 45.0, 55.0, 66.0, 78.0, 91.0, 105.0, 120.0, 136.0, 153.0, 171.0, 190.0, 210.0, 231.0, 253.0, 276.0, 300.0, 325.0, 351.0, 378.0, 406.0, 435.0, 465.0, 496.0, 528.0]
โœ… Warp inclusive prefix sum test passed!

์†”๋ฃจ์…˜

fn warp_inclusive_prefix_sum[
    layout: Layout, size: Int
](
    output: LayoutTensor[dtype, layout, MutAnyOrigin],
    input: LayoutTensor[dtype, layout, ImmutAnyOrigin],
):
    """
    Inclusive prefix sum using warp primitive: Each thread gets sum of all elements up to and including its position.
    Compare this to Puzzle 12's complex shared memory + barrier approach.

    Puzzle 12 approach:
    - Shared memory allocation
    - Multiple barrier synchronizations
    - Log(n) iterations with manual tree reduction
    - Complex multi-phase algorithm

    Warp prefix_sum approach:
    - Single function call!
    - Hardware-optimized parallel scan
    - Automatic synchronization
    - O(log n) complexity, but implemented in hardware.

    NOTE: This implementation only works correctly within a single warp (WARP_SIZE threads).
    For multi-warp scenarios, additional coordination would be needed.
    """
    global_i = Int(block_dim.x * block_idx.x + thread_idx.x)

    if global_i < size:
        current_val = input[global_i]

        # This one call replaces ~30 lines of complex shared memory logic from Puzzle 12!
        # But it only works within the current warp (WARP_SIZE threads)
        scan_result = prefix_sum[exclusive=False](
            rebind[Scalar[dtype]](current_val)
        )

        output[global_i] = scan_result


์ด ์†”๋ฃจ์…˜์€ prefix_sum()์ด ๋ณต์žกํ•œ ๋‹ค๋‹จ๊ณ„ ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ํ•˜๋“œ์›จ์–ด ์ตœ์ ํ™”๋œ ๋‹จ์ผ ํ•จ์ˆ˜ ํ˜ธ์ถœ๋กœ ์–ด๋–ป๊ฒŒ ๋Œ€์ฒดํ•˜๋Š”์ง€ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

์•Œ๊ณ ๋ฆฌ์ฆ˜ ๋ถ„์„:

if global_i < size:
    current_val = input[global_i]

    # ์ด ํ•œ ์ค„์ด Puzzle 14์˜ ๋ณต์žกํ•œ ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ๋กœ์ง ~30์ค„์„ ๋Œ€์ฒดํ•ฉ๋‹ˆ๋‹ค!
    # ๋‹จ, ํ˜„์žฌ ์›Œํ”„ (WARP_SIZE ์Šค๋ ˆ๋“œ) ๋‚ด์—์„œ๋งŒ ๋™์ž‘ํ•ฉ๋‹ˆ๋‹ค
    scan_result = prefix_sum[exclusive=False](
        rebind[Scalar[dtype]](current_val)
    )

    output[global_i] = scan_result

SIMT ์‹คํ–‰ ์ƒ์„ธ ๋ถ„์„:

์ž…๋ ฅ: [1, 2, 3, 4, 5, 6, 7, 8, ...]

์‚ฌ์ดํด 1: ๋ชจ๋“  ๋ ˆ์ธ์ด ๋™์‹œ์— ๊ฐ’์„ ๋กœ๋“œ
  Lane 0: current_val = 1
  Lane 1: current_val = 2
  Lane 2: current_val = 3
  Lane 3: current_val = 4
  ...
  Lane 31: current_val = 32

์‚ฌ์ดํด 2: prefix_sum[exclusive=False] ์‹คํ–‰ (ํ•˜๋“œ์›จ์–ด ๊ฐ€์†)
  Lane 0: scan_result = 1 (์š”์†Œ 0~0์˜ ํ•ฉ)
  Lane 1: scan_result = 3 (์š”์†Œ 0~1์˜ ํ•ฉ: 1+2)
  Lane 2: scan_result = 6 (์š”์†Œ 0~2์˜ ํ•ฉ: 1+2+3)
  Lane 3: scan_result = 10 (์š”์†Œ 0~3์˜ ํ•ฉ: 1+2+3+4)
  ...
  Lane 31: scan_result = 528 (์š”์†Œ 0~31์˜ ํ•ฉ)

์‚ฌ์ดํด 3: ๊ฒฐ๊ณผ ์ €์žฅ
  Lane 0: output[0] = 1
  Lane 1: output[1] = 3
  Lane 2: output[2] = 6
  Lane 3: output[3] = 10
  ...

์ˆ˜ํ•™์  ํ†ต์ฐฐ: ํฌํ•จ ๋ˆ„์  ํ•ฉ ์—ฐ์‚ฐ์„ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค: \[\Large \text{output}[i] = \sum_{j=0}^{i} \text{input}[j]\]

Puzzle 14 ๋ฐฉ์‹๊ณผ์˜ ๋น„๊ต:

  • Puzzle 14: ๋ˆ„์  ํ•ฉ: ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ~30์ค„ + ๋‹ค์ค‘ ๋ฐฐ๋ฆฌ์–ด + ๋ณต์žกํ•œ ์ธ๋ฑ์‹ฑ
  • ์›Œํ”„ ๊ธฐ๋ณธ ์š”์†Œ: ํ•˜๋“œ์›จ์–ด ๊ฐ€์†์˜ ํ•จ์ˆ˜ ํ˜ธ์ถœ 1๊ฐœ
  • ์„ฑ๋Šฅ: ๊ฐ™์€ \(O(\log n)\) ๋ณต์žก๋„์ด์ง€๋งŒ, ์ „์šฉ ํ•˜๋“œ์›จ์–ด์—์„œ ๊ตฌํ˜„
  • ๋ฉ”๋ชจ๋ฆฌ: ๋ช…์‹œ์  ํ• ๋‹น ๋Œ€๋น„ ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ์ œ๋กœ

Puzzle 12์—์„œ์˜ ๋ฐœ์ „: ํ˜„๋Œ€ GPU ์•„ํ‚คํ…์ฒ˜์˜ ๊ฐ•๋ ฅํ•จ์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค - Puzzle 12์—์„œ ์‹ ์ค‘ํ•œ ์ˆ˜๋™ ๊ตฌํ˜„์ด ํ•„์š”ํ–ˆ๋˜ ๊ฒƒ์ด ์ด์ œ๋Š” ํ•˜๋“œ์›จ์–ด ๊ฐ€์† ๊ธฐ๋ณธ ์š”์†Œ ํ•˜๋‚˜๋กœ ํ•ด๊ฒฐ๋ฉ๋‹ˆ๋‹ค. ์›Œํ”„ ๋ ˆ๋ฒจ prefix_sum()์€ ๊ตฌํ˜„ ๋ณต์žก๋„ ์ œ๋กœ๋กœ ๊ฐ™์€ ์•Œ๊ณ ๋ฆฌ์ฆ˜์  ์ด์ ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.

prefix_sum์ด ์šฐ์›”ํ•œ ์ด์œ :

  1. ํ•˜๋“œ์›จ์–ด ๊ฐ€์†: ํ˜„๋Œ€ GPU์˜ ์ „์šฉ ์Šค์บ” ์œ ๋‹›
  2. ๋ฉ”๋ชจ๋ฆฌ ์˜ค๋ฒ„ํ—ค๋“œ ์ œ๋กœ: ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ํ• ๋‹น ๋ถˆํ•„์š”
  3. ์ž๋™ ๋™๊ธฐํ™”: ๋ช…์‹œ์  ๋ฐฐ๋ฆฌ์–ด ๋ถˆํ•„์š”
  4. ์™„๋ฒฝํ•œ ํ™•์žฅ์„ฑ: ๋ชจ๋“  WARP_SIZE์—์„œ ์ตœ์ ์œผ๋กœ ๋™์ž‘

์„ฑ๋Šฅ ํŠน์„ฑ:

  • ์ง€์—ฐ ์‹œ๊ฐ„: ~1-2 ์‚ฌ์ดํด (ํ•˜๋“œ์›จ์–ด ์Šค์บ” ์œ ๋‹›)
  • ๋Œ€์—ญํญ: ๋ฉ”๋ชจ๋ฆฌ ํŠธ๋ž˜ํ”ฝ ์ œ๋กœ (๋ ˆ์ง€์Šคํ„ฐ ์ „์šฉ ์—ฐ์‚ฐ)
  • ๋ณ‘๋ ฌ์„ฑ: WARP_SIZE๊ฐœ ๋ ˆ์ธ ๋ชจ๋‘ ๋™์‹œ์— ์ฐธ์—ฌ
  • ํ™•์žฅ์„ฑ: ํ•˜๋“œ์›จ์–ด ์ตœ์ ํ™”๋ฅผ ๋™๋ฐ˜ํ•œ \(O(\log n)\) ๋ณต์žก๋„

์ค‘์š”ํ•œ ์ œํ•œ์‚ฌํ•ญ: ์ด ๊ธฐ๋ณธ ์š”์†Œ๋Š” ๋‹จ์ผ ์›Œํ”„ ๋‚ด์—์„œ๋งŒ ๋™์ž‘ํ•ฉ๋‹ˆ๋‹ค. ๋ฉ€ํ‹ฐ ์›Œํ”„ ์‹œ๋‚˜๋ฆฌ์˜ค์—์„œ๋Š” ์›Œํ”„ ๊ฐ„ ์ถ”๊ฐ€ ์กฐ์ •์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

2. ์›Œํ”„ ํŒŒํ‹ฐ์…˜

๊ตฌ์„ฑ

  • ๋ฒกํ„ฐ ํฌ๊ธฐ: SIZE = WARP_SIZE (GPU์— ๋”ฐ๋ผ 32 ๋˜๋Š” 64)
  • ๊ทธ๋ฆฌ๋“œ ๊ตฌ์„ฑ: (1, 1) ๊ทธ๋ฆฌ๋“œ๋‹น ๋ธ”๋ก ์ˆ˜
  • ๋ธ”๋ก ๊ตฌ์„ฑ: (WARP_SIZE, 1) ๋ธ”๋ก๋‹น ์Šค๋ ˆ๋“œ ์ˆ˜

์™„์„ฑํ•  ์ฝ”๋“œ

shuffle_xor๊ณผ prefix_sum ๊ธฐ๋ณธ ์š”์†Œ๋ฅผ ๋ชจ๋‘ ์‚ฌ์šฉํ•˜์—ฌ ๋‹จ์ผ ์›Œํ”„ ๋ณ‘๋ ฌ ํŒŒํ‹ฐ์…”๋‹์„ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค.

์ˆ˜ํ•™์  ์—ฐ์‚ฐ: ํ”ผ๋ฒ— ๊ฐ’์„ ๊ธฐ์ค€์œผ๋กœ ์š”์†Œ๋ฅผ ๋ถ„ํ• ํ•˜์—ฌ, < pivot์ธ ์š”์†Œ๋Š” ์™ผ์ชฝ์—, >= pivot์ธ ์š”์†Œ๋Š” ์˜ค๋ฅธ์ชฝ์— ๋ฐฐ์น˜ํ•ฉ๋‹ˆ๋‹ค: \[\Large \text{output} = [\text{elements} < \text{pivot}] \,|\, [\text{elements} \geq \text{pivot}]\]

๊ณ ๊ธ‰ ์•Œ๊ณ ๋ฆฌ์ฆ˜: ์ด ์•Œ๊ณ ๋ฆฌ์ฆ˜์€ ๋‘ ๊ฐ€์ง€ ์ •๊ตํ•œ ์›Œํ”„ ๊ธฐ๋ณธ ์š”์†Œ๋ฅผ ๊ฒฐํ•ฉํ•ฉ๋‹ˆ๋‹ค:

  1. shuffle_xor(): ์™ผ์ชฝ ์š”์†Œ ๊ฐœ์ˆ˜๋ฅผ ์„ธ๊ธฐ ์œ„ํ•œ ์›Œํ”„ ๋ ˆ๋ฒจ ๋ฒ„ํ„ฐํ”Œ๋ผ์ด ๋ฆฌ๋•์…˜
  2. prefix_sum(): ๊ฐ ํŒŒํ‹ฐ์…˜ ๋‚ด ์œ„์น˜ ๊ณ„์‚ฐ์„ ์œ„ํ•œ ๋น„ํฌํ•จ ์Šค์บ”

์ด๋Š” ๋‹จ์ผ ์›Œํ”„ ๋‚ด์—์„œ ์—ฌ๋Ÿฌ ์›Œํ”„ ๊ธฐ๋ณธ ์š”์†Œ๋ฅผ ๊ฒฐํ•ฉํ•˜์—ฌ ๋ณต์žกํ•œ ๋ณ‘๋ ฌ ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ๊ตฌํ˜„ํ•˜๋Š” ๊ฐ•๋ ฅํ•จ์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

fn warp_partition[
    layout: Layout, size: Int
](
    output: LayoutTensor[dtype, layout, MutAnyOrigin],
    input: LayoutTensor[dtype, layout, ImmutAnyOrigin],
    pivot: Float32,
):
    """
    Single-warp parallel partitioning using BOTH shuffle_xor AND prefix_sum.
    This implements a warp-level quicksort partition step that places elements < pivot
    on the left and elements >= pivot on the right.

    ALGORITHM COMPLEXITY - combines two advanced warp primitives:
    1. shuffle_xor(): Butterfly pattern for warp-level reductions
    2. prefix_sum(): Warp-level exclusive scan for position calculation.

    This demonstrates the power of warp primitives for sophisticated parallel algorithms
    within a single warp (works for any WARP_SIZE: 32, 64, etc.).

    Example with pivot=5:
    Input:  [3, 7, 1, 8, 2, 9, 4, 6]
    Result: [3, 1, 2, 4, 7, 8, 9, 6] (< pivot | >= pivot).
    """
    global_i = Int(block_dim.x * block_idx.x + thread_idx.x)

    if global_i < size:
        current_val = input[global_i]

        # FILL ME IN (roughly 13 lines)


ํŒ

1. ๋‹ค๋‹จ๊ณ„ ์•Œ๊ณ ๋ฆฌ์ฆ˜ ๊ตฌ์กฐ

์ด ์•Œ๊ณ ๋ฆฌ์ฆ˜์€ ์—ฌ๋Ÿฌ ์กฐ์ •๋œ ๋‹จ๊ณ„๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ํŒŒํ‹ฐ์…”๋‹์— ํ•„์š”ํ•œ ๋…ผ๋ฆฌ์  ๋‹จ๊ณ„๋ฅผ ์ƒ๊ฐํ•ด ๋ณด์„ธ์š”.

๊ณ ๋ คํ•  ํ•ต์‹ฌ ๋‹จ๊ณ„:

  • ์–ด๋–ค ์š”์†Œ๊ฐ€ ์–ด๋А ํŒŒํ‹ฐ์…˜์— ์†ํ•˜๋Š”์ง€ ์–ด๋–ป๊ฒŒ ์‹๋ณ„ํ•˜๋‚˜์š”?
  • ๊ฐ ํŒŒํ‹ฐ์…˜ ๋‚ด์—์„œ ์œ„์น˜๋ฅผ ์–ด๋–ป๊ฒŒ ๊ณ„์‚ฐํ•˜๋‚˜์š”?
  • ์™ผ์ชฝ ํŒŒํ‹ฐ์…˜์˜ ์ „์ฒด ํฌ๊ธฐ๋ฅผ ์–ด๋–ป๊ฒŒ ์•Œ ์ˆ˜ ์žˆ๋‚˜์š”?
  • ์ตœ์ข… ์œ„์น˜์— ์š”์†Œ๋ฅผ ์–ด๋–ป๊ฒŒ ๊ธฐ๋กํ•˜๋‚˜์š”?

2. ํ”„๋ ˆ๋””์ผ€์ดํŠธ ์ƒ์„ฑ

์–ด๋А ํŒŒํ‹ฐ์…˜์— ์†ํ•˜๋Š”์ง€ ํŒ๋ณ„ํ•˜๋Š” ๋ถˆ๋ฆฌ์–ธ ํ”„๋ ˆ๋””์ผ€์ดํŠธ๋ฅผ ๋งŒ๋“ค์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

์ƒ๊ฐํ•ด ๋ณด์„ธ์š”:

  • โ€œ์ด ์š”์†Œ๋Š” ์™ผ์ชฝ ํŒŒํ‹ฐ์…˜์— ์†ํ•œ๋‹คโ€œ๋ฅผ ์–ด๋–ป๊ฒŒ ํ‘œํ˜„ํ•˜๋‚˜์š”?
  • โ€œ์ด ์š”์†Œ๋Š” ์˜ค๋ฅธ์ชฝ ํŒŒํ‹ฐ์…˜์— ์†ํ•œ๋‹คโ€œ๋ฅผ ์–ด๋–ป๊ฒŒ ํ‘œํ˜„ํ•˜๋‚˜์š”?
  • prefix_sum์— ์ „๋‹ฌํ•  ํ”„๋ ˆ๋””์ผ€์ดํŠธ๋Š” ์–ด๋–ค ๋ฐ์ดํ„ฐ ํƒ€์ž…์ด์–ด์•ผ ํ•˜๋‚˜์š”?

3. shuffle_xor๊ณผ prefix_sum ๊ฒฐํ•ฉ

์ด ์•Œ๊ณ ๋ฆฌ์ฆ˜์€ ๋‘ ์›Œํ”„ ๊ธฐ๋ณธ ์š”์†Œ๋ฅผ ์„œ๋กœ ๋‹ค๋ฅธ ๋ชฉ์ ์œผ๋กœ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

๊ณ ๋ คํ•  ์ :

  • ์ด ๋งฅ๋ฝ์—์„œ shuffle_xor์€ ๋ฌด์—‡์— ์‚ฌ์šฉ๋˜๋‚˜์š”?
  • ์ด ๋งฅ๋ฝ์—์„œ prefix_sum์€ ๋ฌด์—‡์— ์‚ฌ์šฉ๋˜๋‚˜์š”?
  • ์ด ๋‘ ์—ฐ์‚ฐ์ด ์–ด๋–ป๊ฒŒ ํ•จ๊ป˜ ๋™์ž‘ํ•˜๋‚˜์š”?

4. ์œ„์น˜ ๊ณ„์‚ฐ

๊ฐ€์žฅ ๊นŒ๋‹ค๋กœ์šด ๋ถ€๋ถ„์€ ๊ฐ ์š”์†Œ๊ฐ€ ์ถœ๋ ฅ์—์„œ ์–ด๋””์— ๊ธฐ๋ก๋˜์–ด์•ผ ํ•˜๋Š”์ง€ ๊ณ„์‚ฐํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

ํ•ต์‹ฌ ํ†ต์ฐฐ:

  • ์™ผ์ชฝ ํŒŒํ‹ฐ์…˜ ์š”์†Œ: ์ตœ์ข… ์œ„์น˜๋ฅผ ๋ฌด์—‡์ด ๊ฒฐ์ •ํ•˜๋‚˜์š”?
  • ์˜ค๋ฅธ์ชฝ ํŒŒํ‹ฐ์…˜ ์š”์†Œ: ์˜คํ”„์…‹์„ ์–ด๋–ป๊ฒŒ ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ์ ์šฉํ•˜๋‚˜์š”?
  • ๋กœ์ปฌ ์œ„์น˜์™€ ํŒŒํ‹ฐ์…˜ ๊ฒฝ๊ณ„๋ฅผ ์–ด๋–ป๊ฒŒ ๊ฒฐํ•ฉํ•˜๋‚˜์š”?

์›Œํ”„ ํŒŒํ‹ฐ์…˜ ํ…Œ์ŠคํŠธ:

uv run poe p26 --partition
pixi run p26 --partition

ํ’€์—ˆ์„ ๋•Œ์˜ ์˜ˆ์ƒ ์ถœ๋ ฅ:

WARP_SIZE:  32
SIZE:  32
output: HostBuffer([3.0, 1.0, 2.0, 4.0, 0.0, 3.0, 1.0, 4.0, 3.0, 1.0, 2.0, 4.0, 0.0, 3.0, 1.0, 4.0, 7.0, 8.0, 9.0, 6.0, 10.0, 11.0, 12.0, 13.0, 7.0, 8.0, 9.0, 6.0, 10.0, 11.0, 12.0, 13.0])
expected: HostBuffer([3.0, 1.0, 2.0, 4.0, 0.0, 3.0, 1.0, 4.0, 3.0, 1.0, 2.0, 4.0, 0.0, 3.0, 1.0, 4.0, 7.0, 8.0, 9.0, 6.0, 10.0, 11.0, 12.0, 13.0, 7.0, 8.0, 9.0, 6.0, 10.0, 11.0, 12.0, 13.0])
pivot: 5.0
โœ… Warp partition test passed!

์†”๋ฃจ์…˜

fn warp_partition[
    layout: Layout, size: Int
](
    output: LayoutTensor[dtype, layout, MutAnyOrigin],
    input: LayoutTensor[dtype, layout, ImmutAnyOrigin],
    pivot: Float32,
):
    """
    Single-warp parallel partitioning using BOTH shuffle_xor AND prefix_sum.
    This implements a warp-level quicksort partition step that places elements < pivot
    on the left and elements >= pivot on the right.

    ALGORITHM COMPLEXITY - combines two advanced warp primitives:
    1. shuffle_xor(): Butterfly pattern for warp-level reductions
    2. prefix_sum(): Warp-level exclusive scan for position calculation.

    This demonstrates the power of warp primitives for sophisticated parallel algorithms
    within a single warp (works for any WARP_SIZE: 32, 64, etc.).

    Example with pivot=5:
    Input:  [3, 7, 1, 8, 2, 9, 4, 6]
    Result: [3, 1, 2, 4, 7, 8, 9, 6] (< pivot | >= pivot).
    """
    global_i = Int(block_dim.x * block_idx.x + thread_idx.x)

    if global_i < size:
        current_val = input[global_i]

        # Phase 1: Create warp-level predicates
        predicate_left = Float32(1.0) if current_val < pivot else Float32(0.0)
        predicate_right = Float32(1.0) if current_val >= pivot else Float32(0.0)

        # Phase 2: Warp-level prefix sum to get positions within warp
        warp_left_pos = prefix_sum[exclusive=True](predicate_left)
        warp_right_pos = prefix_sum[exclusive=True](predicate_right)

        # Phase 3: Get total left count using shuffle_xor reduction
        warp_left_total = predicate_left

        # Butterfly reduction to get total across the warp: dynamic for any WARP_SIZE
        offset = WARP_SIZE // 2
        while offset > 0:
            warp_left_total += shuffle_xor(warp_left_total, offset)
            offset //= 2

        # Phase 4: Write to output positions
        if current_val < pivot:
            # Left partition: use warp-level position
            output[Int(warp_left_pos)] = current_val
        else:
            # Right partition: offset by total left count + right position
            output[Int(warp_left_total + warp_right_pos)] = current_val


์ด ์†”๋ฃจ์…˜์€ ์—ฌ๋Ÿฌ ์›Œํ”„ ๊ธฐ๋ณธ ์š”์†Œ ๊ฐ„์˜ ๊ณ ๊ธ‰ ์กฐ์ •์„ ํ†ตํ•ด ์ •๊ตํ•œ ๋ณ‘๋ ฌ ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ๊ตฌํ˜„ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

์ „์ฒด ์•Œ๊ณ ๋ฆฌ์ฆ˜ ๋ถ„์„:

if global_i < size:
    current_val = input[global_i]

    # 1๋‹จ๊ณ„: ์›Œํ”„ ๋ ˆ๋ฒจ ํ”„๋ ˆ๋””์ผ€์ดํŠธ ์ƒ์„ฑ
    predicate_left = Float32(1.0) if current_val < pivot else Float32(0.0)
    predicate_right = Float32(1.0) if current_val >= pivot else Float32(0.0)

    # 2๋‹จ๊ณ„: ์›Œํ”„ ๋ ˆ๋ฒจ ๋ˆ„์  ํ•ฉ์œผ๋กœ ์›Œํ”„ ๋‚ด ์œ„์น˜ ๊ณ„์‚ฐ
    warp_left_pos = prefix_sum[exclusive=True](predicate_left)
    warp_right_pos = prefix_sum[exclusive=True](predicate_right)

    # 3๋‹จ๊ณ„: shuffle_xor ๋ฒ„ํ„ฐํ”Œ๋ผ์ด ๋ฆฌ๋•์…˜์œผ๋กœ ์™ผ์ชฝ ์ด ๊ฐœ์ˆ˜ ๊ตฌํ•˜๊ธฐ
    warp_left_total = predicate_left

    # ์›Œํ”„ ์ „์ฒด์˜ ํ•ฉ์‚ฐ์„ ์œ„ํ•œ ๋ฒ„ํ„ฐํ”Œ๋ผ์ด ๋ฆฌ๋•์…˜: ๋ชจ๋“  WARP_SIZE์— ๋™์  ๋Œ€์‘
    offset = WARP_SIZE // 2
    while offset > 0:
        warp_left_total += shuffle_xor(warp_left_total, offset)
        offset //= 2

    # 4๋‹จ๊ณ„: ์ถœ๋ ฅ ์œ„์น˜์— ๊ธฐ๋ก
    if current_val < pivot:
        # ์™ผ์ชฝ ํŒŒํ‹ฐ์…˜: ์›Œํ”„ ๋ ˆ๋ฒจ ์œ„์น˜ ์‚ฌ์šฉ
        output[Int(warp_left_pos)] = current_val
    else:
        # ์˜ค๋ฅธ์ชฝ ํŒŒํ‹ฐ์…˜: ์™ผ์ชฝ ์ด ๊ฐœ์ˆ˜ + ์˜ค๋ฅธ์ชฝ ์œ„์น˜๋กœ offset
        output[Int(warp_left_total + warp_right_pos)] = current_val

๋‹ค๋‹จ๊ณ„ ์‹คํ–‰ ์ถ”์  (8-๋ ˆ์ธ ์˜ˆ์ œ, pivot=5, ๊ฐ’ [3,7,1,8,2,9,4,6]):

์ดˆ๊ธฐ ์ƒํƒœ:
  Lane 0: current_val=3 (< 5)  Lane 1: current_val=7 (>= 5)
  Lane 2: current_val=1 (< 5)  Lane 3: current_val=8 (>= 5)
  Lane 4: current_val=2 (< 5)  Lane 5: current_val=9 (>= 5)
  Lane 6: current_val=4 (< 5)  Lane 7: current_val=6 (>= 5)

1๋‹จ๊ณ„: ํ”„๋ ˆ๋””์ผ€์ดํŠธ ์ƒ์„ฑ
  Lane 0: predicate_left=1.0, predicate_right=0.0
  Lane 1: predicate_left=0.0, predicate_right=1.0
  Lane 2: predicate_left=1.0, predicate_right=0.0
  Lane 3: predicate_left=0.0, predicate_right=1.0
  Lane 4: predicate_left=1.0, predicate_right=0.0
  Lane 5: predicate_left=0.0, predicate_right=1.0
  Lane 6: predicate_left=1.0, predicate_right=0.0
  Lane 7: predicate_left=0.0, predicate_right=1.0

2๋‹จ๊ณ„: ์œ„์น˜ ๊ณ„์‚ฐ์„ ์œ„ํ•œ ๋น„ํฌํ•จ ๋ˆ„์  ํ•ฉ
  warp_left_pos:  [0, 0, 1, 1, 2, 2, 3, 3]
  warp_right_pos: [0, 0, 0, 1, 1, 2, 2, 3]

3๋‹จ๊ณ„: ์™ผ์ชฝ ์ด ๊ฐœ์ˆ˜๋ฅผ ์œ„ํ•œ ๋ฒ„ํ„ฐํ”Œ๋ผ์ด ๋ฆฌ๋•์…˜
  ์ดˆ๊ธฐ๊ฐ’: [1, 0, 1, 0, 1, 0, 1, 0]
  ๋ฆฌ๋•์…˜ ํ›„: ๋ชจ๋“  ๋ ˆ์ธ์ด warp_left_total = 4๋ฅผ ๊ฐ€์ง

4๋‹จ๊ณ„: ์ถœ๋ ฅ ์œ„์น˜์— ๊ธฐ๋ก
  Lane 0: current_val=3 < pivot โ†’ output[0] = 3
  Lane 1: current_val=7 >= pivot โ†’ output[4+0] = output[4] = 7
  Lane 2: current_val=1 < pivot โ†’ output[1] = 1
  Lane 3: current_val=8 >= pivot โ†’ output[4+1] = output[5] = 8
  Lane 4: current_val=2 < pivot โ†’ output[2] = 2
  Lane 5: current_val=9 >= pivot โ†’ output[4+2] = output[6] = 9
  Lane 6: current_val=4 < pivot โ†’ output[3] = 4
  Lane 7: current_val=6 >= pivot โ†’ output[4+3] = output[7] = 6

์ตœ์ข… ๊ฒฐ๊ณผ: [3, 1, 2, 4, 7, 8, 9, 6] (< pivot | >= pivot)

์ˆ˜ํ•™์  ํ†ต์ฐฐ: ์ด์ค‘ ์›Œํ”„ ๊ธฐ๋ณธ ์š”์†Œ๋ฅผ ์‚ฌ์šฉํ•œ ๋ณ‘๋ ฌ ํŒŒํ‹ฐ์…”๋‹์„ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค: \[\Large \begin{align} \text{left_pos}[i] &= \text{prefix_sum}_{\text{exclusive}}(\text{predicate_left}[i]) \\ \text{right_pos}[i] &= \text{prefix_sum}_{\text{exclusive}}(\text{predicate_right}[i]) \\ \text{left_total} &= \text{butterfly_reduce}(\text{predicate_left}) \\ \text{final_pos}[i] &= \begin{cases} \text{left_pos}[i] & \text{if } \text{input}[i] < \text{pivot} \\ \text{left_total} + \text{right_pos}[i] & \text{if } \text{input}[i] \geq \text{pivot} \end{cases} \end{align}\]

๋‹ค์ค‘ ๊ธฐ๋ณธ ์š”์†Œ ์ ‘๊ทผ ๋ฐฉ์‹์ด ๋™์ž‘ํ•˜๋Š” ์ด์œ :

  1. ํ”„๋ ˆ๋””์ผ€์ดํŠธ ์ƒ์„ฑ: ๊ฐ ์š”์†Œ์˜ ํŒŒํ‹ฐ์…˜ ์†Œ์†์„ ์‹๋ณ„
  2. ๋น„ํฌํ•จ ๋ˆ„์  ํ•ฉ: ๊ฐ ํŒŒํ‹ฐ์…˜ ๋‚ด ์ƒ๋Œ€์  ์œ„์น˜๋ฅผ ๊ณ„์‚ฐ
  3. ๋ฒ„ํ„ฐํ”Œ๋ผ์ด ๋ฆฌ๋•์…˜: ํŒŒํ‹ฐ์…˜ ๊ฒฝ๊ณ„ (์™ผ์ชฝ ์ด ๊ฐœ์ˆ˜)๋ฅผ ์‚ฐ์ถœ
  4. ์กฐ์ •๋œ ๊ธฐ๋ก: ๋กœ์ปฌ ์œ„์น˜์™€ ์ „์—ญ ํŒŒํ‹ฐ์…˜ ๊ตฌ์กฐ๋ฅผ ๊ฒฐํ•ฉ

์•Œ๊ณ ๋ฆฌ์ฆ˜ ๋ณต์žก๋„:

  • 1๋‹จ๊ณ„: \(O(1)\) - ํ”„๋ ˆ๋””์ผ€์ดํŠธ ์ƒ์„ฑ
  • 2๋‹จ๊ณ„: \(O(\log n)\) - ํ•˜๋“œ์›จ์–ด ๊ฐ€์† ๋ˆ„์  ํ•ฉ
  • 3๋‹จ๊ณ„: \(O(\log n)\) - shuffle_xor์„ ํ™œ์šฉํ•œ ๋ฒ„ํ„ฐํ”Œ๋ผ์ด ๋ฆฌ๋•์…˜
  • 4๋‹จ๊ณ„: \(O(1)\) - ์กฐ์ •๋œ ๊ธฐ๋ก
  • ์ „์ฒด: ์šฐ์ˆ˜ํ•œ ์ƒ์ˆ˜๋ฅผ ๊ฐ€์ง„ \(O(\log n)\)

์„ฑ๋Šฅ ํŠน์„ฑ:

  • ํ†ต์‹  ๋‹จ๊ณ„: \(2 \times \log_2(\text{WARP_SIZE})\) (๋ˆ„์  ํ•ฉ + ๋ฒ„ํ„ฐํ”Œ๋ผ์ด ๋ฆฌ๋•์…˜)
  • ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์„ฑ: ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ์ œ๋กœ, ๋ชจ๋‘ ๋ ˆ์ง€์Šคํ„ฐ ๊ธฐ๋ฐ˜
  • ๋ณ‘๋ ฌ์„ฑ: ์•Œ๊ณ ๋ฆฌ์ฆ˜ ์ „์ฒด์—์„œ ๋ชจ๋“  ๋ ˆ์ธ์ด ํ™œ์„ฑ ์ƒํƒœ
  • ํ™•์žฅ์„ฑ: ๋ชจ๋“  WARP_SIZE (32, 64 ๋“ฑ)์—์„œ ๋™์ž‘

์‹ค์šฉ์  ํ™œ์šฉ: ์ด ํŒจํ„ด์˜ ๊ธฐ๋ฐ˜์ด ๋˜๋Š” ๋ถ„์•ผ:

  • Quicksort ํŒŒํ‹ฐ์…”๋‹: ๋ณ‘๋ ฌ ์ •๋ ฌ ์•Œ๊ณ ๋ฆฌ์ฆ˜์˜ ํ•ต์‹ฌ ๋‹จ๊ณ„
  • ์ŠคํŠธ๋ฆผ ์ปดํŒฉ์…˜: ๋ฐ์ดํ„ฐ ์ŠคํŠธ๋ฆผ์—์„œ null/๋ฌดํšจ ์š”์†Œ ์ œ๊ฑฐ
  • ๋ณ‘๋ ฌ ํ•„ํ„ฐ๋ง: ๋ณต์žกํ•œ ํ”„๋ ˆ๋””์ผ€์ดํŠธ์— ๋”ฐ๋ฅธ ๋ฐ์ดํ„ฐ ๋ถ„๋ฆฌ
  • ๋ถ€ํ•˜ ๋ถ„์‚ฐ: ์—ฐ์‚ฐ ์š”๊ตฌ๋Ÿ‰์— ๋”ฐ๋ฅธ ์ž‘์—… ์žฌ๋ถ„๋ฐฐ

์š”์•ฝ

prefix_sum() ๊ธฐ๋ณธ ์š”์†Œ๋Š” ๋ณต์žกํ•œ ๋‹ค๋‹จ๊ณ„ ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ๋‹จ์ผ ํ•จ์ˆ˜ ํ˜ธ์ถœ๋กœ ๋Œ€์ฒดํ•˜๋Š” ํ•˜๋“œ์›จ์–ด ๊ฐ€์† ๋ณ‘๋ ฌ ์Šค์บ” ์—ฐ์‚ฐ์„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค. ๋‘ ๊ฐ€์ง€ ๋ฌธ์ œ๋ฅผ ํ†ตํ•ด ๋‹ค์Œ์„ ๋ฐฐ์› ์Šต๋‹ˆ๋‹ค:

ํ•ต์‹ฌ ๋ˆ„์  ํ•ฉ ํŒจํ„ด

  1. ํฌํ•จ ๋ˆ„์  ํ•ฉ (prefix_sum[exclusive=False]):

    • ํ•˜๋“œ์›จ์–ด ๊ฐ€์† ๋ˆ„์  ์—ฐ์‚ฐ
    • ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ์ฝ”๋“œ ~30์ค„์„ ๋‹จ์ผ ํ•จ์ˆ˜ ํ˜ธ์ถœ๋กœ ๋Œ€์ฒด
    • ์ „์šฉ ํ•˜๋“œ์›จ์–ด ์ตœ์ ํ™”๋ฅผ ๋™๋ฐ˜ํ•œ \(O(\log n)\) ๋ณต์žก๋„
  2. ๊ณ ๊ธ‰ ๋‹ค์ค‘ ๊ธฐ๋ณธ ์š”์†Œ ์กฐ์ • (prefix_sum + shuffle_xor ๊ฒฐํ•ฉ):

    • ๋‹จ์ผ ์›Œํ”„ ๋‚ด ์ •๊ตํ•œ ๋ณ‘๋ ฌ ์•Œ๊ณ ๋ฆฌ์ฆ˜
    • ์œ„์น˜ ๊ณ„์‚ฐ์„ ์œ„ํ•œ ๋น„ํฌํ•จ ์Šค์บ” + ์ดํ•ฉ์„ ์œ„ํ•œ ๋ฒ„ํ„ฐํ”Œ๋ผ์ด ๋ฆฌ๋•์…˜
    • ์ตœ์ ์˜ ๋ณ‘๋ ฌ ํšจ์œจ์„ฑ์„ ๊ฐ€์ง„ ๋ณต์žกํ•œ ํŒŒํ‹ฐ์…”๋‹ ์—ฐ์‚ฐ

ํ•ต์‹ฌ ์•Œ๊ณ ๋ฆฌ์ฆ˜ ํ†ต์ฐฐ

ํ•˜๋“œ์›จ์–ด ๊ฐ€์†์˜ ์ด์ :

  • prefix_sum()์ด ํ˜„๋Œ€ GPU์˜ ์ „์šฉ ์Šค์บ” ์œ ๋‹›์„ ํ™œ์šฉ
  • ๊ธฐ์กด ๋ฐฉ์‹ ๋Œ€๋น„ ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ์˜ค๋ฒ„ํ—ค๋“œ ์ œ๋กœ
  • ๋ช…์‹œ์  ๋ฐฐ๋ฆฌ์–ด ์—†๋Š” ์ž๋™ ๋™๊ธฐํ™”

๋‹ค์ค‘ ๊ธฐ๋ณธ ์š”์†Œ ์กฐ์ •:

# 1๋‹จ๊ณ„: ํŒŒํ‹ฐ์…˜ ์†Œ์†์„ ์œ„ํ•œ ํ”„๋ ˆ๋””์ผ€์ดํŠธ ์ƒ์„ฑ
predicate = 1.0 if condition else 0.0

# 2๋‹จ๊ณ„: ๋กœ์ปฌ ์œ„์น˜๋ฅผ ์œ„ํ•œ prefix_sum ์‚ฌ์šฉ
local_pos = prefix_sum[exclusive=True](predicate)

# 3๋‹จ๊ณ„: ์ „์—ญ ์ดํ•ฉ์„ ์œ„ํ•œ shuffle_xor ์‚ฌ์šฉ
global_total = butterfly_reduce(predicate)

# 4๋‹จ๊ณ„: ์ตœ์ข… ์œ„์น˜ ๊ฒฐ์ •์„ ์œ„ํ•œ ๊ฒฐํ•ฉ
final_pos = local_pos + partition_offset

์„ฑ๋Šฅ ์ด์ :

  • ํ•˜๋“œ์›จ์–ด ์ตœ์ ํ™”: ์†Œํ”„ํŠธ์›จ์–ด ๊ตฌํ˜„ ๋Œ€๋น„ ์ „์šฉ ์Šค์บ” ์œ ๋‹›
  • ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์„ฑ: ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ํ• ๋‹น ๋Œ€๋น„ ๋ ˆ์ง€์Šคํ„ฐ ์ „์šฉ ์—ฐ์‚ฐ
  • ํ™•์žฅ ๊ฐ€๋Šฅํ•œ ๋ณต์žก๋„: ํ•˜๋“œ์›จ์–ด ๊ฐ€์†์„ ๋™๋ฐ˜ํ•œ \(O(\log n)\)
  • ๋‹จ์ผ ์›Œํ”„ ์ตœ์ ํ™”: WARP_SIZE ํ•œ๋„ ๋‚ด ์•Œ๊ณ ๋ฆฌ์ฆ˜์— ์ตœ์ 

์‹ค์šฉ์  ํ™œ์šฉ

์ด ๋ˆ„์  ํ•ฉ ํŒจํ„ด๋“ค์˜ ๊ธฐ๋ฐ˜์ด ๋˜๋Š” ๋ถ„์•ผ:

  • ๋ณ‘๋ ฌ ์Šค์บ” ์—ฐ์‚ฐ: ๋ˆ„์  ํ•ฉ, ๋ˆ„์  ๊ณฑ, min/max ์Šค์บ”
  • ์ŠคํŠธ๋ฆผ ์ปดํŒฉ์…˜: ๋ณ‘๋ ฌ ํ•„ํ„ฐ๋ง๊ณผ ๋ฐ์ดํ„ฐ ์žฌ๋ฐฐ์น˜
  • Quicksort ํŒŒํ‹ฐ์…”๋‹: ๋ณ‘๋ ฌ ์ •๋ ฌ ์•Œ๊ณ ๋ฆฌ์ฆ˜์˜ ํ•ต์‹ฌ ๋นŒ๋”ฉ ๋ธ”๋ก
  • ๋ณ‘๋ ฌ ์•Œ๊ณ ๋ฆฌ์ฆ˜: ๋ถ€ํ•˜ ๋ถ„์‚ฐ, ์ž‘์—… ๋ถ„๋ฐฐ, ๋ฐ์ดํ„ฐ ์žฌ๊ตฌ์กฐํ™”

prefix_sum()๊ณผ shuffle_xor()์˜ ๊ฒฐํ•ฉ์€ ํ˜„๋Œ€ GPU ์›Œํ”„ ๊ธฐ๋ณธ ์š”์†Œ๊ฐ€ ์ตœ์†Œํ•œ์˜ ์ฝ”๋“œ ๋ณต์žก๋„์™€ ์ตœ์ ์˜ ์„ฑ๋Šฅ ํŠน์„ฑ์œผ๋กœ ์ •๊ตํ•œ ๋ณ‘๋ ฌ ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ์–ด๋–ป๊ฒŒ ๊ตฌํ˜„ํ•  ์ˆ˜ ์žˆ๋Š”์ง€๋ฅผ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.