warp.broadcast() ์ผ๋Œ€๋‹ค ํ†ต์‹ 

์›Œํ”„ ๋ ˆ๋ฒจ ์กฐ์ •์—์„œ๋Š” broadcast()๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ•˜๋‚˜์˜ ๋ ˆ์ธ์—์„œ ์›Œํ”„ ๋‚ด ๋‹ค๋ฅธ ๋ชจ๋“  ๋ ˆ์ธ์œผ๋กœ ๋ฐ์ดํ„ฐ๋ฅผ ๊ณต์œ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ๊ฐ•๋ ฅํ•œ ๊ธฐ๋ณธ ์š”์†Œ๋ฅผ ํ†ตํ•ด ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ๋‚˜ ๋ช…์‹œ์  ๋™๊ธฐํ™” ์—†์ด ๋ธ”๋ก ๋ ˆ๋ฒจ ๊ณ„์‚ฐ, ์กฐ๊ฑด๋ถ€ ๋กœ์ง ์กฐ์ •, ์ผ๋Œ€๋‹ค ํ†ต์‹  ํŒจํ„ด์„ ํšจ์œจ์ ์œผ๋กœ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

ํ•ต์‹ฌ ํ†ต์ฐฐ: broadcast() ์—ฐ์‚ฐ์€ SIMT ์‹คํ–‰์„ ํ™œ์šฉํ•˜์—ฌ ํ•˜๋‚˜์˜ ๋ ˆ์ธ(๋ณดํ†ต ๋ ˆ์ธ 0)์ด ๊ณ„์‚ฐํ•œ ๊ฐ’์„ ๊ฐ™์€ ์›Œํ”„์˜ ๋ชจ๋“  ๋ ˆ์ธ์— ์ „๋‹ฌํ•˜๋ฉฐ, ํšจ์œจ์ ์ธ ์กฐ์ • ํŒจํ„ด๊ณผ ์ง‘ํ•ฉ์  ์˜์‚ฌ๊ฒฐ์ •์„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค.

๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ์—ฐ์‚ฐ์ด๋ž€? ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ์—ฐ์‚ฐ์€ ํ•˜๋‚˜์˜ ์Šค๋ ˆ๋“œ๊ฐ€ ๊ฐ’์„ ๊ณ„์‚ฐํ•˜๊ณ  ๊ทธ๋ฃน ๋‚ด ๋‹ค๋ฅธ ๋ชจ๋“  ์Šค๋ ˆ๋“œ์™€ ๊ณต์œ ํ•˜๋Š” ํ†ต์‹  ํŒจํ„ด์ž…๋‹ˆ๋‹ค. ๋ธ”๋ก ๋ ˆ๋ฒจ ํ†ต๊ณ„ ๊ณ„์‚ฐ, ์ง‘ํ•ฉ์  ์˜์‚ฌ๊ฒฐ์ •, ์›Œํ”„ ๋‚ด ๋ชจ๋“  ์Šค๋ ˆ๋“œ์— ์„ค์ • ํŒŒ๋ผ๋ฏธํ„ฐ ์ „๋‹ฌ ๋“ฑ์˜ ์กฐ์ • ์ž‘์—…์— ํ•„์ˆ˜์ ์ž…๋‹ˆ๋‹ค.

ํ•ต์‹ฌ ๊ฐœ๋…

์ด ํผ์ฆ์—์„œ ๋ฐฐ์šธ ๋‚ด์šฉ:

  • broadcast()๋ฅผ ํ™œ์šฉํ•œ ์›Œํ”„ ๋ ˆ๋ฒจ ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ
  • ์ผ๋Œ€๋‹ค ํ†ต์‹  ํŒจํ„ด
  • ์ง‘ํ•ฉ ๊ณ„์‚ฐ ์ „๋žต
  • ๋ ˆ์ธ ๊ฐ„ ์กฐ๊ฑด๋ถ€ ์กฐ์ •
  • ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ-shuffle ๊ฒฐํ•ฉ ์—ฐ์‚ฐ

broadcast() ์—ฐ์‚ฐ์€ ํ•˜๋‚˜์˜ ๋ ˆ์ธ(๊ธฐ๋ณธ์ ์œผ๋กœ ๋ ˆ์ธ 0)์ด ์ž์‹ ์˜ ๊ฐ’์„ ๋‹ค๋ฅธ ๋ชจ๋“  ๋ ˆ์ธ๊ณผ ๊ณต์œ ํ•  ์ˆ˜ ์žˆ๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค: \[\Large \text{broadcast}(\text{value}) = \text{value_from_lane_0_to_all_lanes}\]

์ด๋ฅผ ํ†ตํ•ด ๋ณต์žกํ•œ ์กฐ์ • ํŒจํ„ด์ด ๊ฐ„๋‹จํ•œ ์›Œํ”„ ๋ ˆ๋ฒจ ์—ฐ์‚ฐ์œผ๋กœ ๋ณ€ํ™˜๋˜์–ด, ๋ช…์‹œ์  ๋™๊ธฐํ™” ์—†์ด ํšจ์œจ์ ์ธ ์ง‘ํ•ฉ ๊ณ„์‚ฐ์ด ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.

๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ๊ฐœ๋…

๊ธฐ์กด ์กฐ์ • ๋ฐฉ์‹์€ ๋ณต์žกํ•œ ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ํŒจํ„ด์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค:

# ๊ธฐ์กด ๋ฐฉ์‹ - ๋ณต์žกํ•˜๊ณ  ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ•˜๊ธฐ ์‰ฌ์›€
shared_memory[lane] = local_computation()
sync_threads()  # ๋น„์šฉ์ด ํฐ ๋™๊ธฐํ™”
if lane == 0:
    result = compute_from_shared_memory()
sync_threads()  # ๋˜ ๋‹ค๋ฅธ ๋น„์šฉ์ด ํฐ ๋™๊ธฐํ™”
final_result = shared_memory[0]  # ๋ชจ๋“  ์Šค๋ ˆ๋“œ๊ฐ€ ์ฝ์Œ

๊ธฐ์กด ๋ฐฉ์‹์˜ ๋ฌธ์ œ์ :

  • ๋ฉ”๋ชจ๋ฆฌ ์˜ค๋ฒ„ํ—ค๋“œ: ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ํ• ๋‹น์ด ํ•„์š”
  • ๋™๊ธฐํ™”: ๋น„์šฉ์ด ํฐ ๋ฐฐ๋ฆฌ์–ด ์—ฐ์‚ฐ์ด ์—ฌ๋Ÿฌ ๋ฒˆ ํ•„์š”
  • ๋ณต์žกํ•œ ๋กœ์ง: ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ์ธ๋ฑ์Šค์™€ ์ ‘๊ทผ ํŒจํ„ด ๊ด€๋ฆฌ
  • ์˜ค๋ฅ˜ ๋ฐœ์ƒ ๊ฐ€๋Šฅ์„ฑ: ๊ฒฝ์Ÿ ์ƒํƒœ๊ฐ€ ์‰ฝ๊ฒŒ ๋ฐœ์ƒ

broadcast()๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ์กฐ์ •์ด ๊ฐ„๊ฒฐํ•ด์ง‘๋‹ˆ๋‹ค:

# ์›Œํ”„ ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ๋ฐฉ์‹ - ๊ฐ„๋‹จํ•˜๊ณ  ์•ˆ์ „
collective_value = 0
if lane == 0:
    collective_value = compute_block_statistic()
collective_value = broadcast(collective_value)  # ๋ชจ๋“  ๋ ˆ์ธ๊ณผ ๊ณต์œ 
result = use_collective_value(collective_value)

๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ์˜ ์žฅ์ :

  • ๋ฉ”๋ชจ๋ฆฌ ์˜ค๋ฒ„ํ—ค๋“œ ์ œ๋กœ: ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ๋ถˆํ•„์š”
  • ์ž๋™ ๋™๊ธฐํ™”: SIMT ์‹คํ–‰์ด ์ •ํ™•์„ฑ์„ ๋ณด์žฅ
  • ๊ฐ„๋‹จํ•œ ํŒจํ„ด: ํ•˜๋‚˜์˜ ๋ ˆ์ธ์ด ๊ณ„์‚ฐํ•˜๊ณ  ๋ชจ๋“  ๋ ˆ์ธ์ด ์ˆ˜์‹ 
  • ์กฐํ•ฉ ๊ฐ€๋Šฅ: ๋‹ค๋ฅธ ์›Œํ”„ ์—ฐ์‚ฐ๊ณผ ์‰ฝ๊ฒŒ ๊ฒฐํ•ฉ

