block.prefix_sum()๊ณผ ๋ณ‘๋ ฌ ํžˆ์Šคํ† ๊ทธ๋žจ ๊ตฌ๊ฐ„ ๋ถ„๋ฅ˜

์ด ํผ์ฆ์€ ๋ธ”๋ก ๋ ˆ๋ฒจ block.prefix_sum ์—ฐ์‚ฐ์„ ์‚ฌ์šฉํ•˜์—ฌ ๊ณ ๊ธ‰ ๋ณ‘๋ ฌ ํ•„ํ„ฐ๋ง๊ณผ ์ถ”์ถœ์„ ์œ„ํ•œ ๋ณ‘๋ ฌ ํžˆ์Šคํ† ๊ทธ๋žจ ๊ตฌ๊ฐ„ ๋ถ„๋ฅ˜๋ฅผ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค. ๊ฐ ์Šค๋ ˆ๋“œ๊ฐ€ ์ž์‹ ์˜ ์š”์†Œ๊ฐ€ ์†ํ•  ๋Œ€์ƒ ๊ตฌ๊ฐ„์„ ๊ฒฐ์ •ํ•œ ๋‹ค์Œ, block.prefix_sum()์„ ์ ์šฉํ•˜์—ฌ ํŠน์ • ๊ตฌ๊ฐ„์˜ ์š”์†Œ๋ฅผ ์ถ”์ถœํ•˜๊ธฐ ์œ„ํ•œ ์“ฐ๊ธฐ ์œ„์น˜๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค. ๋ˆ„์  ํ•ฉ์ด ๋‹จ์ˆœํ•œ ๋ฆฌ๋•์…˜์„ ๋„˜์–ด ๊ณ ๊ธ‰ ๋ณ‘๋ ฌ ํŒŒํ‹ฐ์…”๋‹์„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

ํ•ต์‹ฌ ํ†ต์ฐฐ: block.prefix_sum() ์—ฐ์‚ฐ์€ ๋ธ”๋ก ๋‚ด ๋ชจ๋“  ์Šค๋ ˆ๋“œ์— ๊ฑธ์ณ ์ผ์น˜ํ•˜๋Š” ์š”์†Œ์˜ ๋ˆ„์  ์“ฐ๊ธฐ ์œ„์น˜๋ฅผ ๊ณ„์‚ฐํ•˜์—ฌ ๋ณ‘๋ ฌ ํ•„ํ„ฐ๋ง๊ณผ ์ถ”์ถœ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.

ํ•ต์‹ฌ ๊ฐœ๋…

์ด ํผ์ฆ์—์„œ ๋‹ค๋ฃจ๋Š” ๋‚ด์šฉ:

  • block.prefix_sum()์„ ํ™œ์šฉํ•œ ๋ธ”๋ก ๋ ˆ๋ฒจ ๋ˆ„์  ํ•ฉ
  • ๋ˆ„์  ์—ฐ์‚ฐ์„ ์‚ฌ์šฉํ•œ ๋ณ‘๋ ฌ ํ•„ํ„ฐ๋ง๊ณผ ์ถ”์ถœ
  • ๊ณ ๊ธ‰ ๋ณ‘๋ ฌ ํŒŒํ‹ฐ์…”๋‹ ์•Œ๊ณ ๋ฆฌ์ฆ˜
  • ๋ธ”๋ก ์ „์ฒด ์กฐ์œจ์„ ํ†ตํ•œ ํžˆ์Šคํ† ๊ทธ๋žจ ๊ตฌ๊ฐ„ ๋ถ„๋ฅ˜
  • ๋น„ํฌํ•จ(exclusive) vs ํฌํ•จ(inclusive) ๋ˆ„์  ํ•ฉ ํŒจํ„ด

์ด ์•Œ๊ณ ๋ฆฌ์ฆ˜์€ ํŠน์ • ๊ฐ’ ๋ฒ”์œ„(๊ตฌ๊ฐ„)์— ์†ํ•˜๋Š” ์š”์†Œ๋ฅผ ์ถ”์ถœํ•˜์—ฌ ํžˆ์Šคํ† ๊ทธ๋žจ์„ ๊ตฌ์„ฑํ•ฉ๋‹ˆ๋‹ค: \[\Large \text{Bin}_k = \{x_i : k/N \leq x_i < (k+1)/N\}\]

๊ฐ ์Šค๋ ˆ๋“œ๊ฐ€ ์ž์‹ ์˜ ์š”์†Œ๊ฐ€ ์†ํ•˜๋Š” ๊ตฌ๊ฐ„์„ ๊ฒฐ์ •ํ•˜๊ณ , block.prefix_sum()์ด ๋ณ‘๋ ฌ ์ถ”์ถœ์„ ์กฐ์œจํ•ฉ๋‹ˆ๋‹ค.

