warp.broadcast() ์ผ๋๋ค ํต์
์ํ ๋ ๋ฒจ ์กฐ์ ์์๋ broadcast()๋ฅผ ์ฌ์ฉํ์ฌ ํ๋์ ๋ ์ธ์์ ์ํ ๋ด ๋ค๋ฅธ ๋ชจ๋ ๋ ์ธ์ผ๋ก ๋ฐ์ดํฐ๋ฅผ ๊ณต์ ํ ์ ์์ต๋๋ค. ์ด ๊ฐ๋ ฅํ ๊ธฐ๋ณธ ์์๋ฅผ ํตํด ๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ ๋ช
์์ ๋๊ธฐํ ์์ด ๋ธ๋ก ๋ ๋ฒจ ๊ณ์ฐ, ์กฐ๊ฑด๋ถ ๋ก์ง ์กฐ์ , ์ผ๋๋ค ํต์ ํจํด์ ํจ์จ์ ์ผ๋ก ์ํํ ์ ์์ต๋๋ค.
ํต์ฌ ํต์ฐฐ: broadcast() ์ฐ์ฐ์ SIMT ์คํ์ ํ์ฉํ์ฌ ํ๋์ ๋ ์ธ(๋ณดํต ๋ ์ธ 0)์ด ๊ณ์ฐํ ๊ฐ์ ๊ฐ์ ์ํ์ ๋ชจ๋ ๋ ์ธ์ ์ ๋ฌํ๋ฉฐ, ํจ์จ์ ์ธ ์กฐ์ ํจํด๊ณผ ์งํฉ์ ์์ฌ๊ฒฐ์ ์ ๊ฐ๋ฅํ๊ฒ ํฉ๋๋ค.
๋ธ๋ก๋์บ์คํธ ์ฐ์ฐ์ด๋? ๋ธ๋ก๋์บ์คํธ ์ฐ์ฐ์ ํ๋์ ์ค๋ ๋๊ฐ ๊ฐ์ ๊ณ์ฐํ๊ณ ๊ทธ๋ฃน ๋ด ๋ค๋ฅธ ๋ชจ๋ ์ค๋ ๋์ ๊ณต์ ํ๋ ํต์ ํจํด์ ๋๋ค. ๋ธ๋ก ๋ ๋ฒจ ํต๊ณ ๊ณ์ฐ, ์งํฉ์ ์์ฌ๊ฒฐ์ , ์ํ ๋ด ๋ชจ๋ ์ค๋ ๋์ ์ค์ ํ๋ผ๋ฏธํฐ ์ ๋ฌ ๋ฑ์ ์กฐ์ ์์ ์ ํ์์ ์ ๋๋ค.
ํต์ฌ ๊ฐ๋
์ด ํผ์ฆ์์ ๋ฐฐ์ธ ๋ด์ฉ:
broadcast()๋ฅผ ํ์ฉํ ์ํ ๋ ๋ฒจ ๋ธ๋ก๋์บ์คํธ- ์ผ๋๋ค ํต์ ํจํด
- ์งํฉ ๊ณ์ฐ ์ ๋ต
- ๋ ์ธ ๊ฐ ์กฐ๊ฑด๋ถ ์กฐ์
- ๋ธ๋ก๋์บ์คํธ-shuffle ๊ฒฐํฉ ์ฐ์ฐ
broadcast() ์ฐ์ฐ์ ํ๋์ ๋ ์ธ(๊ธฐ๋ณธ์ ์ผ๋ก ๋ ์ธ 0)์ด ์์ ์ ๊ฐ์ ๋ค๋ฅธ ๋ชจ๋ ๋ ์ธ๊ณผ ๊ณต์ ํ ์ ์๊ฒ ํฉ๋๋ค:
\[\Large \text{broadcast}(\text{value}) = \text{value_from_lane_0_to_all_lanes}\]
์ด๋ฅผ ํตํด ๋ณต์กํ ์กฐ์ ํจํด์ด ๊ฐ๋จํ ์ํ ๋ ๋ฒจ ์ฐ์ฐ์ผ๋ก ๋ณํ๋์ด, ๋ช ์์ ๋๊ธฐํ ์์ด ํจ์จ์ ์ธ ์งํฉ ๊ณ์ฐ์ด ๊ฐ๋ฅํฉ๋๋ค.
๋ธ๋ก๋์บ์คํธ ๊ฐ๋
๊ธฐ์กด ์กฐ์ ๋ฐฉ์์ ๋ณต์กํ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ํจํด์ด ํ์ํฉ๋๋ค:
# ๊ธฐ์กด ๋ฐฉ์ - ๋ณต์กํ๊ณ ์ค๋ฅ๊ฐ ๋ฐ์ํ๊ธฐ ์ฌ์
shared_memory[lane] = local_computation()
sync_threads() # ๋น์ฉ์ด ํฐ ๋๊ธฐํ
if lane == 0:
result = compute_from_shared_memory()
sync_threads() # ๋ ๋ค๋ฅธ ๋น์ฉ์ด ํฐ ๋๊ธฐํ
final_result = shared_memory[0] # ๋ชจ๋ ์ค๋ ๋๊ฐ ์ฝ์
๊ธฐ์กด ๋ฐฉ์์ ๋ฌธ์ ์ :
- ๋ฉ๋ชจ๋ฆฌ ์ค๋ฒํค๋: ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ํ ๋น์ด ํ์
- ๋๊ธฐํ: ๋น์ฉ์ด ํฐ ๋ฐฐ๋ฆฌ์ด ์ฐ์ฐ์ด ์ฌ๋ฌ ๋ฒ ํ์
- ๋ณต์กํ ๋ก์ง: ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ์ธ๋ฑ์ค์ ์ ๊ทผ ํจํด ๊ด๋ฆฌ
- ์ค๋ฅ ๋ฐ์ ๊ฐ๋ฅ์ฑ: ๊ฒฝ์ ์ํ๊ฐ ์ฝ๊ฒ ๋ฐ์
broadcast()๋ฅผ ์ฌ์ฉํ๋ฉด ์กฐ์ ์ด ๊ฐ๊ฒฐํด์ง๋๋ค:
# ์ํ ๋ธ๋ก๋์บ์คํธ ๋ฐฉ์ - ๊ฐ๋จํ๊ณ ์์
collective_value = 0
if lane == 0:
collective_value = compute_block_statistic()
collective_value = broadcast(collective_value) # ๋ชจ๋ ๋ ์ธ๊ณผ ๊ณต์
result = use_collective_value(collective_value)
๋ธ๋ก๋์บ์คํธ์ ์ฅ์ :
- ๋ฉ๋ชจ๋ฆฌ ์ค๋ฒํค๋ ์ ๋ก: ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ๋ถํ์
- ์๋ ๋๊ธฐํ: SIMT ์คํ์ด ์ ํ์ฑ์ ๋ณด์ฅ
- ๊ฐ๋จํ ํจํด: ํ๋์ ๋ ์ธ์ด ๊ณ์ฐํ๊ณ ๋ชจ๋ ๋ ์ธ์ด ์์
- ์กฐํฉ ๊ฐ๋ฅ: ๋ค๋ฅธ ์ํ ์ฐ์ฐ๊ณผ ์ฝ๊ฒ ๊ฒฐํฉ
1. ๊ธฐ๋ณธ ๋ธ๋ก๋์บ์คํธ
๋ ์ธ 0์ด ๋ธ๋ก ๋ ๋ฒจ ํต๊ณ๋ฅผ ๊ณ์ฐํ๊ณ ๋ชจ๋ ๋ ์ธ๊ณผ ๊ณต์ ํ๋ ๊ธฐ๋ณธ ๋ธ๋ก๋์บ์คํธ ํจํด์ ๊ตฌํํฉ๋๋ค.
์๊ตฌ์ฌํญ:
- ๋ ์ธ 0์ด ํ์ฌ ๋ธ๋ก์ ์ฒ์ 4๊ฐ ์์์ ํฉ์ ๊ณ์ฐํด์ผ ํฉ๋๋ค
- ์ด ๊ณ์ฐ๋ ๊ฐ์
broadcast()๋ฅผ ์ฌ์ฉํ์ฌ ์ํ์ ๋ค๋ฅธ ๋ชจ๋ ๋ ์ธ๊ณผ ๊ณต์ ํด์ผ ํฉ๋๋ค - ๊ฐ ๋ ์ธ์ ์ด ๊ณต์ ๋ ๊ฐ์ ์์ ์ ์ ๋ ฅ ์์์ ๋ํด์ผ ํฉ๋๋ค
ํ
์คํธ ๋ฐ์ดํฐ: ์
๋ ฅ [1, 2, 3, 4, 5, 6, 7, 8, ...]์ ์ถ๋ ฅ [11, 12, 13, 14, 15, 16, 17, 18, ...]์ ์์ฑํด์ผ ํฉ๋๋ค
๊ณผ์ : ํ๋์ ๋ ์ธ๋ง ๋ธ๋ก ๋ ๋ฒจ ๊ณ์ฐ์ ์ํํ๋, ๋ชจ๋ ๋ ์ธ์ด ๊ทธ ๊ฒฐ๊ณผ๋ฅผ ์์ ์ ๊ฐ๋ณ ์ฐ์ฐ์ ์ฌ์ฉํ๋ ค๋ฉด ์ด๋ป๊ฒ ์กฐ์ ํด์ผ ํ ๊น์?
๊ตฌ์ฑ
- ๋ฒกํฐ ํฌ๊ธฐ:
SIZE = WARP_SIZE(GPU์ ๋ฐ๋ผ 32 ๋๋ 64) - ๊ทธ๋ฆฌ๋ ๊ตฌ์ฑ:
(1, 1)๊ทธ๋ฆฌ๋๋น ๋ธ๋ก ์ - ๋ธ๋ก ๊ตฌ์ฑ:
(WARP_SIZE, 1)๋ธ๋ก๋น ์ค๋ ๋ ์ - ๋ฐ์ดํฐ ํ์
:
DType.float32 - ๋ ์ด์์:
Layout.row_major(SIZE)(1D row-major)
์์ฑํ ์ฝ๋
fn basic_broadcast[
layout: Layout, size: Int
](
output: LayoutTensor[dtype, layout, MutAnyOrigin],
input: LayoutTensor[dtype, layout, ImmutAnyOrigin],
):
"""
Basic broadcast: Lane 0 computes a block-local value, broadcasts it to all lanes.
Each lane then uses this broadcast value in its own computation.
"""
global_i = Int(block_dim.x * block_idx.x + thread_idx.x)
lane = Int(lane_id())
if global_i < size:
var broadcast_value: output.element_type = 0.0
# FILL IN (roughly 10 lines)
์ ์ฒด ํ์ผ ๋ณด๊ธฐ: problems/p25/p25.mojo
ํ
1. ๋ธ๋ก๋์บ์คํธ ๋์ ๋ฐฉ์ ์ดํดํ๊ธฐ
broadcast(value) ์ฐ์ฐ์ ๋ ์ธ 0์ ๊ฐ์ ๊ฐ์ ธ์ ์ํ์ ๋ชจ๋ ๋ ์ธ์ ์ ๋ฌํฉ๋๋ค.
ํต์ฌ ํต์ฐฐ: ๋ธ๋ก๋์บ์คํธ์์๋ ๋ ์ธ 0์ ๊ฐ๋ง ์๋ฏธ๊ฐ ์์ต๋๋ค. ๋ค๋ฅธ ๋ ์ธ์ ๊ฐ์ ๋ฌด์๋์ง๋ง, ๋ชจ๋ ๋ ์ธ์ด ๋ ์ธ 0์ ๊ฐ์ ์์ ํฉ๋๋ค.
์๊ฐํ:
๋ธ๋ก๋์บ์คํธ ์ : ๋ ์ธ 0์ valโ, ๋ ์ธ 1์ valโ, ๋ ์ธ 2๋ valโ, ...
๋ธ๋ก๋์บ์คํธ ํ: ๋ ์ธ 0์ valโ, ๋ ์ธ 1์ valโ, ๋ ์ธ 2๋ valโ, ...
์๊ฐํด ๋ณด์ธ์: ๋ ์ธ 0๋ง ๋ธ๋ก๋์บ์คํธํ๋ ค๋ ๊ฐ์ ๊ณ์ฐํ๋๋ก ํ๋ ค๋ฉด ์ด๋ป๊ฒ ํด์ผ ํ ๊น์?
2. ๋ ์ธ๋ณ ๊ณ์ฐ
๋ ์ธ 0์ด ํน๋ณํ ๊ณ์ฐ์ ์ํํ๊ณ ๋ค๋ฅธ ๋ ์ธ์ ๋๊ธฐํ๋๋ก ์๊ณ ๋ฆฌ์ฆ์ ์ค๊ณํฉ๋๋ค.
๊ณ ๋ คํ ํจํด:
var shared_value = ์ด๊ธฐ๊ฐ
if lane == 0:
# ๋ ์ธ 0๋ง ๊ณ์ฐ
shared_value = ํน๋ณํ_๊ณ์ฐ()
# ๋ชจ๋ ๋ ์ธ์ด ๋ธ๋ก๋์บ์คํธ์ ์ฐธ์ฌ
shared_value = broadcast(shared_value)
ํต์ฌ ์ง๋ฌธ:
- ๋ธ๋ก๋์บ์คํธ ์ ์ ๋ค๋ฅธ ๋ ์ธ์ ๊ฐ์ ์ด๋ค ์ํ์ฌ์ผ ํ ๊น์?
- ๋ ์ธ 0์ด ๋ธ๋ก๋์บ์คํธํ ์ฌ๋ฐ๋ฅธ ๊ฐ์ ๊ฐ๋๋ก ํ๋ ค๋ฉด ์ด๋ป๊ฒ ํด์ผ ํ ๊น์?
3. ์งํฉ์ ํ์ฉ
๋ธ๋ก๋์บ์คํธ ํ ๋ชจ๋ ๋ ์ธ์ด ๊ฐ์ ๊ฐ์ ๊ฐ๊ฒ ๋๋ฉฐ, ์ด๋ฅผ ๊ฐ์์ ๊ฐ๋ณ ๊ณ์ฐ์ ํ์ฉํ ์ ์์ต๋๋ค.
์๊ฐํด ๋ณด์ธ์: ๊ฐ ๋ ์ธ์ด ๋ธ๋ก๋์บ์คํธ ๊ฐ๊ณผ ์์ ์ ๋ก์ปฌ ๋ฐ์ดํฐ๋ฅผ ์ด๋ป๊ฒ ๊ฒฐํฉํ ๊น์?
๊ธฐ๋ณธ ๋ธ๋ก๋์บ์คํธ ํ ์คํธ:
pixi run p25 --broadcast-basic
pixi run -e amd p25 --broadcast-basic
pixi run -e apple p25 --broadcast-basic
uv run poe p25 --broadcast-basic
ํ์์ ๋์ ์์ ์ถ๋ ฅ:
WARP_SIZE: 32
SIZE: 32
output: HostBuffer([11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0])
expected: HostBuffer([11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0])
โ
Basic broadcast test passed!
์๋ฃจ์
fn basic_broadcast[
layout: Layout, size: Int
](
output: LayoutTensor[dtype, layout, MutAnyOrigin],
input: LayoutTensor[dtype, layout, ImmutAnyOrigin],
):
"""
Basic broadcast: Lane 0 computes a block-local value, broadcasts it to all lanes.
Each lane then uses this broadcast value in its own computation.
"""
global_i = Int(block_dim.x * block_idx.x + thread_idx.x)
lane = Int(lane_id())
if global_i < size:
# Step 1: Lane 0 computes special value (sum of first 4 elements in this block)
var broadcast_value: output.element_type = 0.0
if lane == 0:
block_start = Int(block_idx.x * block_dim.x)
var sum: output.element_type = 0.0
for i in range(4):
if block_start + i < size:
sum += input[block_start + i]
broadcast_value = sum
# Step 2: Broadcast lane 0's value to all lanes in this warp
broadcast_value = broadcast(broadcast_value)
# Step 3: All lanes use broadcast value in their computation
output[global_i] = broadcast_value + input[global_i]
์ด ์๋ฃจ์ ์ ์ํ ๋ ๋ฒจ ์กฐ์ ์ ์ํ ๊ธฐ๋ณธ ๋ธ๋ก๋์บ์คํธ ํจํด์ ๋ณด์ฌ์ค๋๋ค.
์๊ณ ๋ฆฌ์ฆ ๋ถ์:
if global_i < size:
# ๋จ๊ณ 1: ๋ ์ธ 0์ด ํน๋ณํ ๊ฐ์ ๊ณ์ฐ
var broadcast_value: output.element_type = 0.0
if lane == 0:
# ๋ ์ธ 0๋ง ์ด ๊ณ์ฐ์ ์ํ
block_start = block_idx.x * block_dim.x
var sum: output.element_type = 0.0
for i in range(4):
if block_start + i < size:
sum += input[block_start + i]
broadcast_value = sum
# ๋จ๊ณ 2: ๋ ์ธ 0์ ๊ฐ์ ๋ชจ๋ ๋ ์ธ๊ณผ ๊ณต์
broadcast_value = broadcast(broadcast_value)
# ๋จ๊ณ 3: ๋ชจ๋ ๋ ์ธ์ด ๋ธ๋ก๋์บ์คํธ ๊ฐ์ ํ์ฉ
output[global_i] = broadcast_value + input[global_i]
SIMT ์คํ ์ถ์ :
์ฌ์ดํด 1: ๋ ์ธ๋ณ ๊ณ์ฐ
๋ ์ธ 0: input[0] + input[1] + input[2] + input[3] = 1+2+3+4 = 10์ ๊ณ์ฐ
๋ ์ธ 1: broadcast_value๋ 0.0 ์ ์ง (๋ ์ธ 0์ด ์๋)
๋ ์ธ 2: broadcast_value๋ 0.0 ์ ์ง (๋ ์ธ 0์ด ์๋)
...
๋ ์ธ 31: broadcast_value๋ 0.0 ์ ์ง (๋ ์ธ 0์ด ์๋)
์ฌ์ดํด 2: broadcast(broadcast_value) ์คํ
๋ ์ธ 0: ์์ ์ ๊ฐ ์ ์ง โ broadcast_value = 10.0
๋ ์ธ 1: ๋ ์ธ 0์ ๊ฐ ์์ โ broadcast_value = 10.0
๋ ์ธ 2: ๋ ์ธ 0์ ๊ฐ ์์ โ broadcast_value = 10.0
...
๋ ์ธ 31: ๋ ์ธ 0์ ๊ฐ ์์ โ broadcast_value = 10.0
์ฌ์ดํด 3: ๋ธ๋ก๋์บ์คํธ ๊ฐ์ ํ์ฉํ ๊ฐ๋ณ ๊ณ์ฐ
๋ ์ธ 0: output[0] = 10.0 + input[0] = 10.0 + 1.0 = 11.0
๋ ์ธ 1: output[1] = 10.0 + input[1] = 10.0 + 2.0 = 12.0
๋ ์ธ 2: output[2] = 10.0 + input[2] = 10.0 + 3.0 = 13.0
...
๋ ์ธ 31: output[31] = 10.0 + input[31] = 10.0 + 32.0 = 42.0
๋ธ๋ก๋์บ์คํธ๊ฐ ์ฐ์ํ ์ด์ :
- ์กฐ์ ํจ์จ์ฑ: ๋จ์ผ ์ฐ์ฐ์ผ๋ก ๋ชจ๋ ๋ ์ธ์ ์กฐ์
- ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ: ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ํ ๋น ๋ถํ์
- ๋๊ธฐํ ๋ถํ์: SIMT ์คํ์ด ์๋์ผ๋ก ์กฐ์ ์ ์ฒ๋ฆฌ
- ํ์ฅ ๊ฐ๋ฅํ ํจํด: ์ํ ํฌ๊ธฐ์ ๋ฌด๊ดํ๊ฒ ๋์ผํ๊ฒ ๋์
์ฑ๋ฅ ํน์ฑ:
- ์ง์ฐ ์๊ฐ: ๋ธ๋ก๋์บ์คํธ ์ฐ์ฐ 1 ์ฌ์ดํด
- ๋์ญํญ: 0 ๋ฐ์ดํธ (๋ ์ง์คํฐ ๊ฐ ์ง์ ํต์ )
- ์กฐ์ : 32๊ฐ ๋ ์ธ ๋ชจ๋ ์๋ ๋๊ธฐํ
2. ์กฐ๊ฑด๋ถ ๋ธ๋ก๋์บ์คํธ
๋ ์ธ 0์ด ๋ธ๋ก ๋ฐ์ดํฐ๋ฅผ ๋ถ์ํ๊ณ ๋ชจ๋ ๋ ์ธ์ ์ํฅ์ ๋ฏธ์น๋ ๊ฒฐ์ ์ ๋ด๋ฆฌ๋ ์กฐ๊ฑด๋ถ ์กฐ์ ์ ๊ตฌํํฉ๋๋ค.
์๊ตฌ์ฌํญ:
- ๋ ์ธ 0์ด ํ์ฌ ๋ธ๋ก์ ์ฒ์ 8๊ฐ ์์๋ฅผ ๋ถ์ํ๊ณ ์ต๋๊ฐ์ ์ฐพ์์ผ ํฉ๋๋ค
- ์ด ์ต๋๊ฐ์
broadcast()๋ฅผ ์ฌ์ฉํ์ฌ ๋ค๋ฅธ ๋ชจ๋ ๋ ์ธ์ ์ ๋ฌํด์ผ ํฉ๋๋ค - ๊ฐ ๋ ์ธ์ ์กฐ๊ฑด๋ถ ๋ก์ง์ ์ ์ฉํฉ๋๋ค: ์์ ์ ์์๊ฐ ์ต๋๊ฐ์ ์ ๋ฐ๋ณด๋ค ํฌ๋ฉด 2๋ฐฐ๋ก, ๊ทธ๋ ์ง ์์ผ๋ฉด ์ ๋ฐ์ผ๋ก ๋ง๋ญ๋๋ค
ํ
์คํธ ๋ฐ์ดํฐ: ์
๋ ฅ [3, 1, 7, 2, 9, 4, 6, 8, ...] (๋ฐ๋ณต ํจํด)์ ์ถ๋ ฅ [1.5, 0.5, 14.0, 1.0, 18.0, 2.0, 12.0, 16.0, ...]์ ์์ฑํด์ผ ํฉ๋๋ค
๊ณผ์ : ๋ธ๋ก ๋ ๋ฒจ ๋ถ์๊ณผ ์์๋ณ ์กฐ๊ฑด๋ถ ๋ณํ์ ๋ชจ๋ ๋ ์ธ์ ๊ฑธ์ณ ์ด๋ป๊ฒ ์กฐ์ ํ ๊น์?
๊ตฌ์ฑ
- ๋ฒกํฐ ํฌ๊ธฐ:
SIZE = WARP_SIZE(GPU์ ๋ฐ๋ผ 32 ๋๋ 64) - ๊ทธ๋ฆฌ๋ ๊ตฌ์ฑ:
(1, 1)๊ทธ๋ฆฌ๋๋น ๋ธ๋ก ์ - ๋ธ๋ก ๊ตฌ์ฑ:
(WARP_SIZE, 1)๋ธ๋ก๋น ์ค๋ ๋ ์
์์ฑํ ์ฝ๋
fn conditional_broadcast[
layout: Layout, size: Int
](
output: LayoutTensor[dtype, layout, MutAnyOrigin],
input: LayoutTensor[dtype, layout, ImmutAnyOrigin],
):
"""
Conditional broadcast: Lane 0 makes a decision based on block-local data, broadcasts it to all lanes.
All lanes apply different logic based on the broadcast decision.
"""
global_i = Int(block_dim.x * block_idx.x + thread_idx.x)
lane = Int(lane_id())
if global_i < size:
var decision_value: output.element_type = 0.0
# FILL IN (roughly 10 lines)
current_input = input[global_i]
threshold = decision_value / 2.0
if current_input >= threshold:
output[global_i] = current_input * 2.0 # Double if >= threshold
else:
output[global_i] = current_input / 2.0 # Halve if < threshold
ํ
1. ๋ถ์๊ณผ ์์ฌ๊ฒฐ์
๋ ์ธ 0์ด ์ฌ๋ฌ ๋ฐ์ดํฐ ํฌ์ธํธ๋ฅผ ๋ถ์ํ๊ณ ๋ค๋ฅธ ๋ชจ๋ ๋ ์ธ์ ๋์์ ์๋ดํ ๊ฒฐ์ ์ ๋ด๋ ค์ผ ํฉ๋๋ค.
ํต์ฌ ์ง๋ฌธ:
- ๋ ์ธ 0์ด ์ฌ๋ฌ ์์๋ฅผ ํจ์จ์ ์ผ๋ก ๋ถ์ํ๋ ค๋ฉด ์ด๋ป๊ฒ ํด์ผ ํ ๊น์?
- ๋ ์ธ์ ๋์์ ์กฐ์ ํ๊ธฐ ์ํด ์ด๋ค ์ข ๋ฅ์ ๊ฒฐ์ ์ ๋ธ๋ก๋์บ์คํธํด์ผ ํ ๊น์?
- ๋ฐ์ดํฐ๋ฅผ ๋ถ์ํ ๋ ๊ฒฝ๊ณ ์กฐ๊ฑด์ ์ด๋ป๊ฒ ์ฒ๋ฆฌํ ๊น์?
๊ณ ๋ คํ ํจํด:
var decision = ๊ธฐ๋ณธ๊ฐ
if lane == 0:
# ๋ธ๋ก ๋ก์ปฌ ๋ฐ์ดํฐ ๋ถ์
decision = ๋ถ์_ํ_๊ฒฐ์ ()
decision = broadcast(decision)
2. ์กฐ๊ฑด๋ถ ์คํ ์กฐ์
๋ธ๋ก๋์บ์คํธ๋ ๊ฒฐ์ ์ ์์ ํ ํ, ๋ชจ๋ ๋ ์ธ์ด ๊ทธ ๊ฒฐ์ ์ ๊ธฐ๋ฐํ์ฌ ์๋ก ๋ค๋ฅธ ๋ก์ง์ ์ ์ฉํด์ผ ํฉ๋๋ค.
์๊ฐํด ๋ณด์ธ์:
- ๋ ์ธ์ด ๋ธ๋ก๋์บ์คํธ ๊ฐ์ ์ฌ์ฉํ์ฌ ๋ก์ปฌ ๊ฒฐ์ ์ ๋ด๋ฆฌ๋ ๋ฐฉ๋ฒ์?
- ๊ฐ ์กฐ๊ฑด๋ถ ๋ถ๊ธฐ์์ ์ด๋ค ์ฐ์ฐ์ ์ ์ฉํด์ผ ํ ๊น์?
- ๋ชจ๋ ๋ ์ธ์์ ์ผ๊ด๋ ๋์์ ๋ณด์ฅํ๋ ค๋ฉด ์ด๋ป๊ฒ ํด์ผ ํ ๊น์?
์กฐ๊ฑด๋ถ ํจํด:
if (๋ก์ปฌ_๋ฐ์ดํฐ๊ฐ broadcast_๊ธฐ์ค์ ์ถฉ์กฑ):
# ํ๋์ ๋ณํ ์ ์ฉ
else:
# ๋ค๋ฅธ ๋ณํ ์ ์ฉ
3. ๋ฐ์ดํฐ ๋ถ์ ์ ๋ต
๋ ์ธ 0์ด ์ฌ๋ฌ ๋ฐ์ดํฐ ํฌ์ธํธ๋ฅผ ํจ์จ์ ์ผ๋ก ๋ถ์ํ๋ ๋ฐฉ๋ฒ์ ๊ณ ๋ คํ์ธ์.
๊ณ ๋ คํ ์ ๊ทผ๋ฒ:
- ์ต๋๊ฐ/์ต์๊ฐ ์ฐพ๊ธฐ
- ํ๊ท ์ด๋ ํฉ๊ณ ๊ณ์ฐ
- ํจํด์ด๋ ์๊ณ๊ฐ ๊ฐ์ง
- ๋ฐ์ดํฐ ํน์ฑ์ ๊ธฐ๋ฐํ ์ด์ง ๊ฒฐ์
์กฐ๊ฑด๋ถ ๋ธ๋ก๋์บ์คํธ ํ ์คํธ:
pixi run p25 --broadcast-conditional
pixi run -e amd p25 --broadcast-conditional
uv run poe p25 --broadcast-conditional
ํ์์ ๋์ ์์ ์ถ๋ ฅ:
WARP_SIZE: 32
SIZE: 32
output: HostBuffer([1.5, 0.5, 14.0, 1.0, 18.0, 2.0, 12.0, 16.0, 1.5, 0.5, 14.0, 1.0, 18.0, 2.0, 12.0, 16.0, 1.5, 0.5, 14.0, 1.0, 18.0, 2.0, 12.0, 16.0, 1.5, 0.5, 14.0, 1.0, 18.0, 2.0, 12.0, 16.0])
expected: HostBuffer([1.5, 0.5, 14.0, 1.0, 18.0, 2.0, 12.0, 16.0, 1.5, 0.5, 14.0, 1.0, 18.0, 2.0, 12.0, 16.0, 1.5, 0.5, 14.0, 1.0, 18.0, 2.0, 12.0, 16.0, 1.5, 0.5, 14.0, 1.0, 18.0, 2.0, 12.0, 16.0])
โ
Conditional broadcast test passed!
์๋ฃจ์
fn conditional_broadcast[
layout: Layout, size: Int
](
output: LayoutTensor[dtype, layout, MutAnyOrigin],
input: LayoutTensor[dtype, layout, ImmutAnyOrigin],
):
"""
Conditional broadcast: Lane 0 makes a decision based on block-local data, broadcasts it to all lanes.
All lanes apply different logic based on the broadcast decision.
"""
global_i = Int(block_dim.x * block_idx.x + thread_idx.x)
lane = Int(lane_id())
if global_i < size:
# Step 1: Lane 0 analyzes block-local data and makes decision (find max of first 8 in block)
var decision_value: output.element_type = 0.0
if lane == 0:
block_start = Int(block_idx.x * block_dim.x)
decision_value = input[block_start] if block_start < size else 0.0
for i in range(1, min(8, min(WARP_SIZE, size - block_start))):
if block_start + i < size:
current_val = input[block_start + i]
if current_val > decision_value:
decision_value = current_val
# Step 2: Broadcast decision to all lanes in this warp
decision_value = broadcast(decision_value)
# Step 3: All lanes apply conditional logic based on broadcast decision
current_input = input[global_i]
threshold = decision_value / 2.0
if current_input >= threshold:
output[global_i] = current_input * 2.0 # Double if >= threshold
else:
output[global_i] = current_input / 2.0 # Halve if < threshold
์ด ์๋ฃจ์ ์ ๋ ์ธ ๊ฐ ์กฐ๊ฑด๋ถ ์กฐ์ ์ ์ํ ๊ณ ๊ธ ๋ธ๋ก๋์บ์คํธ ํจํด์ ๋ณด์ฌ์ค๋๋ค.
์ ์ฒด ์๊ณ ๋ฆฌ์ฆ ๋ถ์:
if global_i < size:
# ๋จ๊ณ 1: ๋ ์ธ 0์ด ๋ธ๋ก ๋ฐ์ดํฐ๋ฅผ ๋ถ์ํ๊ณ ๊ฒฐ์ ์ ๋ด๋ฆผ
var decision_value: output.element_type = 0.0
if lane == 0:
# ๋ธ๋ก์ ์ฒ์ 8๊ฐ ์์ ์ค ์ต๋๊ฐ ์ฐพ๊ธฐ
block_start = block_idx.x * block_dim.x
decision_value = input[block_start] if block_start < size else 0.0
for i in range(1, min(8, min(WARP_SIZE, size - block_start))):
if block_start + i < size:
current_val = input[block_start + i]
if current_val > decision_value:
decision_value = current_val
# ๋จ๊ณ 2: ๊ฒฐ์ ์ broadcastํ์ฌ ๋ชจ๋ ๋ ์ธ์ ์กฐ์
decision_value = broadcast(decision_value)
# ๋จ๊ณ 3: ๋ชจ๋ ๋ ์ธ์ด ๋ธ๋ก๋์บ์คํธ์ ๊ธฐ๋ฐํ ์กฐ๊ฑด๋ถ ๋ก์ง์ ์ ์ฉ
current_input = input[global_i]
threshold = decision_value / 2.0
if current_input >= threshold:
output[global_i] = current_input * 2.0 # ์๊ณ๊ฐ ์ด์์ด๋ฉด 2๋ฐฐ
else:
output[global_i] = current_input / 2.0 # ์๊ณ๊ฐ ๋ฏธ๋ง์ด๋ฉด ์ ๋ฐ
์์ฌ๊ฒฐ์ ์คํ ์ถ์ :
์
๋ ฅ ๋ฐ์ดํฐ: [3.0, 1.0, 7.0, 2.0, 9.0, 4.0, 6.0, 8.0, ...]
๋จ๊ณ 1: ๋ ์ธ 0์ด ์ฒ์ 8๊ฐ ์์์ ์ต๋๊ฐ์ ์ฐพ์
๋ ์ธ 0 ๋ถ์:
input[0] = 3.0์ผ๋ก ์์
input[1] = 1.0๊ณผ ๋น๊ต โ 3.0 ์ ์ง
input[2] = 7.0๊ณผ ๋น๊ต โ 7.0์ผ๋ก ๊ฐฑ์
input[3] = 2.0๊ณผ ๋น๊ต โ 7.0 ์ ์ง
input[4] = 9.0๊ณผ ๋น๊ต โ 9.0์ผ๋ก ๊ฐฑ์
input[5] = 4.0๊ณผ ๋น๊ต โ 9.0 ์ ์ง
input[6] = 6.0๊ณผ ๋น๊ต โ 9.0 ์ ์ง
input[7] = 8.0๊ณผ ๋น๊ต โ 9.0 ์ ์ง
์ต์ข
decision_value = 9.0
๋จ๊ณ 2: decision_value = 9.0์ ๋ชจ๋ ๋ ์ธ์ broadcast
๋ชจ๋ ๋ ์ธ: decision_value = 9.0, threshold = 4.5
๋จ๊ณ 3: ๋ ์ธ๋ณ ์กฐ๊ฑด๋ถ ์คํ
๋ ์ธ 0: input[0] = 3.0 < 4.5 โ output[0] = 3.0 / 2.0 = 1.5
๋ ์ธ 1: input[1] = 1.0 < 4.5 โ output[1] = 1.0 / 2.0 = 0.5
๋ ์ธ 2: input[2] = 7.0 โฅ 4.5 โ output[2] = 7.0 * 2.0 = 14.0
๋ ์ธ 3: input[3] = 2.0 < 4.5 โ output[3] = 2.0 / 2.0 = 1.0
๋ ์ธ 4: input[4] = 9.0 โฅ 4.5 โ output[4] = 9.0 * 2.0 = 18.0
๋ ์ธ 5: input[5] = 4.0 < 4.5 โ output[5] = 4.0 / 2.0 = 2.0
๋ ์ธ 6: input[6] = 6.0 โฅ 4.5 โ output[6] = 6.0 * 2.0 = 12.0
๋ ์ธ 7: input[7] = 8.0 โฅ 4.5 โ output[7] = 8.0 * 2.0 = 16.0
...๋๋จธ์ง ๋ ์ธ์ ํจํด ๋ฐ๋ณต
์ํ์ ๊ธฐ๋ฐ: ์๊ณ๊ฐ ๊ธฐ๋ฐ ๋ณํ์ ๊ตฌํํฉ๋๋ค: \[\Large f(x) = \begin{cases} 2x & \text{if } x \geq \tau \\ \frac{x}{2} & \text{if } x < \tau \end{cases}\]
์ฌ๊ธฐ์ \(\tau = \frac{\max(\text{block_data})}{2}\)๋ ๋ธ๋ก๋์บ์คํธ๋ ์๊ณ๊ฐ์ ๋๋ค.
์กฐ์ ํจํด์ ์ฅ์ :
- ์ค์ํ๋ ๋ถ์: ํ๋์ ๋ ์ธ์ด ๋ถ์ํ๊ณ ๋ชจ๋ ๋ ์ธ์ด ํํ์ ๋ฐ์
- ์ผ๊ด๋ ๊ฒฐ์ : ๋ชจ๋ ๋ ์ธ์ด ๊ฐ์ ์๊ณ๊ฐ์ ์ฌ์ฉ
- ์ ์ํ ๋์: ์๊ณ๊ฐ์ด ๋ธ๋ก ๋ก์ปฌ ๋ฐ์ดํฐ ํน์ฑ์ ๋ฐ๋ผ ์ ์
- ํจ์จ์ ์กฐ์ : ๋จ์ผ ๋ธ๋ก๋์บ์คํธ๋ก ๋ณต์กํ ์กฐ๊ฑด๋ถ ๋ก์ง์ ์กฐ์
ํ์ฉ ๋ถ์ผ:
- ์ ์ํ ์๊ณ ๋ฆฌ์ฆ: ๋ก์ปฌ ๋ฐ์ดํฐ ํน์ฑ์ ๋ฐ๋ผ ํ๋ผ๋ฏธํฐ ์กฐ์
- ํ์ง ๊ด๋ฆฌ: ๋ฐ์ดํฐ ํ์ง ์งํ์ ๋ฐ๋ผ ๋ค๋ฅธ ์ฒ๋ฆฌ ์ ์ฉ
- ๋ถํ ๋ถ์ฐ: ๋ธ๋ก ๋ก์ปฌ ๋ณต์ก๋ ๋ถ์์ ๊ธฐ๋ฐํ ์์ ๋ถ๋ฐฐ
3. ๋ธ๋ก๋์บ์คํธ-shuffle ์กฐ์
broadcast()์ shuffle_down()์ ๋ชจ๋ ๊ฒฐํฉํ ๊ณ ๊ธ ์กฐ์ ์ ๊ตฌํํฉ๋๋ค.
์๊ตฌ์ฌํญ:
- ๋ ์ธ 0์ด ๋ธ๋ก์ ์ฒ์ 4๊ฐ ์์์ ํ๊ท ์ ๊ณ์ฐํ๊ณ ์ด ์ค์ผ์ผ๋ง ํฉํฐ๋ฅผ ๋ชจ๋ ๋ ์ธ์ ๋ธ๋ก๋์บ์คํธํด์ผ ํฉ๋๋ค
- ๊ฐ ๋ ์ธ์
shuffle_down(offset=1)์ ์ฌ์ฉํ์ฌ ๋ค์ ์ด์์ ๊ฐ์ ๊ฐ์ ธ์์ผ ํฉ๋๋ค - ๋๋ถ๋ถ์ ๋ ์ธ: ์ค์ผ์ผ๋ง ํฉํฐ์
(ํ์ฌ_๊ฐ + ๋ค์_์ด์_๊ฐ)์ ๊ณฑํฉ๋๋ค - ์ํ์ ๋ง์ง๋ง ๋ ์ธ: ์ค์ผ์ผ๋ง ํฉํฐ์
ํ์ฌ_๊ฐ๋ง ๊ณฑํฉ๋๋ค (์ ํจํ ์ด์ ์์)
ํ
์คํธ ๋ฐ์ดํฐ: ์
๋ ฅ์ [2, 4, 6, 8, 1, 3, 5, 7, ...] ํจํด์ ๋ฐ๋ฆ
๋๋ค (์ฒ์ 4๊ฐ ์์: 2,4,6,8 ์ดํ 1,3,5,7 ๋ฐ๋ณต)
- ๋ ์ธ 0์ด ์ค์ผ์ผ๋ง ํฉํฐ๋ฅผ ๊ณ์ฐ:
(2+4+6+8)/4 = 5.0 - ์์ ์ถ๋ ฅ:
[30.0, 50.0, 70.0, 45.0, 20.0, 40.0, 60.0, 40.0, ...]
๊ณผ์ : ํ๋์ ๋ ์ธ์ ๊ณ์ฐ์ด ๋ชจ๋ ๋ ์ธ์ ์ํฅ์ ๋ฏธ์น๋ฉด์, ๊ฐ ๋ ์ธ์ด ์์ ์ ์ด์ ๋ฐ์ดํฐ์๋ ์ ๊ทผํด์ผ ํ๋ ์ํฉ์์ ์ฌ๋ฌ ์ํ ๊ธฐ๋ณธ ์์๋ฅผ ์ด๋ป๊ฒ ์กฐ์ ํ ๊น์?
๊ตฌ์ฑ
- ๋ฒกํฐ ํฌ๊ธฐ:
SIZE = WARP_SIZE(GPU์ ๋ฐ๋ผ 32 ๋๋ 64) - ๊ทธ๋ฆฌ๋ ๊ตฌ์ฑ:
(1, 1)๊ทธ๋ฆฌ๋๋น ๋ธ๋ก ์ - ๋ธ๋ก ๊ตฌ์ฑ:
(WARP_SIZE, 1)๋ธ๋ก๋น ์ค๋ ๋ ์
์์ฑํ ์ฝ๋
fn broadcast_shuffle_coordination[
layout: Layout, size: Int
](
output: LayoutTensor[dtype, layout, MutAnyOrigin],
input: LayoutTensor[dtype, layout, ImmutAnyOrigin],
):
"""
Combine broadcast() and shuffle_down() for advanced warp coordination.
Lane 0 computes block-local scaling factor, broadcasts it to all lanes in the warp.
Each lane uses shuffle_down() for neighbor access and applies broadcast factor.
"""
global_i = Int(block_dim.x * block_idx.x + thread_idx.x)
lane = Int(lane_id())
if global_i < size:
var scale_factor: output.element_type = 0.0
# FILL IN (roughly 14 lines)
ํ
1. ๋ค์ค ๊ธฐ๋ณธ ์์ ์กฐ์
์ด ํผ์ฆ์ broadcast์ ์ ํ ์ฐ์ฐ์ ์์๋๋ก ์กฐ์จํด์ผ ํฉ๋๋ค.
ํ๋ฆ์ ์๊ฐํด ๋ณด์ธ์:
- ํ๋์ ๋ ์ธ์ด ์ ์ฒด ์ํ๋ฅผ ์ํ ๊ฐ์ ๊ณ์ฐ
- ์ด ๊ฐ์ด ๋ชจ๋ ๋ ์ธ์ broadcast๋จ
- ๊ฐ ๋ ์ธ์ด ์ ํ๋ก ์ด์ ๋ฐ์ดํฐ์ ์ ๊ทผ
- ๋ธ๋ก๋์บ์คํธ ๊ฐ์ด ์ด์ ๋ฐ์ดํฐ์ ์ฒ๋ฆฌ ๋ฐฉ์์ ์ํฅ
์กฐ์ ํจํด:
# ๋จ๊ณ 1: ๋ธ๋ก๋์บ์คํธ ์กฐ์
var shared_param = lane_0์ด๋ฉด_๊ณ์ฐ()
shared_param = broadcast(shared_param)
# ๋จ๊ณ 2: ์
ํ ์ด์ ์ ๊ทผ
current_val = input[global_i]
neighbor_val = shuffle_down(current_val, offset)
# ๋จ๊ณ 3: ๊ฒฐํฉ ๊ณ์ฐ
result = ๊ฒฐํฉ(current_val, neighbor_val, shared_param)
2. ํ๋ผ๋ฏธํฐ ๊ณ์ฐ ์ ๋ต
์ด์ ์ฐ์ฐ์ ์ค์ผ์ผ๋งํ๋ ๋ฐ ์ ์ฉํ ๋ธ๋ก ๋ ๋ฒจ ํ๋ผ๋ฏธํฐ๊ฐ ๋ฌด์์ผ์ง ๊ณ ๋ คํ์ธ์.
ํ๊ตฌํ ์ง๋ฌธ:
- ๋ ์ธ 0์ด ๋ธ๋ก ๋ฐ์ดํฐ์์ ์ด๋ค ํต๊ณ๋ฅผ ๊ณ์ฐํด์ผ ํ ๊น์?
- ์ด ํ๋ผ๋ฏธํฐ๊ฐ ์ด์ ๊ธฐ๋ฐ ๊ณ์ฐ์ ์ด๋ค ์ํฅ์ ๋ฏธ์ณ์ผ ํ ๊น์?
- ์ ํ ์ฐ์ฐ์ด ํฌํจ๋ ๋ ์ํ ๊ฒฝ๊ณ์์ ๋ฌด์จ ์ผ์ด ์ผ์ด๋ ๊น์?
3. ๊ฒฐํฉ ์ฐ์ฐ ์ค๊ณ
๋ธ๋ก๋์บ์คํธ ํ๋ผ๋ฏธํฐ์ ์ ํ ๊ธฐ๋ฐ ์ด์ ์ ๊ทผ์ ์๋ฏธ ์๊ฒ ๊ฒฐํฉํ๋ ๋ฐฉ๋ฒ์ ์๊ฐํ์ธ์.
ํจํด ๊ณ ๋ ค์ฌํญ:
- ๋ธ๋ก๋์บ์คํธ ํ๋ผ๋ฏธํฐ๊ฐ ์ ๋ ฅ, ์ถ๋ ฅ, ๋๋ ๊ณ์ฐ์ ์ค์ผ์ผ๋งํด์ผ ํ ๊น์?
- ์ ํ์ด ๋ฏธ์ ์ ๋ฐ์ดํฐ๋ฅผ ๋ฐํํ๋ ๊ฒฝ๊ณ ์ผ์ด์ค๋ฅผ ์ด๋ป๊ฒ ์ฒ๋ฆฌํ ๊น์?
- ๊ฐ์ฅ ํจ์จ์ ์ธ ์ฐ์ฐ ์์๋ ๋ฌด์์ผ๊น์?
๋ธ๋ก๋์บ์คํธ-shuffle ์กฐ์ ํ ์คํธ:
pixi run p25 --broadcast-shuffle-coordination
pixi run -e amd p25 --broadcast-shuffle-coordination
uv run poe p25 --broadcast-shuffle-coordination
ํ์์ ๋์ ์์ ์ถ๋ ฅ:
WARP_SIZE: 32
SIZE: 32
output: HostBuffer([30.0, 50.0, 70.0, 45.0, 20.0, 40.0, 60.0, 40.0, 20.0, 40.0, 60.0, 40.0, 20.0, 40.0, 60.0, 40.0, 20.0, 40.0, 60.0, 40.0, 20.0, 40.0, 60.0, 40.0, 20.0, 40.0, 60.0, 40.0, 20.0, 40.0, 60.0, 35.0])
expected: HostBuffer([30.0, 50.0, 70.0, 45.0, 20.0, 40.0, 60.0, 40.0, 20.0, 40.0, 60.0, 40.0, 20.0, 40.0, 60.0, 40.0, 20.0, 40.0, 60.0, 40.0, 20.0, 40.0, 60.0, 40.0, 20.0, 40.0, 60.0, 40.0, 20.0, 40.0, 60.0, 35.0])
โ
๋ธ๋ก๋์บ์คํธ + ์
ํ coordination test passed!
์๋ฃจ์
fn broadcast_shuffle_coordination[
layout: Layout, size: Int
](
output: LayoutTensor[dtype, layout, MutAnyOrigin],
input: LayoutTensor[dtype, layout, ImmutAnyOrigin],
):
"""
Combine broadcast() and shuffle_down() for advanced warp coordination.
Lane 0 computes block-local scaling factor, broadcasts it to all lanes in the warp.
Each lane uses shuffle_down() for neighbor access and applies broadcast factor.
"""
global_i = Int(block_dim.x * block_idx.x + thread_idx.x)
lane = Int(lane_id())
if global_i < size:
# Step 1: Lane 0 computes block-local scaling factor
var scale_factor: output.element_type = 0.0
if lane == 0:
# Compute average of first 4 elements in this block's data
block_start = Int(block_idx.x * block_dim.x)
var sum: output.element_type = 0.0
for i in range(4):
if block_start + i < size:
sum += input[block_start + i]
scale_factor = sum / 4.0
# Step 2: Broadcast scaling factor to all lanes in this warp
scale_factor = broadcast(scale_factor)
# Step 3: Each lane gets current and next values
current_val = input[global_i]
next_val = shuffle_down(current_val, 1)
# Step 4: Apply broadcast factor with neighbor coordination
if lane < WARP_SIZE - 1 and global_i < size - 1:
# Combine current + next, then scale by broadcast factor
output[global_i] = (current_val + next_val) * scale_factor
else:
# Last lane in warp or last element: only current value, scaled by broadcast factor
output[global_i] = current_val * scale_factor
์ด ์๋ฃจ์ ์ broadcast์ ์ ํ ๊ธฐ๋ณธ ์์๋ฅผ ๊ฒฐํฉํ ๊ฐ์ฅ ๊ณ ๊ธ ์ํ ์กฐ์ ํจํด์ ๋ณด์ฌ์ค๋๋ค.
์ ์ฒด ์๊ณ ๋ฆฌ์ฆ ๋ถ์:
if global_i < size:
# ๋จ๊ณ 1: ๋ ์ธ 0์ด ๋ธ๋ก ๋ก์ปฌ ์ค์ผ์ผ๋ง ํฉํฐ๋ฅผ ๊ณ์ฐ
var scale_factor: output.element_type = 0.0
if lane == 0:
block_start = block_idx.x * block_dim.x
var sum: output.element_type = 0.0
for i in range(4):
if block_start + i < size:
sum += input[block_start + i]
scale_factor = sum / 4.0
# ๋จ๊ณ 2: ์ค์ผ์ผ๋ง ํฉํฐ๋ฅผ ๋ชจ๋ ๋ ์ธ์ broadcast
scale_factor = broadcast(scale_factor)
# ๋จ๊ณ 3: ๊ฐ ๋ ์ธ์ด shuffle์ ํตํด ํ์ฌ ๊ฐ๊ณผ ๋ค์ ๊ฐ์ ๊ฐ์ ธ์ด
current_val = input[global_i]
next_val = shuffle_down(current_val, 1)
# ๋จ๊ณ 4: ๋ธ๋ก๋์บ์คํธ ํฉํฐ๋ฅผ ์ด์ ์กฐ์ ๊ณผ ๊ฒฐํฉํ์ฌ ์ ์ฉ
if lane < WARP_SIZE - 1 and global_i < size - 1:
output[global_i] = (current_val + next_val) * scale_factor
else:
output[global_i] = current_val * scale_factor
๋ค์ค ๊ธฐ๋ณธ ์์ ์คํ ์ถ์ :
์
๋ ฅ ๋ฐ์ดํฐ: [2, 4, 6, 8, 1, 3, 5, 7, ...]
๋จ๊ณ 1: ๋ ์ธ 0์ด ์ค์ผ์ผ๋ง ํฉํฐ๋ฅผ ๊ณ์ฐ
๋ ์ธ 0 ๊ณ์ฐ: (input[0] + input[1] + input[2] + input[3]) / 4
= (2 + 4 + 6 + 8) / 4 = 20 / 4 = 5.0
๋ค๋ฅธ ๋ ์ธ: scale_factor๋ 0.0 ์ ์ง
๋จ๊ณ 2: scale_factor = 5.0์ ๋ชจ๋ ๋ ์ธ์ broadcast
๋ชจ๋ ๋ ์ธ: scale_factor = 5.0
๋จ๊ณ 3: ์ด์ ์ ๊ทผ์ ์ํ ์
ํ ์ฐ์ฐ
๋ ์ธ 0: current_val = input[0] = 2, next_val = shuffle_down(2, 1) = input[1] = 4
๋ ์ธ 1: current_val = input[1] = 4, next_val = shuffle_down(4, 1) = input[2] = 6
๋ ์ธ 2: current_val = input[2] = 6, next_val = shuffle_down(6, 1) = input[3] = 8
๋ ์ธ 3: current_val = input[3] = 8, next_val = shuffle_down(8, 1) = input[4] = 1
...
๋ ์ธ 31: current_val = input[31], next_val = ๋ฏธ์ ์
๋จ๊ณ 4: ๋ธ๋ก๋์บ์คํธ ์ค์ผ์ผ๋ง๊ณผ ๊ฒฐํฉํ ๊ณ์ฐ
๋ ์ธ 0: output[0] = (2 + 4) * 5.0 = 6 * 5.0 = 30.0
๋ ์ธ 1: output[1] = (4 + 6) * 5.0 = 10 * 5.0 = 50.0
๋ ์ธ 2: output[2] = (6 + 8) * 5.0 = 14 * 5.0 = 70.0
๋ ์ธ 3: output[3] = (8 + 1) * 5.0 = 9 * 5.0 = 45.0
...
๋ ์ธ 31: output[31] = 7 * 5.0 = 35.0 (๊ฒฝ๊ณ - ์ด์ ์์)
ํต์ ํจํด ๋ถ์: ์ด ์๊ณ ๋ฆฌ์ฆ์ ๊ณ์ธต์ ์กฐ์ ํจํด์ ๊ตฌํํฉ๋๋ค:
- ์์ง ์กฐ์ (broadcast): ๋ ์ธ 0 โ ๋ชจ๋ ๋ ์ธ
- ์ํ ์กฐ์ (shuffle): ๋ ์ธ i โ ๋ ์ธ i+1
- ๊ฒฐํฉ ๊ณ์ฐ: ๋ธ๋ก๋์บ์คํธ ๋ฐ์ดํฐ์ ์ ํ ๋ฐ์ดํฐ๋ฅผ ๋ชจ๋ ํ์ฉ
์ํ์ ๊ธฐ๋ฐ: \[\Large \text{output}[i] = \begin{cases} (\text{input}[i] + \text{input}[i+1]) \cdot \beta & \text{if lane } i < \text{WARP_SIZE} - 1 \\ \text{input}[i] \cdot \beta & \text{if lane } i = \text{WARP_SIZE} - 1 \end{cases}\]
์ฌ๊ธฐ์ \(\beta = \frac{1}{4}\sum_{k=0}^{3} \text{input}[\text{block_start} + k]\)๋ ๋ธ๋ก๋์บ์คํธ๋ ์ค์ผ์ผ๋ง ํฉํฐ์ ๋๋ค.
๊ณ ๊ธ ์กฐ์ ์ ์ฅ์ :
- ๋ค๋จ๊ณ ํต์ : ์ ์ญ(broadcast)๊ณผ ์ง์ญ(shuffle) ์กฐ์ ์ ๊ฒฐํฉ
- ์ ์ํ ์ค์ผ์ผ๋ง: ๋ธ๋ก ๋ ๋ฒจ ํ๋ผ๋ฏธํฐ๊ฐ ์ด์ ์ฐ์ฐ์ ์ํฅ
- ํจ์จ์ ๊ตฌ์ฑ: ๋ ๊ธฐ๋ณธ ์์๊ฐ ๋งค๋๋ฝ๊ฒ ํ๋ ฅ
- ๋ณต์กํ ์๊ณ ๋ฆฌ์ฆ ๊ตฌํ: ์ ๊ตํ ๋ณ๋ ฌ ์๊ณ ๋ฆฌ์ฆ์ ๊ฐ๋ฅํ๊ฒ ํจ
์ค์ ํ์ฉ ์ฌ๋ก:
- ์ ์ํ ํํฐ๋ง: ๋ธ๋ก ๋ ๋ฒจ ๋ ธ์ด์ฆ ์ถ์ ๊ณผ ์ด์ ๊ธฐ๋ฐ ํํฐ๋ง
- ๋์ ๋ถํ ๋ถ์ฐ: ์ ์ญ ์์ ๋ถ๋ฐฐ์ ๋ก์ปฌ ์กฐ์
- ๋ค์ค ์ค์ผ์ผ ์ฒ๋ฆฌ: ์ ์ญ ํ๋ผ๋ฏธํฐ๊ฐ ๋ก์ปฌ ์คํ ์ค ์ฐ์ฐ์ ์ ์ด
์์ฝ
์ด ์น์ ์ ํต์ฌ ํจํด์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค
var shared_value = initial_value
if lane == 0:
shared_value = compute_block_statistic()
shared_value = broadcast(shared_value)
result = use_shared_value(shared_value, local_data)
ํต์ฌ ์ฅ์ :
- ์ผ๋๋ค ์กฐ์ : ํ๋์ ๋ ์ธ์ด ๊ณ์ฐํ๊ณ ๋ชจ๋ ๋ ์ธ์ด ํํ์ ๋ฐ์
- ๋๊ธฐํ ์ค๋ฒํค๋ ์ ๋ก: SIMT ์คํ์ด ์กฐ์ ์ ์ฒ๋ฆฌ
- ์กฐํฉ ๊ฐ๋ฅํ ํจํด: ์ ํ๊ณผ ๋ค๋ฅธ ์ํ ์ฐ์ฐ๊ณผ ์ฝ๊ฒ ๊ฒฐํฉ
ํ์ฉ ๋ถ์ผ: ๋ธ๋ก ํต๊ณ, ์งํฉ์ ์์ฌ๊ฒฐ์ , ํ๋ผ๋ฏธํฐ ๊ณต์ , ์ ์ํ ์๊ณ ๋ฆฌ์ฆ.