1. ๊ธฐ๋ณธ ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ

๋ ˆ์ธ 0์ด ๋ธ”๋ก ๋ ˆ๋ฒจ ํ†ต๊ณ„๋ฅผ ๊ณ„์‚ฐํ•˜๊ณ  ๋ชจ๋“  ๋ ˆ์ธ๊ณผ ๊ณต์œ ํ•˜๋Š” ๊ธฐ๋ณธ ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ํŒจํ„ด์„ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค.

์š”๊ตฌ์‚ฌํ•ญ:

  • ๋ ˆ์ธ 0์ด ํ˜„์žฌ ๋ธ”๋ก์˜ ์ฒ˜์Œ 4๊ฐœ ์š”์†Œ์˜ ํ•ฉ์„ ๊ณ„์‚ฐํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค
  • ์ด ๊ณ„์‚ฐ๋œ ๊ฐ’์„ broadcast()๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์›Œํ”„์˜ ๋‹ค๋ฅธ ๋ชจ๋“  ๋ ˆ์ธ๊ณผ ๊ณต์œ ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค
  • ๊ฐ ๋ ˆ์ธ์€ ์ด ๊ณต์œ ๋œ ๊ฐ’์„ ์ž์‹ ์˜ ์ž…๋ ฅ ์š”์†Œ์— ๋”ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค

ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ: ์ž…๋ ฅ [1, 2, 3, 4, 5, 6, 7, 8, ...]์€ ์ถœ๋ ฅ [11, 12, 13, 14, 15, 16, 17, 18, ...]์„ ์ƒ์„ฑํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค

๊ณผ์ œ: ํ•˜๋‚˜์˜ ๋ ˆ์ธ๋งŒ ๋ธ”๋ก ๋ ˆ๋ฒจ ๊ณ„์‚ฐ์„ ์ˆ˜ํ–‰ํ•˜๋˜, ๋ชจ๋“  ๋ ˆ์ธ์ด ๊ทธ ๊ฒฐ๊ณผ๋ฅผ ์ž์‹ ์˜ ๊ฐœ๋ณ„ ์—ฐ์‚ฐ์— ์‚ฌ์šฉํ•˜๋ ค๋ฉด ์–ด๋–ป๊ฒŒ ์กฐ์ •ํ•ด์•ผ ํ• ๊นŒ์š”?

๊ตฌ์„ฑ

  • ๋ฒกํ„ฐ ํฌ๊ธฐ: SIZE = WARP_SIZE (GPU์— ๋”ฐ๋ผ 32 ๋˜๋Š” 64)
  • ๊ทธ๋ฆฌ๋“œ ๊ตฌ์„ฑ: (1, 1) ๊ทธ๋ฆฌ๋“œ๋‹น ๋ธ”๋ก ์ˆ˜
  • ๋ธ”๋ก ๊ตฌ์„ฑ: (WARP_SIZE, 1) ๋ธ”๋ก๋‹น ์Šค๋ ˆ๋“œ ์ˆ˜
  • ๋ฐ์ดํ„ฐ ํƒ€์ž…: DType.float32
  • ๋ ˆ์ด์•„์›ƒ: Layout.row_major(SIZE) (1D row-major)

์™„์„ฑํ•  ์ฝ”๋“œ

fn basic_broadcast[
    layout: Layout, size: Int
](
    output: LayoutTensor[dtype, layout, MutAnyOrigin],
    input: LayoutTensor[dtype, layout, ImmutAnyOrigin],
):
    """
    Basic broadcast: Lane 0 computes a block-local value, broadcasts it to all lanes.
    Each lane then uses this broadcast value in its own computation.
    """
    global_i = Int(block_dim.x * block_idx.x + thread_idx.x)
    lane = Int(lane_id())
    if global_i < size:
        var broadcast_value: output.element_type = 0.0

        # FILL IN (roughly 10 lines)


์ „์ฒด ํŒŒ์ผ ๋ณด๊ธฐ: problems/p25/p25.mojo

ํŒ

1. ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ๋™์ž‘ ๋ฐฉ์‹ ์ดํ•ดํ•˜๊ธฐ

broadcast(value) ์—ฐ์‚ฐ์€ ๋ ˆ์ธ 0์˜ ๊ฐ’์„ ๊ฐ€์ ธ์™€ ์›Œํ”„์˜ ๋ชจ๋“  ๋ ˆ์ธ์— ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.

ํ•ต์‹ฌ ํ†ต์ฐฐ: ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ์—์„œ๋Š” ๋ ˆ์ธ 0์˜ ๊ฐ’๋งŒ ์˜๋ฏธ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. ๋‹ค๋ฅธ ๋ ˆ์ธ์˜ ๊ฐ’์€ ๋ฌด์‹œ๋˜์ง€๋งŒ, ๋ชจ๋“  ๋ ˆ์ธ์ด ๋ ˆ์ธ 0์˜ ๊ฐ’์„ ์ˆ˜์‹ ํ•ฉ๋‹ˆ๋‹ค.

์‹œ๊ฐํ™”:

๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ์ „: ๋ ˆ์ธ 0์€ valโ‚€, ๋ ˆ์ธ 1์€ valโ‚, ๋ ˆ์ธ 2๋Š” valโ‚‚, ...
๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ํ›„: ๋ ˆ์ธ 0์€ valโ‚€, ๋ ˆ์ธ 1์€ valโ‚€, ๋ ˆ์ธ 2๋Š” valโ‚€, ...

์ƒ๊ฐํ•ด ๋ณด์„ธ์š”: ๋ ˆ์ธ 0๋งŒ ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธํ•˜๋ ค๋Š” ๊ฐ’์„ ๊ณ„์‚ฐํ•˜๋„๋ก ํ•˜๋ ค๋ฉด ์–ด๋–ป๊ฒŒ ํ•ด์•ผ ํ• ๊นŒ์š”?

2. ๋ ˆ์ธ๋ณ„ ๊ณ„์‚ฐ

๋ ˆ์ธ 0์ด ํŠน๋ณ„ํ•œ ๊ณ„์‚ฐ์„ ์ˆ˜ํ–‰ํ•˜๊ณ  ๋‹ค๋ฅธ ๋ ˆ์ธ์€ ๋Œ€๊ธฐํ•˜๋„๋ก ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ์„ค๊ณ„ํ•ฉ๋‹ˆ๋‹ค.

๊ณ ๋ คํ•  ํŒจํ„ด:

var shared_value = ์ดˆ๊ธฐ๊ฐ’
if lane == 0:
    # ๋ ˆ์ธ 0๋งŒ ๊ณ„์‚ฐ
    shared_value = ํŠน๋ณ„ํ•œ_๊ณ„์‚ฐ()
# ๋ชจ๋“  ๋ ˆ์ธ์ด ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ์— ์ฐธ์—ฌ
shared_value = broadcast(shared_value)

ํ•ต์‹ฌ ์งˆ๋ฌธ:

  • ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ์ „์— ๋‹ค๋ฅธ ๋ ˆ์ธ์˜ ๊ฐ’์€ ์–ด๋–ค ์ƒํƒœ์—ฌ์•ผ ํ• ๊นŒ์š”?
  • ๋ ˆ์ธ 0์ด ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธํ•  ์˜ฌ๋ฐ”๋ฅธ ๊ฐ’์„ ๊ฐ–๋„๋ก ํ•˜๋ ค๋ฉด ์–ด๋–ป๊ฒŒ ํ•ด์•ผ ํ• ๊นŒ์š”?

3. ์ง‘ํ•ฉ์  ํ™œ์šฉ

๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ํ›„ ๋ชจ๋“  ๋ ˆ์ธ์ด ๊ฐ™์€ ๊ฐ’์„ ๊ฐ–๊ฒŒ ๋˜๋ฉฐ, ์ด๋ฅผ ๊ฐ์ž์˜ ๊ฐœ๋ณ„ ๊ณ„์‚ฐ์— ํ™œ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ƒ๊ฐํ•ด ๋ณด์„ธ์š”: ๊ฐ ๋ ˆ์ธ์ด ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ๊ฐ’๊ณผ ์ž์‹ ์˜ ๋กœ์ปฌ ๋ฐ์ดํ„ฐ๋ฅผ ์–ด๋–ป๊ฒŒ ๊ฒฐํ•ฉํ• ๊นŒ์š”?