๊ตฌ์„ฑ

  • ๋ฒกํ„ฐ ํฌ๊ธฐ: SIZE = 128 ์š”์†Œ
  • ๋ฐ์ดํ„ฐ ํƒ€์ž…: DType.float32
  • ๋ธ”๋ก ๊ตฌ์„ฑ: (128, 1) ๋ธ”๋ก๋‹น ์Šค๋ ˆ๋“œ ์ˆ˜ (TPB = 128)
  • ๊ทธ๋ฆฌ๋“œ ๊ตฌ์„ฑ: (1, 1) ๊ทธ๋ฆฌ๋“œ๋‹น ๋ธ”๋ก ์ˆ˜
  • ๊ตฌ๊ฐ„ ์ˆ˜: NUM_BINS = 8 (๋ฒ”์œ„ [0.0, 0.125), [0.125, 0.25) ๋“ฑ)
  • ๋ ˆ์ด์•„์›ƒ: Layout.row_major(SIZE) (1D row-major)
  • ๋ธ”๋ก๋‹น ์›Œํ”„ ์ˆ˜: 128 / WARP_SIZE (GPU์— ๋”ฐ๋ผ 2๊ฐœ ๋˜๋Š” 4๊ฐœ)

๋„์ „ ๊ณผ์ œ: ๋ณ‘๋ ฌ ๊ตฌ๊ฐ„ ์ถ”์ถœ

๊ธฐ์กด์˜ ์ˆœ์ฐจ์  ํžˆ์Šคํ† ๊ทธ๋žจ ๊ตฌ์„ฑ์€ ์š”์†Œ๋ฅผ ํ•˜๋‚˜์”ฉ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค:

# ์ˆœ์ฐจ์  ๋ฐฉ์‹ - ๋ณ‘๋ ฌํ™”๊ฐ€ ์–ด๋ ค์›€
histogram = [[] for _ in range(NUM_BINS)]
for element in data:
    bin_id = int(element * NUM_BINS)  # ๊ตฌ๊ฐ„ ๊ฒฐ์ •
    histogram[bin_id].append(element)  # ์ˆœ์ฐจ์  ์ถ”๊ฐ€

๋‹จ์ˆœํ•œ GPU ๋ณ‘๋ ฌํ™”์˜ ๋ฌธ์ œ์ :

  • ๊ฒฝ์Ÿ ์ƒํƒœ: ์—ฌ๋Ÿฌ ์Šค๋ ˆ๋“œ๊ฐ€ ๊ฐ™์€ ๊ตฌ๊ฐ„์— ๋™์‹œ์— ์“ฐ๊ธฐ
  • ๋น„์ •๋ ฌ ๋ฉ”๋ชจ๋ฆฌ ์ ‘๊ทผ: ์Šค๋ ˆ๋“œ๋“ค์ด ์„œ๋กœ ๋‹ค๋ฅธ ๋ฉ”๋ชจ๋ฆฌ ์œ„์น˜์— ์ ‘๊ทผ
  • ๋ถ€ํ•˜ ๋ถˆ๊ท ํ˜•: ์ผ๋ถ€ ๊ตฌ๊ฐ„์— ํ›จ์”ฌ ๋งŽ์€ ์š”์†Œ๊ฐ€ ๋ชฐ๋ฆด ์ˆ˜ ์žˆ์Œ
  • ๋ณต์žกํ•œ ๋™๊ธฐํ™”: ๋ฐฐ๋ฆฌ์–ด์™€ ์›์ž์  ์—ฐ์‚ฐ์ด ํ•„์š”

๊ณ ๊ธ‰ ๋ฐฉ์‹: block.prefix_sum() ์กฐ์œจ

๋ณต์žกํ•œ ๋ณ‘๋ ฌ ํŒŒํ‹ฐ์…”๋‹์„ ์กฐ์œจ๋œ ์ถ”์ถœ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค:

์™„์„ฑํ•  ์ฝ”๋“œ

block.prefix_sum() ๋ฐฉ์‹

block.prefix_sum()์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ณ‘๋ ฌ ํžˆ์Šคํ† ๊ทธ๋žจ ๊ตฌ๊ฐ„ ๋ถ„๋ฅ˜๋ฅผ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค:

comptime bin_layout = Layout.row_major(SIZE)  # Max SIZE elements per bin


fn block_histogram_bin_extract[
    in_layout: Layout, bin_layout: Layout, out_layout: Layout, tpb: Int
](
    input_data: LayoutTensor[dtype, in_layout, ImmutAnyOrigin],
    bin_output: LayoutTensor[dtype, bin_layout, MutAnyOrigin],
    count_output: LayoutTensor[DType.int32, out_layout, MutAnyOrigin],
    size: Int,
    target_bin: Int,
    num_bins: Int,
):
    """Parallel histogram using block.prefix_sum() for bin extraction.

    This demonstrates advanced parallel filtering and extraction:
    1. Each thread determines which bin its element belongs to
    2. Use block.prefix_sum() to compute write positions for target_bin elements
    3. Extract and pack only elements belonging to target_bin
    """

    global_i = Int(block_dim.x * block_idx.x + thread_idx.x)
    local_i = Int(thread_idx.x)

    # Step 1: Each thread determines its bin and element value

    # FILL IN (roughly 9 lines)

    # Step 2: Create predicate for target bin extraction

    # FILL IN (roughly 3 line)

    # Step 3: Use block.prefix_sum() for parallel bin extraction!
    # This computes where each thread should write within the target bin

    # FILL IN (1 line)

    # Step 4: Extract and pack elements belonging to target_bin

    # FILL IN (roughly 2 line)

    # Step 5: Final thread computes total count for this bin

    # FILL IN (roughly 3 line)


์ „์ฒด ํŒŒ์ผ ๋ณด๊ธฐ: problems/p27/p27.mojo

