๊ฐœ์š”

๋ฒกํ„ฐ a์—์„œ ๊ฐ ์œ„์น˜์˜ ์ง์ „ 3๊ฐœ ๊ฐ’์˜ ํ•ฉ์„ ๊ณ„์‚ฐํ•˜์—ฌ ๋ฒกํ„ฐ output์— ์ €์žฅํ•˜๋Š” ์ปค๋„์„ ๊ตฌํ˜„ํ•˜์„ธ์š”.

์ฐธ๊ณ : ๊ฐ ์œ„์น˜๋งˆ๋‹ค ์Šค๋ ˆ๋“œ 1๊ฐœ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. ์Šค๋ ˆ๋“œ๋‹น ์ „์—ญ ์ฝ๊ธฐ 1ํšŒ, ์ „์—ญ ์“ฐ๊ธฐ 1ํšŒ๋งŒ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

ํ•ต์‹ฌ ๊ฐœ๋…

์ด ํผ์ฆ์—์„œ ๋ฐฐ์šธ ๋‚ด์šฉ:

  • ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ๋กœ ์Šฌ๋ผ์ด๋”ฉ ์œˆ๋„์šฐ ์—ฐ์‚ฐ ๊ตฌํ˜„ํ•˜๊ธฐ
  • ํ’€๋ง์˜ ๊ฒฝ๊ณ„ ์กฐ๊ฑด ์ฒ˜๋ฆฌ
  • ์ด์›ƒ ๋ฐ์ดํ„ฐ ์ ‘๊ทผ์„ ์œ„ํ•œ ์Šค๋ ˆ๋“œ ๊ฐ„ ํ˜‘๋ ฅ

ํ•ต์‹ฌ์€ ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•ด ์œˆ๋„์šฐ ๋‚ด ๊ฐ’๋“ค์— ํšจ์œจ์ ์œผ๋กœ ์ ‘๊ทผํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์‹œํ€€์Šค ์•ž๋ถ€๋ถ„์€ ํŠน๋ณ„ํžˆ ์ฒ˜๋ฆฌํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

๊ตฌ์„ฑ

  • ๋ฐฐ์—ด ํฌ๊ธฐ: SIZE = 8
  • ๋ธ”๋ก๋‹น ์Šค๋ ˆ๋“œ ์ˆ˜: TPB = 8
  • ์œˆ๋„์šฐ ํฌ๊ธฐ: 3
  • ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ: TPB๊ฐœ

์ฐธ๊ณ :

  • ์œˆ๋„์šฐ ์ ‘๊ทผ: ๊ฐ ์ถœ๋ ฅ์€ ์ด์ „ ์ตœ๋Œ€ 3๊ฐœ ๊ฐ’์— ์˜์กดํ•ฉ๋‹ˆ๋‹ค
  • ๊ฒฝ๊ณ„ ์ฒ˜๋ฆฌ: ์ฒ˜์Œ ๋‘ ์œ„์น˜๋Š” ํŠน๋ณ„ํ•œ ์ฒ˜๋ฆฌ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค
  • ๋ฉ”๋ชจ๋ฆฌ ํŒจํ„ด: ์Šค๋ ˆ๋“œ๋‹น ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ๋กœ๋“œ 1ํšŒ
  • ์Šค๋ ˆ๋“œ ๋™๊ธฐํ™”: ์œˆ๋„์šฐ ์—ฐ์‚ฐ ์ „์— ์กฐ์œจ ํ•„์š”

์™„์„ฑํ•  ์ฝ”๋“œ

comptime TPB = 8
comptime SIZE = 8
comptime BLOCKS_PER_GRID = (1, 1)
comptime THREADS_PER_BLOCK = (TPB, 1)
comptime dtype = DType.float32


fn pooling(
    output: UnsafePointer[Scalar[dtype], MutAnyOrigin],
    a: UnsafePointer[Scalar[dtype], MutAnyOrigin],
    size: UInt,
):
    shared = stack_allocation[
        TPB,
        Scalar[dtype],
        address_space = AddressSpace.SHARED,
    ]()
    global_i = block_dim.x * block_idx.x + thread_idx.x
    local_i = thread_idx.x
    # FILL ME IN (roughly 10 lines)


์ „์ฒด ํŒŒ์ผ ๋ณด๊ธฐ: problems/p11/p11.mojo

ํŒ
  1. ๋ฐ์ดํ„ฐ๋ฅผ ๋กœ๋“œํ•˜๊ณ  barrier() ํ˜ธ์ถœ
  2. ํŠน์ˆ˜ ์ผ€์ด์Šค: output[0] = shared[0], output[1] = shared[0] + shared[1]
  3. ์ผ๋ฐ˜ ์ผ€์ด์Šค: if 1 < global_i < size
  4. ์„ธ ๊ฐ’์˜ ํ•ฉ: shared[local_i - 2] + shared[local_i - 1] + shared[local_i]

์ฝ”๋“œ ์‹คํ–‰

์†”๋ฃจ์…˜์„ ํ…Œ์ŠคํŠธํ•˜๋ ค๋ฉด ํ„ฐ๋ฏธ๋„์—์„œ ๋‹ค์Œ ๋ช…๋ น์–ด๋ฅผ ์‹คํ–‰ํ•˜์„ธ์š”:

pixi run p11
pixi run -e amd p11
pixi run -e apple p11
uv run poe p11