๊ธฐ๋ณธ ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ํ…Œ์ŠคํŠธ:

pixi run p25 --broadcast-basic
pixi run -e amd p25 --broadcast-basic
pixi run -e apple p25 --broadcast-basic
uv run poe p25 --broadcast-basic

ํ’€์—ˆ์„ ๋•Œ์˜ ์˜ˆ์ƒ ์ถœ๋ ฅ:

WARP_SIZE:  32
SIZE:  32
output: HostBuffer([11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0])
expected: HostBuffer([11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0])
โœ… Basic broadcast test passed!

์†”๋ฃจ์…˜

fn basic_broadcast[
    layout: Layout, size: Int
](
    output: LayoutTensor[dtype, layout, MutAnyOrigin],
    input: LayoutTensor[dtype, layout, ImmutAnyOrigin],
):
    """
    Basic broadcast: Lane 0 computes a block-local value, broadcasts it to all lanes.
    Each lane then uses this broadcast value in its own computation.
    """
    global_i = Int(block_dim.x * block_idx.x + thread_idx.x)
    lane = Int(lane_id())

    if global_i < size:
        # Step 1: Lane 0 computes special value (sum of first 4 elements in this block)
        var broadcast_value: output.element_type = 0.0
        if lane == 0:
            block_start = Int(block_idx.x * block_dim.x)
            var sum: output.element_type = 0.0
            for i in range(4):
                if block_start + i < size:
                    sum += input[block_start + i]
            broadcast_value = sum

        # Step 2: Broadcast lane 0's value to all lanes in this warp
        broadcast_value = broadcast(broadcast_value)

        # Step 3: All lanes use broadcast value in their computation
        output[global_i] = broadcast_value + input[global_i]


์ด ์†”๋ฃจ์…˜์€ ์›Œํ”„ ๋ ˆ๋ฒจ ์กฐ์ •์„ ์œ„ํ•œ ๊ธฐ๋ณธ ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ํŒจํ„ด์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

์•Œ๊ณ ๋ฆฌ์ฆ˜ ๋ถ„์„:

if global_i < size:
    # ๋‹จ๊ณ„ 1: ๋ ˆ์ธ 0์ด ํŠน๋ณ„ํ•œ ๊ฐ’์„ ๊ณ„์‚ฐ
    var broadcast_value: output.element_type = 0.0
    if lane == 0:
        # ๋ ˆ์ธ 0๋งŒ ์ด ๊ณ„์‚ฐ์„ ์ˆ˜ํ–‰
        block_start = block_idx.x * block_dim.x
        var sum: output.element_type = 0.0
        for i in range(4):
            if block_start + i < size:
                sum += input[block_start + i]
        broadcast_value = sum

    # ๋‹จ๊ณ„ 2: ๋ ˆ์ธ 0์˜ ๊ฐ’์„ ๋ชจ๋“  ๋ ˆ์ธ๊ณผ ๊ณต์œ 
    broadcast_value = broadcast(broadcast_value)

    # ๋‹จ๊ณ„ 3: ๋ชจ๋“  ๋ ˆ์ธ์ด ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ๊ฐ’์„ ํ™œ์šฉ
    output[global_i] = broadcast_value + input[global_i]

SIMT ์‹คํ–‰ ์ถ”์ :

์‚ฌ์ดํด 1: ๋ ˆ์ธ๋ณ„ ๊ณ„์‚ฐ
  ๋ ˆ์ธ 0: input[0] + input[1] + input[2] + input[3] = 1+2+3+4 = 10์„ ๊ณ„์‚ฐ
  ๋ ˆ์ธ 1: broadcast_value๋Š” 0.0 ์œ ์ง€ (๋ ˆ์ธ 0์ด ์•„๋‹˜)
  ๋ ˆ์ธ 2: broadcast_value๋Š” 0.0 ์œ ์ง€ (๋ ˆ์ธ 0์ด ์•„๋‹˜)
  ...
  ๋ ˆ์ธ 31: broadcast_value๋Š” 0.0 ์œ ์ง€ (๋ ˆ์ธ 0์ด ์•„๋‹˜)

์‚ฌ์ดํด 2: broadcast(broadcast_value) ์‹คํ–‰
  ๋ ˆ์ธ 0: ์ž์‹ ์˜ ๊ฐ’ ์œ ์ง€ โ†’ broadcast_value = 10.0
  ๋ ˆ์ธ 1: ๋ ˆ์ธ 0์˜ ๊ฐ’ ์ˆ˜์‹  โ†’ broadcast_value = 10.0
  ๋ ˆ์ธ 2: ๋ ˆ์ธ 0์˜ ๊ฐ’ ์ˆ˜์‹  โ†’ broadcast_value = 10.0
  ...
  ๋ ˆ์ธ 31: ๋ ˆ์ธ 0์˜ ๊ฐ’ ์ˆ˜์‹  โ†’ broadcast_value = 10.0

์‚ฌ์ดํด 3: ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ๊ฐ’์„ ํ™œ์šฉํ•œ ๊ฐœ๋ณ„ ๊ณ„์‚ฐ
  ๋ ˆ์ธ 0: output[0] = 10.0 + input[0] = 10.0 + 1.0 = 11.0
  ๋ ˆ์ธ 1: output[1] = 10.0 + input[1] = 10.0 + 2.0 = 12.0
  ๋ ˆ์ธ 2: output[2] = 10.0 + input[2] = 10.0 + 3.0 = 13.0
  ...
  ๋ ˆ์ธ 31: output[31] = 10.0 + input[31] = 10.0 + 32.0 = 42.0

๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ๊ฐ€ ์šฐ์›”ํ•œ ์ด์œ :

  1. ์กฐ์ • ํšจ์œจ์„ฑ: ๋‹จ์ผ ์—ฐ์‚ฐ์œผ๋กœ ๋ชจ๋“  ๋ ˆ์ธ์„ ์กฐ์ •
  2. ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์„ฑ: ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ํ• ๋‹น ๋ถˆํ•„์š”
  3. ๋™๊ธฐํ™” ๋ถˆํ•„์š”: SIMT ์‹คํ–‰์ด ์ž๋™์œผ๋กœ ์กฐ์ •์„ ์ฒ˜๋ฆฌ
  4. ํ™•์žฅ ๊ฐ€๋Šฅํ•œ ํŒจํ„ด: ์›Œํ”„ ํฌ๊ธฐ์™€ ๋ฌด๊ด€ํ•˜๊ฒŒ ๋™์ผํ•˜๊ฒŒ ๋™์ž‘

์„ฑ๋Šฅ ํŠน์„ฑ:

  • ์ง€์—ฐ ์‹œ๊ฐ„: ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ์—ฐ์‚ฐ 1 ์‚ฌ์ดํด
  • ๋Œ€์—ญํญ: 0 ๋ฐ”์ดํŠธ (๋ ˆ์ง€์Šคํ„ฐ ๊ฐ„ ์ง์ ‘ ํ†ต์‹ )
  • ์กฐ์ •: 32๊ฐœ ๋ ˆ์ธ ๋ชจ๋‘ ์ž๋™ ๋™๊ธฐํ™”

2. ์กฐ๊ฑด๋ถ€ ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ

๋ ˆ์ธ 0์ด ๋ธ”๋ก ๋ฐ์ดํ„ฐ๋ฅผ ๋ถ„์„ํ•˜๊ณ  ๋ชจ๋“  ๋ ˆ์ธ์— ์˜ํ–ฅ์„ ๋ฏธ์น˜๋Š” ๊ฒฐ์ •์„ ๋‚ด๋ฆฌ๋Š” ์กฐ๊ฑด๋ถ€ ์กฐ์ •์„ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค.