ํŒ

1. ํ•ต์‹ฌ ์•Œ๊ณ ๋ฆฌ์ฆ˜ ๊ตฌ์กฐ (์ด์ „ ํผ์ฆ์—์„œ ์ ์šฉ)

block_sum_dot_product์™€ ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ ๋‹ค์Œ ํ•ต์‹ฌ ๋ณ€์ˆ˜๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค:

global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x

ํ•จ์ˆ˜๋Š” 5๊ฐ€์ง€ ์ฃผ์š” ๋‹จ๊ณ„(์ด ์•ฝ 15-20์ค„)๋กœ ๊ตฌ์„ฑ๋ฉ๋‹ˆ๋‹ค:

  1. ์š”์†Œ๋ฅผ ๋กœ๋“œํ•˜๊ณ  ๊ตฌ๊ฐ„์„ ๊ฒฐ์ •
  2. ๋Œ€์ƒ ๊ตฌ๊ฐ„์— ๋Œ€ํ•œ ์ด์ง„ ํ”„๋ ˆ๋””์ผ€์ดํŠธ ์ƒ์„ฑ
  3. ํ”„๋ ˆ๋””์ผ€์ดํŠธ์— block.prefix_sum() ์‹คํ–‰
  4. ๊ณ„์‚ฐ๋œ ์˜คํ”„์…‹์„ ์‚ฌ์šฉํ•˜์—ฌ ์กฐ๊ฑด๋ถ€ ์“ฐ๊ธฐ
  5. ๋งˆ์ง€๋ง‰ ์Šค๋ ˆ๋“œ๊ฐ€ ์ด ๊ฐœ์ˆ˜๋ฅผ ๊ณ„์‚ฐ

2. ๊ตฌ๊ฐ„ ๊ณ„์‚ฐ (math.floor ์‚ฌ์šฉ)

Float32 ๊ฐ’์„ ๊ตฌ๊ฐ„์œผ๋กœ ๋ถ„๋ฅ˜ํ•˜๋ ค๋ฉด:

my_value = input_data[global_i][0]  # ๋‚ด์ ์—์„œ์ฒ˜๋Ÿผ SIMD ์ถ”์ถœ
bin_number = Int(floor(my_value * num_bins))

๊ฒฝ๊ณ„ ์‚ฌ๋ก€ ์ฒ˜๋ฆฌ: ์ •ํ™•ํžˆ 1.0์ธ ๊ฐ’์€ ๊ตฌ๊ฐ„ NUM_BINS์— ๋“ค์–ด๊ฐ€์ง€๋งŒ, ์‹ค์ œ ๊ตฌ๊ฐ„์€ 0๋ถ€ํ„ฐ NUM_BINS-1๊นŒ์ง€์ž…๋‹ˆ๋‹ค. if ๋ฌธ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ตœ๋Œ€ ๊ตฌ๊ฐ„์„ ์ œํ•œํ•˜์„ธ์š”.

3. ์ด์ง„ ํ”„๋ ˆ๋””์ผ€์ดํŠธ ์ƒ์„ฑ

์ด ์Šค๋ ˆ๋“œ์˜ ์š”์†Œ๊ฐ€ target_bin์— ์†ํ•˜๋Š”์ง€๋ฅผ ๋‚˜ํƒ€๋‚ด๋Š” ์ •์ˆ˜ ๋ณ€์ˆ˜(0 ๋˜๋Š” 1)๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค:

var belongs_to_target: Int = 0
if (thread_has_valid_element) and (my_bin == target_bin):
    belongs_to_target = 1

์ด๊ฒƒ์ด ํ•ต์‹ฌ ํ†ต์ฐฐ์ž…๋‹ˆ๋‹ค: ๋ˆ„์  ํ•ฉ์ด ์ด ์ด์ง„ ํ”Œ๋ž˜๊ทธ์— ์ž‘์šฉํ•˜์—ฌ ์œ„์น˜๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค!

4. block.prefix_sum() ํ˜ธ์ถœ ํŒจํ„ด

๋ฌธ์„œ์— ๋”ฐ๋ฅด๋ฉด ํ˜ธ์ถœ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:

offset = block.prefix_sum[
    dtype=DType.int32,         # ์ •์ˆ˜ ํ”„๋ ˆ๋””์ผ€์ดํŠธ๋กœ ์ž‘์—…
    block_size=tpb,            # block.sum()๊ณผ ๋™์ผ
    exclusive=True             # ํ•ต์‹ฌ: ๊ฐ ์Šค๋ ˆ๋“œ ์ด์ „์˜ ์œ„์น˜๋ฅผ ์ œ๊ณต
](val=SIMD[DType.int32, 1](my_predicate_value))

์™œ ๋น„ํฌํ•จ(exclusive)์ธ๊ฐ€? ์œ„์น˜ 5์—์„œ ํ”„๋ ˆ๋””์ผ€์ดํŠธ=1์ธ ์Šค๋ ˆ๋“œ๋Š”, ์ž์‹  ์•ž์— 4๊ฐœ์˜ ์š”์†Œ๊ฐ€ ์žˆ์—ˆ๋‹ค๋ฉด output[4]์— ์จ์•ผ ํ•ฉ๋‹ˆ๋‹ค.

5. ์กฐ๊ฑด๋ถ€ ์“ฐ๊ธฐ ํŒจํ„ด