ํผ์ฆ์„ ์•„์ง ํ’€์ง€ ์•Š์•˜๋‹ค๋ฉด ์ถœ๋ ฅ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:

out: HostBuffer([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
expected: HostBuffer([0.0, 1.0, 3.0, 6.0, 9.0, 12.0, 15.0, 18.0])

์†”๋ฃจ์…˜

fn pooling(
    output: UnsafePointer[Scalar[dtype], MutAnyOrigin],
    a: UnsafePointer[Scalar[dtype], MutAnyOrigin],
    size: UInt,
):
    shared = stack_allocation[
        TPB,
        Scalar[dtype],
        address_space = AddressSpace.SHARED,
    ]()
    global_i = block_dim.x * block_idx.x + thread_idx.x
    local_i = thread_idx.x
    if global_i < size:
        shared[local_i] = a[global_i]

    barrier()

    if global_i == 0:
        output[0] = shared[0]
    elif global_i == 1:
        output[1] = shared[0] + shared[1]
    elif UInt(1) < global_i < size:
        output[global_i] = (
            shared[local_i - 2] + shared[local_i - 1] + shared[local_i]
        )


๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ํ™œ์šฉํ•œ ์Šฌ๋ผ์ด๋”ฉ ์œˆ๋„์šฐ ํ•ฉ๊ณ„ ๊ตฌํ˜„์ž…๋‹ˆ๋‹ค. ์ฃผ์š” ๋‹จ๊ณ„๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:

  1. ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ์„ค์ •

    • ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ์— TPB๊ฐœ ํ• ๋‹น:

      Input array:  [0.0 1.0 2.0 3.0 4.0 5.0 6.0 7.0]
      Block shared: [0.0 1.0 2.0 3.0 4.0 5.0 6.0 7.0]
      
    • ๊ฐ ์Šค๋ ˆ๋“œ๊ฐ€ ์ „์—ญ ๋ฉ”๋ชจ๋ฆฌ์—์„œ ํ•˜๋‚˜์”ฉ ๋กœ๋“œ

    • barrier()๋กœ ๋ชจ๋“  ๋ฐ์ดํ„ฐ ๋กœ๋“œ ์™„๋ฃŒ๋ฅผ ๋ณด์žฅ

  2. ๊ฒฝ๊ณ„ ์ผ€์ด์Šค

    • ์œ„์น˜ 0: ํ•˜๋‚˜๋งŒ

      output[0] = shared[0] = 0.0
      
    • ์œ„์น˜ 1: ์ฒ˜์Œ ๋‘ ๊ฐ’์˜ ํ•ฉ

      output[1] = shared[0] + shared[1] = 0.0 + 1.0 = 1.0
      
  3. ๋ฉ”์ธ ์œˆ๋„์šฐ ์—ฐ์‚ฐ

    • ์œ„์น˜ 2 ์ดํ›„:

      Position 2: shared[0] + shared[1] + shared[2] = 0.0 + 1.0 + 2.0 = 3.0
      Position 3: shared[1] + shared[2] + shared[3] = 1.0 + 2.0 + 3.0 = 6.0
      Position 4: shared[2] + shared[3] + shared[4] = 2.0 + 3.0 + 4.0 = 9.0
      ...
      
    • ๋กœ์ปฌ ์ธ๋ฑ์Šค๋ฅผ ์‚ฌ์šฉํ•œ ์œˆ๋„์šฐ ๊ณ„์‚ฐ:

      # 3๊ฐœ์งœ๋ฆฌ ์Šฌ๋ผ์ด๋”ฉ ์œˆ๋„์šฐ
      window_sum = shared[i-2] + shared[i-1] + shared[i]
      
  4. ๋ฉ”๋ชจ๋ฆฌ ์ ‘๊ทผ ํŒจํ„ด

    • ์Šค๋ ˆ๋“œ๋‹น ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ๋กœ ์ „์—ญ ์ฝ๊ธฐ 1ํšŒ
    • ์Šค๋ ˆ๋“œ๋‹น ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ์—์„œ ์ „์—ญ ์“ฐ๊ธฐ 1ํšŒ
    • ์ด์›ƒ ์ ‘๊ทผ์„ ์œ„ํ•ด ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ํ™œ์šฉ
    • ๋ณ‘ํ•ฉ ๋ฉ”๋ชจ๋ฆฌ ์ ‘๊ทผ ํŒจํ„ด ์œ ์ง€

์ด ๋ฐฉ์‹์˜ ์„ฑ๋Šฅ ์ตœ์ ํ™” ํฌ์ธํŠธ:

  • ์ „์—ญ ๋ฉ”๋ชจ๋ฆฌ ์ ‘๊ทผ ์ตœ์†Œํ™”
  • ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ๋กœ ๋น ๋ฅธ ์ด์›ƒ ์กฐํšŒ
  • ๊น”๋”ํ•œ ๊ฒฝ๊ณ„ ์ฒ˜๋ฆฌ
  • ํšจ์œจ์ ์ธ ๋ฉ”๋ชจ๋ฆฌ ๋ณ‘ํ•ฉ

์ตœ์ข… ์ถœ๋ ฅ์€ ๋ˆ„์  ์œˆ๋„์šฐ ํ•ฉ๊ณ„๋ฅผ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค:

[0.0, 1.0, 3.0, 6.0, 9.0, 12.0, 15.0, 18.0]