warp.shuffle_xor() ๋ฒํฐํ๋ผ์ด ํต์
์ํ ๋ ๋ฒจ ๋ฒํฐํ๋ผ์ด ํต์ ์์๋ shuffle_xor()์ ์ฌ์ฉํ์ฌ ์ํ ๋ด์ ์ ๊ตํ ํธ๋ฆฌ ๊ธฐ๋ฐ ํต์ ํจํด์ ๊ตฌ์ฑํ ์ ์์ต๋๋ค. ์ด ๊ฐ๋ ฅํ ๊ธฐ๋ณธ ์์๋ฅผ ํตํด ๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ ๋ช
์์ ๋๊ธฐํ ์์ด ํจ์จ์ ์ธ ๋ณ๋ ฌ ๋ฆฌ๋์
, ์ ๋ ฌ ๋คํธ์ํฌ, ๊ณ ๊ธ ์กฐ์ ์๊ณ ๋ฆฌ์ฆ์ ๊ตฌํํ ์ ์์ต๋๋ค.
ํต์ฌ ํต์ฐฐ: shuffle_xor() ์ฐ์ฐ์ SIMT ์คํ์ ํ์ฉํ์ฌ XOR ๊ธฐ๋ฐ ํต์ ํธ๋ฆฌ๋ฅผ ์์ฑํ๋ฉฐ, ์ํ ํฌ๊ธฐ์ ๋ํด \(O(\log n)\) ๋ณต์ก๋๋ก ํ์ฅ๋๋ ํจ์จ์ ์ธ ๋ฒํฐํ๋ผ์ด ๋คํธ์ํฌ์ ๋ณ๋ ฌ ์๊ณ ๋ฆฌ์ฆ์ ๊ฐ๋ฅํ๊ฒ ํฉ๋๋ค.
๋ฒํฐํ๋ผ์ด ๋คํธ์ํฌ๋? ๋ฒํฐํ๋ผ์ด ๋คํธ์ํฌ๋ ์ค๋ ๋๋ค์ด ์ธ๋ฑ์ค์ XOR ํจํด์ ๋ฐ๋ผ ๋ฐ์ดํฐ๋ฅผ ๊ตํํ๋ ํต์ ํ ํด๋ก์ง์ ๋๋ค. ์ด๋ฆ์ ์๊ฐ์ ์ผ๋ก ๊ทธ๋ ธ์ ๋ ๋๋น ๋ ๊ฐ์ฒ๋ผ ๋ณด์ด๋ ์ฐ๊ฒฐ ํจํด์์ ์ ๋ํ์ต๋๋ค. ์ด ๋คํธ์ํฌ๋ \(O(\log n)\) ํต์ ๋ณต์ก๋๋ฅผ ๊ฐ๋ฅํ๊ฒ ํ๊ธฐ ๋๋ฌธ์ FFT, bitonic ์ ๋ ฌ, ๋ณ๋ ฌ ๋ฆฌ๋์ ๊ฐ์ ๋ณ๋ ฌ ์๊ณ ๋ฆฌ์ฆ์ ๊ธฐ๋ฐ์ด ๋ฉ๋๋ค.
ํต์ฌ ๊ฐ๋
์ด ํผ์ฆ์์ ๋ฐฐ์ธ ๋ด์ฉ:
shuffle_xor()์ ํ์ฉํ XOR ๊ธฐ๋ฐ ํต์ ํจํด- ๋ณ๋ ฌ ์๊ณ ๋ฆฌ์ฆ์ ์ํ ๋ฒํฐํ๋ผ์ด ๋คํธ์ํฌ ํ ํด๋ก์ง
- \(O(\log n)\) ๋ณต์ก๋์ ํธ๋ฆฌ ๊ธฐ๋ฐ ๋ณ๋ ฌ ๋ฆฌ๋์
- ๊ณ ๊ธ ์กฐ์ ์ ์ํ ์กฐ๊ฑด๋ถ ๋ฒํฐํ๋ผ์ด ์ฐ์ฐ
- ๋ณต์กํ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋์ฒดํ๋ ํ๋์จ์ด ์ต์ ํ ๋ณ๋ ฌ ๊ธฐ๋ณธ ์์
shuffle_xor ์ฐ์ฐ์ ๊ฐ ๋ ์ธ์ด XOR ํจํด์ ๋ฐ๋ผ ๋ค๋ฅธ ๋ ์ธ๊ณผ ๋ฐ์ดํฐ๋ฅผ ๊ตํํ ์ ์๊ฒ ํฉ๋๋ค:
\[\Large \text{shuffle_xor}(\text{value}, \text{mask}) = \text{value_from_lane}(\text{lane_id} \oplus \text{mask})\]
์ด๋ฅผ ํตํด ๋ณต์กํ ๋ณ๋ ฌ ์๊ณ ๋ฆฌ์ฆ์ด ์ฐ์ํ ๋ฒํฐํ๋ผ์ด ํต์ ํจํด์ผ๋ก ๋ณํ๋์ด, ๋ช ์์ ์กฐ์ ์์ด ํจ์จ์ ์ธ ํธ๋ฆฌ ๋ฆฌ๋์ ๊ณผ ์ ๋ ฌ ๋คํธ์ํฌ๊ฐ ๊ฐ๋ฅํฉ๋๋ค.
1. ๊ธฐ๋ณธ ๋ฒํฐํ๋ผ์ด ํ์ด ๊ตํ
๊ตฌ์ฑ
- ๋ฒกํฐ ํฌ๊ธฐ:
SIZE = WARP_SIZE(GPU์ ๋ฐ๋ผ 32 ๋๋ 64) - ๊ทธ๋ฆฌ๋ ๊ตฌ์ฑ:
(1, 1)๊ทธ๋ฆฌ๋๋น ๋ธ๋ก ์ - ๋ธ๋ก ๊ตฌ์ฑ:
(WARP_SIZE, 1)๋ธ๋ก๋น ์ค๋ ๋ ์ - ๋ฐ์ดํฐ ํ์
:
DType.float32 - ๋ ์ด์์:
Layout.row_major(SIZE)(1D row-major)
shuffle_xor ๊ฐ๋
๊ธฐ์กด ํ์ด ๊ตํ ๋ฐฉ์์ ๋ณต์กํ ์ธ๋ฑ์ฑ๊ณผ ์กฐ์ ์ด ํ์ํฉ๋๋ค:
# ๊ธฐ์กด ๋ฐฉ์ - ๋ณต์กํ๊ณ ๋๊ธฐํ๊ฐ ํ์
shared_memory[lane] = input[global_i]
barrier()
if lane % 2 == 0:
partner = lane + 1
else:
partner = lane - 1
if partner < WARP_SIZE:
swapped_val = shared_memory[partner]
๊ธฐ์กด ๋ฐฉ์์ ๋ฌธ์ ์ :
- ๋ฉ๋ชจ๋ฆฌ ์ค๋ฒํค๋: ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ํ ๋น์ด ํ์
- ๋๊ธฐํ: ๋ช ์์ ๋ฐฐ๋ฆฌ์ด๊ฐ ํ์
- ๋ณต์กํ ๋ก์ง: ์๋ ํํธ๋ ๊ณ์ฐ๊ณผ ๊ฒฝ๊ณ ๊ฒ์ฌ
- ๋ฎ์ ํ์ฅ์ฑ: ํ๋์จ์ด ํต์ ์ ํ์ฉํ์ง ๋ชปํจ
shuffle_xor()์ ์ฌ์ฉํ๋ฉด ํ์ด ๊ตํ์ด ์ฐ์ํด์ง๋๋ค:
# ๋ฒํฐํ๋ผ์ด XOR ๋ฐฉ์ - ๊ฐ๋จํ๊ณ ํ๋์จ์ด ์ต์ ํ
current_val = input[global_i]
swapped_val = shuffle_xor(current_val, 1) # 1๊ณผ XORํ๋ฉด ํ์ด๊ฐ ์์ฑ๋จ
output[global_i] = swapped_val
shuffle_xor์ ์ฅ์ :
- ๋ฉ๋ชจ๋ฆฌ ์ค๋ฒํค๋ ์ ๋ก: ๋ ์ง์คํฐ ๊ฐ ์ง์ ํต์
- ๋๊ธฐํ ๋ถํ์: SIMT ์คํ์ด ์ ํ์ฑ์ ๋ณด์ฅ
- ํ๋์จ์ด ์ต์ ํ: ๋ชจ๋ ๋ ์ธ์ ๋ํด ๋จ์ผ ๋ช ๋ น์ผ๋ก ์ฒ๋ฆฌ
- ๋ฒํฐํ๋ผ์ด ๊ธฐ๋ฐ: ๋ณต์กํ ๋ณ๋ ฌ ์๊ณ ๋ฆฌ์ฆ์ ๋น๋ฉ ๋ธ๋ก
์์ฑํ ์ฝ๋
shuffle_xor()์ ์ฌ์ฉํ์ฌ ์ธ์ ํ์ด ๊ฐ ๊ฐ์ ๊ตํํ๋ ํ์ด ๊ตํ์ ๊ตฌํํฉ๋๋ค.
์ํ์ ์ฐ์ฐ: XOR ํจํด์ผ๋ก ์ธ์ ํ์ด๋ฅผ ๋ง๋ค์ด ๊ฐ์ ๊ตํํฉ๋๋ค: \[\Large \text{output}[i] = \text{input}[i \oplus 1]\]
์
๋ ฅ ๋ฐ์ดํฐ [0, 1, 2, 3, 4, 5, 6, 7, ...]์ ํ์ด [1, 0, 3, 2, 5, 4, 7, 6, ...]์ผ๋ก ๋ณํํ๋ฉฐ, ๊ฐ ํ์ด (i, i+1)์ด XOR ํต์ ์ผ๋ก ๊ฐ์ ๊ตํํฉ๋๋ค.
์ ์ฒด ํ์ผ ๋ณด๊ธฐ: problems/p26/p26.mojo
ํ
1. shuffle_xor ์ดํดํ๊ธฐ
shuffle_xor(value, mask) ์ฐ์ฐ์ ๊ฐ ๋ ์ธ์ด XOR ๋ง์คํฌ๋งํผ ์ฐจ์ด๋๋ ๋ ์ธ๊ณผ ๋ฐ์ดํฐ๋ฅผ ๊ตํํ ์ ์๊ฒ ํฉ๋๋ค. ์๋ก ๋ค๋ฅธ ๋ง์คํฌ ๊ฐ์ผ๋ก ๋ ์ธ ID๋ฅผ XORํ์ ๋ ์ด๋ค ์ผ์ด ์ผ์ด๋๋์ง ์๊ฐํด ๋ณด์ธ์.
ํ๊ตฌํ ํต์ฌ ์ง๋ฌธ:
- ๋ ์ธ 0์ด ๋ง์คํฌ 1๋ก XORํ๋ฉด ์ด๋ค ํํธ๋๋ฅผ ์ป๋์?
- ๋ ์ธ 1์ด ๋ง์คํฌ 1๋ก XORํ๋ฉด ์ด๋ค ํํธ๋๋ฅผ ์ป๋์?
- ํจํด์ด ๋ณด์ด๋์?
ํํธ: ์ฒ์ ๋ช ๊ฐ์ ๋ ์ธ ID์ ๋ํด XOR ์ฐ์ฐ์ ์ง์ ํด๋ณด๋ฉด ํ์ด๋ง ํจํด์ ์ดํดํ ์ ์์ต๋๋ค.
2. XOR ํ์ด ํจํด
๋ ์ธ ID์ ์ด์ง ํํ๊ณผ ์ตํ์ ๋นํธ๋ฅผ ๋ค์ง์ผ๋ฉด ์ด๋ป๊ฒ ๋๋์ง ์๊ฐํด ๋ณด์ธ์.
๊ณ ๋ คํ ์ง๋ฌธ:
- ์ง์ ๋ ์ธ์ 1๊ณผ XORํ๋ฉด ์ด๋ป๊ฒ ๋๋์?
- ํ์ ๋ ์ธ์ 1๊ณผ XORํ๋ฉด ์ด๋ป๊ฒ ๋๋์?
- ์ ์ด๊ฒ์ด ์๋ฒฝํ ํ์ด๋ฅผ ๋ง๋๋์?
3. ๊ฒฝ๊ณ ๊ฒ์ฌ ๋ถํ์
shuffle_down()๊ณผ ๋ฌ๋ฆฌ shuffle_xor() ์ฐ์ฐ์ ์ํ ๊ฒฝ๊ณ ๋ด์์ ์ ์ง๋ฉ๋๋ค. ์์ ๋ง์คํฌ๋ก์ XOR์ด ์ ๋๋ก ๋ฒ์ ๋ฐ์ ๋ ์ธ ID๋ฅผ ๋ง๋ค์ง ์๋ ์ด์ ๋ฅผ ์๊ฐํด ๋ณด์ธ์.
์๊ฐํด ๋ณด์ธ์: ์ ํจํ ๋ ์ธ ID๋ฅผ 1๊ณผ XORํ์ ๋ ๋์ฌ ์ ์๋ ์ต๋ ๋ ์ธ ID๋ ์ผ๋ง์ธ๊ฐ์?
๋ฒํฐํ๋ผ์ด ํ์ด ๊ตํ ํ ์คํธ:
pixi run p26 --pair-swap
pixi run -e amd p26 --pair-swap
pixi run -e apple p26 --pair-swap
uv run poe p26 --pair-swap
ํ์์ ๋์ ์์ ์ถ๋ ฅ:
WARP_SIZE: 32
SIZE: 32
output: [1.0, 0.0, 3.0, 2.0, 5.0, 4.0, 7.0, 6.0, 9.0, 8.0, 11.0, 10.0, 13.0, 12.0, 15.0, 14.0, 17.0, 16.0, 19.0, 18.0, 21.0, 20.0, 23.0, 22.0, 25.0, 24.0, 27.0, 26.0, 29.0, 28.0, 31.0, 30.0]
expected: [1.0, 0.0, 3.0, 2.0, 5.0, 4.0, 7.0, 6.0, 9.0, 8.0, 11.0, 10.0, 13.0, 12.0, 15.0, 14.0, 17.0, 16.0, 19.0, 18.0, 21.0, 20.0, 23.0, 22.0, 25.0, 24.0, 27.0, 26.0, 29.0, 28.0, 31.0, 30.0]
โ
Butterfly pair swap test passed!
์๋ฃจ์
fn butterfly_pair_swap[
layout: Layout, size: Int
](
output: LayoutTensor[dtype, layout, MutAnyOrigin],
input: LayoutTensor[dtype, layout, ImmutAnyOrigin],
):
"""
Basic butterfly pair swap: Exchange values between adjacent pairs using XOR pattern.
Each thread exchanges its value with its XOR-1 neighbor, creating pairs: (0,1), (2,3), (4,5), etc.
Uses shuffle_xor(val, 1) to swap values within each pair.
This is the foundation of butterfly network communication patterns.
"""
global_i = Int(block_dim.x * block_idx.x + thread_idx.x)
if global_i < size:
current_val = input[global_i]
# Exchange with XOR-1 neighbor using butterfly pattern
# Lane 0 exchanges with lane 1, lane 2 with lane 3, etc.
swapped_val = shuffle_xor(current_val, 1)
# For demonstration, we'll store the swapped value
# In real applications, this might be used for sorting, reduction, etc.
output[global_i] = swapped_val
์ด ํ์ด๋ shuffle_xor()์ด XOR ํต์ ํจํด์ ํตํด ์๋ฒฝํ ํ์ด ๊ตํ์ ์ด๋ป๊ฒ ๋ง๋๋์ง ๋ณด์ฌ์ค๋๋ค.
์๊ณ ๋ฆฌ์ฆ ๋ถ์:
if global_i < size:
current_val = input[global_i] # ๊ฐ ๋ ์ธ์ด ์์ ์ ์์๋ฅผ ์ฝ์
swapped_val = shuffle_xor(current_val, 1) # XOR๋ก ํ์ด ๊ตํ ์์ฑ
# ๊ตํ๋ ๊ฐ์ ์ ์ฅ
output[global_i] = swapped_val
SIMT ์คํ ์์ธ ๋ถ์:
์ฌ์ดํด 1: ๋ชจ๋ ๋ ์ธ์ด ๋์์ ๊ฐ์ ๋ก๋
Lane 0: current_val = input[0] = 0
Lane 1: current_val = input[1] = 1
Lane 2: current_val = input[2] = 2
Lane 3: current_val = input[3] = 3
...
Lane 31: current_val = input[31] = 31
์ฌ์ดํด 2: shuffle_xor(current_val, 1)์ด ๋ชจ๋ ๋ ์ธ์์ ์คํ
Lane 0: Lane 1์์ ์์ (0โ1=1) โ swapped_val = 1
Lane 1: Lane 0์์ ์์ (1โ1=0) โ swapped_val = 0
Lane 2: Lane 3์์ ์์ (2โ1=3) โ swapped_val = 3
Lane 3: Lane 2์์ ์์ (3โ1=2) โ swapped_val = 2
...
Lane 30: Lane 31์์ ์์ (30โ1=31) โ swapped_val = 31
Lane 31: Lane 30์์ ์์ (31โ1=30) โ swapped_val = 30
์ฌ์ดํด 3: ๊ฒฐ๊ณผ ์ ์ฅ
Lane 0: output[0] = 1
Lane 1: output[1] = 0
Lane 2: output[2] = 3
Lane 3: output[3] = 2
...
์ํ์ ํต์ฐฐ: XOR ์์ฑ์ ํ์ฉํ ์๋ฒฝํ ํ์ด ๊ตํ์ ๊ตฌํํฉ๋๋ค: \[\Large \text{XOR}(i, 1) = \begin{cases} i + 1 & \text{if } i \bmod 2 = 0 \\ i - 1 & \text{if } i \bmod 2 = 1 \end{cases}\]
shuffle_xor์ด ์ฐ์ํ ์ด์ :
- ์๋ฒฝํ ๋์นญ: ๋ชจ๋ ๋ ์ธ์ด ์ ํํ ํ๋์ ํ์ด์ ์ฐธ์ฌ
- ์กฐ์ ๋ถํ์: ๋ชจ๋ ํ์ด๊ฐ ๋์์ ๊ตํ
- ํ๋์จ์ด ์ต์ ํ: ์ํ ์ ์ฒด์ ๋ํด ๋จ์ผ ๋ช ๋ น์ผ๋ก ์ฒ๋ฆฌ
- ๋ฒํฐํ๋ผ์ด ๊ธฐ๋ฐ: ๋ณต์กํ ๋ณ๋ ฌ ์๊ณ ๋ฆฌ์ฆ์ ๋น๋ฉ ๋ธ๋ก
์ฑ๋ฅ ํน์ฑ:
- ์ง์ฐ ์๊ฐ: 1 ์ฌ์ดํด (ํ๋์จ์ด ๋ ์ง์คํฐ ๊ตํ)
- ๋์ญํญ: 0 ๋ฐ์ดํธ (๋ฉ๋ชจ๋ฆฌ ํธ๋ํฝ ์์)
- ๋ณ๋ ฌ์ฑ: WARP_SIZE๊ฐ ๋ ์ธ ๋ชจ๋ ๋์์ ๊ตํ
- ํ์ฅ์ฑ: ๋ฐ์ดํฐ ํฌ๊ธฐ์ ๊ด๊ณ์์ด \(O(1)\) ๋ณต์ก๋
2. ๋ฒํฐํ๋ผ์ด ๋ณ๋ ฌ ์ต๋๊ฐ
๊ตฌ์ฑ
- ๋ฒกํฐ ํฌ๊ธฐ:
SIZE = WARP_SIZE(GPU์ ๋ฐ๋ผ 32 ๋๋ 64) - ๊ทธ๋ฆฌ๋ ๊ตฌ์ฑ:
(1, 1)๊ทธ๋ฆฌ๋๋น ๋ธ๋ก ์ - ๋ธ๋ก ๊ตฌ์ฑ:
(WARP_SIZE, 1)๋ธ๋ก๋น ์ค๋ ๋ ์
์์ฑํ ์ฝ๋
๊ฐ์ํ๋ offset์ผ๋ก ๋ฒํฐํ๋ผ์ด shuffle_xor์ ์ฌ์ฉํ์ฌ ๋ณ๋ ฌ ์ต๋๊ฐ ๋ฆฌ๋์
์ ๊ตฌํํฉ๋๋ค.
์ํ์ ์ฐ์ฐ: ํธ๋ฆฌ ๋ฆฌ๋์ ์ ํตํด ๋ชจ๋ ์ํ ๋ ์ธ์์ ์ต๋๊ฐ์ ๊ณ์ฐํฉ๋๋ค: \[\Large \text{max_result} = \max_{i=0}^{\small\text{WARP_SIZE}-1} \text{input}[i]\]
๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์
ํจํด: XOR ์คํ์
์ WARP_SIZE/2์์ 1๊น์ง ์ ๋ฐ์ฉ ์ค์ฌ๊ฐ๋ฉฐ, ํต์ ๋ฒ์๊ฐ ๋จ๊ณ๋ง๋ค ๋ฐ์ผ๋ก ์ข์์ง๋ ์ด์ง ํธ๋ฆฌ๋ฅผ ๊ตฌ์ฑํฉ๋๋ค:
- 1๋จ๊ณ:
WARP_SIZE/2๊ฑฐ๋ฆฌ์ ๋ ์ธ๊ณผ ๋น๊ต (์ํ ์ ์ฒด๋ฅผ ํฌ๊ด) - 2๋จ๊ณ:
WARP_SIZE/4๊ฑฐ๋ฆฌ์ ๋ ์ธ๊ณผ ๋น๊ต (๋ฒ์๋ฅผ ์ ๋ฐ์ผ๋ก ์ขํ) - 3๋จ๊ณ:
WARP_SIZE/8๊ฑฐ๋ฆฌ์ ๋ ์ธ๊ณผ ๋น๊ต - 4๋จ๊ณ:
offset = 1์ด ๋ ๋๊น์ง ๊ณ์ ์ ๋ฐ์ผ๋ก ์ค์
\(\log_2(\text{WARP_SIZE})\) ๋จ๊ณ๋ฅผ ๊ฑฐ์น๋ฉด ๋ชจ๋ ๋ ์ธ์ด ์ ์ญ ์ต๋๊ฐ์ ๊ฐ๊ฒ ๋ฉ๋๋ค. ์ด ๋ฐฉ์์ ๋ชจ๋ WARP_SIZE (32, 64 ๋ฑ)์์ ๋์ํฉ๋๋ค.
fn butterfly_parallel_max[
layout: Layout, size: Int
](
output: LayoutTensor[dtype, layout, MutAnyOrigin],
input: LayoutTensor[dtype, layout, ImmutAnyOrigin],
):
"""
Parallel maximum reduction using butterfly pattern.
Uses shuffle_xor with decreasing offsets starting from WARP_SIZE/2 down to 1.
Each step reduces the active range by half until all threads have the maximum value.
This implements an efficient O(log n) parallel reduction algorithm that works
for any WARP_SIZE (32, 64, etc.).
"""
global_i = Int(block_dim.x * block_idx.x + thread_idx.x)
# FILL ME IN (roughly 7 lines)
ํ
1. ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์ ์ดํดํ๊ธฐ
๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์ ์ ์ด์ง ํธ๋ฆฌ ํต์ ํจํด์ ์์ฑํฉ๋๋ค. ๊ฐ ๋จ๊ณ์์ ๋ฌธ์ ํฌ๊ธฐ๋ฅผ ์ฒด๊ณ์ ์ผ๋ก ์ค์ด๋ ๋ฐฉ๋ฒ์ ์๊ฐํด ๋ณด์ธ์.
ํต์ฌ ์ง๋ฌธ:
- ์ต๋ ๋ฒ์๋ฅผ ์ปค๋ฒํ๋ ค๋ฉด ์์ offset์ด ์ผ๋ง์ฌ์ผ ํ๋์?
- ๋จ๊ณ ์ฌ์ด์ ์คํ์ ์ ์ด๋ป๊ฒ ๋ณ๊ฒฝํด์ผ ํ๋์?
- ์ธ์ ๋ฆฌ๋์ ์ ๋ฉ์ถฐ์ผ ํ๋์?
ํํธ: โ๋ฒํฐํ๋ผ์ดโ๋ผ๋ ์ด๋ฆ์ ํต์ ํจํด์์ ์ ๋ํฉ๋๋ค - ์์ ์์ ์ ๋ํด ์ง์ ๊ทธ๋ ค๋ณด์ธ์.
2. XOR ๋ฆฌ๋์ ํน์ฑ
XOR์ ๊ฐ ๋จ๊ณ์์ ๊ฒน์น์ง ์๋ ํต์ ํ์ด๋ฅผ ์์ฑํฉ๋๋ค. ์ด๊ฒ์ด ๋ณ๋ ฌ ๋ฆฌ๋์ ์์ ์ ์ค์ํ์ง ์๊ฐํด ๋ณด์ธ์.
์๊ฐํด ๋ณด์ธ์:
- ์๋ก ๋ค๋ฅธ ์คํ์ ์ผ๋ก์ XOR์ด ์ด๋ป๊ฒ ๋ค๋ฅธ ํต์ ํจํด์ ๋ง๋๋์?
- ๊ฐ์ ๋จ๊ณ์์ ๋ ์ธ๋ค์ด ์ ์๋ก ๊ฐ์ญํ์ง ์๋์?
- XOR์ด ํธ๋ฆฌ ๋ฆฌ๋์ ์ ํนํ ์ ํฉํ ์ด์ ๋ ๋ฌด์์ธ๊ฐ์?
3. ์ต๋๊ฐ ๋์
๊ฐ ๋ ์ธ์ ์์ ์ โ์์ญโ์์ ์ต๋๊ฐ์ ์ง์์ ์ ์ง์ ์ผ๋ก ์์๊ฐ์ผ ํฉ๋๋ค.
์๊ณ ๋ฆฌ์ฆ ๊ตฌ์กฐ:
- ์์ ์ ๊ฐ์ผ๋ก ์์
- ๊ฐ ๋จ๊ณ์์ ์ด์์ ๊ฐ๊ณผ ๋น๊ต
- ์ต๋๊ฐ์ ์ ์งํ๊ณ ๊ณ์ ์งํ
ํต์ฌ ํต์ฐฐ: ๊ฐ ๋จ๊ณ ํ, โ์ง์์ ์์ญโ์ด ๋ ๋ฐฐ๋ก ํ์ฅ๋ฉ๋๋ค.
- ๋ง์ง๋ง ๋จ๊ณ ํ: ๊ฐ ๋ ์ธ์ด ์ ์ญ ์ต๋๊ฐ์ ์๊ฒ ๋ฉ๋๋ค
4. ์ด ํจํด์ด ๋์ํ๋ ์ด์
๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์ ์ \(\log_2(\text{WARP_SIZE})\) ๋จ๊ณ ํ์ ๋ค์์ ๋ณด์ฅํฉ๋๋ค:
- ๋ชจ๋ ๋ ์ธ์ด ๋ค๋ฅธ ๋ชจ๋ ๋ ์ธ์ ๊ฐ์ ๊ฐ์ ์ ์ผ๋ก ํ์ธ
- ์ค๋ณต ํต์ ์์: ๊ฐ ํ์ด๊ฐ ๋จ๊ณ๋น ์ ํํ ํ ๋ฒ ๊ตํ
- ์ต์ ๋ณต์ก๋: \(O(n)\) ์์ฐจ ๋น๊ต ๋์ \(O(\log n)\) ๋จ๊ณ
์ถ์ ์์ (4๊ฐ ๋ ์ธ, ๊ฐ [3, 1, 7, 2]):
์ด๊ธฐ ์ํ: Lane 0=3, Lane 1=1, Lane 2=7, Lane 3=2
1๋จ๊ณ (offset=2): 0 โ 2, 1 โ 3
Lane 0: max(3, 7) = 7
Lane 1: max(1, 2) = 2
Lane 2: max(7, 3) = 7
Lane 3: max(2, 1) = 2
2๋จ๊ณ (offset=1): 0 โ 1, 2 โ 3
Lane 0: max(7, 2) = 7
Lane 1: max(2, 7) = 7
Lane 2: max(7, 2) = 7
Lane 3: max(2, 7) = 7
๊ฒฐ๊ณผ: ๋ชจ๋ ๋ ์ธ์ด ์ ์ญ ์ต๋๊ฐ = 7์ ๊ฐ์ง
๋ฒํฐํ๋ผ์ด ๋ณ๋ ฌ ์ต๋๊ฐ ํ ์คํธ:
pixi run p26 --parallel-max
pixi run -e amd p26 --parallel-max
uv run poe p26 --parallel-max
ํ์์ ๋์ ์์ ์ถ๋ ฅ:
WARP_SIZE: 32
SIZE: 32
output: [1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0]
expected: [1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0]
โ
Butterfly parallel max test passed!
์๋ฃจ์
fn butterfly_parallel_max[
layout: Layout, size: Int
](
output: LayoutTensor[dtype, layout, MutAnyOrigin],
input: LayoutTensor[dtype, layout, ImmutAnyOrigin],
):
"""
Parallel maximum reduction using butterfly pattern.
Uses shuffle_xor with decreasing offsets (16, 8, 4, 2, 1) to perform tree-based reduction.
Each step reduces the active range by half until all threads have the maximum value.
This implements an efficient O(log n) parallel reduction algorithm.
"""
global_i = Int(block_dim.x * block_idx.x + thread_idx.x)
if global_i < size:
max_val = input[global_i]
# Butterfly reduction tree: dynamic for any WARP_SIZE (32, 64, etc.)
# Start with half the warp size and reduce by half each step
offset = WARP_SIZE // 2
while offset > 0:
max_val = max(max_val, shuffle_xor(max_val, offset))
offset //= 2
# All threads now have the maximum value across the entire warp
output[global_i] = max_val
์ด ํ์ด๋ shuffle_xor()์ด \(O(\log n)\) ๋ณต์ก๋์ ํจ์จ์ ์ธ ๋ณ๋ ฌ ๋ฆฌ๋์
ํธ๋ฆฌ๋ฅผ ์ด๋ป๊ฒ ์์ฑํ๋์ง ๋ณด์ฌ์ค๋๋ค.
์ ์ฒด ์๊ณ ๋ฆฌ์ฆ ๋ถ์:
if global_i < size:
max_val = input[global_i] # ๋ก์ปฌ ๊ฐ์ผ๋ก ์์
# ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์
ํธ๋ฆฌ: ๋ชจ๋ WARP_SIZE์ ๋์ ์ผ๋ก ๋์
offset = WARP_SIZE // 2
while offset > 0:
max_val = max(max_val, shuffle_xor(max_val, offset))
offset //= 2
output[global_i] = max_val # ๋ชจ๋ ๋ ์ธ์ด ์ ์ญ ์ต๋๊ฐ์ ๊ฐ์ง
๋ฒํฐํ๋ผ์ด ์คํ ์ถ์ (8-๋ ์ธ ์์ , ๊ฐ [0,2,4,6,8,10,12,1000]):
์ด๊ธฐ ์ํ:
Lane 0: max_val = 0, Lane 1: max_val = 2
Lane 2: max_val = 4, Lane 3: max_val = 6
Lane 4: max_val = 8, Lane 5: max_val = 10
Lane 6: max_val = 12, Lane 7: max_val = 1000
1๋จ๊ณ: shuffle_xor(max_val, 4) - ์ ๋ฐ ๊ตํ
Lane 0โ4: max(0,8)=8, Lane 1โ5: max(2,10)=10
Lane 2โ6: max(4,12)=12, Lane 3โ7: max(6,1000)=1000
Lane 4โ0: max(8,0)=8, Lane 5โ1: max(10,2)=10
Lane 6โ2: max(12,4)=12, Lane 7โ3: max(1000,6)=1000
2๋จ๊ณ: shuffle_xor(max_val, 2) - 1/4 ๊ตํ
Lane 0โ2: max(8,12)=12, Lane 1โ3: max(10,1000)=1000
Lane 2โ0: max(12,8)=12, Lane 3โ1: max(1000,10)=1000
Lane 4โ6: max(8,12)=12, Lane 5โ7: max(10,1000)=1000
Lane 6โ4: max(12,8)=12, Lane 7โ5: max(1000,10)=1000
3๋จ๊ณ: shuffle_xor(max_val, 1) - ํ์ด ๊ตํ
Lane 0โ1: max(12,1000)=1000, Lane 1โ0: max(1000,12)=1000
Lane 2โ3: max(12,1000)=1000, Lane 3โ2: max(1000,12)=1000
Lane 4โ5: max(12,1000)=1000, Lane 5โ4: max(1000,12)=1000
Lane 6โ7: max(12,1000)=1000, Lane 7โ6: max(1000,12)=1000
์ต์ข
๊ฒฐ๊ณผ: ๋ชจ๋ ๋ ์ธ์ max_val = 1000
์ํ์ ํต์ฐฐ: ๋ฒํฐํ๋ผ์ด ํต์ ์ผ๋ก ๋ณ๋ ฌ ๋ฆฌ๋์ ์ฐ์ฐ์๋ฅผ ๊ตฌํํฉ๋๋ค: \[\Large \text{Reduce}(\oplus, [a_0, a_1, \ldots, a_{n-1}]) = a_0 \oplus a_1 \oplus \cdots \oplus a_{n-1}\]
์ฌ๊ธฐ์ \(\oplus\)๋ max ์ฐ์ฐ์ด๋ฉฐ, ๋ฒํฐํ๋ผ์ด ํจํด์ด ์ต์ \(O(\log n)\) ๋ณต์ก๋๋ฅผ ๋ณด์ฅํฉ๋๋ค.
๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์ ์ด ์ฐ์ํ ์ด์ :
- ๋ก๊ทธ ๋ณต์ก๋: ์์ฐจ ๋ฆฌ๋์ ์ \(O(n)\)์ ๋นํด \(O(\log n)\)
- ์๋ฒฝํ ๋ถํ ๋ถ์ฐ: ๋ชจ๋ ๋ ์ธ์ด ๊ฐ ๋จ๊ณ์์ ๋๋ฑํ๊ฒ ์ฐธ์ฌ
- ๋ฉ๋ชจ๋ฆฌ ๋ณ๋ชฉ ์์: ์์ ๋ ์ง์คํฐ ๊ฐ ํต์
- ํ๋์จ์ด ์ต์ ํ: GPU ๋ฒํฐํ๋ผ์ด ๋คํธ์ํฌ์ ์ง์ ๋งคํ
์ฑ๋ฅ ํน์ฑ:
- ๋จ๊ณ ์: \(\log_2(\text{WARP_SIZE})\) (์: 32-์ค๋ ๋ ์ํ๋ 5๋จ๊ณ, 64-์ค๋ ๋ ์ํ๋ 6๋จ๊ณ)
- ๋จ๊ณ๋น ์ง์ฐ ์๊ฐ: 1 ์ฌ์ดํด (๋ ์ง์คํฐ ๊ตํ + ๋น๊ต)
- ์ด ์ง์ฐ ์๊ฐ: ์์ฐจ ๋ฐฉ์์ \((\text{WARP_SIZE}-1)\) ์ฌ์ดํด ๋๋น \(\log_2(\text{WARP_SIZE})\) ์ฌ์ดํด
- ๋ณ๋ ฌ์ฑ: ์๊ณ ๋ฆฌ์ฆ ์ ์ฒด์์ ๋ชจ๋ ๋ ์ธ์ด ํ์ฑ ์ํ
3. ๋ฒํฐํ๋ผ์ด ์กฐ๊ฑด๋ถ ์ต๋๊ฐ
๊ตฌ์ฑ
- ๋ฒกํฐ ํฌ๊ธฐ:
SIZE_2 = 64(๋ฉํฐ ๋ธ๋ก ์๋๋ฆฌ์ค) - ๊ทธ๋ฆฌ๋ ๊ตฌ์ฑ:
BLOCKS_PER_GRID_2 = (2, 1)๊ทธ๋ฆฌ๋๋น ๋ธ๋ก ์ - ๋ธ๋ก ๊ตฌ์ฑ:
THREADS_PER_BLOCK_2 = (WARP_SIZE, 1)๋ธ๋ก๋น ์ค๋ ๋ ์
์์ฑํ ์ฝ๋
์ง์ ๋ ์ธ์ ์ต๋๊ฐ์, ํ์ ๋ ์ธ์ ์ต์๊ฐ์ ์ ์ฅํ๋ ์กฐ๊ฑด๋ถ ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์ ์ ๊ตฌํํฉ๋๋ค.
์ํ์ ์ฐ์ฐ: ์ต๋๊ฐ๊ณผ ์ต์๊ฐ ๋ชจ๋์ ๋ํด ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์ ์ ์ํํ ํ, ๋ ์ธ ํ์ง์ ๋ฐ๋ผ ์กฐ๊ฑด๋ถ๋ก ์ถ๋ ฅํฉ๋๋ค: \[\Large \text{output}[i] = \begin{cases} \max_{j=0}^{\text{WARP_SIZE}-1} \text{input}[j] & \text{if } i \bmod 2 = 0 \\ \min_{j=0}^{\text{WARP_SIZE}-1} \text{input}[j] & \text{if } i \bmod 2 = 1 \end{cases}\]
์ด์ค ๋ฆฌ๋์ ํจํด: ๋ฒํฐํ๋ผ์ด ํธ๋ฆฌ๋ฅผ ํตํด ์ต๋๊ฐ๊ณผ ์ต์๊ฐ์ ๋์์ ์ถ์ ํ ํ, ๋ ์ธ ID ํ์ง์ ๋ฐ๋ผ ์กฐ๊ฑด๋ถ๋ก ์ถ๋ ฅํฉ๋๋ค. ์ด๋ ๋ฒํฐํ๋ผ์ด ํจํด์ด ๋ณต์กํ ๋ค์ค ๊ฐ ๋ฆฌ๋์ ์ผ๋ก ์ด๋ป๊ฒ ํ์ฅ๋๋์ง๋ฅผ ๋ณด์ฌ์ค๋๋ค.
comptime SIZE_2 = 64
comptime BLOCKS_PER_GRID_2 = (2, 1)
comptime THREADS_PER_BLOCK_2 = (WARP_SIZE, 1)
comptime layout_2 = Layout.row_major(SIZE_2)
fn butterfly_conditional_max[
layout: Layout, size: Int
](
output: LayoutTensor[dtype, layout, MutAnyOrigin],
input: LayoutTensor[dtype, layout, ImmutAnyOrigin],
):
"""
Conditional butterfly maximum: Perform butterfly max reduction, but only store result
in even-numbered lanes. Odd-numbered lanes store the minimum value seen.
Demonstrates conditional logic combined with butterfly communication patterns.
"""
global_i = Int(block_dim.x * block_idx.x + thread_idx.x)
lane = lane_id()
if global_i < size:
current_val = input[global_i]
min_val = current_val
# FILL ME IN (roughly 11 lines)
ํ
1. ์ด์ค ์ถ์ ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์
์ด ํผ์ฆ์ ๋ฒํฐํ๋ผ์ด ํธ๋ฆฌ๋ฅผ ํตํด ๋ ๊ฐ์ง ๋ค๋ฅธ ๊ฐ์ ๋์์ ์ถ์ ํด์ผ ํฉ๋๋ค. ์ฌ๋ฌ ๋ฆฌ๋์ ์ ๋ณ๋ ฌ๋ก ์คํํ๋ ๋ฐฉ๋ฒ์ ์๊ฐํด ๋ณด์ธ์.
ํต์ฌ ์ง๋ฌธ:
- ๋ฆฌ๋์ ๊ณผ์ ์์ ์ต๋๊ฐ๊ณผ ์ต์๊ฐ์ ์ด๋ป๊ฒ ๋์์ ์ ์งํ ์ ์๋์?
- ๋ ์ฐ์ฐ์ ๊ฐ์ ๋ฒํฐํ๋ผ์ด ํจํด์ ์ฌ์ฉํ ์ ์๋์?
- ์ด๋ค ๋ณ์๋ฅผ ์ถ์ ํด์ผ ํ๋์?
2. ์กฐ๊ฑด๋ถ ์ถ๋ ฅ ๋ก์ง
๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์ ์ ์๋ฃํ ํ, ๋ ์ธ ํ์ง์ ๋ฐ๋ผ ๋ค๋ฅธ ๊ฐ์ ์ถ๋ ฅํด์ผ ํฉ๋๋ค.
๊ณ ๋ คํ ์ :
- ๋ ์ธ์ด ์ง์์ธ์ง ํ์์ธ์ง ์ด๋ป๊ฒ ํ๋ณํ๋์?
- ์ด๋ค ๋ ์ธ์ด ์ต๋๊ฐ์, ์ด๋ค ๋ ์ธ์ด ์ต์๊ฐ์ ์ถ๋ ฅํด์ผ ํ๋์?
- ๋ ์ธ ID์ ์ด๋ป๊ฒ ์ ๊ทผํ๋์?
3. min๊ณผ max ๋์ ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์
์ด ๊ณผ์ ์ ํต์ฌ์ ๊ฐ์ ๋ฒํฐํ๋ผ์ด ํต์ ํจํด์ผ๋ก min๊ณผ max๋ฅผ ํจ์จ์ ์ผ๋ก ๋ณ๋ ฌ ๊ณ์ฐํ๋ ๊ฒ์ ๋๋ค.
์๊ฐํด ๋ณด์ธ์:
- min๊ณผ max์ ๋ณ๋์ ์ ํ ์ฐ์ฐ์ด ํ์ํ๊ฐ์?
- ๋ ์ฐ์ฐ์ ๊ฐ์ ์ด์ ๊ฐ์ ์ฌ์ฌ์ฉํ ์ ์๋์?
- ๋ ๋ฆฌ๋์ ๋ชจ๋ ์ฌ๋ฐ๋ฅด๊ฒ ์๋ฃ๋๋ ค๋ฉด ์ด๋ป๊ฒ ํด์ผ ํ๋์?
4. ๋ฉํฐ ๋ธ๋ก ๊ฒฝ๊ณ ๊ณ ๋ ค์ฌํญ
์ด ํผ์ฆ์ ์ฌ๋ฌ ๋ธ๋ก์ ์ฌ์ฉํฉ๋๋ค. ์ด๊ฒ์ด ๋ฆฌ๋์ ๋ฒ์์ ์ด๋ค ์ํฅ์ ๋ฏธ์น๋์ง ์๊ฐํด ๋ณด์ธ์.
์ค์ํ ๊ณ ๋ ค์ฌํญ:
- ๊ฐ ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์ ์ ๋ฒ์๋ ์ด๋๊น์ง์ธ๊ฐ์?
- ๋ธ๋ก ๊ตฌ์กฐ๊ฐ ๋ ์ธ ๋ฒํธ ๋งค๊ธฐ๊ธฐ์ ์ด๋ค ์ํฅ์ ๋ฏธ์น๋์?
- ์ ์ญ min/max๋ฅผ ๊ณ์ฐํ๋์, ๋ธ๋ก๋ณ min/max๋ฅผ ๊ณ์ฐํ๋์?
๋ฒํฐํ๋ผ์ด ์กฐ๊ฑด๋ถ ์ต๋๊ฐ ํ ์คํธ:
pixi run p26 --conditional-max
pixi run -e amd p26 --conditional-max
uv run poe p26 --conditional-max
ํ์์ ๋์ ์์ ์ถ๋ ฅ:
WARP_SIZE: 32
SIZE_2: 64
output: [9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0]
expected: [9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0]
โ
Butterfly conditional max test passed!
์๋ฃจ์
fn butterfly_conditional_max[
layout: Layout, size: Int
](
output: LayoutTensor[dtype, layout, MutAnyOrigin],
input: LayoutTensor[dtype, layout, ImmutAnyOrigin],
):
"""
Conditional butterfly maximum: Perform butterfly max reduction, but only store result
in even-numbered lanes. Odd-numbered lanes store the minimum value seen.
Demonstrates conditional logic combined with butterfly communication patterns.
"""
global_i = Int(block_dim.x * block_idx.x + thread_idx.x)
lane = lane_id()
if global_i < size:
current_val = input[global_i]
min_val = current_val
# Butterfly reduction for both maximum and minimum: dynamic for any WARP_SIZE
offset = WARP_SIZE // 2
while offset > 0:
neighbor_val = shuffle_xor(current_val, offset)
current_val = max(current_val, neighbor_val)
min_neighbor_val = shuffle_xor(min_val, offset)
min_val = min(min_val, min_neighbor_val)
offset //= 2
# Conditional output: max for even lanes, min for odd lanes
if lane % 2 == 0:
output[global_i] = current_val # Maximum
else:
output[global_i] = min_val # Minimum
์ด ํ์ด๋ ์ด์ค ์ถ์ ๊ณผ ์กฐ๊ฑด๋ถ ์ถ๋ ฅ์ ์ฌ์ฉํ๋ ๊ณ ๊ธ ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์ ์ ๋ณด์ฌ์ค๋๋ค.
์ ์ฒด ์๊ณ ๋ฆฌ์ฆ ๋ถ์:
if global_i < size:
current_val = input[global_i]
min_val = current_val # ์ต์๊ฐ์ ๋ณ๋๋ก ์ถ์
# max์ min ๋์ ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์
(log_2(WARP_SIZE) ๋จ๊ณ)
offset = WARP_SIZE // 2
while offset > 0:
neighbor_val = shuffle_xor(current_val, offset)
current_val = max(current_val, neighbor_val) # Max ๋ฆฌ๋์
min_neighbor_val = shuffle_xor(min_val, offset)
min_val = min(min_val, min_neighbor_val) # Min ๋ฆฌ๋์
offset //= 2
# ๋ ์ธ ํ์ง์ ๋ฐ๋ฅธ ์กฐ๊ฑด๋ถ ์ถ๋ ฅ
if lane % 2 == 0:
output[global_i] = current_val # ์ง์ ๋ ์ธ: ์ต๋๊ฐ
else:
output[global_i] = min_val # ํ์ ๋ ์ธ: ์ต์๊ฐ
์ด์ค ๋ฆฌ๋์ ์คํ ์ถ์ (4-๋ ์ธ ์์ , ๊ฐ [3, 1, 7, 2]):
์ด๊ธฐ ์ํ:
Lane 0: current_val=3, min_val=3
Lane 1: current_val=1, min_val=1
Lane 2: current_val=7, min_val=7
Lane 3: current_val=2, min_val=2
1๋จ๊ณ: shuffle_xor(current_val, 2)์ shuffle_xor(min_val, 2) - ์ ๋ฐ ๊ตํ
Lane 0โ2: max_neighbor=7, min_neighbor=7 โ current_val=max(3,7)=7, min_val=min(3,7)=3
Lane 1โ3: max_neighbor=2, min_neighbor=2 โ current_val=max(1,2)=2, min_val=min(1,2)=1
Lane 2โ0: max_neighbor=3, min_neighbor=3 โ current_val=max(7,3)=7, min_val=min(7,3)=3
Lane 3โ1: max_neighbor=1, min_neighbor=1 โ current_val=max(2,1)=2, min_val=min(2,1)=1
2๋จ๊ณ: shuffle_xor(current_val, 1)์ shuffle_xor(min_val, 1) - ํ์ด ๊ตํ
Lane 0โ1: max_neighbor=2, min_neighbor=1 โ current_val=max(7,2)=7, min_val=min(3,1)=1
Lane 1โ0: max_neighbor=7, min_neighbor=3 โ current_val=max(2,7)=7, min_val=min(1,3)=1
Lane 2โ3: max_neighbor=2, min_neighbor=1 โ current_val=max(7,2)=7, min_val=min(3,1)=1
Lane 3โ2: max_neighbor=7, min_neighbor=3 โ current_val=max(2,7)=7, min_val=min(1,3)=1
์ต์ข
๊ฒฐ๊ณผ: ๋ชจ๋ ๋ ์ธ์ด current_val=7 (์ ์ญ max)๊ณผ min_val=1 (์ ์ญ min)์ ๊ฐ์ง
๋์ ์๊ณ ๋ฆฌ์ฆ (๋ชจ๋ WARP_SIZE์์ ๋์):
offset = WARP_SIZE // 2
while offset > 0:
neighbor_val = shuffle_xor(current_val, offset)
current_val = max(current_val, neighbor_val)
min_neighbor_val = shuffle_xor(min_val, offset)
min_val = min(min_val, min_neighbor_val)
offset //= 2
์ํ์ ํต์ฐฐ: ์กฐ๊ฑด๋ถ ๋๋ฉํฐํ๋ ์ฑ์ ์ฌ์ฉํ๋ ์ด์ค ๋ณ๋ ฌ ๋ฆฌ๋์ ์ ๊ตฌํํฉ๋๋ค: \[\Large \begin{align} \text{max_result} &= \max_{i=0}^{n-1} \text{input}[i] \\ \text{min_result} &= \min_{i=0}^{n-1} \text{input}[i] \\ \text{output}[i] &= \text{lane_parity}(i) \; \text{?} \; \text{min_result} : \text{max_result} \end{align}\]
์ด์ค ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์ ์ด ๋์ํ๋ ์ด์ :
- ๋ ๋ฆฝ์ ๋ฆฌ๋์ : Max์ min ๋ฆฌ๋์ ์ ์ํ์ ์ผ๋ก ๋ ๋ฆฝ
- ๋ณ๋ ฌ ์คํ: ๋ ๋ค ๊ฐ์ ๋ฒํฐํ๋ผ์ด ํต์ ํจํด์ ์ฌ์ฉ ๊ฐ๋ฅ
- ํต์ ๊ณต์ : ๊ฐ์ ์ ํ ์ฐ์ฐ์ด ๋ ๋ฆฌ๋์ ๋ชจ๋์ ํ์ฉ
- ์กฐ๊ฑด๋ถ ์ถ๋ ฅ: ๋ ์ธ ํ์ง์ด ์ด๋ค ๊ฒฐ๊ณผ๋ฅผ ์ถ๋ ฅํ ์ง ๊ฒฐ์
์ฑ๋ฅ ํน์ฑ:
- ํต์ ๋จ๊ณ: \(\log_2(\text{WARP_SIZE})\) (๋จ์ผ ๋ฆฌ๋์ ๊ณผ ๋์ผ)
- ๋จ๊ณ๋น ์ฐ์ฐ: ๋จ์ผ ๋ฆฌ๋์ ์ 1๊ฐ ๋๋น 2๊ฐ ์ฐ์ฐ (max + min)
- ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ: ๋ณต์กํ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ๋ฐฉ์ ๋๋น ์ค๋ ๋๋น ๋ ์ง์คํฐ 2๊ฐ
- ์ถ๋ ฅ ์ ์ฐ์ฑ: ์๋ก ๋ค๋ฅธ ๋ ์ธ์ด ๋ค๋ฅธ ๋ฆฌ๋์ ๊ฒฐ๊ณผ๋ฅผ ์ถ๋ ฅ ๊ฐ๋ฅ
์์ฝ
shuffle_xor() ๊ธฐ๋ณธ ์์๋ ํจ์จ์ ์ธ ๋ณ๋ ฌ ์๊ณ ๋ฆฌ์ฆ์ ๊ธฐ๋ฐ์ด ๋๋ ๊ฐ๋ ฅํ ๋ฒํฐํ๋ผ์ด ํต์ ํจํด์ ๊ฐ๋ฅํ๊ฒ ํฉ๋๋ค. ์ธ ๊ฐ์ง ๋ฌธ์ ๋ฅผ ํตํด ๋ค์์ ๋ฐฐ์ ์ต๋๋ค:
ํต์ฌ ๋ฒํฐํ๋ผ์ด ํจํด
-
ํ์ด ๊ตํ (
shuffle_xor(value, 1)):- ์๋ฒฝํ ์ธ์ ํ์ด ์์ฑ: (0,1), (2,3), (4,5), โฆ
- ๋ฉ๋ชจ๋ฆฌ ์ค๋ฒํค๋ ์ ๋ก์ \(O(1)\) ๋ณต์ก๋
- ์ ๋ ฌ ๋คํธ์ํฌ์ ๋ฐ์ดํฐ ์ฌ๋ฐฐ์น์ ๊ธฐ๋ฐ
-
ํธ๋ฆฌ ๋ฆฌ๋์ (๋์ offset:
WARP_SIZE/2โ1):- ๋ก๊ทธ ๋ณ๋ ฌ ๋ฆฌ๋์ : ์์ฐจ์ \(O(n)\) ๋๋น \(O(\log n)\)
- ๋ชจ๋ ๊ฒฐํฉ ์ฐ์ฐ์ ์ ์ฉ ๊ฐ๋ฅ (max, min, sum ๋ฑ)
- ๋ชจ๋ ์ํ ๋ ์ธ์ ๊ฑธ์ณ ์ต์ ์ ๋ถํ ๋ถ์ฐ
-
์กฐ๊ฑด๋ถ ๋ค์ค ๋ฆฌ๋์ (์ด์ค ์ถ์ + ๋ ์ธ ํ์ง):
- ์ฌ๋ฌ ๋ฆฌ๋์ ์ ๋์์ ๋ณ๋ ฌ ์ํ
- ์ค๋ ๋ ํน์ฑ์ ๋ฐ๋ฅธ ์กฐ๊ฑด๋ถ ์ถ๋ ฅ
- ๋ช ์์ ๋๊ธฐํ ์๋ ๊ณ ๊ธ ์กฐ์
ํต์ฌ ์๊ณ ๋ฆฌ์ฆ ํต์ฐฐ
XOR ํต์ ํน์ฑ:
shuffle_xor(value, mask)๊ฐ ๋์นญ์ ์ด๊ณ ๊ฒน์น์ง ์๋ ํ์ด๋ฅผ ์์ฑ- ๊ฐ ๋ง์คํฌ๊ฐ ๊ณ ์ ํ ํต์ ํ ํด๋ก์ง๋ฅผ ์์ฑ
- ์ด์ง XOR ํจํด์์ ๋ฒํฐํ๋ผ์ด ๋คํธ์ํฌ๊ฐ ์์ฐ์ค๋ฝ๊ฒ ๋์ถ
๋์ ์๊ณ ๋ฆฌ์ฆ ์ค๊ณ:
offset = WARP_SIZE // 2
while offset > 0:
neighbor_val = shuffle_xor(current_val, offset)
current_val = operation(current_val, neighbor_val)
offset //= 2
์ฑ๋ฅ ์ด์ :
- ํ๋์จ์ด ์ต์ ํ: ๋ ์ง์คํฐ ๊ฐ ์ง์ ํต์
- ๋๊ธฐํ ๋ถํ์: SIMT ์คํ์ด ์ ํ์ฑ์ ๋ณด์ฅ
- ํ์ฅ ๊ฐ๋ฅํ ๋ณต์ก๋: ๋ชจ๋ WARP_SIZE (32, 64 ๋ฑ)์์ \(O(\log n)\)
- ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ: ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ๋ถํ์
์ค์ฉ์ ํ์ฉ
์ด ๋ฒํฐํ๋ผ์ด ํจํด๋ค์ ๊ธฐ๋ฐ์ด ๋๋ ๋ถ์ผ:
- ๋ณ๋ ฌ ๋ฆฌ๋์ : ํฉ๊ณ, max, min, ๋ ผ๋ฆฌ ์ฐ์ฐ
- ๋์ ํฉ/์ค์บ ์ฐ์ฐ: ๋์ ํฉ, ๋ณ๋ ฌ ์ ๋ ฌ
- FFT ์๊ณ ๋ฆฌ์ฆ: ์ ํธ ์ฒ๋ฆฌ์ ํฉ์ฑ๊ณฑ
- Bitonic ์ ๋ ฌ: ๋ณ๋ ฌ ์ ๋ ฌ ๋คํธ์ํฌ
- ๊ทธ๋ํ ์๊ณ ๋ฆฌ์ฆ: ํธ๋ฆฌ ์ํ์ ์ฐ๊ฒฐ์ฑ
shuffle_xor() ๊ธฐ๋ณธ ์์๋ ๋ณต์กํ ๋ณ๋ ฌ ์กฐ์ ์ ์ฐ์ํ๊ณ ํ๋์จ์ด ์ต์ ํ๋ ํต์ ํจํด์ผ๋ก ๋ณํํ๋ฉฐ, ๋ค์ํ GPU ์ํคํ
์ฒ์์ ํจ์จ์ ์ผ๋ก ํ์ฅ๋ฉ๋๋ค.