belongs_to_target == 1์ธ ์Šค๋ ˆ๋“œ๋งŒ ๊ธฐ๋กํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:

if belongs_to_target == 1:
    bin_output[Int(offset[0])] = my_value  # ์ธ๋ฑ์‹ฑ์„ ์œ„ํ•ด SIMD๋ฅผ Int๋กœ ๋ณ€ํ™˜

์ด๊ฒƒ์€ Puzzle 12์˜ ๊ฒฝ๊ณ„ ๊ฒ€์‚ฌ ํŒจํ„ด๊ณผ ๋™์ผํ•˜์ง€๋งŒ, ์กฐ๊ฑด์ด โ€œ๋Œ€์ƒ ๊ตฌ๊ฐ„์— ์†ํ•˜๋Š”์ง€โ€œ๋กœ ๋ฐ”๋€Œ์—ˆ์Šต๋‹ˆ๋‹ค.

6. ์ตœ์ข… ๊ฐœ์ˆ˜ ๊ณ„์‚ฐ

๋งˆ์ง€๋ง‰ ์Šค๋ ˆ๋“œ(์Šค๋ ˆ๋“œ 0์ด ์•„๋‹˜!)๊ฐ€ ์ด ๊ฐœ์ˆ˜๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค:

if local_i == tpb - 1:  # ๋ธ”๋ก์˜ ๋งˆ์ง€๋ง‰ ์Šค๋ ˆ๋“œ
    total_count = offset[0] + belongs_to_target  # ํฌํ•จ = ๋น„ํฌํ•จ + ์ž์‹ ์˜ ๊ธฐ์—ฌ๋ถ„
    count_output[0] = total_count

์™œ ๋งˆ์ง€๋ง‰ ์Šค๋ ˆ๋“œ์ธ๊ฐ€? ๊ฐ€์žฅ ๋†’์€ offset ๊ฐ’์„ ๊ฐ€์ง€๋ฏ€๋กœ, offset + ๊ธฐ์—ฌ๋ถ„์ด ์ด ๊ฐœ์ˆ˜๊ฐ€ ๋ฉ๋‹ˆ๋‹ค.

7. ๋ฐ์ดํ„ฐ ํƒ€์ž…๊ณผ ๋ณ€ํ™˜

์ด์ „ ํผ์ฆ์˜ ํŒจํ„ด์„ ๊ธฐ์–ตํ•˜์„ธ์š”:

  • LayoutTensor ์ธ๋ฑ์‹ฑ์€ SIMD๋ฅผ ๋ฐ˜ํ™˜: input_data[i][0]
  • block.prefix_sum()์€ SIMD๋ฅผ ๋ฐ˜ํ™˜: offset[0]์œผ๋กœ ์ถ”์ถœ
  • ๋ฐฐ์—ด ์ธ๋ฑ์‹ฑ์€ Int๊ฐ€ ํ•„์š”: bin_output[...]์— Int(offset[0])

block.prefix_sum() ๋ฐฉ์‹ ํ…Œ์ŠคํŠธ:

pixi run p27 --histogram
pixi run -e amd p27 --histogram
pixi run -e apple p27 --histogram
uv run poe p27 --histogram

ํ’€์—ˆ์„ ๋•Œ์˜ ์˜ˆ์ƒ ์ถœ๋ ฅ:

SIZE: 128
TPB: 128
NUM_BINS: 8

Input sample: 0.0 0.01 0.02 0.03 0.04 0.05 0.06 0.07 0.08 0.09 0.1 0.11 0.12 0.13 0.14 0.15 ...

=== Processing Bin 0 (range [ 0.0 , 0.125 )) ===
Bin 0 count: 26
Bin 0 extracted elements: 0.0 0.01 0.02 0.03 0.04 0.05 0.06 0.07 ...

=== Processing Bin 1 (range [ 0.125 , 0.25 )) ===
Bin 1 count: 24
Bin 1 extracted elements: 0.13 0.14 0.15 0.16 0.17 0.18 0.19 0.2 ...

=== Processing Bin 2 (range [ 0.25 , 0.375 )) ===
Bin 2 count: 26
Bin 2 extracted elements: 0.25 0.26 0.27 0.28 0.29 0.3 0.31 0.32 ...

=== Processing Bin 3 (range [ 0.375 , 0.5 )) ===
Bin 3 count: 22
Bin 3 extracted elements: 0.38 0.39 0.4 0.41 0.42 0.43 0.44 0.45 ...

=== Processing Bin 4 (range [ 0.5 , 0.625 )) ===
Bin 4 count: 13
Bin 4 extracted elements: 0.5 0.51 0.52 0.53 0.54 0.55 0.56 0.57 ...

=== Processing Bin 5 (range [ 0.625 , 0.75 )) ===
Bin 5 count: 12
Bin 5 extracted elements: 0.63 0.64 0.65 0.66 0.67 0.68 0.69 0.7 ...

=== Processing Bin 6 (range [ 0.75 , 0.875 )) ===
Bin 6 count: 5
Bin 6 extracted elements: 0.75 0.76 0.77 0.78 0.79

=== Processing Bin 7 (range [ 0.875 , 1.0 )) ===
Bin 7 count: 0
Bin 7 extracted elements:

์†”๋ฃจ์…˜

