warp.shuffle_xor() ๋ฒํฐํ๋ผ์ด ํต์
์ํ ๋ ๋ฒจ ๋ฒํฐํ๋ผ์ด ํต์ ์์๋ shuffle_xor()์ ์ฌ์ฉํ์ฌ ์ํ ๋ด์ ์ ๊ตํ ํธ๋ฆฌ
๊ธฐ๋ฐ ํต์ ํจํด์ ๊ตฌ์ฑํ ์ ์์ต๋๋ค. ์ด ๊ฐ๋ ฅํ ๊ธฐ๋ณธ ์์๋ฅผ ํตํด ๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋
๋ช
์์ ๋๊ธฐํ ์์ด ํจ์จ์ ์ธ ๋ณ๋ ฌ ๋ฆฌ๋์
, ์ ๋ ฌ ๋คํธ์ํฌ, ๊ณ ๊ธ ์กฐ์ ์๊ณ ๋ฆฌ์ฆ์
๊ตฌํํ ์ ์์ต๋๋ค.
ํต์ฌ ํต์ฐฐ: shuffle_xor() ์ฐ์ฐ์ SIMT ์คํ์ ํ์ฉํ์ฌ XOR ๊ธฐ๋ฐ ํต์ ํธ๋ฆฌ๋ฅผ ์์ฑํ๋ฉฐ, ์ํ ํฌ๊ธฐ์ ๋ํด \(O(\log n)\) ๋ณต์ก๋๋ก ํ์ฅ๋๋ ํจ์จ์ ์ธ ๋ฒํฐํ๋ผ์ด ๋คํธ์ํฌ์ ๋ณ๋ ฌ ์๊ณ ๋ฆฌ์ฆ์ ๊ฐ๋ฅํ๊ฒ ํฉ๋๋ค.
๋ฒํฐํ๋ผ์ด ๋คํธ์ํฌ๋? ๋ฒํฐํ๋ผ์ด ๋คํธ์ํฌ๋ ์ค๋ ๋๋ค์ด ์ธ๋ฑ์ค์ XOR ํจํด์ ๋ฐ๋ผ ๋ฐ์ดํฐ๋ฅผ ๊ตํํ๋ ํต์ ํ ํด๋ก์ง์ ๋๋ค. ์ด๋ฆ์ ์๊ฐ์ ์ผ๋ก ๊ทธ๋ ธ์ ๋ ๋๋น ๋ ๊ฐ์ฒ๋ผ ๋ณด์ด๋ ์ฐ๊ฒฐ ํจํด์์ ์ ๋ํ์ต๋๋ค. ์ด ๋คํธ์ํฌ๋ \(O(\log n)\) ํต์ ๋ณต์ก๋๋ฅผ ๊ฐ๋ฅํ๊ฒ ํ๊ธฐ ๋๋ฌธ์ FFT, bitonic ์ ๋ ฌ, ๋ณ๋ ฌ ๋ฆฌ๋์ ๊ฐ์ ๋ณ๋ ฌ ์๊ณ ๋ฆฌ์ฆ์ ๊ธฐ๋ฐ์ด ๋ฉ๋๋ค.
ํต์ฌ ๊ฐ๋
์ด ํผ์ฆ์์ ๋ฐฐ์ธ ๋ด์ฉ:
shuffle_xor()์ ํ์ฉํ XOR ๊ธฐ๋ฐ ํต์ ํจํด- ๋ณ๋ ฌ ์๊ณ ๋ฆฌ์ฆ์ ์ํ ๋ฒํฐํ๋ผ์ด ๋คํธ์ํฌ ํ ํด๋ก์ง
- \(O(\log n)\) ๋ณต์ก๋์ ํธ๋ฆฌ ๊ธฐ๋ฐ ๋ณ๋ ฌ ๋ฆฌ๋์
- ๊ณ ๊ธ ์กฐ์ ์ ์ํ ์กฐ๊ฑด๋ถ ๋ฒํฐํ๋ผ์ด ์ฐ์ฐ
- ๋ณต์กํ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋์ฒดํ๋ ํ๋์จ์ด ์ต์ ํ ๋ณ๋ ฌ ๊ธฐ๋ณธ ์์
shuffle_xor ์ฐ์ฐ์ ๊ฐ ๋ ์ธ์ด XOR
ํจํด์ ๋ฐ๋ผ ๋ค๋ฅธ ๋ ์ธ๊ณผ ๋ฐ์ดํฐ๋ฅผ ๊ตํํ ์ ์๊ฒ ํฉ๋๋ค: \[\Large
\text{shuffle_xor}(\text{value}, \text{mask}) =
\text{value_from_lane}(\text{lane_id} \oplus \text{mask})\]
์ด๋ฅผ ํตํด ๋ณต์กํ ๋ณ๋ ฌ ์๊ณ ๋ฆฌ์ฆ์ด ์ฐ์ํ ๋ฒํฐํ๋ผ์ด ํต์ ํจํด์ผ๋ก ๋ณํ๋์ด, ๋ช ์์ ์กฐ์ ์์ด ํจ์จ์ ์ธ ํธ๋ฆฌ ๋ฆฌ๋์ ๊ณผ ์ ๋ ฌ ๋คํธ์ํฌ๊ฐ ๊ฐ๋ฅํฉ๋๋ค.
1. ๊ธฐ๋ณธ ๋ฒํฐํ๋ผ์ด ํ์ด ๊ตํ
๊ตฌ์ฑ
- ๋ฒกํฐ ํฌ๊ธฐ:
SIZE = WARP_SIZE(GPU์ ๋ฐ๋ผ 32 ๋๋ 64) - ๊ทธ๋ฆฌ๋ ๊ตฌ์ฑ:
(1, 1)๊ทธ๋ฆฌ๋๋น ๋ธ๋ก ์ - ๋ธ๋ก ๊ตฌ์ฑ:
(WARP_SIZE, 1)๋ธ๋ก๋น ์ค๋ ๋ ์ - ๋ฐ์ดํฐ ํ์
:
DType.float32 - ๋ ์ด์์:
row_major[SIZE]()(1D row-major)
shuffle_xor ๊ฐ๋
๊ธฐ์กด ํ์ด ๊ตํ ๋ฐฉ์์ ๋ณต์กํ ์ธ๋ฑ์ฑ๊ณผ ์กฐ์ ์ด ํ์ํฉ๋๋ค:
# ๊ธฐ์กด ๋ฐฉ์ - ๋ณต์กํ๊ณ ๋๊ธฐํ๊ฐ ํ์
shared_memory[lane] = input[global_i]
barrier()
if lane % 2 == 0:
partner = lane + 1
else:
partner = lane - 1
if partner < WARP_SIZE:
swapped_val = shared_memory[partner]
๊ธฐ์กด ๋ฐฉ์์ ๋ฌธ์ ์ :
- ๋ฉ๋ชจ๋ฆฌ ์ค๋ฒํค๋: ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ํ ๋น์ด ํ์
- ๋๊ธฐํ: ๋ช ์์ ๋ฐฐ๋ฆฌ์ด๊ฐ ํ์
- ๋ณต์กํ ๋ก์ง: ์๋ ํํธ๋ ๊ณ์ฐ๊ณผ ๊ฒฝ๊ณ ๊ฒ์ฌ
- ๋ฎ์ ํ์ฅ์ฑ: ํ๋์จ์ด ํต์ ์ ํ์ฉํ์ง ๋ชปํจ
shuffle_xor()์ ์ฌ์ฉํ๋ฉด ํ์ด ๊ตํ์ด ์ฐ์ํด์ง๋๋ค:
# ๋ฒํฐํ๋ผ์ด XOR ๋ฐฉ์ - ๊ฐ๋จํ๊ณ ํ๋์จ์ด ์ต์ ํ
current_val = input[global_i]
swapped_val = shuffle_xor(current_val, 1) # 1๊ณผ XORํ๋ฉด ํ์ด๊ฐ ์์ฑ๋จ
output[global_i] = swapped_val
shuffle_xor์ ์ฅ์ :
- ๋ฉ๋ชจ๋ฆฌ ์ค๋ฒํค๋ ์ ๋ก: ๋ ์ง์คํฐ ๊ฐ ์ง์ ํต์
- ๋๊ธฐํ ๋ถํ์: SIMT ์คํ์ด ์ ํ์ฑ์ ๋ณด์ฅ
- ํ๋์จ์ด ์ต์ ํ: ๋ชจ๋ ๋ ์ธ์ ๋ํด ๋จ์ผ ๋ช ๋ น์ผ๋ก ์ฒ๋ฆฌ
- ๋ฒํฐํ๋ผ์ด ๊ธฐ๋ฐ: ๋ณต์กํ ๋ณ๋ ฌ ์๊ณ ๋ฆฌ์ฆ์ ๋น๋ฉ ๋ธ๋ก
์์ฑํ ์ฝ๋
shuffle_xor()์ ์ฌ์ฉํ์ฌ ์ธ์ ํ์ด ๊ฐ ๊ฐ์ ๊ตํํ๋ ํ์ด ๊ตํ์ ๊ตฌํํฉ๋๋ค.
์ํ์ ์ฐ์ฐ: XOR ํจํด์ผ๋ก ์ธ์ ํ์ด๋ฅผ ๋ง๋ค์ด ๊ฐ์ ๊ตํํฉ๋๋ค: \[\Large \text{output}[i] = \text{input}[i \oplus 1]\]
์
๋ ฅ ๋ฐ์ดํฐ [0, 1, 2, 3, 4, 5, 6, 7, ...]์ ํ์ด
[1, 0, 3, 2, 5, 4, 7, 6, ...]์ผ๋ก ๋ณํํ๋ฉฐ, ๊ฐ ํ์ด (i, i+1)์ด XOR ํต์ ์ผ๋ก
๊ฐ์ ๊ตํํฉ๋๋ค.
์ ์ฒด ํ์ผ ๋ณด๊ธฐ: problems/p26/p26.mojo
ํ
1. shuffle_xor ์ดํดํ๊ธฐ
shuffle_xor(value, mask) ์ฐ์ฐ์ ๊ฐ ๋ ์ธ์ด XOR ๋ง์คํฌ๋งํผ ์ฐจ์ด๋๋ ๋ ์ธ๊ณผ
๋ฐ์ดํฐ๋ฅผ ๊ตํํ ์ ์๊ฒ ํฉ๋๋ค. ์๋ก ๋ค๋ฅธ ๋ง์คํฌ ๊ฐ์ผ๋ก ๋ ์ธ ID๋ฅผ XORํ์ ๋
์ด๋ค ์ผ์ด ์ผ์ด๋๋์ง ์๊ฐํด ๋ณด์ธ์.
ํ๊ตฌํ ํต์ฌ ์ง๋ฌธ:
- ๋ ์ธ 0์ด ๋ง์คํฌ 1๋ก XORํ๋ฉด ์ด๋ค ํํธ๋๋ฅผ ์ป๋์?
- ๋ ์ธ 1์ด ๋ง์คํฌ 1๋ก XORํ๋ฉด ์ด๋ค ํํธ๋๋ฅผ ์ป๋์?
- ํจํด์ด ๋ณด์ด๋์?
ํํธ: ์ฒ์ ๋ช ๊ฐ์ ๋ ์ธ ID์ ๋ํด XOR ์ฐ์ฐ์ ์ง์ ํด๋ณด๋ฉด ํ์ด๋ง ํจํด์ ์ดํดํ ์ ์์ต๋๋ค.
2. XOR ํ์ด ํจํด
๋ ์ธ ID์ ์ด์ง ํํ๊ณผ ์ตํ์ ๋นํธ๋ฅผ ๋ค์ง์ผ๋ฉด ์ด๋ป๊ฒ ๋๋์ง ์๊ฐํด ๋ณด์ธ์.
๊ณ ๋ คํ ์ง๋ฌธ:
- ์ง์ ๋ ์ธ์ 1๊ณผ XORํ๋ฉด ์ด๋ป๊ฒ ๋๋์?
- ํ์ ๋ ์ธ์ 1๊ณผ XORํ๋ฉด ์ด๋ป๊ฒ ๋๋์?
- ์ ์ด๊ฒ์ด ์๋ฒฝํ ํ์ด๋ฅผ ๋ง๋๋์?
3. ๊ฒฝ๊ณ ๊ฒ์ฌ ๋ถํ์
shuffle_down()๊ณผ ๋ฌ๋ฆฌ shuffle_xor() ์ฐ์ฐ์ ์ํ ๊ฒฝ๊ณ ๋ด์์ ์ ์ง๋ฉ๋๋ค. ์์
๋ง์คํฌ๋ก์ XOR์ด ์ ๋๋ก ๋ฒ์ ๋ฐ์ ๋ ์ธ ID๋ฅผ ๋ง๋ค์ง ์๋ ์ด์ ๋ฅผ ์๊ฐํด ๋ณด์ธ์.
์๊ฐํด ๋ณด์ธ์: ์ ํจํ ๋ ์ธ ID๋ฅผ 1๊ณผ XORํ์ ๋ ๋์ฌ ์ ์๋ ์ต๋ ๋ ์ธ ID๋ ์ผ๋ง์ธ๊ฐ์?
๋ฒํฐํ๋ผ์ด ํ์ด ๊ตํ ํ ์คํธ:
pixi run p26 --pair-swap
pixi run -e amd p26 --pair-swap
pixi run -e apple p26 --pair-swap
uv run poe p26 --pair-swap
ํ์์ ๋์ ์์ ์ถ๋ ฅ:
WARP_SIZE: 32
SIZE: 32
output: [1.0, 0.0, 3.0, 2.0, 5.0, 4.0, 7.0, 6.0, 9.0, 8.0, 11.0, 10.0, 13.0, 12.0, 15.0, 14.0, 17.0, 16.0, 19.0, 18.0, 21.0, 20.0, 23.0, 22.0, 25.0, 24.0, 27.0, 26.0, 29.0, 28.0, 31.0, 30.0]
expected: [1.0, 0.0, 3.0, 2.0, 5.0, 4.0, 7.0, 6.0, 9.0, 8.0, 11.0, 10.0, 13.0, 12.0, 15.0, 14.0, 17.0, 16.0, 19.0, 18.0, 21.0, 20.0, 23.0, 22.0, 25.0, 24.0, 27.0, 26.0, 29.0, 28.0, 31.0, 30.0]
โ
Butterfly pair swap test passed!
์๋ฃจ์
def butterfly_pair_swap[
size: Int
](
output: TileTensor[mut=True, dtype, LayoutType, MutAnyOrigin],
input: TileTensor[mut=False, dtype, LayoutType, ImmutAnyOrigin],
):
"""
Basic butterfly pair swap: Exchange values between adjacent pairs using XOR pattern.
Each thread exchanges its value with its XOR-1 neighbor, creating pairs: (0,1), (2,3), (4,5), etc.
Uses shuffle_xor(val, 1) to swap values within each pair.
This is the foundation of butterfly network communication patterns.
"""
var global_i = block_dim.x * block_idx.x + thread_idx.x
if global_i < size:
var current_val = input[global_i]
# Exchange with XOR-1 neighbor using butterfly pattern
# Lane 0 exchanges with lane 1, lane 2 with lane 3, etc.
var swapped_val = shuffle_xor(current_val, 1)
# For demonstration, we'll store the swapped value
# In real applications, this might be used for sorting, reduction, etc.
output[global_i] = swapped_val
์ด ํ์ด๋ shuffle_xor()์ด XOR ํต์ ํจํด์ ํตํด ์๋ฒฝํ ํ์ด ๊ตํ์ ์ด๋ป๊ฒ
๋ง๋๋์ง ๋ณด์ฌ์ค๋๋ค.
์๊ณ ๋ฆฌ์ฆ ๋ถ์:
if global_i < size:
current_val = input[global_i] # ๊ฐ ๋ ์ธ์ด ์์ ์ ์์๋ฅผ ์ฝ์
swapped_val = shuffle_xor(current_val, 1) # XOR๋ก ํ์ด ๊ตํ ์์ฑ
# ๊ตํ๋ ๊ฐ์ ์ ์ฅ
output[global_i] = swapped_val
SIMT ์คํ ์์ธ ๋ถ์:
์ฌ์ดํด 1: ๋ชจ๋ ๋ ์ธ์ด ๋์์ ๊ฐ์ ๋ก๋
Lane 0: current_val = input[0] = 0
Lane 1: current_val = input[1] = 1
Lane 2: current_val = input[2] = 2
Lane 3: current_val = input[3] = 3
...
Lane 31: current_val = input[31] = 31
์ฌ์ดํด 2: shuffle_xor(current_val, 1)์ด ๋ชจ๋ ๋ ์ธ์์ ์คํ
Lane 0: Lane 1์์ ์์ (0โ1=1) โ swapped_val = 1
Lane 1: Lane 0์์ ์์ (1โ1=0) โ swapped_val = 0
Lane 2: Lane 3์์ ์์ (2โ1=3) โ swapped_val = 3
Lane 3: Lane 2์์ ์์ (3โ1=2) โ swapped_val = 2
...
Lane 30: Lane 31์์ ์์ (30โ1=31) โ swapped_val = 31
Lane 31: Lane 30์์ ์์ (31โ1=30) โ swapped_val = 30
์ฌ์ดํด 3: ๊ฒฐ๊ณผ ์ ์ฅ
Lane 0: output[0] = 1
Lane 1: output[1] = 0
Lane 2: output[2] = 3
Lane 3: output[3] = 2
...
์ํ์ ํต์ฐฐ: XOR ์์ฑ์ ํ์ฉํ ์๋ฒฝํ ํ์ด ๊ตํ์ ๊ตฌํํฉ๋๋ค: \[\Large \text{XOR}(i, 1) = \begin{cases} i + 1 & \text{if } i \bmod 2 = 0 \\ i - 1 & \text{if } i \bmod 2 = 1 \end{cases}\]
shuffle_xor์ด ์ฐ์ํ ์ด์ :
- ์๋ฒฝํ ๋์นญ: ๋ชจ๋ ๋ ์ธ์ด ์ ํํ ํ๋์ ํ์ด์ ์ฐธ์ฌ
- ์กฐ์ ๋ถํ์: ๋ชจ๋ ํ์ด๊ฐ ๋์์ ๊ตํ
- ํ๋์จ์ด ์ต์ ํ: ์ํ ์ ์ฒด์ ๋ํด ๋จ์ผ ๋ช ๋ น์ผ๋ก ์ฒ๋ฆฌ
- ๋ฒํฐํ๋ผ์ด ๊ธฐ๋ฐ: ๋ณต์กํ ๋ณ๋ ฌ ์๊ณ ๋ฆฌ์ฆ์ ๋น๋ฉ ๋ธ๋ก
์ฑ๋ฅ ํน์ฑ:
- ์ง์ฐ ์๊ฐ: 1 ์ฌ์ดํด (ํ๋์จ์ด ๋ ์ง์คํฐ ๊ตํ)
- ๋์ญํญ: 0 ๋ฐ์ดํธ (๋ฉ๋ชจ๋ฆฌ ํธ๋ํฝ ์์)
- ๋ณ๋ ฌ์ฑ: WARP_SIZE๊ฐ ๋ ์ธ ๋ชจ๋ ๋์์ ๊ตํ
- ํ์ฅ์ฑ: ๋ฐ์ดํฐ ํฌ๊ธฐ์ ๊ด๊ณ์์ด \(O(1)\) ๋ณต์ก๋
2. ๋ฒํฐํ๋ผ์ด ๋ณ๋ ฌ ์ต๋๊ฐ
๊ตฌ์ฑ
- ๋ฒกํฐ ํฌ๊ธฐ:
SIZE = WARP_SIZE(GPU์ ๋ฐ๋ผ 32 ๋๋ 64) - ๊ทธ๋ฆฌ๋ ๊ตฌ์ฑ:
(1, 1)๊ทธ๋ฆฌ๋๋น ๋ธ๋ก ์ - ๋ธ๋ก ๊ตฌ์ฑ:
(WARP_SIZE, 1)๋ธ๋ก๋น ์ค๋ ๋ ์
์์ฑํ ์ฝ๋
๊ฐ์ํ๋ offset์ผ๋ก ๋ฒํฐํ๋ผ์ด shuffle_xor์ ์ฌ์ฉํ์ฌ ๋ณ๋ ฌ ์ต๋๊ฐ ๋ฆฌ๋์
์
๊ตฌํํฉ๋๋ค.
์ํ์ ์ฐ์ฐ: ํธ๋ฆฌ ๋ฆฌ๋์ ์ ํตํด ๋ชจ๋ ์ํ ๋ ์ธ์์ ์ต๋๊ฐ์ ๊ณ์ฐํฉ๋๋ค: \[\Large \text{max_result} = \max_{i=0}^{\small\text{WARP_SIZE}-1} \text{input}[i]\]
๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์
ํจํด: XOR ์คํ์
์ WARP_SIZE/2์์ 1๊น์ง ์ ๋ฐ์ฉ
์ค์ฌ๊ฐ๋ฉฐ, ํต์ ๋ฒ์๊ฐ ๋จ๊ณ๋ง๋ค ๋ฐ์ผ๋ก ์ข์์ง๋ ์ด์ง ํธ๋ฆฌ๋ฅผ ๊ตฌ์ฑํฉ๋๋ค:
- 1๋จ๊ณ:
WARP_SIZE/2๊ฑฐ๋ฆฌ์ ๋ ์ธ๊ณผ ๋น๊ต (์ํ ์ ์ฒด๋ฅผ ํฌ๊ด) - 2๋จ๊ณ:
WARP_SIZE/4๊ฑฐ๋ฆฌ์ ๋ ์ธ๊ณผ ๋น๊ต (๋ฒ์๋ฅผ ์ ๋ฐ์ผ๋ก ์ขํ) - 3๋จ๊ณ:
WARP_SIZE/8๊ฑฐ๋ฆฌ์ ๋ ์ธ๊ณผ ๋น๊ต - 4๋จ๊ณ:
offset = 1์ด ๋ ๋๊น์ง ๊ณ์ ์ ๋ฐ์ผ๋ก ์ค์
\(\log_2(\text{WARP_SIZE})\) ๋จ๊ณ๋ฅผ ๊ฑฐ์น๋ฉด ๋ชจ๋ ๋ ์ธ์ด ์ ์ญ ์ต๋๊ฐ์ ๊ฐ๊ฒ
๋ฉ๋๋ค. ์ด ๋ฐฉ์์ ๋ชจ๋ WARP_SIZE (32, 64 ๋ฑ)์์ ๋์ํฉ๋๋ค.
def butterfly_parallel_max[
size: Int
](
output: TileTensor[mut=True, dtype, LayoutType, MutAnyOrigin],
input: TileTensor[mut=False, dtype, LayoutType, ImmutAnyOrigin],
):
"""
Parallel maximum reduction using butterfly pattern.
Uses shuffle_xor with decreasing offsets starting from WARP_SIZE/2 down to 1.
Each step reduces the active range by half until all threads have the maximum value.
This implements an efficient O(log n) parallel reduction algorithm that works
for any WARP_SIZE (32, 64, etc.).
"""
var global_i = block_dim.x * block_idx.x + thread_idx.x
# FILL ME IN (roughly 7 lines)
ํ
1. ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์ ์ดํดํ๊ธฐ
๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์ ์ ์ด์ง ํธ๋ฆฌ ํต์ ํจํด์ ์์ฑํฉ๋๋ค. ๊ฐ ๋จ๊ณ์์ ๋ฌธ์ ํฌ๊ธฐ๋ฅผ ์ฒด๊ณ์ ์ผ๋ก ์ค์ด๋ ๋ฐฉ๋ฒ์ ์๊ฐํด ๋ณด์ธ์.
ํต์ฌ ์ง๋ฌธ:
- ์ต๋ ๋ฒ์๋ฅผ ์ปค๋ฒํ๋ ค๋ฉด ์์ offset์ด ์ผ๋ง์ฌ์ผ ํ๋์?
- ๋จ๊ณ ์ฌ์ด์ ์คํ์ ์ ์ด๋ป๊ฒ ๋ณ๊ฒฝํด์ผ ํ๋์?
- ์ธ์ ๋ฆฌ๋์ ์ ๋ฉ์ถฐ์ผ ํ๋์?
ํํธ: โ๋ฒํฐํ๋ผ์ดโ๋ผ๋ ์ด๋ฆ์ ํต์ ํจํด์์ ์ ๋ํฉ๋๋ค - ์์ ์์ ์ ๋ํด ์ง์ ๊ทธ๋ ค๋ณด์ธ์.
2. XOR ๋ฆฌ๋์ ํน์ฑ
XOR์ ๊ฐ ๋จ๊ณ์์ ๊ฒน์น์ง ์๋ ํต์ ํ์ด๋ฅผ ์์ฑํฉ๋๋ค. ์ด๊ฒ์ด ๋ณ๋ ฌ ๋ฆฌ๋์ ์์ ์ ์ค์ํ์ง ์๊ฐํด ๋ณด์ธ์.
์๊ฐํด ๋ณด์ธ์:
- ์๋ก ๋ค๋ฅธ ์คํ์ ์ผ๋ก์ XOR์ด ์ด๋ป๊ฒ ๋ค๋ฅธ ํต์ ํจํด์ ๋ง๋๋์?
- ๊ฐ์ ๋จ๊ณ์์ ๋ ์ธ๋ค์ด ์ ์๋ก ๊ฐ์ญํ์ง ์๋์?
- XOR์ด ํธ๋ฆฌ ๋ฆฌ๋์ ์ ํนํ ์ ํฉํ ์ด์ ๋ ๋ฌด์์ธ๊ฐ์?
3. ์ต๋๊ฐ ๋์
๊ฐ ๋ ์ธ์ ์์ ์ โ์์ญโ์์ ์ต๋๊ฐ์ ์ง์์ ์ ์ง์ ์ผ๋ก ์์๊ฐ์ผ ํฉ๋๋ค.
์๊ณ ๋ฆฌ์ฆ ๊ตฌ์กฐ:
- ์์ ์ ๊ฐ์ผ๋ก ์์
- ๊ฐ ๋จ๊ณ์์ ์ด์์ ๊ฐ๊ณผ ๋น๊ต
- ์ต๋๊ฐ์ ์ ์งํ๊ณ ๊ณ์ ์งํ
ํต์ฌ ํต์ฐฐ: ๊ฐ ๋จ๊ณ ํ, โ์ง์์ ์์ญโ์ด ๋ ๋ฐฐ๋ก ํ์ฅ๋ฉ๋๋ค.
- ๋ง์ง๋ง ๋จ๊ณ ํ: ๊ฐ ๋ ์ธ์ด ์ ์ญ ์ต๋๊ฐ์ ์๊ฒ ๋ฉ๋๋ค
4. ์ด ํจํด์ด ๋์ํ๋ ์ด์
๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์ ์ \(\log_2(\text{WARP_SIZE})\) ๋จ๊ณ ํ์ ๋ค์์ ๋ณด์ฅํฉ๋๋ค:
- ๋ชจ๋ ๋ ์ธ์ด ๋ค๋ฅธ ๋ชจ๋ ๋ ์ธ์ ๊ฐ์ ๊ฐ์ ์ ์ผ๋ก ํ์ธ
- ์ค๋ณต ํต์ ์์: ๊ฐ ํ์ด๊ฐ ๋จ๊ณ๋น ์ ํํ ํ ๋ฒ ๊ตํ
- ์ต์ ๋ณต์ก๋: \(O(n)\) ์์ฐจ ๋น๊ต ๋์ \(O(\log n)\) ๋จ๊ณ
์ถ์ ์์ (4๊ฐ ๋ ์ธ, ๊ฐ [3, 1, 7, 2]):
์ด๊ธฐ ์ํ: Lane 0=3, Lane 1=1, Lane 2=7, Lane 3=2
1๋จ๊ณ (offset=2): 0 โ 2, 1 โ 3
Lane 0: max(3, 7) = 7
Lane 1: max(1, 2) = 2
Lane 2: max(7, 3) = 7
Lane 3: max(2, 1) = 2
2๋จ๊ณ (offset=1): 0 โ 1, 2 โ 3
Lane 0: max(7, 2) = 7
Lane 1: max(2, 7) = 7
Lane 2: max(7, 2) = 7
Lane 3: max(2, 7) = 7
๊ฒฐ๊ณผ: ๋ชจ๋ ๋ ์ธ์ด ์ ์ญ ์ต๋๊ฐ = 7์ ๊ฐ์ง
๋ฒํฐํ๋ผ์ด ๋ณ๋ ฌ ์ต๋๊ฐ ํ ์คํธ:
pixi run p26 --parallel-max
pixi run -e amd p26 --parallel-max
uv run poe p26 --parallel-max
ํ์์ ๋์ ์์ ์ถ๋ ฅ:
WARP_SIZE: 32
SIZE: 32
output: [1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0]
expected: [1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0]
โ
Butterfly parallel max test passed!
์๋ฃจ์
def butterfly_parallel_max[
size: Int
](
output: TileTensor[mut=True, dtype, LayoutType, MutAnyOrigin],
input: TileTensor[mut=False, dtype, LayoutType, ImmutAnyOrigin],
):
"""
Parallel maximum reduction using butterfly pattern.
Uses shuffle_xor with decreasing offsets (16, 8, 4, 2, 1) to perform tree-based reduction.
Each step reduces the active range by half until all threads have the maximum value.
This implements an efficient O(log n) parallel reduction algorithm.
"""
var global_i = block_dim.x * block_idx.x + thread_idx.x
if global_i < size:
var max_val = input[global_i]
# Butterfly reduction tree: dynamic for any WARP_SIZE (32, 64, etc.)
# Start with half the warp size and reduce by half each step
var offset = WARP_SIZE // 2
while offset > 0:
max_val = max(max_val, shuffle_xor(max_val, UInt32(offset)))
offset //= 2
# All threads now have the maximum value across the entire warp
output[global_i] = max_val
์ด ํ์ด๋ shuffle_xor()์ด \(O(\log n)\) ๋ณต์ก๋์ ํจ์จ์ ์ธ ๋ณ๋ ฌ ๋ฆฌ๋์
ํธ๋ฆฌ๋ฅผ ์ด๋ป๊ฒ ์์ฑํ๋์ง ๋ณด์ฌ์ค๋๋ค.
์ ์ฒด ์๊ณ ๋ฆฌ์ฆ ๋ถ์:
if global_i < size:
max_val = input[global_i] # ๋ก์ปฌ ๊ฐ์ผ๋ก ์์
# ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์
ํธ๋ฆฌ: ๋ชจ๋ WARP_SIZE์ ๋์ ์ผ๋ก ๋์
offset = WARP_SIZE // 2
while offset > 0:
max_val = max(max_val, shuffle_xor(max_val, offset))
offset //= 2
output[global_i] = max_val # ๋ชจ๋ ๋ ์ธ์ด ์ ์ญ ์ต๋๊ฐ์ ๊ฐ์ง
๋ฒํฐํ๋ผ์ด ์คํ ์ถ์ (8-๋ ์ธ ์์ , ๊ฐ [0,2,4,6,8,10,12,1000]):
์ด๊ธฐ ์ํ:
Lane 0: max_val = 0, Lane 1: max_val = 2
Lane 2: max_val = 4, Lane 3: max_val = 6
Lane 4: max_val = 8, Lane 5: max_val = 10
Lane 6: max_val = 12, Lane 7: max_val = 1000
1๋จ๊ณ: shuffle_xor(max_val, 4) - ์ ๋ฐ ๊ตํ
Lane 0โ4: max(0,8)=8, Lane 1โ5: max(2,10)=10
Lane 2โ6: max(4,12)=12, Lane 3โ7: max(6,1000)=1000
Lane 4โ0: max(8,0)=8, Lane 5โ1: max(10,2)=10
Lane 6โ2: max(12,4)=12, Lane 7โ3: max(1000,6)=1000
2๋จ๊ณ: shuffle_xor(max_val, 2) - 1/4 ๊ตํ
Lane 0โ2: max(8,12)=12, Lane 1โ3: max(10,1000)=1000
Lane 2โ0: max(12,8)=12, Lane 3โ1: max(1000,10)=1000
Lane 4โ6: max(8,12)=12, Lane 5โ7: max(10,1000)=1000
Lane 6โ4: max(12,8)=12, Lane 7โ5: max(1000,10)=1000
3๋จ๊ณ: shuffle_xor(max_val, 1) - ํ์ด ๊ตํ
Lane 0โ1: max(12,1000)=1000, Lane 1โ0: max(1000,12)=1000
Lane 2โ3: max(12,1000)=1000, Lane 3โ2: max(1000,12)=1000
Lane 4โ5: max(12,1000)=1000, Lane 5โ4: max(1000,12)=1000
Lane 6โ7: max(12,1000)=1000, Lane 7โ6: max(1000,12)=1000
์ต์ข
๊ฒฐ๊ณผ: ๋ชจ๋ ๋ ์ธ์ max_val = 1000
์ํ์ ํต์ฐฐ: ๋ฒํฐํ๋ผ์ด ํต์ ์ผ๋ก ๋ณ๋ ฌ ๋ฆฌ๋์ ์ฐ์ฐ์๋ฅผ ๊ตฌํํฉ๋๋ค: \[\Large \text{Reduce}(\oplus, [a_0, a_1, \ldots, a_{n-1}]) = a_0 \oplus a_1 \oplus \cdots \oplus a_{n-1}\]
์ฌ๊ธฐ์ \(\oplus\)๋ max ์ฐ์ฐ์ด๋ฉฐ, ๋ฒํฐํ๋ผ์ด ํจํด์ด ์ต์ \(O(\log n)\)
๋ณต์ก๋๋ฅผ ๋ณด์ฅํฉ๋๋ค.
๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์ ์ด ์ฐ์ํ ์ด์ :
- ๋ก๊ทธ ๋ณต์ก๋: ์์ฐจ ๋ฆฌ๋์ ์ \(O(n)\)์ ๋นํด \(O(\log n)\)
- ์๋ฒฝํ ๋ถํ ๋ถ์ฐ: ๋ชจ๋ ๋ ์ธ์ด ๊ฐ ๋จ๊ณ์์ ๋๋ฑํ๊ฒ ์ฐธ์ฌ
- ๋ฉ๋ชจ๋ฆฌ ๋ณ๋ชฉ ์์: ์์ ๋ ์ง์คํฐ ๊ฐ ํต์
- ํ๋์จ์ด ์ต์ ํ: GPU ๋ฒํฐํ๋ผ์ด ๋คํธ์ํฌ์ ์ง์ ๋งคํ
์ฑ๋ฅ ํน์ฑ:
- ๋จ๊ณ ์: \(\log_2(\text{WARP_SIZE})\) (์: 32-์ค๋ ๋ ์ํ๋ 5๋จ๊ณ, 64-์ค๋ ๋ ์ํ๋ 6๋จ๊ณ)
- ๋จ๊ณ๋น ์ง์ฐ ์๊ฐ: 1 ์ฌ์ดํด (๋ ์ง์คํฐ ๊ตํ + ๋น๊ต)
- ์ด ์ง์ฐ ์๊ฐ: ์์ฐจ ๋ฐฉ์์ \((\text{WARP_SIZE}-1)\) ์ฌ์ดํด ๋๋น \(\log_2(\text{WARP_SIZE})\) ์ฌ์ดํด
- ๋ณ๋ ฌ์ฑ: ์๊ณ ๋ฆฌ์ฆ ์ ์ฒด์์ ๋ชจ๋ ๋ ์ธ์ด ํ์ฑ ์ํ
3. ๋ฒํฐํ๋ผ์ด ์กฐ๊ฑด๋ถ ์ต๋๊ฐ
๊ตฌ์ฑ
- ๋ฒกํฐ ํฌ๊ธฐ:
SIZE_2 = 64(๋ฉํฐ ๋ธ๋ก ์๋๋ฆฌ์ค) - ๊ทธ๋ฆฌ๋ ๊ตฌ์ฑ:
BLOCKS_PER_GRID_2 = (2, 1)๊ทธ๋ฆฌ๋๋น ๋ธ๋ก ์ - ๋ธ๋ก ๊ตฌ์ฑ:
THREADS_PER_BLOCK_2 = (WARP_SIZE, 1)๋ธ๋ก๋น ์ค๋ ๋ ์
์์ฑํ ์ฝ๋
์ง์ ๋ ์ธ์ ์ต๋๊ฐ์, ํ์ ๋ ์ธ์ ์ต์๊ฐ์ ์ ์ฅํ๋ ์กฐ๊ฑด๋ถ ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์ ์ ๊ตฌํํฉ๋๋ค.
์ํ์ ์ฐ์ฐ: ์ต๋๊ฐ๊ณผ ์ต์๊ฐ ๋ชจ๋์ ๋ํด ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์ ์ ์ํํ ํ, ๋ ์ธ ํ์ง์ ๋ฐ๋ผ ์กฐ๊ฑด๋ถ๋ก ์ถ๋ ฅํฉ๋๋ค: \[\Large \text{output}[i] = \begin{cases} \max_{j=0}^{\text{WARP_SIZE}-1} \text{input}[j] & \text{if} i \bmod 2 = 0 \\ \min_{j=0}^{\text{WARP_SIZE}-1} \text{input}[j] & \text{if } i \bmod 2 = 1 \end{cases}\]
์ด์ค ๋ฆฌ๋์ ํจํด: ๋ฒํฐํ๋ผ์ด ํธ๋ฆฌ๋ฅผ ํตํด ์ต๋๊ฐ๊ณผ ์ต์๊ฐ์ ๋์์ ์ถ์ ํ ํ, ๋ ์ธ ID ํ์ง์ ๋ฐ๋ผ ์กฐ๊ฑด๋ถ๋ก ์ถ๋ ฅํฉ๋๋ค. ์ด๋ ๋ฒํฐํ๋ผ์ด ํจํด์ด ๋ณต์กํ ๋ค์ค ๊ฐ ๋ฆฌ๋์ ์ผ๋ก ์ด๋ป๊ฒ ํ์ฅ๋๋์ง๋ฅผ ๋ณด์ฌ์ค๋๋ค.
comptime SIZE_2 = 64
comptime BLOCKS_PER_GRID_2 = (2, 1)
comptime THREADS_PER_BLOCK_2 = (WARP_SIZE, 1)
comptime layout_2 = row_major[SIZE_2]()
comptime LayoutType_2 = type_of(layout_2)
def butterfly_conditional_max[
size: Int
](
output: TileTensor[mut=True, dtype, LayoutType_2, MutAnyOrigin],
input: TileTensor[mut=False, dtype, LayoutType_2, ImmutAnyOrigin],
):
"""
Conditional butterfly maximum: Perform butterfly max reduction, but only store result
in even-numbered lanes. Odd-numbered lanes store the minimum value seen.
Demonstrates conditional logic combined with butterfly communication patterns.
"""
var global_i = block_dim.x * block_idx.x + thread_idx.x
var lane = lane_id()
if global_i < size:
var current_val = input[global_i]
var min_val = current_val
# FILL ME IN (roughly 11 lines)
ํ
1. ์ด์ค ์ถ์ ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์
์ด ํผ์ฆ์ ๋ฒํฐํ๋ผ์ด ํธ๋ฆฌ๋ฅผ ํตํด ๋ ๊ฐ์ง ๋ค๋ฅธ ๊ฐ์ ๋์์ ์ถ์ ํด์ผ ํฉ๋๋ค. ์ฌ๋ฌ ๋ฆฌ๋์ ์ ๋ณ๋ ฌ๋ก ์คํํ๋ ๋ฐฉ๋ฒ์ ์๊ฐํด ๋ณด์ธ์.
ํต์ฌ ์ง๋ฌธ:
- ๋ฆฌ๋์ ๊ณผ์ ์์ ์ต๋๊ฐ๊ณผ ์ต์๊ฐ์ ์ด๋ป๊ฒ ๋์์ ์ ์งํ ์ ์๋์?
- ๋ ์ฐ์ฐ์ ๊ฐ์ ๋ฒํฐํ๋ผ์ด ํจํด์ ์ฌ์ฉํ ์ ์๋์?
- ์ด๋ค ๋ณ์๋ฅผ ์ถ์ ํด์ผ ํ๋์?
2. ์กฐ๊ฑด๋ถ ์ถ๋ ฅ ๋ก์ง
๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์ ์ ์๋ฃํ ํ, ๋ ์ธ ํ์ง์ ๋ฐ๋ผ ๋ค๋ฅธ ๊ฐ์ ์ถ๋ ฅํด์ผ ํฉ๋๋ค.
๊ณ ๋ คํ ์ :
- ๋ ์ธ์ด ์ง์์ธ์ง ํ์์ธ์ง ์ด๋ป๊ฒ ํ๋ณํ๋์?
- ์ด๋ค ๋ ์ธ์ด ์ต๋๊ฐ์, ์ด๋ค ๋ ์ธ์ด ์ต์๊ฐ์ ์ถ๋ ฅํด์ผ ํ๋์?
- ๋ ์ธ ID์ ์ด๋ป๊ฒ ์ ๊ทผํ๋์?
3. min๊ณผ max ๋์ ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์
์ด ๊ณผ์ ์ ํต์ฌ์ ๊ฐ์ ๋ฒํฐํ๋ผ์ด ํต์ ํจํด์ผ๋ก min๊ณผ max๋ฅผ ํจ์จ์ ์ผ๋ก ๋ณ๋ ฌ ๊ณ์ฐํ๋ ๊ฒ์ ๋๋ค.
์๊ฐํด ๋ณด์ธ์:
- min๊ณผ max์ ๋ณ๋์ ์ ํ ์ฐ์ฐ์ด ํ์ํ๊ฐ์?
- ๋ ์ฐ์ฐ์ ๊ฐ์ ์ด์ ๊ฐ์ ์ฌ์ฌ์ฉํ ์ ์๋์?
- ๋ ๋ฆฌ๋์ ๋ชจ๋ ์ฌ๋ฐ๋ฅด๊ฒ ์๋ฃ๋๋ ค๋ฉด ์ด๋ป๊ฒ ํด์ผ ํ๋์?
4. ๋ฉํฐ ๋ธ๋ก ๊ฒฝ๊ณ ๊ณ ๋ ค์ฌํญ
์ด ํผ์ฆ์ ์ฌ๋ฌ ๋ธ๋ก์ ์ฌ์ฉํฉ๋๋ค. ์ด๊ฒ์ด ๋ฆฌ๋์ ๋ฒ์์ ์ด๋ค ์ํฅ์ ๋ฏธ์น๋์ง ์๊ฐํด ๋ณด์ธ์.
์ค์ํ ๊ณ ๋ ค์ฌํญ:
- ๊ฐ ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์ ์ ๋ฒ์๋ ์ด๋๊น์ง์ธ๊ฐ์?
- ๋ธ๋ก ๊ตฌ์กฐ๊ฐ ๋ ์ธ ๋ฒํธ ๋งค๊ธฐ๊ธฐ์ ์ด๋ค ์ํฅ์ ๋ฏธ์น๋์?
- ์ ์ญ min/max๋ฅผ ๊ณ์ฐํ๋์, ๋ธ๋ก๋ณ min/max๋ฅผ ๊ณ์ฐํ๋์?
๋ฒํฐํ๋ผ์ด ์กฐ๊ฑด๋ถ ์ต๋๊ฐ ํ ์คํธ:
pixi run p26 --conditional-max
pixi run -e amd p26 --conditional-max
uv run poe p26 --conditional-max
ํ์์ ๋์ ์์ ์ถ๋ ฅ:
WARP_SIZE: 32
SIZE_2: 64
output: [9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0]
expected: [9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 9.0, 0.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0, 63.0, 32.0]
โ
Butterfly conditional max test passed!
์๋ฃจ์
def butterfly_conditional_max[
size: Int
](
output: TileTensor[mut=True, dtype, Layout2Type, MutAnyOrigin],
input: TileTensor[mut=False, dtype, Layout2Type, ImmutAnyOrigin],
):
"""
Conditional butterfly maximum: Perform butterfly max reduction, but only store result
in even-numbered lanes. Odd-numbered lanes store the minimum value seen.
Demonstrates conditional logic combined with butterfly communication patterns.
"""
var global_i = block_dim.x * block_idx.x + thread_idx.x
var lane = lane_id()
if global_i < size:
var current_val = input[global_i]
var min_val = current_val
# Butterfly reduction for both maximum and minimum: dynamic for any WARP_SIZE
var offset = WARP_SIZE // 2
while offset > 0:
var neighbor_val = shuffle_xor(current_val, UInt32(offset))
current_val = max(current_val, neighbor_val)
var min_neighbor_val = shuffle_xor(min_val, UInt32(offset))
min_val = min(min_val, min_neighbor_val)
offset //= 2
# Conditional output: max for even lanes, min for odd lanes
if lane % 2 == 0:
output[global_i] = current_val # Maximum
else:
output[global_i] = min_val # Minimum
์ด ํ์ด๋ ์ด์ค ์ถ์ ๊ณผ ์กฐ๊ฑด๋ถ ์ถ๋ ฅ์ ์ฌ์ฉํ๋ ๊ณ ๊ธ ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์ ์ ๋ณด์ฌ์ค๋๋ค.
์ ์ฒด ์๊ณ ๋ฆฌ์ฆ ๋ถ์:
if global_i < size:
current_val = input[global_i]
min_val = current_val # ์ต์๊ฐ์ ๋ณ๋๋ก ์ถ์
# max์ min ๋์ ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์
(log_2(WARP_SIZE) ๋จ๊ณ)
offset = WARP_SIZE // 2
while offset > 0:
neighbor_val = shuffle_xor(current_val, offset)
current_val = max(current_val, neighbor_val) # Max ๋ฆฌ๋์
min_neighbor_val = shuffle_xor(min_val, offset)
min_val = min(min_val, min_neighbor_val) # Min ๋ฆฌ๋์
offset //= 2
# ๋ ์ธ ํ์ง์ ๋ฐ๋ฅธ ์กฐ๊ฑด๋ถ ์ถ๋ ฅ
if lane % 2 == 0:
output[global_i] = current_val # ์ง์ ๋ ์ธ: ์ต๋๊ฐ
else:
output[global_i] = min_val # ํ์ ๋ ์ธ: ์ต์๊ฐ
์ด์ค ๋ฆฌ๋์ ์คํ ์ถ์ (4-๋ ์ธ ์์ , ๊ฐ [3, 1, 7, 2]):
์ด๊ธฐ ์ํ:
Lane 0: current_val=3, min_val=3
Lane 1: current_val=1, min_val=1
Lane 2: current_val=7, min_val=7
Lane 3: current_val=2, min_val=2
1๋จ๊ณ: shuffle_xor(current_val, 2)์ shuffle_xor(min_val, 2) - ์ ๋ฐ ๊ตํ
Lane 0โ2: max_neighbor=7, min_neighbor=7 โ current_val=max(3,7)=7, min_val=min(3,7)=3
Lane 1โ3: max_neighbor=2, min_neighbor=2 โ current_val=max(1,2)=2, min_val=min(1,2)=1
Lane 2โ0: max_neighbor=3, min_neighbor=3 โ current_val=max(7,3)=7, min_val=min(7,3)=3
Lane 3โ1: max_neighbor=1, min_neighbor=1 โ current_val=max(2,1)=2, min_val=min(2,1)=1
2๋จ๊ณ: shuffle_xor(current_val, 1)์ shuffle_xor(min_val, 1) - ํ์ด ๊ตํ
Lane 0โ1: max_neighbor=2, min_neighbor=1 โ current_val=max(7,2)=7, min_val=min(3,1)=1
Lane 1โ0: max_neighbor=7, min_neighbor=3 โ current_val=max(2,7)=7, min_val=min(1,3)=1
Lane 2โ3: max_neighbor=2, min_neighbor=1 โ current_val=max(7,2)=7, min_val=min(3,1)=1
Lane 3โ2: max_neighbor=7, min_neighbor=3 โ current_val=max(2,7)=7, min_val=min(1,3)=1
์ต์ข
๊ฒฐ๊ณผ: ๋ชจ๋ ๋ ์ธ์ด current_val=7 (์ ์ญ max)๊ณผ min_val=1 (์ ์ญ min)์ ๊ฐ์ง
๋์ ์๊ณ ๋ฆฌ์ฆ (๋ชจ๋ WARP_SIZE์์ ๋์):
offset = WARP_SIZE // 2
while offset > 0:
neighbor_val = shuffle_xor(current_val, offset)
current_val = max(current_val, neighbor_val)
min_neighbor_val = shuffle_xor(min_val, offset)
min_val = min(min_val, min_neighbor_val)
offset //= 2
์ํ์ ํต์ฐฐ: ์กฐ๊ฑด๋ถ ๋๋ฉํฐํ๋ ์ฑ์ ์ฌ์ฉํ๋ ์ด์ค ๋ณ๋ ฌ ๋ฆฌ๋์ ์ ๊ตฌํํฉ๋๋ค: \[\Large \begin{align} \text{max_result} &= \max_{i=0}^{n-1} \text{input}[i] \\ \text{min_result} &= \min_{i=0}^{n-1} \text{input}[i] \\ \text{output}[i] &= \text{lane_parity}(i) \; \text{?} \; \text{min_result}: \text{max_result} \end{align}\]
์ด์ค ๋ฒํฐํ๋ผ์ด ๋ฆฌ๋์ ์ด ๋์ํ๋ ์ด์ :
- ๋ ๋ฆฝ์ ๋ฆฌ๋์ : Max์ min ๋ฆฌ๋์ ์ ์ํ์ ์ผ๋ก ๋ ๋ฆฝ
- ๋ณ๋ ฌ ์คํ: ๋ ๋ค ๊ฐ์ ๋ฒํฐํ๋ผ์ด ํต์ ํจํด์ ์ฌ์ฉ ๊ฐ๋ฅ
- ํต์ ๊ณต์ : ๊ฐ์ ์ ํ ์ฐ์ฐ์ด ๋ ๋ฆฌ๋์ ๋ชจ๋์ ํ์ฉ
- ์กฐ๊ฑด๋ถ ์ถ๋ ฅ: ๋ ์ธ ํ์ง์ด ์ด๋ค ๊ฒฐ๊ณผ๋ฅผ ์ถ๋ ฅํ ์ง ๊ฒฐ์
์ฑ๋ฅ ํน์ฑ:
- ํต์ ๋จ๊ณ: \(\log_2(\text{WARP_SIZE})\) (๋จ์ผ ๋ฆฌ๋์ ๊ณผ ๋์ผ)
- ๋จ๊ณ๋น ์ฐ์ฐ: ๋จ์ผ ๋ฆฌ๋์ ์ 1๊ฐ ๋๋น 2๊ฐ ์ฐ์ฐ (max + min)
- ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ: ๋ณต์กํ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ๋ฐฉ์ ๋๋น ์ค๋ ๋๋น ๋ ์ง์คํฐ 2๊ฐ
- ์ถ๋ ฅ ์ ์ฐ์ฑ: ์๋ก ๋ค๋ฅธ ๋ ์ธ์ด ๋ค๋ฅธ ๋ฆฌ๋์ ๊ฒฐ๊ณผ๋ฅผ ์ถ๋ ฅ ๊ฐ๋ฅ
์์ฝ
shuffle_xor() ๊ธฐ๋ณธ ์์๋ ํจ์จ์ ์ธ ๋ณ๋ ฌ ์๊ณ ๋ฆฌ์ฆ์ ๊ธฐ๋ฐ์ด ๋๋ ๊ฐ๋ ฅํ
๋ฒํฐํ๋ผ์ด ํต์ ํจํด์ ๊ฐ๋ฅํ๊ฒ ํฉ๋๋ค. ์ธ ๊ฐ์ง ๋ฌธ์ ๋ฅผ ํตํด ๋ค์์ ๋ฐฐ์ ์ต๋๋ค:
ํต์ฌ ๋ฒํฐํ๋ผ์ด ํจํด
-
ํ์ด ๊ตํ (
shuffle_xor(value, 1)):- ์๋ฒฝํ ์ธ์ ํ์ด ์์ฑ: (0,1), (2,3), (4,5), โฆ
- ๋ฉ๋ชจ๋ฆฌ ์ค๋ฒํค๋ ์ ๋ก์ \(O(1)\) ๋ณต์ก๋
- ์ ๋ ฌ ๋คํธ์ํฌ์ ๋ฐ์ดํฐ ์ฌ๋ฐฐ์น์ ๊ธฐ๋ฐ
-
ํธ๋ฆฌ ๋ฆฌ๋์ (๋์ offset:
WARP_SIZE/2โ1):- ๋ก๊ทธ ๋ณ๋ ฌ ๋ฆฌ๋์ : ์์ฐจ์ \(O(n)\) ๋๋น \(O(\log n)\)
- ๋ชจ๋ ๊ฒฐํฉ ์ฐ์ฐ์ ์ ์ฉ ๊ฐ๋ฅ (max, min, sum ๋ฑ)
- ๋ชจ๋ ์ํ ๋ ์ธ์ ๊ฑธ์ณ ์ต์ ์ ๋ถํ ๋ถ์ฐ
-
์กฐ๊ฑด๋ถ ๋ค์ค ๋ฆฌ๋์ (์ด์ค ์ถ์ + ๋ ์ธ ํ์ง):
- ์ฌ๋ฌ ๋ฆฌ๋์ ์ ๋์์ ๋ณ๋ ฌ ์ํ
- ์ค๋ ๋ ํน์ฑ์ ๋ฐ๋ฅธ ์กฐ๊ฑด๋ถ ์ถ๋ ฅ
- ๋ช ์์ ๋๊ธฐํ ์๋ ๊ณ ๊ธ ์กฐ์
ํต์ฌ ์๊ณ ๋ฆฌ์ฆ ํต์ฐฐ
XOR ํต์ ํน์ฑ:
shuffle_xor(value, mask)๊ฐ ๋์นญ์ ์ด๊ณ ๊ฒน์น์ง ์๋ ํ์ด๋ฅผ ์์ฑ- ๊ฐ ๋ง์คํฌ๊ฐ ๊ณ ์ ํ ํต์ ํ ํด๋ก์ง๋ฅผ ์์ฑ
- ์ด์ง XOR ํจํด์์ ๋ฒํฐํ๋ผ์ด ๋คํธ์ํฌ๊ฐ ์์ฐ์ค๋ฝ๊ฒ ๋์ถ
๋์ ์๊ณ ๋ฆฌ์ฆ ์ค๊ณ:
offset = WARP_SIZE // 2
while offset > 0:
neighbor_val = shuffle_xor(current_val, offset)
current_val = operation(current_val, neighbor_val)
offset //= 2
์ฑ๋ฅ ์ด์ :
- ํ๋์จ์ด ์ต์ ํ: ๋ ์ง์คํฐ ๊ฐ ์ง์ ํต์
- ๋๊ธฐํ ๋ถํ์: SIMT ์คํ์ด ์ ํ์ฑ์ ๋ณด์ฅ
- ํ์ฅ ๊ฐ๋ฅํ ๋ณต์ก๋: ๋ชจ๋ WARP_SIZE (32, 64 ๋ฑ)์์ \(O(\log n)\)
- ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ: ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ๋ถํ์
์ค์ฉ์ ํ์ฉ
์ด ๋ฒํฐํ๋ผ์ด ํจํด๋ค์ ๊ธฐ๋ฐ์ด ๋๋ ๋ถ์ผ:
- ๋ณ๋ ฌ ๋ฆฌ๋์ : ํฉ๊ณ, max, min, ๋ ผ๋ฆฌ ์ฐ์ฐ
- ๋์ ํฉ/์ค์บ ์ฐ์ฐ: ๋์ ํฉ, ๋ณ๋ ฌ ์ ๋ ฌ
- FFT ์๊ณ ๋ฆฌ์ฆ: ์ ํธ ์ฒ๋ฆฌ์ ํฉ์ฑ๊ณฑ
- Bitonic ์ ๋ ฌ: ๋ณ๋ ฌ ์ ๋ ฌ ๋คํธ์ํฌ
- ๊ทธ๋ํ ์๊ณ ๋ฆฌ์ฆ: ํธ๋ฆฌ ์ํ์ ์ฐ๊ฒฐ์ฑ
shuffle_xor() ๊ธฐ๋ณธ ์์๋ ๋ณต์กํ ๋ณ๋ ฌ ์กฐ์ ์ ์ฐ์ํ๊ณ ํ๋์จ์ด ์ต์ ํ๋ ํต์
ํจํด์ผ๋ก ๋ณํํ๋ฉฐ, ๋ค์ํ GPU ์ํคํ
์ฒ์์ ํจ์จ์ ์ผ๋ก ํ์ฅ๋ฉ๋๋ค.