warp.prefix_sum() ํ๋์จ์ด ์ต์ ํ ๋ณ๋ ฌ ์ค์บ
์ํ ๋ ๋ฒจ ๋ณ๋ ฌ ์ค์บ ์ฐ์ฐ์์๋ prefix_sum()์ ์ฌ์ฉํ์ฌ ๋ณต์กํ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ์๊ณ ๋ฆฌ์ฆ์ ํ๋์จ์ด ์ต์ ํ ๊ธฐ๋ณธ ์์๋ก ๋์ฒดํ ์ ์์ต๋๋ค. ์ด ๊ฐ๋ ฅํ ์ฐ์ฐ์ ํตํด ์์ญ ์ค์ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ๋ฐ ๋๊ธฐํ ์ฝ๋๊ฐ ํ์ํ์ ํจ์จ์ ์ธ ๋์ ๊ณ์ฐ, ๋ณ๋ ฌ ํํฐ์
๋, ๊ณ ๊ธ ์กฐ์ ์๊ณ ๋ฆฌ์ฆ์ ๊ตฌํํ ์ ์์ต๋๋ค.
ํต์ฌ ํต์ฐฐ: prefix_sum() ์ฐ์ฐ์ ํ๋์จ์ด ๊ฐ์ ๋ณ๋ ฌ ์ค์บ์ ํ์ฉํ์ฌ ์ํ ๋ ์ธ์ ๊ฑธ์ณ \(O(\log n)\) ๋ณต์ก๋๋ก ๋์ ์ฐ์ฐ์ ์ํํ๋ฉฐ, ๋ณต์กํ ๋ค๋จ๊ณ ์๊ณ ๋ฆฌ์ฆ์ ๋จ์ผ ํจ์ ํธ์ถ๋ก ๋์ฒดํฉ๋๋ค.
๋ณ๋ ฌ ์ค์บ์ด๋? ๋ณ๋ ฌ ์ค์บ (๋์ ํฉ)์ ๋ฐ์ดํฐ ์์์ ๊ฑธ์ณ ๋์ ์ฐ์ฐ์ ์ํํ๋ ๊ธฐ๋ณธ์ ์ธ ๋ณ๋ ฌ ๊ธฐ๋ณธ ์์์ ๋๋ค. ๋ง์ ์ ๊ฒฝ์ฐ
[a, b, c, d]๋ฅผ[a, a+b, a+b+c, a+b+c+d]๋ก ๋ณํํฉ๋๋ค. ์ด ์ฐ์ฐ์ ์คํธ๋ฆผ ์ปดํฉ์ , quicksort ํํฐ์ ๋, ๋ณ๋ ฌ ์ ๋ ฌ ๊ฐ์ ๋ณ๋ ฌ ์๊ณ ๋ฆฌ์ฆ์ ํ์์ ์ ๋๋ค.
ํต์ฌ ๊ฐ๋
์ด ํผ์ฆ์์ ๋ฐฐ์ธ ๋ด์ฉ:
prefix_sum()์ ํ์ฉํ ํ๋์จ์ด ์ต์ ํ ๋ณ๋ ฌ ์ค์บ- ํฌํจ(inclusive) vs ๋นํฌํจ(exclusive) ๋์ ํฉ ํจํด
- ๋ฐ์ดํฐ ์ฌ๋ฐฐ์น๋ฅผ ์ํ ์ํ ๋ ๋ฒจ ์คํธ๋ฆผ ์ปดํฉ์
- ์ฌ๋ฌ ์ํ ๊ธฐ๋ณธ ์์๋ฅผ ๊ฒฐํฉํ ๊ณ ๊ธ ๋ณ๋ ฌ ํํฐ์ ๋
- ๋ณต์กํ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋์ฒดํ๋ ๋จ์ผ ์ํ ์๊ณ ๋ฆฌ์ฆ ์ต์ ํ
์ด๋ฅผ ํตํด ๋ค๋จ๊ณ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ์๊ณ ๋ฆฌ์ฆ์ด ์ฐ์ํ ๋จ์ผ ํจ์ ํธ์ถ๋ก ๋ณํ๋์ด, ๋ช ์์ ๋๊ธฐํ ์์ด ํจ์จ์ ์ธ ๋ณ๋ ฌ ์ค์บ ์ฐ์ฐ์ด ๊ฐ๋ฅํฉ๋๋ค.
1. ์ํ ํฌํจ ๋์ ํฉ
๊ตฌ์ฑ
- ๋ฒกํฐ ํฌ๊ธฐ:
SIZE = WARP_SIZE(GPU์ ๋ฐ๋ผ 32 ๋๋ 64) - ๊ทธ๋ฆฌ๋ ๊ตฌ์ฑ:
(1, 1)๊ทธ๋ฆฌ๋๋น ๋ธ๋ก ์ - ๋ธ๋ก ๊ตฌ์ฑ:
(WARP_SIZE, 1)๋ธ๋ก๋น ์ค๋ ๋ ์ - ๋ฐ์ดํฐ ํ์
:
DType.float32 - ๋ ์ด์์:
Layout.row_major(SIZE)(1D row-major)
prefix_sum์ ์ด์
๊ธฐ์กด ๋์ ํฉ์ ๋ณต์กํ ๋ค๋จ๊ณ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ์๊ณ ๋ฆฌ์ฆ์ด ํ์ํฉ๋๋ค. Puzzle 14: ๋์ ํฉ์์๋ ๋ช ์์ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ๋ก ์ด๋ฅผ ํ๋ค๊ฒ ๊ตฌํํ์ต๋๋ค:
fn prefix_sum_simple[
layout: Layout
](
output: LayoutTensor[dtype, layout, MutAnyOrigin],
a: LayoutTensor[dtype, layout, ImmutAnyOrigin],
size: UInt,
):
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
shared = LayoutTensor[
dtype,
Layout.row_major(TPB),
MutAnyOrigin,
address_space = AddressSpace.SHARED,
].stack_allocation()
if global_i < size:
shared[local_i] = a[global_i]
barrier()
offset = UInt(1)
for i in range(Int(log2(Scalar[dtype](TPB)))):
var current_val: output.element_type = 0
if local_i >= offset and local_i < size:
current_val = shared[local_i - offset] # read
barrier()
if local_i >= offset and local_i < size:
shared[local_i] += current_val
barrier()
offset *= 2
if global_i < size:
output[global_i] = shared[local_i]
๊ธฐ์กด ๋ฐฉ์์ ๋ฌธ์ ์ :
- ๋ฉ๋ชจ๋ฆฌ ์ค๋ฒํค๋: ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ํ ๋น์ด ํ์
- ๋ค์ค ๋ฐฐ๋ฆฌ์ด: ๋ณต์กํ ๋ค๋จ๊ณ ๋๊ธฐํ
- ๋ณต์กํ ์ธ๋ฑ์ฑ: ์๋ ์คํธ๋ผ์ด๋ ๊ณ์ฐ๊ณผ ๊ฒฝ๊ณ ๊ฒ์ฌ
- ๋ฎ์ ํ์ฅ์ฑ: ๊ฐ ๋จ๊ณ ์ฌ์ด์ ๋ฐฐ๋ฆฌ์ด๊ฐ ํ์ํ \(O(\log n)\) ๋จ๊ณ
prefix_sum()์ ์ฌ์ฉํ๋ฉด ๋ณ๋ ฌ ์ค์บ์ด ๊ฐ๋จํด์ง๋๋ค:
# ํ๋์จ์ด ์ต์ ํ ๋ฐฉ์ - ๋จ์ผ ํจ์ ํธ์ถ!
current_val = input[global_i]
scan_result = prefix_sum[exclusive=False](current_val)
output[global_i] = scan_result
prefix_sum์ ์ฅ์ :
- ๋ฉ๋ชจ๋ฆฌ ์ค๋ฒํค๋ ์ ๋ก: ํ๋์จ์ด ๊ฐ์ ์ฐ์ฐ
- ๋๊ธฐํ ๋ถํ์: ๋จ์ผ ์ํ ๋ฏน ์ฐ์ฐ
- ํ๋์จ์ด ์ต์ ํ: ์ ์ฉ ์ค์บ ์ ๋ ํ์ฉ
- ์๋ฒฝํ ํ์ฅ์ฑ: ๋ชจ๋
WARP_SIZE(32, 64 ๋ฑ)์์ ๋์
์์ฑํ ์ฝ๋
ํ๋์จ์ด ์ต์ ํ prefix_sum() ๊ธฐ๋ณธ ์์๋ฅผ ์ฌ์ฉํ์ฌ ํฌํจ ๋์ ํฉ์ ๊ตฌํํฉ๋๋ค.
์ํ์ ์ฐ์ฐ: ๊ฐ ๋ ์ธ์ด ์์ ์ ์์น๊น์ง ๋ชจ๋ ์์์ ํฉ์ ํฌํจํ๋ ๋์ ํฉ์ ๊ณ์ฐํฉ๋๋ค: \[\Large \text{output}[i] = \sum_{j=0}^{i} \text{input}[j]\]
์
๋ ฅ ๋ฐ์ดํฐ [1, 2, 3, 4, 5, ...]๋ฅผ ๋์ ํฉ [1, 3, 6, 10, 15, ...]์ผ๋ก ๋ณํํ๋ฉฐ, ๊ฐ ์์น์ ์ด์ ๋ชจ๋ ์์์ ์๊ธฐ ์์ ์ ํฉ์ด ๋ด๊น๋๋ค.
fn warp_inclusive_prefix_sum[
layout: Layout, size: Int
](
output: LayoutTensor[dtype, layout, MutAnyOrigin],
input: LayoutTensor[dtype, layout, ImmutAnyOrigin],
):
"""
Inclusive prefix sum using warp primitive:
Each thread gets sum of all elements up to and including its position.
Compare this to Puzzle 12's complex shared memory + barrier approach.
Puzzle 12 approach:
- Shared memory allocation
- Multiple barrier synchronizations
- Log(n) iterations with manual tree reduction
- Complex multi-phase algorithm
Warp prefix_sum approach:
- Single function call!
- Hardware-optimized parallel scan
- Automatic synchronization
- O(log n) complexity, but implemented in hardware.
NOTE: This implementation only works correctly within a single warp (WARP_SIZE threads).
For multi-warp scenarios, additional coordination would be needed.
"""
global_i = Int(block_dim.x * block_idx.x + thread_idx.x)
# FILL ME IN (roughly 4 lines)
์ ์ฒด ํ์ผ ๋ณด๊ธฐ: problems/p26/p26.mojo
ํ
1. prefix_sum ๋งค๊ฐ๋ณ์ ์ดํดํ๊ธฐ
prefix_sum() ํจ์์๋ ์ค์บ ์ ํ์ ์ ์ดํ๋ ์ค์ํ ํ
ํ๋ฆฟ ๋งค๊ฐ๋ณ์๊ฐ ์์ต๋๋ค.
ํต์ฌ ์ง๋ฌธ:
- ํฌํจ ๋์ ํฉ๊ณผ ๋นํฌํจ ๋์ ํฉ์ ์ฐจ์ด๋ ๋ฌด์์ธ๊ฐ์?
- ์ด๋ค ๋งค๊ฐ๋ณ์๊ฐ ์ด ๋์์ ์ ์ดํ๋์?
- ํฌํจ ์ค์บ์์ ๊ฐ ๋ ์ธ์ ๋ฌด์์ ์ถ๋ ฅํด์ผ ํ๋์?
ํํธ: ํจ์ ์๊ทธ๋์ฒ๋ฅผ ๋ณด๊ณ ๋์ ์ฐ์ฐ์์ โํฌํจ(inclusive)โ์ด ๋ฌด์์ ์๋ฏธํ๋์ง ์๊ฐํด ๋ณด์ธ์.
2. ๋จ์ผ ์ํ ์ ํ
์ด ํ๋์จ์ด ๊ธฐ๋ณธ ์์๋ ๋จ์ผ ์ํ ๋ด์์๋ง ๋์ํฉ๋๋ค. ์ด ์ ํ์ ์๋ฏธ๋ฅผ ์๊ฐํด ๋ณด์ธ์.
์๊ฐํด ๋ณด์ธ์:
- ์ฌ๋ฌ ์ํ๊ฐ ์์ผ๋ฉด ์ด๋ป๊ฒ ๋๋์?
- ์ด ์ ํ์ ์ดํดํ๋ ๊ฒ์ด ์ ์ค์ํ๊ฐ์?
- ๋ฉํฐ ์ํ ์๋๋ฆฌ์ค๋ก ํ์ฅํ๋ ค๋ฉด ์ด๋ป๊ฒ ํด์ผ ํ๋์?
3. ๋ฐ์ดํฐ ํ์ ๊ณ ๋ ค์ฌํญ
prefix_sum ํจ์๋ ์ต์ ์ฑ๋ฅ์ ์ํด ํน์ ๋ฐ์ดํฐ ํ์
์ ์๊ตฌํ ์ ์์ต๋๋ค.
๊ณ ๋ คํ ์ :
- ์ ๋ ฅ์ด ์ด๋ค ๋ฐ์ดํฐ ํ์ ์ ์ฌ์ฉํ๋์?
prefix_sum์ด ํน์ ์ค์นผ๋ผ ํ์ ์ ๊ธฐ๋ํ๋์?- ํ์ํ ๊ฒฝ์ฐ ํ์ ๋ณํ์ ์ด๋ป๊ฒ ์ฒ๋ฆฌํ๋์?
์ํ ํฌํจ ๋์ ํฉ ํ ์คํธ:
pixi run p26 --prefix-sum
pixi run -e amd p26 --prefix-sum
pixi run -e apple p26 --prefix-sum
uv run poe p26 --prefix-sum
ํ์์ ๋์ ์์ ์ถ๋ ฅ:
WARP_SIZE: 32
SIZE: 32
output: [1.0, 3.0, 6.0, 10.0, 15.0, 21.0, 28.0, 36.0, 45.0, 55.0, 66.0, 78.0, 91.0, 105.0, 120.0, 136.0, 153.0, 171.0, 190.0, 210.0, 231.0, 253.0, 276.0, 300.0, 325.0, 351.0, 378.0, 406.0, 435.0, 465.0, 496.0, 528.0]
expected: [1.0, 3.0, 6.0, 10.0, 15.0, 21.0, 28.0, 36.0, 45.0, 55.0, 66.0, 78.0, 91.0, 105.0, 120.0, 136.0, 153.0, 171.0, 190.0, 210.0, 231.0, 253.0, 276.0, 300.0, 325.0, 351.0, 378.0, 406.0, 435.0, 465.0, 496.0, 528.0]
โ
Warp inclusive prefix sum test passed!
์๋ฃจ์
fn warp_inclusive_prefix_sum[
layout: Layout, size: Int
](
output: LayoutTensor[dtype, layout, MutAnyOrigin],
input: LayoutTensor[dtype, layout, ImmutAnyOrigin],
):
"""
Inclusive prefix sum using warp primitive: Each thread gets sum of all elements up to and including its position.
Compare this to Puzzle 12's complex shared memory + barrier approach.
Puzzle 12 approach:
- Shared memory allocation
- Multiple barrier synchronizations
- Log(n) iterations with manual tree reduction
- Complex multi-phase algorithm
Warp prefix_sum approach:
- Single function call!
- Hardware-optimized parallel scan
- Automatic synchronization
- O(log n) complexity, but implemented in hardware.
NOTE: This implementation only works correctly within a single warp (WARP_SIZE threads).
For multi-warp scenarios, additional coordination would be needed.
"""
global_i = Int(block_dim.x * block_idx.x + thread_idx.x)
if global_i < size:
current_val = input[global_i]
# This one call replaces ~30 lines of complex shared memory logic from Puzzle 12!
# But it only works within the current warp (WARP_SIZE threads)
scan_result = prefix_sum[exclusive=False](
rebind[Scalar[dtype]](current_val)
)
output[global_i] = scan_result
์ด ์๋ฃจ์
์ prefix_sum()์ด ๋ณต์กํ ๋ค๋จ๊ณ ์๊ณ ๋ฆฌ์ฆ์ ํ๋์จ์ด ์ต์ ํ๋ ๋จ์ผ ํจ์ ํธ์ถ๋ก ์ด๋ป๊ฒ ๋์ฒดํ๋์ง ๋ณด์ฌ์ค๋๋ค.
์๊ณ ๋ฆฌ์ฆ ๋ถ์:
if global_i < size:
current_val = input[global_i]
# ์ด ํ ์ค์ด Puzzle 14์ ๋ณต์กํ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ๋ก์ง ~30์ค์ ๋์ฒดํฉ๋๋ค!
# ๋จ, ํ์ฌ ์ํ (WARP_SIZE ์ค๋ ๋) ๋ด์์๋ง ๋์ํฉ๋๋ค
scan_result = prefix_sum[exclusive=False](
rebind[Scalar[dtype]](current_val)
)
output[global_i] = scan_result
SIMT ์คํ ์์ธ ๋ถ์:
์
๋ ฅ: [1, 2, 3, 4, 5, 6, 7, 8, ...]
์ฌ์ดํด 1: ๋ชจ๋ ๋ ์ธ์ด ๋์์ ๊ฐ์ ๋ก๋
Lane 0: current_val = 1
Lane 1: current_val = 2
Lane 2: current_val = 3
Lane 3: current_val = 4
...
Lane 31: current_val = 32
์ฌ์ดํด 2: prefix_sum[exclusive=False] ์คํ (ํ๋์จ์ด ๊ฐ์)
Lane 0: scan_result = 1 (์์ 0~0์ ํฉ)
Lane 1: scan_result = 3 (์์ 0~1์ ํฉ: 1+2)
Lane 2: scan_result = 6 (์์ 0~2์ ํฉ: 1+2+3)
Lane 3: scan_result = 10 (์์ 0~3์ ํฉ: 1+2+3+4)
...
Lane 31: scan_result = 528 (์์ 0~31์ ํฉ)
์ฌ์ดํด 3: ๊ฒฐ๊ณผ ์ ์ฅ
Lane 0: output[0] = 1
Lane 1: output[1] = 3
Lane 2: output[2] = 6
Lane 3: output[3] = 10
...
์ํ์ ํต์ฐฐ: ํฌํจ ๋์ ํฉ ์ฐ์ฐ์ ๊ตฌํํฉ๋๋ค: \[\Large \text{output}[i] = \sum_{j=0}^{i} \text{input}[j]\]
Puzzle 14 ๋ฐฉ์๊ณผ์ ๋น๊ต:
- Puzzle 14: ๋์ ํฉ: ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ~30์ค + ๋ค์ค ๋ฐฐ๋ฆฌ์ด + ๋ณต์กํ ์ธ๋ฑ์ฑ
- ์ํ ๊ธฐ๋ณธ ์์: ํ๋์จ์ด ๊ฐ์์ ํจ์ ํธ์ถ 1๊ฐ
- ์ฑ๋ฅ: ๊ฐ์ \(O(\log n)\) ๋ณต์ก๋์ด์ง๋ง, ์ ์ฉ ํ๋์จ์ด์์ ๊ตฌํ
- ๋ฉ๋ชจ๋ฆฌ: ๋ช ์์ ํ ๋น ๋๋น ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋ ์ ๋ก
Puzzle 12์์์ ๋ฐ์ : ํ๋ GPU ์ํคํ
์ฒ์ ๊ฐ๋ ฅํจ์ ๋ณด์ฌ์ค๋๋ค - Puzzle 12์์ ์ ์คํ ์๋ ๊ตฌํ์ด ํ์ํ๋ ๊ฒ์ด ์ด์ ๋ ํ๋์จ์ด ๊ฐ์ ๊ธฐ๋ณธ ์์ ํ๋๋ก ํด๊ฒฐ๋ฉ๋๋ค. ์ํ ๋ ๋ฒจ prefix_sum()์ ๊ตฌํ ๋ณต์ก๋ ์ ๋ก๋ก ๊ฐ์ ์๊ณ ๋ฆฌ์ฆ์ ์ด์ ์ ์ ๊ณตํฉ๋๋ค.
prefix_sum์ด ์ฐ์ํ ์ด์ :
- ํ๋์จ์ด ๊ฐ์: ํ๋ GPU์ ์ ์ฉ ์ค์บ ์ ๋
- ๋ฉ๋ชจ๋ฆฌ ์ค๋ฒํค๋ ์ ๋ก: ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ํ ๋น ๋ถํ์
- ์๋ ๋๊ธฐํ: ๋ช ์์ ๋ฐฐ๋ฆฌ์ด ๋ถํ์
- ์๋ฒฝํ ํ์ฅ์ฑ: ๋ชจ๋
WARP_SIZE์์ ์ต์ ์ผ๋ก ๋์
์ฑ๋ฅ ํน์ฑ:
- ์ง์ฐ ์๊ฐ: ~1-2 ์ฌ์ดํด (ํ๋์จ์ด ์ค์บ ์ ๋)
- ๋์ญํญ: ๋ฉ๋ชจ๋ฆฌ ํธ๋ํฝ ์ ๋ก (๋ ์ง์คํฐ ์ ์ฉ ์ฐ์ฐ)
- ๋ณ๋ ฌ์ฑ:
WARP_SIZE๊ฐ ๋ ์ธ ๋ชจ๋ ๋์์ ์ฐธ์ฌ - ํ์ฅ์ฑ: ํ๋์จ์ด ์ต์ ํ๋ฅผ ๋๋ฐํ \(O(\log n)\) ๋ณต์ก๋
์ค์ํ ์ ํ์ฌํญ: ์ด ๊ธฐ๋ณธ ์์๋ ๋จ์ผ ์ํ ๋ด์์๋ง ๋์ํฉ๋๋ค. ๋ฉํฐ ์ํ ์๋๋ฆฌ์ค์์๋ ์ํ ๊ฐ ์ถ๊ฐ ์กฐ์ ์ด ํ์ํฉ๋๋ค.
2. ์ํ ํํฐ์
๊ตฌ์ฑ
- ๋ฒกํฐ ํฌ๊ธฐ:
SIZE = WARP_SIZE(GPU์ ๋ฐ๋ผ 32 ๋๋ 64) - ๊ทธ๋ฆฌ๋ ๊ตฌ์ฑ:
(1, 1)๊ทธ๋ฆฌ๋๋น ๋ธ๋ก ์ - ๋ธ๋ก ๊ตฌ์ฑ:
(WARP_SIZE, 1)๋ธ๋ก๋น ์ค๋ ๋ ์
์์ฑํ ์ฝ๋
shuffle_xor๊ณผ prefix_sum ๊ธฐ๋ณธ ์์๋ฅผ ๋ชจ๋ ์ฌ์ฉํ์ฌ ๋จ์ผ ์ํ ๋ณ๋ ฌ ํํฐ์
๋์ ๊ตฌํํฉ๋๋ค.
์ํ์ ์ฐ์ฐ: ํผ๋ฒ ๊ฐ์ ๊ธฐ์ค์ผ๋ก ์์๋ฅผ ๋ถํ ํ์ฌ, < pivot์ธ ์์๋ ์ผ์ชฝ์, >= pivot์ธ ์์๋ ์ค๋ฅธ์ชฝ์ ๋ฐฐ์นํฉ๋๋ค:
\[\Large \text{output} = [\text{elements} < \text{pivot}] \,|\, [\text{elements} \geq \text{pivot}]\]
๊ณ ๊ธ ์๊ณ ๋ฆฌ์ฆ: ์ด ์๊ณ ๋ฆฌ์ฆ์ ๋ ๊ฐ์ง ์ ๊ตํ ์ํ ๊ธฐ๋ณธ ์์๋ฅผ ๊ฒฐํฉํฉ๋๋ค:
shuffle_xor(): ์ผ์ชฝ ์์ ๊ฐ์๋ฅผ ์ธ๊ธฐ ์ํ ์ํ ๋ ๋ฒจ ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์ prefix_sum(): ๊ฐ ํํฐ์ ๋ด ์์น ๊ณ์ฐ์ ์ํ ๋นํฌํจ ์ค์บ
์ด๋ ๋จ์ผ ์ํ ๋ด์์ ์ฌ๋ฌ ์ํ ๊ธฐ๋ณธ ์์๋ฅผ ๊ฒฐํฉํ์ฌ ๋ณต์กํ ๋ณ๋ ฌ ์๊ณ ๋ฆฌ์ฆ์ ๊ตฌํํ๋ ๊ฐ๋ ฅํจ์ ๋ณด์ฌ์ค๋๋ค.
fn warp_partition[
layout: Layout, size: Int
](
output: LayoutTensor[dtype, layout, MutAnyOrigin],
input: LayoutTensor[dtype, layout, ImmutAnyOrigin],
pivot: Float32,
):
"""
Single-warp parallel partitioning using BOTH shuffle_xor AND prefix_sum.
This implements a warp-level quicksort partition step that places elements < pivot
on the left and elements >= pivot on the right.
ALGORITHM COMPLEXITY - combines two advanced warp primitives:
1. shuffle_xor(): Butterfly pattern for warp-level reductions
2. prefix_sum(): Warp-level exclusive scan for position calculation.
This demonstrates the power of warp primitives for sophisticated parallel algorithms
within a single warp (works for any WARP_SIZE: 32, 64, etc.).
Example with pivot=5:
Input: [3, 7, 1, 8, 2, 9, 4, 6]
Result: [3, 1, 2, 4, 7, 8, 9, 6] (< pivot | >= pivot).
"""
global_i = Int(block_dim.x * block_idx.x + thread_idx.x)
if global_i < size:
current_val = input[global_i]
# FILL ME IN (roughly 13 lines)
ํ
1. ๋ค๋จ๊ณ ์๊ณ ๋ฆฌ์ฆ ๊ตฌ์กฐ
์ด ์๊ณ ๋ฆฌ์ฆ์ ์ฌ๋ฌ ์กฐ์ ๋ ๋จ๊ณ๊ฐ ํ์ํฉ๋๋ค. ํํฐ์ ๋์ ํ์ํ ๋ ผ๋ฆฌ์ ๋จ๊ณ๋ฅผ ์๊ฐํด ๋ณด์ธ์.
๊ณ ๋ คํ ํต์ฌ ๋จ๊ณ:
- ์ด๋ค ์์๊ฐ ์ด๋ ํํฐ์ ์ ์ํ๋์ง ์ด๋ป๊ฒ ์๋ณํ๋์?
- ๊ฐ ํํฐ์ ๋ด์์ ์์น๋ฅผ ์ด๋ป๊ฒ ๊ณ์ฐํ๋์?
- ์ผ์ชฝ ํํฐ์ ์ ์ ์ฒด ํฌ๊ธฐ๋ฅผ ์ด๋ป๊ฒ ์ ์ ์๋์?
- ์ต์ข ์์น์ ์์๋ฅผ ์ด๋ป๊ฒ ๊ธฐ๋กํ๋์?
2. ํ๋ ๋์ผ์ดํธ ์์ฑ
์ด๋ ํํฐ์ ์ ์ํ๋์ง ํ๋ณํ๋ ๋ถ๋ฆฌ์ธ ํ๋ ๋์ผ์ดํธ๋ฅผ ๋ง๋ค์ด์ผ ํฉ๋๋ค.
์๊ฐํด ๋ณด์ธ์:
- โ์ด ์์๋ ์ผ์ชฝ ํํฐ์ ์ ์ํ๋คโ๋ฅผ ์ด๋ป๊ฒ ํํํ๋์?
- โ์ด ์์๋ ์ค๋ฅธ์ชฝ ํํฐ์ ์ ์ํ๋คโ๋ฅผ ์ด๋ป๊ฒ ํํํ๋์?
prefix_sum์ ์ ๋ฌํ ํ๋ ๋์ผ์ดํธ๋ ์ด๋ค ๋ฐ์ดํฐ ํ์ ์ด์ด์ผ ํ๋์?
3. shuffle_xor๊ณผ prefix_sum ๊ฒฐํฉ
์ด ์๊ณ ๋ฆฌ์ฆ์ ๋ ์ํ ๊ธฐ๋ณธ ์์๋ฅผ ์๋ก ๋ค๋ฅธ ๋ชฉ์ ์ผ๋ก ์ฌ์ฉํฉ๋๋ค.
๊ณ ๋ คํ ์ :
- ์ด ๋งฅ๋ฝ์์
shuffle_xor์ ๋ฌด์์ ์ฌ์ฉ๋๋์? - ์ด ๋งฅ๋ฝ์์
prefix_sum์ ๋ฌด์์ ์ฌ์ฉ๋๋์? - ์ด ๋ ์ฐ์ฐ์ด ์ด๋ป๊ฒ ํจ๊ป ๋์ํ๋์?
4. ์์น ๊ณ์ฐ
๊ฐ์ฅ ๊น๋ค๋ก์ด ๋ถ๋ถ์ ๊ฐ ์์๊ฐ ์ถ๋ ฅ์์ ์ด๋์ ๊ธฐ๋ก๋์ด์ผ ํ๋์ง ๊ณ์ฐํ๋ ๊ฒ์ ๋๋ค.
ํต์ฌ ํต์ฐฐ:
- ์ผ์ชฝ ํํฐ์ ์์: ์ต์ข ์์น๋ฅผ ๋ฌด์์ด ๊ฒฐ์ ํ๋์?
- ์ค๋ฅธ์ชฝ ํํฐ์ ์์: ์คํ์ ์ ์ด๋ป๊ฒ ์ฌ๋ฐ๋ฅด๊ฒ ์ ์ฉํ๋์?
- ๋ก์ปฌ ์์น์ ํํฐ์ ๊ฒฝ๊ณ๋ฅผ ์ด๋ป๊ฒ ๊ฒฐํฉํ๋์?
์ํ ํํฐ์ ํ ์คํธ:
uv run poe p26 --partition
pixi run p26 --partition
ํ์์ ๋์ ์์ ์ถ๋ ฅ:
WARP_SIZE: 32
SIZE: 32
output: HostBuffer([3.0, 1.0, 2.0, 4.0, 0.0, 3.0, 1.0, 4.0, 3.0, 1.0, 2.0, 4.0, 0.0, 3.0, 1.0, 4.0, 7.0, 8.0, 9.0, 6.0, 10.0, 11.0, 12.0, 13.0, 7.0, 8.0, 9.0, 6.0, 10.0, 11.0, 12.0, 13.0])
expected: HostBuffer([3.0, 1.0, 2.0, 4.0, 0.0, 3.0, 1.0, 4.0, 3.0, 1.0, 2.0, 4.0, 0.0, 3.0, 1.0, 4.0, 7.0, 8.0, 9.0, 6.0, 10.0, 11.0, 12.0, 13.0, 7.0, 8.0, 9.0, 6.0, 10.0, 11.0, 12.0, 13.0])
pivot: 5.0
โ
Warp partition test passed!
์๋ฃจ์
fn warp_partition[
layout: Layout, size: Int
](
output: LayoutTensor[dtype, layout, MutAnyOrigin],
input: LayoutTensor[dtype, layout, ImmutAnyOrigin],
pivot: Float32,
):
"""
Single-warp parallel partitioning using BOTH shuffle_xor AND prefix_sum.
This implements a warp-level quicksort partition step that places elements < pivot
on the left and elements >= pivot on the right.
ALGORITHM COMPLEXITY - combines two advanced warp primitives:
1. shuffle_xor(): Butterfly pattern for warp-level reductions
2. prefix_sum(): Warp-level exclusive scan for position calculation.
This demonstrates the power of warp primitives for sophisticated parallel algorithms
within a single warp (works for any WARP_SIZE: 32, 64, etc.).
Example with pivot=5:
Input: [3, 7, 1, 8, 2, 9, 4, 6]
Result: [3, 1, 2, 4, 7, 8, 9, 6] (< pivot | >= pivot).
"""
global_i = Int(block_dim.x * block_idx.x + thread_idx.x)
if global_i < size:
current_val = input[global_i]
# Phase 1: Create warp-level predicates
predicate_left = Float32(1.0) if current_val < pivot else Float32(0.0)
predicate_right = Float32(1.0) if current_val >= pivot else Float32(0.0)
# Phase 2: Warp-level prefix sum to get positions within warp
warp_left_pos = prefix_sum[exclusive=True](predicate_left)
warp_right_pos = prefix_sum[exclusive=True](predicate_right)
# Phase 3: Get total left count using shuffle_xor reduction
warp_left_total = predicate_left
# Butterfly reduction to get total across the warp: dynamic for any WARP_SIZE
offset = WARP_SIZE // 2
while offset > 0:
warp_left_total += shuffle_xor(warp_left_total, offset)
offset //= 2
# Phase 4: Write to output positions
if current_val < pivot:
# Left partition: use warp-level position
output[Int(warp_left_pos)] = current_val
else:
# Right partition: offset by total left count + right position
output[Int(warp_left_total + warp_right_pos)] = current_val
์ด ์๋ฃจ์ ์ ์ฌ๋ฌ ์ํ ๊ธฐ๋ณธ ์์ ๊ฐ์ ๊ณ ๊ธ ์กฐ์ ์ ํตํด ์ ๊ตํ ๋ณ๋ ฌ ์๊ณ ๋ฆฌ์ฆ์ ๊ตฌํํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค.
์ ์ฒด ์๊ณ ๋ฆฌ์ฆ ๋ถ์:
if global_i < size:
current_val = input[global_i]
# 1๋จ๊ณ: ์ํ ๋ ๋ฒจ ํ๋ ๋์ผ์ดํธ ์์ฑ
predicate_left = Float32(1.0) if current_val < pivot else Float32(0.0)
predicate_right = Float32(1.0) if current_val >= pivot else Float32(0.0)
# 2๋จ๊ณ: ์ํ ๋ ๋ฒจ ๋์ ํฉ์ผ๋ก ์ํ ๋ด ์์น ๊ณ์ฐ
warp_left_pos = prefix_sum[exclusive=True](predicate_left)
warp_right_pos = prefix_sum[exclusive=True](predicate_right)
# 3๋จ๊ณ: shuffle_xor ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์
์ผ๋ก ์ผ์ชฝ ์ด ๊ฐ์ ๊ตฌํ๊ธฐ
warp_left_total = predicate_left
# ์ํ ์ ์ฒด์ ํฉ์ฐ์ ์ํ ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์
: ๋ชจ๋ WARP_SIZE์ ๋์ ๋์
offset = WARP_SIZE // 2
while offset > 0:
warp_left_total += shuffle_xor(warp_left_total, offset)
offset //= 2
# 4๋จ๊ณ: ์ถ๋ ฅ ์์น์ ๊ธฐ๋ก
if current_val < pivot:
# ์ผ์ชฝ ํํฐ์
: ์ํ ๋ ๋ฒจ ์์น ์ฌ์ฉ
output[Int(warp_left_pos)] = current_val
else:
# ์ค๋ฅธ์ชฝ ํํฐ์
: ์ผ์ชฝ ์ด ๊ฐ์ + ์ค๋ฅธ์ชฝ ์์น๋ก offset
output[Int(warp_left_total + warp_right_pos)] = current_val
๋ค๋จ๊ณ ์คํ ์ถ์ (8-๋ ์ธ ์์ , pivot=5, ๊ฐ [3,7,1,8,2,9,4,6]):
์ด๊ธฐ ์ํ:
Lane 0: current_val=3 (< 5) Lane 1: current_val=7 (>= 5)
Lane 2: current_val=1 (< 5) Lane 3: current_val=8 (>= 5)
Lane 4: current_val=2 (< 5) Lane 5: current_val=9 (>= 5)
Lane 6: current_val=4 (< 5) Lane 7: current_val=6 (>= 5)
1๋จ๊ณ: ํ๋ ๋์ผ์ดํธ ์์ฑ
Lane 0: predicate_left=1.0, predicate_right=0.0
Lane 1: predicate_left=0.0, predicate_right=1.0
Lane 2: predicate_left=1.0, predicate_right=0.0
Lane 3: predicate_left=0.0, predicate_right=1.0
Lane 4: predicate_left=1.0, predicate_right=0.0
Lane 5: predicate_left=0.0, predicate_right=1.0
Lane 6: predicate_left=1.0, predicate_right=0.0
Lane 7: predicate_left=0.0, predicate_right=1.0
2๋จ๊ณ: ์์น ๊ณ์ฐ์ ์ํ ๋นํฌํจ ๋์ ํฉ
warp_left_pos: [0, 0, 1, 1, 2, 2, 3, 3]
warp_right_pos: [0, 0, 0, 1, 1, 2, 2, 3]
3๋จ๊ณ: ์ผ์ชฝ ์ด ๊ฐ์๋ฅผ ์ํ ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์
์ด๊ธฐ๊ฐ: [1, 0, 1, 0, 1, 0, 1, 0]
๋ฆฌ๋์
ํ: ๋ชจ๋ ๋ ์ธ์ด warp_left_total = 4๋ฅผ ๊ฐ์ง
4๋จ๊ณ: ์ถ๋ ฅ ์์น์ ๊ธฐ๋ก
Lane 0: current_val=3 < pivot โ output[0] = 3
Lane 1: current_val=7 >= pivot โ output[4+0] = output[4] = 7
Lane 2: current_val=1 < pivot โ output[1] = 1
Lane 3: current_val=8 >= pivot โ output[4+1] = output[5] = 8
Lane 4: current_val=2 < pivot โ output[2] = 2
Lane 5: current_val=9 >= pivot โ output[4+2] = output[6] = 9
Lane 6: current_val=4 < pivot โ output[3] = 4
Lane 7: current_val=6 >= pivot โ output[4+3] = output[7] = 6
์ต์ข
๊ฒฐ๊ณผ: [3, 1, 2, 4, 7, 8, 9, 6] (< pivot | >= pivot)
์ํ์ ํต์ฐฐ: ์ด์ค ์ํ ๊ธฐ๋ณธ ์์๋ฅผ ์ฌ์ฉํ ๋ณ๋ ฌ ํํฐ์ ๋์ ๊ตฌํํฉ๋๋ค: \[\Large \begin{align} \text{left_pos}[i] &= \text{prefix_sum}_{\text{exclusive}}(\text{predicate_left}[i]) \\ \text{right_pos}[i] &= \text{prefix_sum}_{\text{exclusive}}(\text{predicate_right}[i]) \\ \text{left_total} &= \text{butterfly_reduce}(\text{predicate_left}) \\ \text{final_pos}[i] &= \begin{cases} \text{left_pos}[i] & \text{if } \text{input}[i] < \text{pivot} \\ \text{left_total} + \text{right_pos}[i] & \text{if } \text{input}[i] \geq \text{pivot} \end{cases} \end{align}\]
๋ค์ค ๊ธฐ๋ณธ ์์ ์ ๊ทผ ๋ฐฉ์์ด ๋์ํ๋ ์ด์ :
- ํ๋ ๋์ผ์ดํธ ์์ฑ: ๊ฐ ์์์ ํํฐ์ ์์์ ์๋ณ
- ๋นํฌํจ ๋์ ํฉ: ๊ฐ ํํฐ์ ๋ด ์๋์ ์์น๋ฅผ ๊ณ์ฐ
- ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์ : ํํฐ์ ๊ฒฝ๊ณ (์ผ์ชฝ ์ด ๊ฐ์)๋ฅผ ์ฐ์ถ
- ์กฐ์ ๋ ๊ธฐ๋ก: ๋ก์ปฌ ์์น์ ์ ์ญ ํํฐ์ ๊ตฌ์กฐ๋ฅผ ๊ฒฐํฉ
์๊ณ ๋ฆฌ์ฆ ๋ณต์ก๋:
- 1๋จ๊ณ: \(O(1)\) - ํ๋ ๋์ผ์ดํธ ์์ฑ
- 2๋จ๊ณ: \(O(\log n)\) - ํ๋์จ์ด ๊ฐ์ ๋์ ํฉ
- 3๋จ๊ณ: \(O(\log n)\) -
shuffle_xor์ ํ์ฉํ ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์ - 4๋จ๊ณ: \(O(1)\) - ์กฐ์ ๋ ๊ธฐ๋ก
- ์ ์ฒด: ์ฐ์ํ ์์๋ฅผ ๊ฐ์ง \(O(\log n)\)
์ฑ๋ฅ ํน์ฑ:
- ํต์ ๋จ๊ณ: \(2 \times \log_2(\text{WARP_SIZE})\) (๋์ ํฉ + ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์ )
- ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ: ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ์ ๋ก, ๋ชจ๋ ๋ ์ง์คํฐ ๊ธฐ๋ฐ
- ๋ณ๋ ฌ์ฑ: ์๊ณ ๋ฆฌ์ฆ ์ ์ฒด์์ ๋ชจ๋ ๋ ์ธ์ด ํ์ฑ ์ํ
- ํ์ฅ์ฑ: ๋ชจ๋
WARP_SIZE(32, 64 ๋ฑ)์์ ๋์
์ค์ฉ์ ํ์ฉ: ์ด ํจํด์ ๊ธฐ๋ฐ์ด ๋๋ ๋ถ์ผ:
- Quicksort ํํฐ์ ๋: ๋ณ๋ ฌ ์ ๋ ฌ ์๊ณ ๋ฆฌ์ฆ์ ํต์ฌ ๋จ๊ณ
- ์คํธ๋ฆผ ์ปดํฉ์ : ๋ฐ์ดํฐ ์คํธ๋ฆผ์์ null/๋ฌดํจ ์์ ์ ๊ฑฐ
- ๋ณ๋ ฌ ํํฐ๋ง: ๋ณต์กํ ํ๋ ๋์ผ์ดํธ์ ๋ฐ๋ฅธ ๋ฐ์ดํฐ ๋ถ๋ฆฌ
- ๋ถํ ๋ถ์ฐ: ์ฐ์ฐ ์๊ตฌ๋์ ๋ฐ๋ฅธ ์์ ์ฌ๋ถ๋ฐฐ
์์ฝ
prefix_sum() ๊ธฐ๋ณธ ์์๋ ๋ณต์กํ ๋ค๋จ๊ณ ์๊ณ ๋ฆฌ์ฆ์ ๋จ์ผ ํจ์ ํธ์ถ๋ก ๋์ฒดํ๋ ํ๋์จ์ด ๊ฐ์ ๋ณ๋ ฌ ์ค์บ ์ฐ์ฐ์ ๊ฐ๋ฅํ๊ฒ ํฉ๋๋ค. ๋ ๊ฐ์ง ๋ฌธ์ ๋ฅผ ํตํด ๋ค์์ ๋ฐฐ์ ์ต๋๋ค:
ํต์ฌ ๋์ ํฉ ํจํด
-
ํฌํจ ๋์ ํฉ (
prefix_sum[exclusive=False]):- ํ๋์จ์ด ๊ฐ์ ๋์ ์ฐ์ฐ
- ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ์ฝ๋ ~30์ค์ ๋จ์ผ ํจ์ ํธ์ถ๋ก ๋์ฒด
- ์ ์ฉ ํ๋์จ์ด ์ต์ ํ๋ฅผ ๋๋ฐํ \(O(\log n)\) ๋ณต์ก๋
-
๊ณ ๊ธ ๋ค์ค ๊ธฐ๋ณธ ์์ ์กฐ์ (
prefix_sum+shuffle_xor๊ฒฐํฉ):- ๋จ์ผ ์ํ ๋ด ์ ๊ตํ ๋ณ๋ ฌ ์๊ณ ๋ฆฌ์ฆ
- ์์น ๊ณ์ฐ์ ์ํ ๋นํฌํจ ์ค์บ + ์ดํฉ์ ์ํ ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์
- ์ต์ ์ ๋ณ๋ ฌ ํจ์จ์ฑ์ ๊ฐ์ง ๋ณต์กํ ํํฐ์ ๋ ์ฐ์ฐ
ํต์ฌ ์๊ณ ๋ฆฌ์ฆ ํต์ฐฐ
ํ๋์จ์ด ๊ฐ์์ ์ด์ :
prefix_sum()์ด ํ๋ GPU์ ์ ์ฉ ์ค์บ ์ ๋์ ํ์ฉ- ๊ธฐ์กด ๋ฐฉ์ ๋๋น ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ์ค๋ฒํค๋ ์ ๋ก
- ๋ช ์์ ๋ฐฐ๋ฆฌ์ด ์๋ ์๋ ๋๊ธฐํ
๋ค์ค ๊ธฐ๋ณธ ์์ ์กฐ์ :
# 1๋จ๊ณ: ํํฐ์
์์์ ์ํ ํ๋ ๋์ผ์ดํธ ์์ฑ
predicate = 1.0 if condition else 0.0
# 2๋จ๊ณ: ๋ก์ปฌ ์์น๋ฅผ ์ํ prefix_sum ์ฌ์ฉ
local_pos = prefix_sum[exclusive=True](predicate)
# 3๋จ๊ณ: ์ ์ญ ์ดํฉ์ ์ํ shuffle_xor ์ฌ์ฉ
global_total = butterfly_reduce(predicate)
# 4๋จ๊ณ: ์ต์ข
์์น ๊ฒฐ์ ์ ์ํ ๊ฒฐํฉ
final_pos = local_pos + partition_offset
์ฑ๋ฅ ์ด์ :
- ํ๋์จ์ด ์ต์ ํ: ์ํํธ์จ์ด ๊ตฌํ ๋๋น ์ ์ฉ ์ค์บ ์ ๋
- ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ: ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ํ ๋น ๋๋น ๋ ์ง์คํฐ ์ ์ฉ ์ฐ์ฐ
- ํ์ฅ ๊ฐ๋ฅํ ๋ณต์ก๋: ํ๋์จ์ด ๊ฐ์์ ๋๋ฐํ \(O(\log n)\)
- ๋จ์ผ ์ํ ์ต์ ํ:
WARP_SIZEํ๋ ๋ด ์๊ณ ๋ฆฌ์ฆ์ ์ต์
์ค์ฉ์ ํ์ฉ
์ด ๋์ ํฉ ํจํด๋ค์ ๊ธฐ๋ฐ์ด ๋๋ ๋ถ์ผ:
- ๋ณ๋ ฌ ์ค์บ ์ฐ์ฐ: ๋์ ํฉ, ๋์ ๊ณฑ, min/max ์ค์บ
- ์คํธ๋ฆผ ์ปดํฉ์ : ๋ณ๋ ฌ ํํฐ๋ง๊ณผ ๋ฐ์ดํฐ ์ฌ๋ฐฐ์น
- Quicksort ํํฐ์ ๋: ๋ณ๋ ฌ ์ ๋ ฌ ์๊ณ ๋ฆฌ์ฆ์ ํต์ฌ ๋น๋ฉ ๋ธ๋ก
- ๋ณ๋ ฌ ์๊ณ ๋ฆฌ์ฆ: ๋ถํ ๋ถ์ฐ, ์์ ๋ถ๋ฐฐ, ๋ฐ์ดํฐ ์ฌ๊ตฌ์กฐํ
prefix_sum()๊ณผ shuffle_xor()์ ๊ฒฐํฉ์ ํ๋ GPU ์ํ ๊ธฐ๋ณธ ์์๊ฐ ์ต์ํ์ ์ฝ๋ ๋ณต์ก๋์ ์ต์ ์ ์ฑ๋ฅ ํน์ฑ์ผ๋ก ์ ๊ตํ ๋ณ๋ ฌ ์๊ณ ๋ฆฌ์ฆ์ ์ด๋ป๊ฒ ๊ตฌํํ ์ ์๋์ง๋ฅผ ๋ณด์ฌ์ค๋๋ค.