fn block_histogram_bin_extract[
    in_layout: Layout, bin_layout: Layout, out_layout: Layout, tpb: Int
](
    input_data: LayoutTensor[dtype, in_layout, ImmutAnyOrigin],
    bin_output: LayoutTensor[dtype, bin_layout, MutAnyOrigin],
    count_output: LayoutTensor[DType.int32, out_layout, MutAnyOrigin],
    size: Int,
    target_bin: Int,
    num_bins: Int,
):
    """Parallel histogram using block.prefix_sum() for bin extraction.

    This demonstrates advanced parallel filtering and extraction:
    1. Each thread determines which bin its element belongs to
    2. Use block.prefix_sum() to compute write positions for target_bin elements
    3. Extract and pack only elements belonging to target_bin
    """

    global_i = Int(block_dim.x * block_idx.x + thread_idx.x)
    local_i = Int(thread_idx.x)

    # Step 1: Each thread determines its bin and element value
    var my_value: Scalar[dtype] = 0.0
    var my_bin: Int = -1

    if global_i < size:
        # `[0]` returns the underlying SIMD value
        my_value = input_data[global_i][0]
        # Bin values [0.0, 1.0) into num_bins buckets
        my_bin = Int(floor(my_value * num_bins))
        # Clamp to valid range
        if my_bin >= num_bins:
            my_bin = num_bins - 1
        if my_bin < 0:
            my_bin = 0

    # Step 2: Create predicate for target bin extraction
    var belongs_to_target: Int = 0
    if global_i < size and my_bin == target_bin:
        belongs_to_target = 1

    # Step 3: Use block.prefix_sum() for parallel bin extraction!
    # This computes where each thread should write within the target bin
    write_offset = block.prefix_sum[
        dtype = DType.int32, block_size=tpb, exclusive=True
    ](val=SIMD[DType.int32, 1](belongs_to_target))

    # Step 4: Extract and pack elements belonging to target_bin
    if belongs_to_target == 1:
        bin_output[Int(write_offset[0])] = my_value

    # Step 5: Final thread computes total count for this bin
    if local_i == tpb - 1:
        # Inclusive sum = exclusive sum + my contribution
        total_count = write_offset[0] + belongs_to_target
        count_output[0] = total_count


block.prefix_sum() ์ปค๋„์€ ์ด์ „ ํผ์ฆ์˜ ๊ฐœ๋…์„ ๊ธฐ๋ฐ˜์œผ๋กœ ๊ณ ๊ธ‰ ๋ณ‘๋ ฌ ์กฐ์œจ ํŒจํ„ด์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค:

๋‹จ๊ณ„๋ณ„ ์•Œ๊ณ ๋ฆฌ์ฆ˜ ๋ถ„์„:

1๋‹จ๊ณ„: ์š”์†Œ ์ฒ˜๋ฆฌ (Puzzle 12 ๋‚ด์ ๊ณผ ์œ ์‚ฌ)

์Šค๋ ˆ๋“œ ์ธ๋ฑ์‹ฑ (์ต์ˆ™ํ•œ ํŒจํ„ด):
  global_i = block_dim.x * block_idx.x + thread_idx.x  // ์ „์—ญ ์š”์†Œ ์ธ๋ฑ์Šค
  local_i = thread_idx.x                               // ๋กœ์ปฌ ์Šค๋ ˆ๋“œ ์ธ๋ฑ์Šค

์š”์†Œ ๋กœ๋”ฉ (LayoutTensor ํŒจํ„ด๊ณผ ๋™์ผ):
  ์Šค๋ ˆ๋“œ 0:  my_value = input_data[0][0] = 0.00
  ์Šค๋ ˆ๋“œ 1:  my_value = input_data[1][0] = 0.01
  ์Šค๋ ˆ๋“œ 13: my_value = input_data[13][0] = 0.13
  ์Šค๋ ˆ๋“œ 25: my_value = input_data[25][0] = 0.25
  ...

2๋‹จ๊ณ„: ๊ตฌ๊ฐ„ ๋ถ„๋ฅ˜ (์ƒˆ๋กœ์šด ๊ฐœ๋…)

floor ์—ฐ์‚ฐ์„ ์‚ฌ์šฉํ•œ ๊ตฌ๊ฐ„ ๊ณ„์‚ฐ:
  ์Šค๋ ˆ๋“œ 0:  my_bin = Int(floor(0.00 * 8)) = 0  // ๊ฐ’ [0.000, 0.125) โ†’ ๊ตฌ๊ฐ„ 0
  ์Šค๋ ˆ๋“œ 1:  my_bin = Int(floor(0.01 * 8)) = 0  // ๊ฐ’ [0.000, 0.125) โ†’ ๊ตฌ๊ฐ„ 0
  ์Šค๋ ˆ๋“œ 13: my_bin = Int(floor(0.13 * 8)) = 1  // ๊ฐ’ [0.125, 0.250) โ†’ ๊ตฌ๊ฐ„ 1
  ์Šค๋ ˆ๋“œ 25: my_bin = Int(floor(0.25 * 8)) = 2  // ๊ฐ’ [0.250, 0.375) โ†’ ๊ตฌ๊ฐ„ 2
  ...