์š”๊ตฌ์‚ฌํ•ญ:

  • ๋ ˆ์ธ 0์ด ํ˜„์žฌ ๋ธ”๋ก์˜ ์ฒ˜์Œ 8๊ฐœ ์š”์†Œ๋ฅผ ๋ถ„์„ํ•˜๊ณ  ์ตœ๋Œ“๊ฐ’์„ ์ฐพ์•„์•ผ ํ•ฉ๋‹ˆ๋‹ค
  • ์ด ์ตœ๋Œ“๊ฐ’์„ broadcast()๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋‹ค๋ฅธ ๋ชจ๋“  ๋ ˆ์ธ์— ์ „๋‹ฌํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค
  • ๊ฐ ๋ ˆ์ธ์€ ์กฐ๊ฑด๋ถ€ ๋กœ์ง์„ ์ ์šฉํ•ฉ๋‹ˆ๋‹ค: ์ž์‹ ์˜ ์š”์†Œ๊ฐ€ ์ตœ๋Œ“๊ฐ’์˜ ์ ˆ๋ฐ˜๋ณด๋‹ค ํฌ๋ฉด 2๋ฐฐ๋กœ, ๊ทธ๋ ‡์ง€ ์•Š์œผ๋ฉด ์ ˆ๋ฐ˜์œผ๋กœ ๋งŒ๋“ญ๋‹ˆ๋‹ค

ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ: ์ž…๋ ฅ [3, 1, 7, 2, 9, 4, 6, 8, ...] (๋ฐ˜๋ณต ํŒจํ„ด)์€ ์ถœ๋ ฅ [1.5, 0.5, 14.0, 1.0, 18.0, 2.0, 12.0, 16.0, ...]์„ ์ƒ์„ฑํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค

๊ณผ์ œ: ๋ธ”๋ก ๋ ˆ๋ฒจ ๋ถ„์„๊ณผ ์š”์†Œ๋ณ„ ์กฐ๊ฑด๋ถ€ ๋ณ€ํ™˜์„ ๋ชจ๋“  ๋ ˆ์ธ์— ๊ฑธ์ณ ์–ด๋–ป๊ฒŒ ์กฐ์ •ํ• ๊นŒ์š”?

๊ตฌ์„ฑ

  • ๋ฒกํ„ฐ ํฌ๊ธฐ: SIZE = WARP_SIZE (GPU์— ๋”ฐ๋ผ 32 ๋˜๋Š” 64)
  • ๊ทธ๋ฆฌ๋“œ ๊ตฌ์„ฑ: (1, 1) ๊ทธ๋ฆฌ๋“œ๋‹น ๋ธ”๋ก ์ˆ˜
  • ๋ธ”๋ก ๊ตฌ์„ฑ: (WARP_SIZE, 1) ๋ธ”๋ก๋‹น ์Šค๋ ˆ๋“œ ์ˆ˜

์™„์„ฑํ•  ์ฝ”๋“œ

fn conditional_broadcast[
    layout: Layout, size: Int
](
    output: LayoutTensor[dtype, layout, MutAnyOrigin],
    input: LayoutTensor[dtype, layout, ImmutAnyOrigin],
):
    """
    Conditional broadcast: Lane 0 makes a decision based on block-local data, broadcasts it to all lanes.
    All lanes apply different logic based on the broadcast decision.
    """
    global_i = Int(block_dim.x * block_idx.x + thread_idx.x)
    lane = Int(lane_id())
    if global_i < size:
        var decision_value: output.element_type = 0.0

        # FILL IN (roughly 10 lines)

        current_input = input[global_i]
        threshold = decision_value / 2.0
        if current_input >= threshold:
            output[global_i] = current_input * 2.0  # Double if >= threshold
        else:
            output[global_i] = current_input / 2.0  # Halve if < threshold


ํŒ

1. ๋ถ„์„๊ณผ ์˜์‚ฌ๊ฒฐ์ •

๋ ˆ์ธ 0์ด ์—ฌ๋Ÿฌ ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ๋ฅผ ๋ถ„์„ํ•˜๊ณ  ๋‹ค๋ฅธ ๋ชจ๋“  ๋ ˆ์ธ์˜ ๋™์ž‘์„ ์•ˆ๋‚ดํ•  ๊ฒฐ์ •์„ ๋‚ด๋ ค์•ผ ํ•ฉ๋‹ˆ๋‹ค.

ํ•ต์‹ฌ ์งˆ๋ฌธ:

  • ๋ ˆ์ธ 0์ด ์—ฌ๋Ÿฌ ์š”์†Œ๋ฅผ ํšจ์œจ์ ์œผ๋กœ ๋ถ„์„ํ•˜๋ ค๋ฉด ์–ด๋–ป๊ฒŒ ํ•ด์•ผ ํ• ๊นŒ์š”?
  • ๋ ˆ์ธ์˜ ๋™์ž‘์„ ์กฐ์ •ํ•˜๊ธฐ ์œ„ํ•ด ์–ด๋–ค ์ข…๋ฅ˜์˜ ๊ฒฐ์ •์„ ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธํ•ด์•ผ ํ• ๊นŒ์š”?
  • ๋ฐ์ดํ„ฐ๋ฅผ ๋ถ„์„ํ•  ๋•Œ ๊ฒฝ๊ณ„ ์กฐ๊ฑด์€ ์–ด๋–ป๊ฒŒ ์ฒ˜๋ฆฌํ• ๊นŒ์š”?

๊ณ ๋ คํ•  ํŒจํ„ด:

var decision = ๊ธฐ๋ณธ๊ฐ’
if lane == 0:
    # ๋ธ”๋ก ๋กœ์ปฌ ๋ฐ์ดํ„ฐ ๋ถ„์„
    decision = ๋ถ„์„_ํ›„_๊ฒฐ์ •()
decision = broadcast(decision)

2. ์กฐ๊ฑด๋ถ€ ์‹คํ–‰ ์กฐ์ •

๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ๋œ ๊ฒฐ์ •์„ ์ˆ˜์‹ ํ•œ ํ›„, ๋ชจ๋“  ๋ ˆ์ธ์ด ๊ทธ ๊ฒฐ์ •์— ๊ธฐ๋ฐ˜ํ•˜์—ฌ ์„œ๋กœ ๋‹ค๋ฅธ ๋กœ์ง์„ ์ ์šฉํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

์ƒ๊ฐํ•ด ๋ณด์„ธ์š”:

  • ๋ ˆ์ธ์ด ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ๊ฐ’์„ ์‚ฌ์šฉํ•˜์—ฌ ๋กœ์ปฌ ๊ฒฐ์ •์„ ๋‚ด๋ฆฌ๋Š” ๋ฐฉ๋ฒ•์€?
  • ๊ฐ ์กฐ๊ฑด๋ถ€ ๋ถ„๊ธฐ์—์„œ ์–ด๋–ค ์—ฐ์‚ฐ์„ ์ ์šฉํ•ด์•ผ ํ• ๊นŒ์š”?
  • ๋ชจ๋“  ๋ ˆ์ธ์—์„œ ์ผ๊ด€๋œ ๋™์ž‘์„ ๋ณด์žฅํ•˜๋ ค๋ฉด ์–ด๋–ป๊ฒŒ ํ•ด์•ผ ํ• ๊นŒ์š”?

์กฐ๊ฑด๋ถ€ ํŒจํ„ด:

if (๋กœ์ปฌ_๋ฐ์ดํ„ฐ๊ฐ€ broadcast_๊ธฐ์ค€์„ ์ถฉ์กฑ):
    # ํ•˜๋‚˜์˜ ๋ณ€ํ™˜ ์ ์šฉ
else:
    # ๋‹ค๋ฅธ ๋ณ€ํ™˜ ์ ์šฉ

3. ๋ฐ์ดํ„ฐ ๋ถ„์„ ์ „๋žต

๋ ˆ์ธ 0์ด ์—ฌ๋Ÿฌ ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ๋ฅผ ํšจ์œจ์ ์œผ๋กœ ๋ถ„์„ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๊ณ ๋ คํ•˜์„ธ์š”.

๊ณ ๋ คํ•  ์ ‘๊ทผ๋ฒ•:

  • ์ตœ๋Œ“๊ฐ’/์ตœ์†Ÿ๊ฐ’ ์ฐพ๊ธฐ
  • ํ‰๊ท ์ด๋‚˜ ํ•ฉ๊ณ„ ๊ณ„์‚ฐ
  • ํŒจํ„ด์ด๋‚˜ ์ž„๊ณ„๊ฐ’ ๊ฐ์ง€
  • ๋ฐ์ดํ„ฐ ํŠน์„ฑ์— ๊ธฐ๋ฐ˜ํ•œ ์ด์ง„ ๊ฒฐ์ •

์กฐ๊ฑด๋ถ€ ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ํ…Œ์ŠคํŠธ:

pixi run p25 --broadcast-conditional
pixi run -e amd p25 --broadcast-conditional
uv run poe p25 --broadcast-conditional

ํ’€์—ˆ์„ ๋•Œ์˜ ์˜ˆ์ƒ ์ถœ๋ ฅ:

WARP_SIZE:  32
SIZE:  32
output: HostBuffer([1.5, 0.5, 14.0, 1.0, 18.0, 2.0, 12.0, 16.0, 1.5, 0.5, 14.0, 1.0, 18.0, 2.0, 12.0, 16.0, 1.5, 0.5, 14.0, 1.0, 18.0, 2.0, 12.0, 16.0, 1.5, 0.5, 14.0, 1.0, 18.0, 2.0, 12.0, 16.0])
expected: HostBuffer([1.5, 0.5, 14.0, 1.0, 18.0, 2.0, 12.0, 16.0, 1.5, 0.5, 14.0, 1.0, 18.0, 2.0, 12.0, 16.0, 1.5, 0.5, 14.0, 1.0, 18.0, 2.0, 12.0, 16.0, 1.5, 0.5, 14.0, 1.0, 18.0, 2.0, 12.0, 16.0])
โœ… Conditional broadcast test passed!

์†”๋ฃจ์…˜

fn conditional_broadcast[
    layout: Layout, size: Int
](
    output: LayoutTensor[dtype, layout, MutAnyOrigin],
    input: LayoutTensor[dtype, layout, ImmutAnyOrigin],
):
    """
    Conditional broadcast: Lane 0 makes a decision based on block-local data, broadcasts it to all lanes.
    All lanes apply different logic based on the broadcast decision.
    """
    global_i = Int(block_dim.x * block_idx.x + thread_idx.x)
    lane = Int(lane_id())

    if global_i < size:
        # Step 1: Lane 0 analyzes block-local data and makes decision (find max of first 8 in block)
        var decision_value: output.element_type = 0.0
        if lane == 0:
            block_start = Int(block_idx.x * block_dim.x)
            decision_value = input[block_start] if block_start < size else 0.0
            for i in range(1, min(8, min(WARP_SIZE, size - block_start))):
                if block_start + i < size:
                    current_val = input[block_start + i]
                    if current_val > decision_value:
                        decision_value = current_val

        # Step 2: Broadcast decision to all lanes in this warp
        decision_value = broadcast(decision_value)

        # Step 3: All lanes apply conditional logic based on broadcast decision
        current_input = input[global_i]
        threshold = decision_value / 2.0
        if current_input >= threshold:
            output[global_i] = current_input * 2.0  # Double if >= threshold
        else:
            output[global_i] = current_input / 2.0  # Halve if < threshold


์ด ์†”๋ฃจ์…˜์€ ๋ ˆ์ธ ๊ฐ„ ์กฐ๊ฑด๋ถ€ ์กฐ์ •์„ ์œ„ํ•œ ๊ณ ๊ธ‰ ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ํŒจํ„ด์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

์ „์ฒด ์•Œ๊ณ ๋ฆฌ์ฆ˜ ๋ถ„์„:

if global_i < size:
    # ๋‹จ๊ณ„ 1: ๋ ˆ์ธ 0์ด ๋ธ”๋ก ๋ฐ์ดํ„ฐ๋ฅผ ๋ถ„์„ํ•˜๊ณ  ๊ฒฐ์ •์„ ๋‚ด๋ฆผ
    var decision_value: output.element_type = 0.0
    if lane == 0:
        # ๋ธ”๋ก์˜ ์ฒ˜์Œ 8๊ฐœ ์š”์†Œ ์ค‘ ์ตœ๋Œ“๊ฐ’ ์ฐพ๊ธฐ
        block_start = block_idx.x * block_dim.x
        decision_value = input[block_start] if block_start < size else 0.0
        for i in range(1, min(8, min(WARP_SIZE, size - block_start))):
            if block_start + i < size:
                current_val = input[block_start + i]
                if current_val > decision_value:
                    decision_value = current_val

    # ๋‹จ๊ณ„ 2: ๊ฒฐ์ •์„ broadcastํ•˜์—ฌ ๋ชจ๋“  ๋ ˆ์ธ์„ ์กฐ์ •
    decision_value = broadcast(decision_value)

    # ๋‹จ๊ณ„ 3: ๋ชจ๋“  ๋ ˆ์ธ์ด ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ์— ๊ธฐ๋ฐ˜ํ•œ ์กฐ๊ฑด๋ถ€ ๋กœ์ง์„ ์ ์šฉ
    current_input = input[global_i]
    threshold = decision_value / 2.0
    if current_input >= threshold:
        output[global_i] = current_input * 2.0  # ์ž„๊ณ„๊ฐ’ ์ด์ƒ์ด๋ฉด 2๋ฐฐ
    else:
        output[global_i] = current_input / 2.0  # ์ž„๊ณ„๊ฐ’ ๋ฏธ๋งŒ์ด๋ฉด ์ ˆ๋ฐ˜

์˜์‚ฌ๊ฒฐ์ • ์‹คํ–‰ ์ถ”์ :

์ž…๋ ฅ ๋ฐ์ดํ„ฐ: [3.0, 1.0, 7.0, 2.0, 9.0, 4.0, 6.0, 8.0, ...]

๋‹จ๊ณ„ 1: ๋ ˆ์ธ 0์ด ์ฒ˜์Œ 8๊ฐœ ์š”์†Œ์˜ ์ตœ๋Œ“๊ฐ’์„ ์ฐพ์Œ
  ๋ ˆ์ธ 0 ๋ถ„์„:
    input[0] = 3.0์œผ๋กœ ์‹œ์ž‘
    input[1] = 1.0๊ณผ ๋น„๊ต โ†’ 3.0 ์œ ์ง€
    input[2] = 7.0๊ณผ ๋น„๊ต โ†’ 7.0์œผ๋กœ ๊ฐฑ์‹ 
    input[3] = 2.0๊ณผ ๋น„๊ต โ†’ 7.0 ์œ ์ง€
    input[4] = 9.0๊ณผ ๋น„๊ต โ†’ 9.0์œผ๋กœ ๊ฐฑ์‹ 
    input[5] = 4.0๊ณผ ๋น„๊ต โ†’ 9.0 ์œ ์ง€
    input[6] = 6.0๊ณผ ๋น„๊ต โ†’ 9.0 ์œ ์ง€
    input[7] = 8.0๊ณผ ๋น„๊ต โ†’ 9.0 ์œ ์ง€
    ์ตœ์ข… decision_value = 9.0

๋‹จ๊ณ„ 2: decision_value = 9.0์„ ๋ชจ๋“  ๋ ˆ์ธ์— broadcast
  ๋ชจ๋“  ๋ ˆ์ธ: decision_value = 9.0, threshold = 4.5

๋‹จ๊ณ„ 3: ๋ ˆ์ธ๋ณ„ ์กฐ๊ฑด๋ถ€ ์‹คํ–‰
  ๋ ˆ์ธ 0: input[0] = 3.0 < 4.5 โ†’ output[0] = 3.0 / 2.0 = 1.5
  ๋ ˆ์ธ 1: input[1] = 1.0 < 4.5 โ†’ output[1] = 1.0 / 2.0 = 0.5
  ๋ ˆ์ธ 2: input[2] = 7.0 โ‰ฅ 4.5 โ†’ output[2] = 7.0 * 2.0 = 14.0
  ๋ ˆ์ธ 3: input[3] = 2.0 < 4.5 โ†’ output[3] = 2.0 / 2.0 = 1.0
  ๋ ˆ์ธ 4: input[4] = 9.0 โ‰ฅ 4.5 โ†’ output[4] = 9.0 * 2.0 = 18.0
  ๋ ˆ์ธ 5: input[5] = 4.0 < 4.5 โ†’ output[5] = 4.0 / 2.0 = 2.0
  ๋ ˆ์ธ 6: input[6] = 6.0 โ‰ฅ 4.5 โ†’ output[6] = 6.0 * 2.0 = 12.0
  ๋ ˆ์ธ 7: input[7] = 8.0 โ‰ฅ 4.5 โ†’ output[7] = 8.0 * 2.0 = 16.0
  ...๋‚˜๋จธ์ง€ ๋ ˆ์ธ์— ํŒจํ„ด ๋ฐ˜๋ณต

์ˆ˜ํ•™์  ๊ธฐ๋ฐ˜: ์ž„๊ณ„๊ฐ’ ๊ธฐ๋ฐ˜ ๋ณ€ํ™˜์„ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค: \[\Large f(x) = \begin{cases} 2x & \text{if } x \geq \tau \\ \frac{x}{2} & \text{if } x < \tau \end{cases}\]

์—ฌ๊ธฐ์„œ \(\tau = \frac{\max(\text{block_data})}{2}\)๋Š” ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ๋œ ์ž„๊ณ„๊ฐ’์ž…๋‹ˆ๋‹ค.

์กฐ์ • ํŒจํ„ด์˜ ์žฅ์ :

  1. ์ค‘์•™ํ™”๋œ ๋ถ„์„: ํ•˜๋‚˜์˜ ๋ ˆ์ธ์ด ๋ถ„์„ํ•˜๊ณ  ๋ชจ๋“  ๋ ˆ์ธ์ด ํ˜œํƒ์„ ๋ฐ›์Œ
  2. ์ผ๊ด€๋œ ๊ฒฐ์ •: ๋ชจ๋“  ๋ ˆ์ธ์ด ๊ฐ™์€ ์ž„๊ณ„๊ฐ’์„ ์‚ฌ์šฉ
  3. ์ ์‘ํ˜• ๋™์ž‘: ์ž„๊ณ„๊ฐ’์ด ๋ธ”๋ก ๋กœ์ปฌ ๋ฐ์ดํ„ฐ ํŠน์„ฑ์— ๋”ฐ๋ผ ์ ์‘
  4. ํšจ์œจ์  ์กฐ์ •: ๋‹จ์ผ ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ๋กœ ๋ณต์žกํ•œ ์กฐ๊ฑด๋ถ€ ๋กœ์ง์„ ์กฐ์ •

ํ™œ์šฉ ๋ถ„์•ผ:

  • ์ ์‘ํ˜• ์•Œ๊ณ ๋ฆฌ์ฆ˜: ๋กœ์ปฌ ๋ฐ์ดํ„ฐ ํŠน์„ฑ์— ๋”ฐ๋ผ ํŒŒ๋ผ๋ฏธํ„ฐ ์กฐ์ •
  • ํ’ˆ์งˆ ๊ด€๋ฆฌ: ๋ฐ์ดํ„ฐ ํ’ˆ์งˆ ์ง€ํ‘œ์— ๋”ฐ๋ผ ๋‹ค๋ฅธ ์ฒ˜๋ฆฌ ์ ์šฉ
  • ๋ถ€ํ•˜ ๋ถ„์‚ฐ: ๋ธ”๋ก ๋กœ์ปฌ ๋ณต์žก๋„ ๋ถ„์„์— ๊ธฐ๋ฐ˜ํ•œ ์ž‘์—… ๋ถ„๋ฐฐ

3. ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ-shuffle ์กฐ์ •

broadcast()์™€ shuffle_down()์„ ๋ชจ๋‘ ๊ฒฐํ•ฉํ•œ ๊ณ ๊ธ‰ ์กฐ์ •์„ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค.

์š”๊ตฌ์‚ฌํ•ญ:

  • ๋ ˆ์ธ 0์ด ๋ธ”๋ก์˜ ์ฒ˜์Œ 4๊ฐœ ์š”์†Œ์˜ ํ‰๊ท ์„ ๊ณ„์‚ฐํ•˜๊ณ  ์ด ์Šค์ผ€์ผ๋ง ํŒฉํ„ฐ๋ฅผ ๋ชจ๋“  ๋ ˆ์ธ์— ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค
  • ๊ฐ ๋ ˆ์ธ์€ shuffle_down(offset=1)์„ ์‚ฌ์šฉํ•˜์—ฌ ๋‹ค์Œ ์ด์›ƒ์˜ ๊ฐ’์„ ๊ฐ€์ ธ์™€์•ผ ํ•ฉ๋‹ˆ๋‹ค
  • ๋Œ€๋ถ€๋ถ„์˜ ๋ ˆ์ธ: ์Šค์ผ€์ผ๋ง ํŒฉํ„ฐ์— (ํ˜„์žฌ_๊ฐ’ + ๋‹ค์Œ_์ด์›ƒ_๊ฐ’)์„ ๊ณฑํ•ฉ๋‹ˆ๋‹ค
  • ์›Œํ”„์˜ ๋งˆ์ง€๋ง‰ ๋ ˆ์ธ: ์Šค์ผ€์ผ๋ง ํŒฉํ„ฐ์— ํ˜„์žฌ_๊ฐ’๋งŒ ๊ณฑํ•ฉ๋‹ˆ๋‹ค (์œ ํšจํ•œ ์ด์›ƒ ์—†์Œ)

ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ: ์ž…๋ ฅ์€ [2, 4, 6, 8, 1, 3, 5, 7, ...] ํŒจํ„ด์„ ๋”ฐ๋ฆ…๋‹ˆ๋‹ค (์ฒ˜์Œ 4๊ฐœ ์š”์†Œ: 2,4,6,8 ์ดํ›„ 1,3,5,7 ๋ฐ˜๋ณต)

  • ๋ ˆ์ธ 0์ด ์Šค์ผ€์ผ๋ง ํŒฉํ„ฐ๋ฅผ ๊ณ„์‚ฐ: (2+4+6+8)/4 = 5.0
  • ์˜ˆ์ƒ ์ถœ๋ ฅ: [30.0, 50.0, 70.0, 45.0, 20.0, 40.0, 60.0, 40.0, ...]

๊ณผ์ œ: ํ•˜๋‚˜์˜ ๋ ˆ์ธ์˜ ๊ณ„์‚ฐ์ด ๋ชจ๋“  ๋ ˆ์ธ์— ์˜ํ–ฅ์„ ๋ฏธ์น˜๋ฉด์„œ, ๊ฐ ๋ ˆ์ธ์ด ์ž์‹ ์˜ ์ด์›ƒ ๋ฐ์ดํ„ฐ์—๋„ ์ ‘๊ทผํ•ด์•ผ ํ•˜๋Š” ์ƒํ™ฉ์—์„œ ์—ฌ๋Ÿฌ ์›Œํ”„ ๊ธฐ๋ณธ ์š”์†Œ๋ฅผ ์–ด๋–ป๊ฒŒ ์กฐ์ •ํ• ๊นŒ์š”?

๊ตฌ์„ฑ

  • ๋ฒกํ„ฐ ํฌ๊ธฐ: SIZE = WARP_SIZE (GPU์— ๋”ฐ๋ผ 32 ๋˜๋Š” 64)
  • ๊ทธ๋ฆฌ๋“œ ๊ตฌ์„ฑ: (1, 1) ๊ทธ๋ฆฌ๋“œ๋‹น ๋ธ”๋ก ์ˆ˜
  • ๋ธ”๋ก ๊ตฌ์„ฑ: (WARP_SIZE, 1) ๋ธ”๋ก๋‹น ์Šค๋ ˆ๋“œ ์ˆ˜

์™„์„ฑํ•  ์ฝ”๋“œ

fn broadcast_shuffle_coordination[
    layout: Layout, size: Int
](
    output: LayoutTensor[dtype, layout, MutAnyOrigin],
    input: LayoutTensor[dtype, layout, ImmutAnyOrigin],
):
    """
    Combine broadcast() and shuffle_down() for advanced warp coordination.
    Lane 0 computes block-local scaling factor, broadcasts it to all lanes in the warp.
    Each lane uses shuffle_down() for neighbor access and applies broadcast factor.
    """
    global_i = Int(block_dim.x * block_idx.x + thread_idx.x)
    lane = Int(lane_id())
    if global_i < size:
        var scale_factor: output.element_type = 0.0

        # FILL IN (roughly 14 lines)


ํŒ

1. ๋‹ค์ค‘ ๊ธฐ๋ณธ ์š”์†Œ ์กฐ์ •

์ด ํผ์ฆ์€ broadcast์™€ ์…”ํ”Œ ์—ฐ์‚ฐ์„ ์ˆœ์„œ๋Œ€๋กœ ์กฐ์œจํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

ํ๋ฆ„์„ ์ƒ๊ฐํ•ด ๋ณด์„ธ์š”:

  1. ํ•˜๋‚˜์˜ ๋ ˆ์ธ์ด ์ „์ฒด ์›Œํ”„๋ฅผ ์œ„ํ•œ ๊ฐ’์„ ๊ณ„์‚ฐ
  2. ์ด ๊ฐ’์ด ๋ชจ๋“  ๋ ˆ์ธ์— broadcast๋จ
  3. ๊ฐ ๋ ˆ์ธ์ด ์…”ํ”Œ๋กœ ์ด์›ƒ ๋ฐ์ดํ„ฐ์— ์ ‘๊ทผ
  4. ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ๊ฐ’์ด ์ด์›ƒ ๋ฐ์ดํ„ฐ์˜ ์ฒ˜๋ฆฌ ๋ฐฉ์‹์— ์˜ํ–ฅ

์กฐ์ • ํŒจํ„ด:

# ๋‹จ๊ณ„ 1: ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ์กฐ์ •
var shared_param = lane_0์ด๋ฉด_๊ณ„์‚ฐ()
shared_param = broadcast(shared_param)

# ๋‹จ๊ณ„ 2: ์…”ํ”Œ ์ด์›ƒ ์ ‘๊ทผ
current_val = input[global_i]
neighbor_val = shuffle_down(current_val, offset)

# ๋‹จ๊ณ„ 3: ๊ฒฐํ•ฉ ๊ณ„์‚ฐ
result = ๊ฒฐํ•ฉ(current_val, neighbor_val, shared_param)

2. ํŒŒ๋ผ๋ฏธํ„ฐ ๊ณ„์‚ฐ ์ „๋žต

์ด์›ƒ ์—ฐ์‚ฐ์„ ์Šค์ผ€์ผ๋งํ•˜๋Š” ๋ฐ ์œ ์šฉํ•œ ๋ธ”๋ก ๋ ˆ๋ฒจ ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ๋ฌด์—‡์ผ์ง€ ๊ณ ๋ คํ•˜์„ธ์š”.

ํƒ๊ตฌํ•  ์งˆ๋ฌธ:

  • ๋ ˆ์ธ 0์ด ๋ธ”๋ก ๋ฐ์ดํ„ฐ์—์„œ ์–ด๋–ค ํ†ต๊ณ„๋ฅผ ๊ณ„์‚ฐํ•ด์•ผ ํ• ๊นŒ์š”?
  • ์ด ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ์ด์›ƒ ๊ธฐ๋ฐ˜ ๊ณ„์‚ฐ์— ์–ด๋–ค ์˜ํ–ฅ์„ ๋ฏธ์ณ์•ผ ํ• ๊นŒ์š”?
  • ์…”ํ”Œ ์—ฐ์‚ฐ์ด ํฌํ•จ๋  ๋•Œ ์›Œํ”„ ๊ฒฝ๊ณ„์—์„œ ๋ฌด์Šจ ์ผ์ด ์ผ์–ด๋‚ ๊นŒ์š”?

3. ๊ฒฐํ•ฉ ์—ฐ์‚ฐ ์„ค๊ณ„

๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ํŒŒ๋ผ๋ฏธํ„ฐ์™€ ์…”ํ”Œ ๊ธฐ๋ฐ˜ ์ด์›ƒ ์ ‘๊ทผ์„ ์˜๋ฏธ ์žˆ๊ฒŒ ๊ฒฐํ•ฉํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์ƒ๊ฐํ•˜์„ธ์š”.

ํŒจํ„ด ๊ณ ๋ ค์‚ฌํ•ญ:

  • ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ์ž…๋ ฅ, ์ถœ๋ ฅ, ๋˜๋Š” ๊ณ„์‚ฐ์„ ์Šค์ผ€์ผ๋งํ•ด์•ผ ํ• ๊นŒ์š”?
  • ์…”ํ”Œ์ด ๋ฏธ์ •์˜ ๋ฐ์ดํ„ฐ๋ฅผ ๋ฐ˜ํ™˜ํ•˜๋Š” ๊ฒฝ๊ณ„ ์ผ€์ด์Šค๋ฅผ ์–ด๋–ป๊ฒŒ ์ฒ˜๋ฆฌํ• ๊นŒ์š”?
  • ๊ฐ€์žฅ ํšจ์œจ์ ์ธ ์—ฐ์‚ฐ ์ˆœ์„œ๋Š” ๋ฌด์—‡์ผ๊นŒ์š”?

๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ-shuffle ์กฐ์ • ํ…Œ์ŠคํŠธ:

pixi run p25 --broadcast-shuffle-coordination
pixi run -e amd p25 --broadcast-shuffle-coordination
uv run poe p25 --broadcast-shuffle-coordination

ํ’€์—ˆ์„ ๋•Œ์˜ ์˜ˆ์ƒ ์ถœ๋ ฅ:

WARP_SIZE:  32
SIZE:  32
output: HostBuffer([30.0, 50.0, 70.0, 45.0, 20.0, 40.0, 60.0, 40.0, 20.0, 40.0, 60.0, 40.0, 20.0, 40.0, 60.0, 40.0, 20.0, 40.0, 60.0, 40.0, 20.0, 40.0, 60.0, 40.0, 20.0, 40.0, 60.0, 40.0, 20.0, 40.0, 60.0, 35.0])
expected: HostBuffer([30.0, 50.0, 70.0, 45.0, 20.0, 40.0, 60.0, 40.0, 20.0, 40.0, 60.0, 40.0, 20.0, 40.0, 60.0, 40.0, 20.0, 40.0, 60.0, 40.0, 20.0, 40.0, 60.0, 40.0, 20.0, 40.0, 60.0, 40.0, 20.0, 40.0, 60.0, 35.0])
โœ… ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ + ์…”ํ”Œ coordination test passed!

์†”๋ฃจ์…˜

fn broadcast_shuffle_coordination[
    layout: Layout, size: Int
](
    output: LayoutTensor[dtype, layout, MutAnyOrigin],
    input: LayoutTensor[dtype, layout, ImmutAnyOrigin],
):
    """
    Combine broadcast() and shuffle_down() for advanced warp coordination.
    Lane 0 computes block-local scaling factor, broadcasts it to all lanes in the warp.
    Each lane uses shuffle_down() for neighbor access and applies broadcast factor.
    """
    global_i = Int(block_dim.x * block_idx.x + thread_idx.x)
    lane = Int(lane_id())

    if global_i < size:
        # Step 1: Lane 0 computes block-local scaling factor
        var scale_factor: output.element_type = 0.0
        if lane == 0:
            # Compute average of first 4 elements in this block's data
            block_start = Int(block_idx.x * block_dim.x)
            var sum: output.element_type = 0.0
            for i in range(4):
                if block_start + i < size:
                    sum += input[block_start + i]
            scale_factor = sum / 4.0

        # Step 2: Broadcast scaling factor to all lanes in this warp
        scale_factor = broadcast(scale_factor)

        # Step 3: Each lane gets current and next values
        current_val = input[global_i]
        next_val = shuffle_down(current_val, 1)

        # Step 4: Apply broadcast factor with neighbor coordination
        if lane < WARP_SIZE - 1 and global_i < size - 1:
            # Combine current + next, then scale by broadcast factor
            output[global_i] = (current_val + next_val) * scale_factor
        else:
            # Last lane in warp or last element: only current value, scaled by broadcast factor
            output[global_i] = current_val * scale_factor


์ด ์†”๋ฃจ์…˜์€ broadcast์™€ ์…”ํ”Œ ๊ธฐ๋ณธ ์š”์†Œ๋ฅผ ๊ฒฐํ•ฉํ•œ ๊ฐ€์žฅ ๊ณ ๊ธ‰ ์›Œํ”„ ์กฐ์ • ํŒจํ„ด์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

์ „์ฒด ์•Œ๊ณ ๋ฆฌ์ฆ˜ ๋ถ„์„:

if global_i < size:
    # ๋‹จ๊ณ„ 1: ๋ ˆ์ธ 0์ด ๋ธ”๋ก ๋กœ์ปฌ ์Šค์ผ€์ผ๋ง ํŒฉํ„ฐ๋ฅผ ๊ณ„์‚ฐ
    var scale_factor: output.element_type = 0.0
    if lane == 0:
        block_start = block_idx.x * block_dim.x
        var sum: output.element_type = 0.0
        for i in range(4):
            if block_start + i < size:
                sum += input[block_start + i]
        scale_factor = sum / 4.0

    # ๋‹จ๊ณ„ 2: ์Šค์ผ€์ผ๋ง ํŒฉํ„ฐ๋ฅผ ๋ชจ๋“  ๋ ˆ์ธ์— broadcast
    scale_factor = broadcast(scale_factor)

    # ๋‹จ๊ณ„ 3: ๊ฐ ๋ ˆ์ธ์ด shuffle์„ ํ†ตํ•ด ํ˜„์žฌ ๊ฐ’๊ณผ ๋‹ค์Œ ๊ฐ’์„ ๊ฐ€์ ธ์˜ด
    current_val = input[global_i]
    next_val = shuffle_down(current_val, 1)

    # ๋‹จ๊ณ„ 4: ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ํŒฉํ„ฐ๋ฅผ ์ด์›ƒ ์กฐ์ •๊ณผ ๊ฒฐํ•ฉํ•˜์—ฌ ์ ์šฉ
    if lane < WARP_SIZE - 1 and global_i < size - 1:
        output[global_i] = (current_val + next_val) * scale_factor
    else:
        output[global_i] = current_val * scale_factor

๋‹ค์ค‘ ๊ธฐ๋ณธ ์š”์†Œ ์‹คํ–‰ ์ถ”์ :

์ž…๋ ฅ ๋ฐ์ดํ„ฐ: [2, 4, 6, 8, 1, 3, 5, 7, ...]

๋‹จ๊ณ„ 1: ๋ ˆ์ธ 0์ด ์Šค์ผ€์ผ๋ง ํŒฉํ„ฐ๋ฅผ ๊ณ„์‚ฐ
  ๋ ˆ์ธ 0 ๊ณ„์‚ฐ: (input[0] + input[1] + input[2] + input[3]) / 4
              = (2 + 4 + 6 + 8) / 4 = 20 / 4 = 5.0
  ๋‹ค๋ฅธ ๋ ˆ์ธ: scale_factor๋Š” 0.0 ์œ ์ง€

๋‹จ๊ณ„ 2: scale_factor = 5.0์„ ๋ชจ๋“  ๋ ˆ์ธ์— broadcast
  ๋ชจ๋“  ๋ ˆ์ธ: scale_factor = 5.0

๋‹จ๊ณ„ 3: ์ด์›ƒ ์ ‘๊ทผ์„ ์œ„ํ•œ ์…”ํ”Œ ์—ฐ์‚ฐ
  ๋ ˆ์ธ 0: current_val = input[0] = 2, next_val = shuffle_down(2, 1) = input[1] = 4
  ๋ ˆ์ธ 1: current_val = input[1] = 4, next_val = shuffle_down(4, 1) = input[2] = 6
  ๋ ˆ์ธ 2: current_val = input[2] = 6, next_val = shuffle_down(6, 1) = input[3] = 8
  ๋ ˆ์ธ 3: current_val = input[3] = 8, next_val = shuffle_down(8, 1) = input[4] = 1
  ...
  ๋ ˆ์ธ 31: current_val = input[31], next_val = ๋ฏธ์ •์˜

๋‹จ๊ณ„ 4: ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ์Šค์ผ€์ผ๋ง๊ณผ ๊ฒฐํ•ฉํ•œ ๊ณ„์‚ฐ
  ๋ ˆ์ธ 0: output[0] = (2 + 4) * 5.0 = 6 * 5.0 = 30.0
  ๋ ˆ์ธ 1: output[1] = (4 + 6) * 5.0 = 10 * 5.0 = 50.0
  ๋ ˆ์ธ 2: output[2] = (6 + 8) * 5.0 = 14 * 5.0 = 70.0
  ๋ ˆ์ธ 3: output[3] = (8 + 1) * 5.0 = 9 * 5.0 = 45.0
  ...
  ๋ ˆ์ธ 31: output[31] = 7 * 5.0 = 35.0 (๊ฒฝ๊ณ„ - ์ด์›ƒ ์—†์Œ)

ํ†ต์‹  ํŒจํ„ด ๋ถ„์„: ์ด ์•Œ๊ณ ๋ฆฌ์ฆ˜์€ ๊ณ„์ธต์  ์กฐ์ • ํŒจํ„ด์„ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค:

  1. ์ˆ˜์ง ์กฐ์ • (broadcast): ๋ ˆ์ธ 0 โ†’ ๋ชจ๋“  ๋ ˆ์ธ
  2. ์ˆ˜ํ‰ ์กฐ์ • (shuffle): ๋ ˆ์ธ i โ†’ ๋ ˆ์ธ i+1
  3. ๊ฒฐํ•ฉ ๊ณ„์‚ฐ: ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ ๋ฐ์ดํ„ฐ์™€ ์…”ํ”Œ ๋ฐ์ดํ„ฐ๋ฅผ ๋ชจ๋‘ ํ™œ์šฉ

์ˆ˜ํ•™์  ๊ธฐ๋ฐ˜: \[\Large \text{output}[i] = \begin{cases} (\text{input}[i] + \text{input}[i+1]) \cdot \beta & \text{if lane } i < \text{WARP_SIZE} - 1 \\ \text{input}[i] \cdot \beta & \text{if lane } i = \text{WARP_SIZE} - 1 \end{cases}\]

์—ฌ๊ธฐ์„œ \(\beta = \frac{1}{4}\sum_{k=0}^{3} \text{input}[\text{block_start} + k]\)๋Š” ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ๋œ ์Šค์ผ€์ผ๋ง ํŒฉํ„ฐ์ž…๋‹ˆ๋‹ค.

๊ณ ๊ธ‰ ์กฐ์ •์˜ ์žฅ์ :

  1. ๋‹ค๋‹จ๊ณ„ ํ†ต์‹ : ์ „์—ญ(broadcast)๊ณผ ์ง€์—ญ(shuffle) ์กฐ์ •์˜ ๊ฒฐํ•ฉ
  2. ์ ์‘ํ˜• ์Šค์ผ€์ผ๋ง: ๋ธ”๋ก ๋ ˆ๋ฒจ ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ์ด์›ƒ ์—ฐ์‚ฐ์— ์˜ํ–ฅ
  3. ํšจ์œจ์  ๊ตฌ์„ฑ: ๋‘ ๊ธฐ๋ณธ ์š”์†Œ๊ฐ€ ๋งค๋„๋Ÿฝ๊ฒŒ ํ˜‘๋ ฅ
  4. ๋ณต์žกํ•œ ์•Œ๊ณ ๋ฆฌ์ฆ˜ ๊ตฌํ˜„: ์ •๊ตํ•œ ๋ณ‘๋ ฌ ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•จ

์‹ค์ œ ํ™œ์šฉ ์‚ฌ๋ก€:

  • ์ ์‘ํ˜• ํ•„ํ„ฐ๋ง: ๋ธ”๋ก ๋ ˆ๋ฒจ ๋…ธ์ด์ฆˆ ์ถ”์ •๊ณผ ์ด์›ƒ ๊ธฐ๋ฐ˜ ํ•„ํ„ฐ๋ง
  • ๋™์  ๋ถ€ํ•˜ ๋ถ„์‚ฐ: ์ „์—ญ ์ž‘์—… ๋ถ„๋ฐฐ์™€ ๋กœ์ปฌ ์กฐ์ •
  • ๋‹ค์ค‘ ์Šค์ผ€์ผ ์ฒ˜๋ฆฌ: ์ „์—ญ ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ๋กœ์ปฌ ์Šคํ…์‹ค ์—ฐ์‚ฐ์„ ์ œ์–ด

์š”์•ฝ

์ด ์„น์…˜์˜ ํ•ต์‹ฌ ํŒจํ„ด์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค

var shared_value = initial_value
if lane == 0:
    shared_value = compute_block_statistic()
shared_value = broadcast(shared_value)
result = use_shared_value(shared_value, local_data)

ํ•ต์‹ฌ ์žฅ์ :

  • ์ผ๋Œ€๋‹ค ์กฐ์ •: ํ•˜๋‚˜์˜ ๋ ˆ์ธ์ด ๊ณ„์‚ฐํ•˜๊ณ  ๋ชจ๋“  ๋ ˆ์ธ์ด ํ˜œํƒ์„ ๋ฐ›์Œ
  • ๋™๊ธฐํ™” ์˜ค๋ฒ„ํ—ค๋“œ ์ œ๋กœ: SIMT ์‹คํ–‰์ด ์กฐ์ •์„ ์ฒ˜๋ฆฌ
  • ์กฐํ•ฉ ๊ฐ€๋Šฅํ•œ ํŒจํ„ด: ์…”ํ”Œ๊ณผ ๋‹ค๋ฅธ ์›Œํ”„ ์—ฐ์‚ฐ๊ณผ ์‰ฝ๊ฒŒ ๊ฒฐํ•ฉ

ํ™œ์šฉ ๋ถ„์•ผ: ๋ธ”๋ก ํ†ต๊ณ„, ์ง‘ํ•ฉ์  ์˜์‚ฌ๊ฒฐ์ •, ํŒŒ๋ผ๋ฏธํ„ฐ ๊ณต์œ , ์ ์‘ํ˜• ์•Œ๊ณ ๋ฆฌ์ฆ˜.