3๋‹จ๊ณ„: ์ด์ง„ ํ”„๋ ˆ๋””์ผ€์ดํŠธ ์ƒ์„ฑ (ํ•„ํ„ฐ๋ง ํŒจํ„ด)

target_bin=0์— ๋Œ€ํ•ด ์ถ”์ถœ ๋งˆ์Šคํฌ ์ƒ์„ฑ:
  ์Šค๋ ˆ๋“œ 0:  belongs_to_target = 1  (๊ตฌ๊ฐ„ 0 == ๋Œ€์ƒ 0)
  ์Šค๋ ˆ๋“œ 1:  belongs_to_target = 1  (๊ตฌ๊ฐ„ 0 == ๋Œ€์ƒ 0)
  ์Šค๋ ˆ๋“œ 13: belongs_to_target = 0  (๊ตฌ๊ฐ„ 1 != ๋Œ€์ƒ 0)
  ์Šค๋ ˆ๋“œ 25: belongs_to_target = 0  (๊ตฌ๊ฐ„ 2 != ๋Œ€์ƒ 0)
  ...

์ด์ง„ ๋ฐฐ์—ด ์ƒ์„ฑ: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, ...]

4๋‹จ๊ณ„: ๋ณ‘๋ ฌ ๋ˆ„์  ํ•ฉ (๋งˆ๋ฒ•์ด ์ผ์–ด๋‚˜๋Š” ๊ณณ!)

ํ”„๋ ˆ๋””์ผ€์ดํŠธ์— block.prefix_sum[exclusive=True] ์ ์šฉ:
์ž…๋ ฅ:      [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, ...]
๋น„ํฌํ•จ:    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, -, -, -, ...]
                                                      ^
                                                 ์ค‘์š”ํ•˜์ง€ ์•Š์Œ

ํ•ต์‹ฌ ํ†ต์ฐฐ: ๊ฐ ์Šค๋ ˆ๋“œ๊ฐ€ ์ถœ๋ ฅ ๋ฐฐ์—ด์—์„œ ์ž์‹ ์˜ ์“ฐ๊ธฐ ์œ„์น˜๋ฅผ ๋ฐ›์Šต๋‹ˆ๋‹ค!

5๋‹จ๊ณ„: ์กฐ์œจ๋œ ์ถ”์ถœ (์กฐ๊ฑด๋ถ€ ์“ฐ๊ธฐ)

belongs_to_target=1์ธ ์Šค๋ ˆ๋“œ๋งŒ ๊ธฐ๋ก:
  ์Šค๋ ˆ๋“œ 0:  bin_output[0] = 0.00   // write_offset[0] = 0 ์‚ฌ์šฉ
  ์Šค๋ ˆ๋“œ 1:  bin_output[1] = 0.01   // write_offset[1] = 1 ์‚ฌ์šฉ
  ์Šค๋ ˆ๋“œ 12: bin_output[12] = 0.12  // write_offset[12] = 12 ์‚ฌ์šฉ
  ์Šค๋ ˆ๋“œ 13: (๊ธฐ๋ก ์•ˆ ํ•จ)             // belongs_to_target = 0
  ์Šค๋ ˆ๋“œ 25: (๊ธฐ๋ก ์•ˆ ํ•จ)             // belongs_to_target = 0
  ...

๊ฒฐ๊ณผ: [0.00, 0.01, 0.02, ..., 0.12, ???, ???, ...] // ๋นˆํ‹ˆ์—†์ด ์ฑ„์›Œ์ง!

6๋‹จ๊ณ„: ๊ฐœ์ˆ˜ ๊ณ„์‚ฐ (block.sum() ํŒจํ„ด๊ณผ ์œ ์‚ฌ)

๋งˆ์ง€๋ง‰ ์Šค๋ ˆ๋“œ๊ฐ€ ์ด ๊ฐœ์ˆ˜๋ฅผ ๊ณ„์‚ฐ (์Šค๋ ˆ๋“œ 0์ด ์•„๋‹˜!):
  if local_i == tpb - 1:  // ์ด ๊ฒฝ์šฐ ์Šค๋ ˆ๋“œ 127
      total = write_offset[0] + belongs_to_target  // ํฌํ•จ ํ•ฉ ๊ณต์‹
      count_output[0] = total

์ด ๊ณ ๊ธ‰ ์•Œ๊ณ ๋ฆฌ์ฆ˜์ด ๋™์ž‘ํ•˜๋Š” ์ด์œ :

Puzzle 12 (๊ธฐ์กด ๋‚ด์ )๊ณผ์˜ ์—ฐ๊ฒฐ:

  • ๋™์ผํ•œ ์Šค๋ ˆ๋“œ ์ธ๋ฑ์‹ฑ: global_i์™€ local_i ํŒจํ„ด
  • ๋™์ผํ•œ ๊ฒฝ๊ณ„ ๊ฒ€์‚ฌ: if global_i < size ๊ฒ€์ฆ
  • ๋™์ผํ•œ ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ: [0]์„ ์‚ฌ์šฉํ•œ LayoutTensor SIMD ์ถ”์ถœ

block.sum() (์ด ํผ์ฆ์˜ ์•ž๋ถ€๋ถ„)๊ณผ์˜ ์—ฐ๊ฒฐ:

  • ๋™์ผํ•œ ๋ธ”๋ก ์ „์ฒด ์—ฐ์‚ฐ: ๋ชจ๋“  ์Šค๋ ˆ๋“œ๊ฐ€ ๋ธ”๋ก ๊ธฐ๋ณธ ์š”์†Œ์— ์ฐธ์—ฌ
  • ๋™์ผํ•œ ๊ฒฐ๊ณผ ์ฒ˜๋ฆฌ: ํŠน์ • ์Šค๋ ˆ๋“œ(์ฒซ ๋ฒˆ์งธ ๋Œ€์‹  ๋งˆ์ง€๋ง‰)๊ฐ€ ์ตœ์ข… ๊ฒฐ๊ณผ ์ฒ˜๋ฆฌ
  • ๋™์ผํ•œ SIMD ๋ณ€ํ™˜: ๋ฐฐ์—ด ์ธ๋ฑ์‹ฑ์„ ์œ„ํ•œ Int(result[0]) ํŒจํ„ด

block.prefix_sum()๋งŒ์˜ ๊ณ ๊ธ‰ ๊ฐœ๋…:

  • ๋ชจ๋“  ์Šค๋ ˆ๋“œ๊ฐ€ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ›์Œ: ์Šค๋ ˆ๋“œ 0๋งŒ ์ค‘์š”ํ•œ block.sum()๊ณผ ๋‹ฌ๋ฆฌ
  • ์กฐ์œจ๋œ ์“ฐ๊ธฐ ์œ„์น˜: ๋ˆ„์  ํ•ฉ์ด ๊ฒฝ์Ÿ ์ƒํƒœ๋ฅผ ์ž๋™์œผ๋กœ ์ œ๊ฑฐ
  • ๋ณ‘๋ ฌ ํ•„ํ„ฐ๋ง: ์ด์ง„ ํ”„๋ ˆ๋””์ผ€์ดํŠธ๊ฐ€ ๊ณ ๊ธ‰ ๋ฐ์ดํ„ฐ ์žฌ๊ตฌ์„ฑ์„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•จ

๋‹จ์ˆœํ•œ ๋ฐฉ์‹ ๋Œ€๋น„ ์„ฑ๋Šฅ ์ด์ :

vs. ์›์ž์  ์—ฐ์‚ฐ:

  • ๊ฒฝ์Ÿ ์ƒํƒœ ์—†์Œ: ๋ˆ„์  ํ•ฉ์ด ๊ณ ์œ ํ•œ ์“ฐ๊ธฐ ์œ„์น˜๋ฅผ ์ œ๊ณต
  • ๋ณ‘ํ•ฉ๋œ ๋ฉ”๋ชจ๋ฆฌ: ์ˆœ์ฐจ์  ์“ฐ๊ธฐ๊ฐ€ ์บ์‹œ ์„ฑ๋Šฅ์„ ํ–ฅ์ƒ
  • ์ง๋ ฌํ™” ์—†์Œ: ๋ชจ๋“  ์“ฐ๊ธฐ๊ฐ€ ๋ณ‘๋ ฌ๋กœ ์ˆ˜ํ–‰

vs. ๋‹ค์ค‘ ํŒจ์Šค ์•Œ๊ณ ๋ฆฌ์ฆ˜:

  • ๋‹จ์ผ ์ปค๋„: ํ•œ ๋ฒˆ์˜ GPU ์‹คํ–‰์œผ๋กœ ํžˆ์Šคํ† ๊ทธ๋žจ ์ถ”์ถœ ์™„๋ฃŒ
  • ์™„์ „ ํ™œ์šฉ: ๋ฐ์ดํ„ฐ ๋ถ„ํฌ์— ๊ด€๊ณ„์—†์ด ๋ชจ๋“  ์Šค๋ ˆ๋“œ๊ฐ€ ์ž‘์—…
  • ์ตœ์  ๋ฉ”๋ชจ๋ฆฌ ๋Œ€์—ญํญ: GPU ๋ฉ”๋ชจ๋ฆฌ ๊ณ„์ธต ๊ตฌ์กฐ์— ์ตœ์ ํ™”๋œ ํŒจํ„ด

์ด๊ฒƒ์€ block.prefix_sum()์ด block.sum() ๊ฐ™์€ ๋‹จ์ˆœํ•œ ๊ธฐ๋ณธ ์š”์†Œ๋กœ๋Š” ๋ณต์žกํ•˜๊ฑฐ๋‚˜ ๋ถˆ๊ฐ€๋Šฅํ•œ ๊ณ ๊ธ‰ ๋ณ‘๋ ฌ ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ์–ด๋–ป๊ฒŒ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•˜๋Š”์ง€ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

์„ฑ๋Šฅ ์ธ์‚ฌ์ดํŠธ

block.prefix_sum() vs ๊ธฐ์กด ๋ฐฉ์‹:

  • ์•Œ๊ณ ๋ฆฌ์ฆ˜ ์ •๊ตํ•จ: ๊ณ ๊ธ‰ ๋ณ‘๋ ฌ ํŒŒํ‹ฐ์…”๋‹ vs ์ˆœ์ฐจ์  ์ฒ˜๋ฆฌ
  • ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ: ๋ณ‘ํ•ฉ๋œ ์“ฐ๊ธฐ vs ๋ถ„์‚ฐ๋œ ๋ฌด์ž‘์œ„ ์ ‘๊ทผ
  • ๋™๊ธฐํ™”: ๋‚ด์žฅ ์กฐ์œจ vs ์ˆ˜๋™ ๋ฐฐ๋ฆฌ์–ด์™€ ์›์ž์  ์—ฐ์‚ฐ
  • ํ™•์žฅ์„ฑ: ๋ชจ๋“  ๋ธ”๋ก ํฌ๊ธฐ์™€ ๊ตฌ๊ฐ„ ์ˆ˜์— ๋™์ž‘

block.prefix_sum() vs block.sum():

  • ๋ฒ”์œ„: ๋ชจ๋“  ์Šค๋ ˆ๋“œ๊ฐ€ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ›์Œ vs ์Šค๋ ˆ๋“œ 0๋งŒ
  • ์šฉ๋„: ๋ณต์žกํ•œ ํŒŒํ‹ฐ์…”๋‹ vs ๋‹จ์ˆœํ•œ ์ง‘๊ณ„
  • ์•Œ๊ณ ๋ฆฌ์ฆ˜ ์œ ํ˜•: ๋ณ‘๋ ฌ ์Šค์บ” ๊ธฐ๋ณธ ์š”์†Œ vs ๋ฆฌ๋•์…˜ ๊ธฐ๋ณธ ์š”์†Œ
  • ์ถœ๋ ฅ ํŒจํ„ด: ์Šค๋ ˆ๋“œ๋ณ„ ์œ„์น˜ vs ๋‹จ์ผ ํ•ฉ๊ณ„

block.prefix_sum()์„ ์‚ฌ์šฉํ•ด์•ผ ํ•  ๋•Œ:

  • ๋ณ‘๋ ฌ ํ•„ํ„ฐ๋ง: ์กฐ๊ฑด์— ๋งž๋Š” ์š”์†Œ ์ถ”์ถœ
  • ์ŠคํŠธ๋ฆผ ์ปดํŒฉ์…˜: ๋ถˆํ•„์š”ํ•œ ์š”์†Œ ์ œ๊ฑฐ
  • ๋ณ‘๋ ฌ ํŒŒํ‹ฐ์…”๋‹: ๋ฐ์ดํ„ฐ๋ฅผ ์นดํ…Œ๊ณ ๋ฆฌ๋ณ„๋กœ ๋ถ„๋ฆฌ
  • ๊ณ ๊ธ‰ ์•Œ๊ณ ๋ฆฌ์ฆ˜: ๋ถ€ํ•˜ ๋ถ„์‚ฐ, ์ •๋ ฌ, ๊ทธ๋ž˜ํ”„ ์•Œ๊ณ ๋ฆฌ์ฆ˜

๋‹ค์Œ ๋‹จ๊ณ„

block.prefix_sum() ์—ฐ์‚ฐ์„ ๋ฐฐ์› ์œผ๋‹ˆ, ๋‹ค์Œ์œผ๋กœ ์ง„ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

  • block.broadcast()์™€ ๋ฒกํ„ฐ ์ •๊ทœํ™”: ๋ธ”๋ก ๋‚ด ๋ชจ๋“  ์Šค๋ ˆ๋“œ์— ๊ฐ’์„ ๊ณต์œ 
  • ๋ฉ€ํ‹ฐ ๋ธ”๋ก ์•Œ๊ณ ๋ฆฌ์ฆ˜: ๋” ํฐ ๋ฌธ์ œ๋ฅผ ์œ„ํ•œ ์—ฌ๋Ÿฌ ๋ธ”๋ก ๊ฐ„ ์กฐ์œจ
  • ๊ณ ๊ธ‰ ๋ณ‘๋ ฌ ์•Œ๊ณ ๋ฆฌ์ฆ˜: ์ •๋ ฌ, ๊ทธ๋ž˜ํ”„ ํƒ์ƒ‰, ๋™์  ๋ถ€ํ•˜ ๋ถ„์‚ฐ
  • ๋ณต์žกํ•œ ๋ฉ”๋ชจ๋ฆฌ ํŒจํ„ด: ๋ธ”๋ก ์—ฐ์‚ฐ๊ณผ ๊ณ ๊ธ‰ ๋ฉ”๋ชจ๋ฆฌ ์ ‘๊ทผ์˜ ๊ฒฐํ•ฉ

๐Ÿ’ก ํ•ต์‹ฌ ์š”์ : ๋ธ”๋ก ๋ˆ„์  ํ•ฉ ์—ฐ์‚ฐ์€ GPU ํ”„๋กœ๊ทธ๋ž˜๋ฐ์„ ๋‹จ์ˆœํ•œ ๋ณ‘๋ ฌ ๊ณ„์‚ฐ์—์„œ ๊ณ ๊ธ‰ ๋ณ‘๋ ฌ ์•Œ๊ณ ๋ฆฌ์ฆ˜์œผ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค. block.sum()์ด ๋ฆฌ๋•์…˜์„ ๋‹จ์ˆœํ™”ํ–ˆ๋‹ค๋ฉด, block.prefix_sum()์€ ๊ณ ์„ฑ๋Šฅ ๋ณ‘๋ ฌ ์•Œ๊ณ ๋ฆฌ์ฆ˜์— ํ•„์ˆ˜์ ์ธ ๊ณ ๊ธ‰ ๋ฐ์ดํ„ฐ ์žฌ๊ตฌ์„ฑ ํŒจํ„ด์„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค.