์ž„๋ฒ ๋”ฉ ์ปค๋„: ๋ณ‘ํ•ฉ vs ๋น„๋ณ‘ํ•ฉ

์ด ํผ์ฆ์—์„œ๋Š” ๋™์ผํ•œ ๊ฒฐ๊ณผ๋ฅผ ์ƒ์„ฑํ•˜์ง€๋งŒ ์„œ๋กœ ๋‹ค๋ฅธ ๋ฉ”๋ชจ๋ฆฌ ์ ‘๊ทผ ํŒจํ„ด์„ ์‚ฌ์šฉํ•˜๋Š” ๋‘ ๊ฐ€์ง€ GPU ์ž„๋ฒ ๋”ฉ ์ปค๋„์„ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค. GPU ์„ฑ๋Šฅ์—์„œ ๋ฉ”๋ชจ๋ฆฌ ๋ณ‘ํ•ฉ์ด ์–ผ๋งˆ๋‚˜ ์ค‘์š”ํ•œ์ง€ ์ง์ ‘ ์ฒดํ—˜ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

1D ๋ณ‘ํ•ฉ ์ปค๋„ (์ตœ์ ํ™”๋œ ์ ‘๊ทผ๋ฒ•)

์ด ์ปค๋„์€ ๊ฐ ์Šค๋ ˆ๋“œ๊ฐ€ ์ •ํ™•ํžˆ ํ•˜๋‚˜์˜ ์ถœ๋ ฅ ์š”์†Œ๋ฅผ ์ฒ˜๋ฆฌํ•˜๋Š” ๋‹จ์ˆœํ•œ 1D ๊ทธ๋ฆฌ๋“œ๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ํ•ต์‹ฌ์€ ์—ฐ์†๋œ ์Šค๋ ˆ๋“œ๊ฐ€ ์—ฐ์†๋œ ๋ฉ”๋ชจ๋ฆฌ ์œ„์น˜์— ์ ‘๊ทผํ•˜์—ฌ ์ตœ์ ์˜ ๋ฉ”๋ชจ๋ฆฌ ๋ณ‘ํ•ฉ์„ ๋‹ฌ์„ฑํ•œ๋‹ค๋Š” ์ ์ž…๋‹ˆ๋‹ค.

์Šค๋ ˆ๋“œ ๊ตฌ์„ฑ:

  • ๊ทธ๋ฆฌ๋“œ ๊ตฌ์„ฑ: [total_elements // 256] ๋ธ”๋ก, ๋ธ”๋ก๋‹น 256 ์Šค๋ ˆ๋“œ
  • ์Šค๋ ˆ๋“œ ๋งคํ•‘: ๊ฐ ์Šค๋ ˆ๋“œ๊ฐ€ ํ•˜๋‚˜์˜ (batch, seq, embed) ์œ„์น˜ ์ฒ˜๋ฆฌ
  • ๋ฉ”๋ชจ๋ฆฌ ํŒจํ„ด: ์—ฐ์†๋œ ์Šค๋ ˆ๋“œ๊ฐ€ ์—ฐ์†๋œ ์ž„๋ฒ ๋”ฉ ์ฐจ์› ์ ‘๊ทผ

๊ตฌํ˜„ํ•  ๋‚ด์šฉ:

  1. ๋ธ”๋ก ์ธ๋ฑ์Šค์™€ ์Šค๋ ˆ๋“œ ์ธ๋ฑ์Šค๋กœ๋ถ€ํ„ฐ ์ „์—ญ ์Šค๋ ˆ๋“œ ์ธ๋ฑ์Šค ๊ณ„์‚ฐ
  2. 1์ฐจ์› ์ธ๋ฑ์Šค๋ฅผ 3D ์ขŒํ‘œ (batch_idx, seq_idx, embed_idx)๋กœ ๋ณ€ํ™˜
  3. indices ํ…์„œ์—์„œ ํ† ํฐ ์ธ๋ฑ์Šค ์กฐํšŒ
  4. ํ•ด๋‹นํ•˜๋Š” ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ ์š”์†Œ๋ฅผ ์ถœ๋ ฅ์— ๋ณต์‚ฌ

์™„์„ฑํ•  ์ฝ”๋“œ

๋‘ ์ž„๋ฒ ๋”ฉ ์ปค๋„์˜ ๋นˆ ๋ถ€๋ถ„์„ ์™„์„ฑํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:

comptime THREADS_PER_BLOCK = 256


fn embedding_kernel_coalesced[
    indices_layout: Layout,
    weights_layout: Layout,
    out_layout: Layout,
    batch_size: Int,
    seq_len: Int,
    vocab_size: Int,
    embed_dim: Int,
    dtype: DType = DType.float32,
](
    output: LayoutTensor[dtype, out_layout, MutAnyOrigin],
    indices: LayoutTensor[DType.int32, indices_layout, MutAnyOrigin],
    weights: LayoutTensor[dtype, weights_layout, MutAnyOrigin],
):
    """
    Memory-coalescing focused embedding kernel.

    Key insight: The bottleneck is memory access patterns, not computation.
    - Each thread handles one (batch, seq, embed) position
    - Simple 1D grid for maximum simplicity and correctness
    - Focus on getting memory access right first
    """

    # Simple 1D indexing - each thread = one output element
    global_idx = Int(block_idx.x * block_dim.x + thread_idx.x)
    total_elements = batch_size * seq_len * embed_dim

    if global_idx >= total_elements:
        return

    # Convert to (batch, seq, embed) coordinates
    # FILL IN roughly 4 lines

    # Get token index
    # FILL IN 1 line

    # Simple, correct assignment
    # FILL IN 4 lines


์ „์ฒด ํŒŒ์ผ ๋ณด๊ธฐ: problems/p21/op/embedding.mojo

ํŒ
  • global_idx = block_idx.x * block_dim.x + thread_idx.x๋กœ ์‹œ์ž‘ํ•˜์„ธ์š”
  • ๋‚˜๋ˆ—์…ˆ๊ณผ ๋‚˜๋จธ์ง€ ์—ฐ์‚ฐ์œผ๋กœ 3D ์ขŒํ‘œ๋ฅผ ๊ตฌํ•ฉ๋‹ˆ๋‹ค: batch_idx = global_idx // (seq_len * embed_dim)
  • remaining = global_idx % (seq_len * embed_dim)์„ ์‚ฌ์šฉํ•˜๋ฉด ์ดํ›„ ๊ณ„์‚ฐ์ด ๊ฐ„๋‹จํ•ด์ง‘๋‹ˆ๋‹ค
  • ํ•ญ์ƒ ๊ฒฝ๊ณ„ ๊ฒ€์‚ฌ๋ฅผ ํ•˜์„ธ์š”: if global_idx >= total_elements: return
  • ์œ ํšจํ•˜์ง€ ์•Š์€ ํ† ํฐ ์ธ๋ฑ์Šค๋Š” ์ถœ๋ ฅ์„ 0์œผ๋กœ ์„ค์ •ํ•˜์„ธ์š”
  • ์ž„๋ฒ ๋”ฉ ์กฐํšŒ: output[batch_idx, seq_idx, embed_idx] = weights[token_idx, embed_idx]

2D ๋น„๋ณ‘ํ•ฉ ์ปค๋„ (๋น„๊ต์šฉ ์ ‘๊ทผ๋ฒ•)

์ด ์ปค๋„์€ X ์ฐจ์›์ด (batch ร— seq) ์œ„์น˜๋ฅผ, Y ์ฐจ์›์ด ์ž„๋ฒ ๋”ฉ ์ฐจ์›์„ ๋‹ด๋‹นํ•˜๋Š” 2D ๊ทธ๋ฆฌ๋“œ๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ์ด ๋ฐฉ์‹์€ ๋ฉ”๋ชจ๋ฆฌ ์ ‘๊ทผ์ด ๋ณ‘ํ•ฉ๋˜์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์Šค๋ ˆ๋“œ ๊ตฌ์„ฑ:

  • ๊ทธ๋ฆฌ๋“œ ๊ตฌ์„ฑ: [batch x seq // 16, embed_dim // 16] ๋ธ”๋ก, 16 x 16 ์Šค๋ ˆ๋“œ
  • ์Šค๋ ˆ๋“œ ๋งคํ•‘: thread_idx.x๋Š” batch/sequence์—, thread_idx.y๋Š” ์ž„๋ฒ ๋”ฉ ์ฐจ์›์— ๋งคํ•‘
  • ๋ฉ”๋ชจ๋ฆฌ ํŒจํ„ด: ์›Œํ”„ ๋‚ด ์Šค๋ ˆ๋“œ๋“ค์ด ํฉ์–ด์ง„ ๋ฉ”๋ชจ๋ฆฌ ์œ„์น˜์— ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ์Œ

๊ตฌํ˜„ํ•  ๋‚ด์šฉ:

  1. 2D ๊ทธ๋ฆฌ๋“œ์—์„œ X, Y ์ขŒํ‘œ ๊ณ„์‚ฐ
  2. X ์ขŒํ‘œ๋ฅผ batch ์ธ๋ฑ์Šค์™€ sequence ์ธ๋ฑ์Šค๋กœ ๋ถ„๋ฆฌ
  3. Y ์ขŒํ‘œ๋ฅผ ์ž„๋ฒ ๋”ฉ ์ฐจ์›์œผ๋กœ ์ง์ ‘ ์‚ฌ์šฉ
  4. ๊ฒฝ๊ณ„ ๊ฒ€์‚ฌ์™€ ํ•จ๊ป˜ ๋™์ผํ•œ ์ž„๋ฒ ๋”ฉ ์กฐํšŒ ์ˆ˜ํ–‰

์™„์„ฑํ•  ์ฝ”๋“œ

๋‘ ์ž„๋ฒ ๋”ฉ ์ปค๋„์˜ ๋นˆ ๋ถ€๋ถ„์„ ์™„์„ฑํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:

fn embedding_kernel_2d[
    indices_layout: Layout,
    weights_layout: Layout,
    out_layout: Layout,
    batch_size: Int,
    seq_len: Int,
    vocab_size: Int,
    embed_dim: Int,
    dtype: DType = DType.float32,
](
    output: LayoutTensor[dtype, out_layout, MutAnyOrigin],
    indices: LayoutTensor[DType.int32, indices_layout, MutAnyOrigin],
    weights: LayoutTensor[dtype, weights_layout, MutAnyOrigin],
):
    """
    2D grid non-coalesced embedding kernel.

    Non-optimal approach for comparison:
    - 2D grid: (batch*seq, embed_dim)
    - More complex indexing
    - Potentially worse memory access patterns
    """

    # 2D grid indexing
    batch_seq_idx = Int(block_idx.x * block_dim.x + thread_idx.x)
    embed_idx = Int(block_idx.y * block_dim.y + thread_idx.y)
    total_positions = batch_size * seq_len

    if batch_seq_idx >= total_positions or embed_idx >= embed_dim:
        return

    # Convert to (batch, seq) coordinates
    # FILL IN 2 lines

    # Get token index
    # FILL IN 1 line

    # Assignment with 2D grid pattern
    # FILL IN 4 lines


์ „์ฒด ํŒŒ์ผ ๋ณด๊ธฐ: problems/p21/op/embedding.mojo

ํŒ
  • X, Y ์Šค๋ ˆ๋“œ ์ขŒํ‘œ๋ฅผ ๋ชจ๋‘ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค: batch_seq_idx = block_idx.x * block_dim.x + thread_idx.x
  • ๊ทธ๋ฆฌ๊ณ : embed_idx = block_idx.y * block_dim.y + thread_idx.y
  • batch_seq_idx๋ฅผ batch์™€ sequence ์ธ๋ฑ์Šค๋กœ ๋ถ„๋ฆฌํ•ฉ๋‹ˆ๋‹ค: batch_idx = batch_seq_idx // seq_len
  • ๋‘ ์ฐจ์› ๋ชจ๋‘ ๊ฒฝ๊ณ„ ๊ฒ€์‚ฌ๋ฅผ ์žŠ์ง€ ๋งˆ์„ธ์š”: if batch_seq_idx >= total_positions or embed_idx >= embed_dim
  • ํ† ํฐ ์กฐํšŒ๋Š” 1D์™€ ๋™์ผํ•˜์ง€๋งŒ, ์Šค๋ ˆ๋“œ๋‹น ํ•˜๋‚˜์˜ ์ž„๋ฒ ๋”ฉ ์ฐจ์›๋งŒ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค
  • ์ด ์ปค๋„์€ ์ „์ฒด ๋ฒกํ„ฐ๊ฐ€ ์•„๋‹Œ ์Šค๋ ˆ๋“œ๋‹น ํ•˜๋‚˜์˜ ์ž„๋ฒ ๋”ฉ ์ฐจ์›์„ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค

์ปค์Šคํ…€ op ๋“ฑ๋ก

์ปค๋„๋“ค์€ PyTorch์™€ ์‰ฝ๊ฒŒ ํ†ตํ•ฉํ•  ์ˆ˜ ์žˆ๋„๋ก ์ปค์Šคํ…€ ์—ฐ์‚ฐ์œผ๋กœ ๋ž˜ํ•‘๋ฉ๋‹ˆ๋‹ค. ๋“ฑ๋ก ํŒจํ„ด์€ MAX ๊ทธ๋ž˜ํ”„ ์ปค์Šคํ…€ op ์ดํ•ดํ•˜๊ธฐ์—์„œ ์„ค๋ช…ํ•œ MAX ์ปค์Šคํ…€ op๊ณผ ๋™์ผํ•ฉ๋‹ˆ๋‹ค:

1D ๋ณ‘ํ•ฉ ์—ฐ์‚ฐ

์ด ์—ฐ์‚ฐ์€ ์ตœ์ ํ™”๋œ 1D ์ž„๋ฒ ๋”ฉ ์ปค๋„์„ "embedding"์œผ๋กœ ๋“ฑ๋กํ•ฉ๋‹ˆ๋‹ค:

import compiler
from runtime.asyncrt import DeviceContextPtr
from tensor import InputTensor, OutputTensor
from memory import UnsafePointer
from gpu.host import DeviceBuffer


@compiler.register("embedding")
struct EmbeddingCustomOp:
    @staticmethod
    fn execute[
        target: StaticString,
        batch_size: Int,
        seq_len: Int,
        vocab_size: Int,
        embed_dim: Int,
    ](
        output: OutputTensor[
            dtype = DType.float32, rank=3
        ],  # [batch_size, seq_len, embed_dim]
        indices: InputTensor[
            dtype = DType.int32, rank=2
        ],  # [batch_size, seq_len]
        weights: InputTensor[
            dtype = output.dtype, rank=2
        ],  # [vocab_size, embed_dim]
        ctx: DeviceContextPtr,
    ) raises:
        output_tensor = output.to_layout_tensor()
        indices_tensor = indices.to_layout_tensor()
        weights_tensor = weights.to_layout_tensor()

        comptime indices_layout = indices_tensor.layout
        comptime weights_layout = weights_tensor.layout
        comptime out_layout = output_tensor.layout

        @parameter
        if target == "gpu":
            gpu_ctx = ctx.get_device_context()

            # Zero out output tensor
            gpu_ctx.enqueue_memset(
                DeviceBuffer[output.dtype](
                    gpu_ctx,
                    output_tensor.ptr,
                    batch_size * seq_len * embed_dim,
                    owning=False,
                ),
                0,
            )

            # Calculate 1D grid dimensions (matching kernel's flat indexing)
            total_elements = batch_size * seq_len * embed_dim
            blocks = max(1, ceildiv(total_elements, THREADS_PER_BLOCK))

            # Compile and launch optimized kernel
            comptime kernel = embedding_kernel_coalesced[
                indices_layout,
                weights_layout,
                out_layout,
                batch_size,
                seq_len,
                vocab_size,
                embed_dim,
                output.dtype,
            ]
            compiled_kernel = gpu_ctx.compile_function[kernel, kernel]()

            gpu_ctx.enqueue_function(
                compiled_kernel,
                output_tensor,
                indices_tensor,
                weights_tensor,
                grid_dim=(blocks,),
                block_dim=(THREADS_PER_BLOCK,),
            )

        elif target == "cpu":
            for batch in range(batch_size):
                for seq in range(seq_len):
                    token_idx_val = Int(indices_tensor[batch, seq])
                    if token_idx_val >= 0 and token_idx_val < vocab_size:
                        for emb in range(embed_dim):
                            output_tensor[batch, seq, emb] = weights_tensor[
                                token_idx_val, emb
                            ]
        else:
            raise Error("Unsupported target: " + target)


๋“ฑ๋ก์˜ ํ•ต์‹ฌ ์š”์†Œ:

  • ๋‹จ์ˆœํ•œ ๊ทธ๋ฆฌ๋“œ ๊ตฌ์„ฑ: ceildiv(total_elements, THREADS_PER_BLOCK) ๋ธ”๋ก์œผ๋กœ ์ง๊ด€์ ์ธ 1D ๊ทธ๋ฆฌ๋“œ ์‚ฌ์šฉ
  • ๋ฉ”๋ชจ๋ฆฌ ์ตœ์ ํ™”: ๋‹จ์ผ enqueue_memset ํ˜ธ์ถœ๋กœ ์ถœ๋ ฅ ๋ฒ„ํผ๋ฅผ ํšจ์œจ์ ์œผ๋กœ ์ดˆ๊ธฐํ™”
  • ์ปดํŒŒ์ผ ํƒ€์ž„ ํŒŒ๋ผ๋ฏธํ„ฐ: ๋ชจ๋“  ํ…์„œ ์ฐจ์›์„ ์ปดํŒŒ์ผ ํƒ€์ž„ ํŒŒ๋ผ๋ฏธํ„ฐ๋กœ ์ „๋‹ฌํ•˜์—ฌ ์ตœ์  ์„ฑ๋Šฅ ๋‹ฌ์„ฑ
  • ๋””๋ฐ”์ด์Šค ์ถ”์ƒํ™”: GPU ์‹คํ–‰๊ณผ CPU ํด๋ฐฑ์„ ๋งค๋„๋Ÿฝ๊ฒŒ ์ฒ˜๋ฆฌ

2D ๋น„๋ณ‘ํ•ฉ ์—ฐ์‚ฐ

์ด ์—ฐ์‚ฐ์€ ๋น„๊ต์šฉ 2D ์ž„๋ฒ ๋”ฉ ์ปค๋„์„ "embedding_2d"๋กœ ๋“ฑ๋กํ•ฉ๋‹ˆ๋‹ค:

@compiler.register("embedding_2d")
struct Embedding2DCustomOp:
    @staticmethod
    fn execute[
        target: StaticString,
        batch_size: Int,
        seq_len: Int,
        vocab_size: Int,
        embed_dim: Int,
    ](
        output: OutputTensor[
            dtype = DType.float32, rank=3
        ],  # [batch_size, seq_len, embed_dim]
        indices: InputTensor[
            dtype = DType.int32, rank=2
        ],  # [batch_size, seq_len]
        weights: InputTensor[
            dtype = output.dtype, rank=2
        ],  # [vocab_size, embed_dim]
        ctx: DeviceContextPtr,
    ) raises:
        output_tensor = output.to_layout_tensor()
        indices_tensor = indices.to_layout_tensor()
        weights_tensor = weights.to_layout_tensor()

        comptime indices_layout = indices_tensor.layout
        comptime weights_layout = weights_tensor.layout
        comptime out_layout = output_tensor.layout

        @parameter
        if target == "gpu":
            gpu_ctx = ctx.get_device_context()

            # Zero out output tensor
            gpu_ctx.enqueue_memset(
                DeviceBuffer[output.dtype](
                    gpu_ctx,
                    output_tensor.ptr,
                    batch_size * seq_len * embed_dim,
                    owning=False,
                ),
                0,
            )

            # Calculate 2D grid dimensions for non-coalesced access
            total_positions = batch_size * seq_len
            comptime BLOCK_X = 16  # batch*seq dimension
            comptime BLOCK_Y = 16  # embed dimension
            blocks_x = max(1, ceildiv(total_positions, BLOCK_X))
            blocks_y = max(1, ceildiv(embed_dim, BLOCK_Y))

            # Compile and launch 2D kernel
            comptime kernel = embedding_kernel_2d[
                indices_layout,
                weights_layout,
                out_layout,
                batch_size,
                seq_len,
                vocab_size,
                embed_dim,
                output.dtype,
            ]

            compiled_kernel = gpu_ctx.compile_function[kernel, kernel]()

            gpu_ctx.enqueue_function(
                compiled_kernel,
                output_tensor,
                indices_tensor,
                weights_tensor,
                grid_dim=(blocks_x, blocks_y),
                block_dim=(BLOCK_X, BLOCK_Y),
            )

        elif target == "cpu":
            # Same CPU fallback as 1D version
            for batch in range(batch_size):
                for seq in range(seq_len):
                    token_idx_val = Int(indices_tensor[batch, seq])
                    if token_idx_val >= 0 and token_idx_val < vocab_size:
                        for emb in range(embed_dim):
                            output_tensor[batch, seq, emb] = weights_tensor[
                                token_idx_val, emb
                            ]
        else:
            raise Error("Unsupported target: " + target)


1D ์—ฐ์‚ฐ๊ณผ์˜ ์ฃผ์š” ์ฐจ์ด์ :

  • ๋ณต์žกํ•œ ๊ทธ๋ฆฌ๋“œ ๊ตฌ์„ฑ: blocks_x์™€ blocks_y๋ฅผ ๋ณ„๋„๋กœ ๊ณ„์‚ฐํ•˜๋Š” 2D ๊ทธ๋ฆฌ๋“œ ์‚ฌ์šฉ
  • ๊ณ ์ • ๋ธ”๋ก ์ฐจ์›: 2D ์Šค๋ ˆ๋“œ ๊ตฌ์„ฑ์„ ์œ„ํ•ด BLOCK_X = 16, BLOCK_Y = 16์œผ๋กœ ๊ณ ์ •
  • ๋™์ผํ•œ ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ: ๋ฉ”๋ชจ๋ฆฌ ์ดˆ๊ธฐํ™”์™€ CPU ํด๋ฐฑ ๋กœ์ง์€ ๋™์ผ
  • ๋‹ค๋ฅธ ์ปค๋„ ํ˜ธ์ถœ ๋ฐฉ์‹: 2D ๊ทธ๋ฆฌ๋“œ ์ฐจ์› (blocks_x, blocks_y)๊ณผ ๋ธ”๋ก ์ฐจ์› (BLOCK_X, BLOCK_Y) ์ „๋‹ฌ

๊ณตํ†ต ๋ž˜ํผ ๊ธฐ๋Šฅ

๋‘ ์ปค์Šคํ…€ ์—ฐ์‚ฐ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ํ•„์ˆ˜ ์ธํ”„๋ผ๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค:

  1. ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ:

    • enqueue_memset์œผ๋กœ ์ถœ๋ ฅ ํ…์„œ 0 ์ดˆ๊ธฐํ™”
    • ์ ์ ˆํ•œ ๋ฒ„ํผ ์ƒ์„ฑ๊ณผ ๋ฉ”๋ชจ๋ฆฌ ๋ ˆ์ด์•„์›ƒ ์ฒ˜๋ฆฌ
    • ์ž๋™ ์ •๋ฆฌ ๋ฐ ๋ฆฌ์†Œ์Šค ๊ด€๋ฆฌ
  2. ๋””๋ฐ”์ด์Šค ์ถ”์ƒํ™”:

    • ์ตœ์ ํ™”๋œ ์ปค๋„๋กœ GPU ์‹คํ–‰
    • ํ˜ธํ™˜์„ฑ๊ณผ ๋””๋ฒ„๊น…์„ ์œ„ํ•œ CPU ํด๋ฐฑ
    • ์‹คํ–‰ ๋Œ€์ƒ์— ๊ด€๊ณ„์—†์ด ์ผ๊ด€๋œ ์ธํ„ฐํŽ˜์ด์Šค
  3. ํŒŒ๋ผ๋ฏธํ„ฐ ์ „๋‹ฌ:

    • ์ปค๋„ ์ตœ์ ํ™”๋ฅผ ์œ„ํ•œ ์ปดํŒŒ์ผ ํƒ€์ž„ ํ…์„œ ์ฐจ์›
    • ๋ ˆ์ด์•„์›ƒ ํ…์„œ ๋ณ€ํ™˜์„ ํ†ตํ•œ ๋Ÿฐํƒ€์ž„ ํ…์„œ ๋ฐ์ดํ„ฐ
    • ํƒ€์ž… ์•ˆ์ „ํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ ๊ฒ€์ฆ
  4. ๊ทธ๋ฆฌ๋“œ ๊ตฌ์„ฑ:

    • ์ตœ์ ์˜ ๊ทธ๋ฆฌ๋“œ ์ฐจ์› ์ž๋™ ๊ณ„์‚ฐ
    • ๊ฐ ์ปค๋„์˜ ์ ‘๊ทผ ํŒจํ„ด์— ์ตœ์ ํ™”๋œ ์„œ๋กœ ๋‹ค๋ฅธ ์ „๋žต
    • ์ ์ ˆํ•œ ๋ธ”๋ก ์ฐจ์› ๊ด€๋ฆฌ

PyTorch ํ†ตํ•ฉ

๋“ฑ๋ก๋œ ์—ฐ์‚ฐ์€ CustomOpLibrary๋ฅผ ํ†ตํ•ด ํŒŒ์ด์ฌ์—์„œ ํ˜ธ์ถœํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

# Load the custom operations
ops = CustomOpLibrary(mojo_kernels)

# Call the 1D coalesced version
result_1d = ops.embedding[{"batch_size": B, "seq_len": L, "vocab_size": V, "embed_dim": E}](
    indices, weights
)

# Call the 2D non-coalesced version
result_2d = ops.embedding_2d[{"batch_size": B, "seq_len": L, "vocab_size": V, "embed_dim": E}](
    indices, weights
)

์ด ์ ‘๊ทผ๋ฒ•์˜ ์žฅ์ ์€ ๋™์ผํ•œ ์ปค๋„ ๊ตฌํ˜„์„ ๋‹ค์–‘ํ•œ ํŒŒ์ด์ฌ ํ”„๋ ˆ์ž„์›Œํฌ์—์„œ ์‚ฌ์šฉํ•˜๋ฉด์„œ๋„ ์ตœ์ ์˜ ์„ฑ๋Šฅ ํŠน์„ฑ์„ ์œ ์ง€ํ•  ์ˆ˜ ์žˆ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

์ฝ”๋“œ ์‹คํ–‰

๋‹ค์Œ ๋ช…๋ น์œผ๋กœ ํผ์ฆ์„ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

pixi run p21
pixi run -e amd p21
uv run poe p21

์„ฑ๊ณตํ•˜๋ฉด ๋‹ค์Œ๊ณผ ๋น„์Šทํ•œ ์ถœ๋ ฅ์„ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

Puzzle 21: Mojo Embedding Kernel Comparison
======================================================================
Configuration: B=8, L=512, V=10000, E=512
------------------------------------------------------------

Testing Correctness...
   1D Coalesced - Max difference: 1.19e-07
   2D Non-coalesced - Max difference: 1.19e-07
   โœ… Both implementations CORRECT

Benchmarking Mojo Kernels...

Performance Results:
   1D Coalesced:     2.145 ms
   2D Non-coalesced: 3.867 ms
   1D is 1.80x faster than 2D

Key Learning Points:
โ€ข Compare different GPU kernel implementations
โ€ข 1D vs 2D grid patterns have different memory access
โ€ข Coalesced memory access should be faster
โ€ข Grid configuration affects GPU utilization

์†”๋ฃจ์…˜

๋‘ ์ปค๋„์˜ ์ขŒํ‘œ ๋ณ€ํ™˜๊ณผ ๋ฉ”๋ชจ๋ฆฌ ์—ฐ์‚ฐ์„ ๊ตฌํ˜„ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค:

1D ๋ณ‘ํ•ฉ ์ปค๋„

fn embedding_kernel_coalesced[
    indices_layout: Layout,
    weights_layout: Layout,
    out_layout: Layout,
    batch_size: Int,
    seq_len: Int,
    vocab_size: Int,
    embed_dim: Int,
    dtype: DType = DType.float32,
](
    output: LayoutTensor[dtype, out_layout, MutAnyOrigin],
    indices: LayoutTensor[DType.int32, indices_layout, MutAnyOrigin],
    weights: LayoutTensor[dtype, weights_layout, MutAnyOrigin],
):
    """
    Memory-coalescing focused embedding kernel.

    Key insight: The bottleneck is memory access patterns, not computation.
    - Each thread handles one (batch, seq, embed) position
    - Simple 1D grid for maximum simplicity and correctness
    - Focus on getting memory access right first
    """

    # Simple 1D indexing - each thread = one output element
    global_idx = Int(block_idx.x * block_dim.x + thread_idx.x)
    total_elements = batch_size * seq_len * embed_dim

    if global_idx >= total_elements:
        return

    # Convert to (batch, seq, embed) coordinates
    batch_idx = global_idx // (seq_len * embed_dim)
    remaining = global_idx % (seq_len * embed_dim)
    seq_idx = remaining // embed_dim
    embed_idx = remaining % embed_dim

    # Get token index
    token_idx_val = Int(indices[batch_idx, seq_idx])

    # Simple, correct assignment
    if token_idx_val >= 0 and token_idx_val < vocab_size:
        output[batch_idx, seq_idx, embed_idx] = weights[
            token_idx_val, embed_idx
        ]
    else:
        output[batch_idx, seq_idx, embed_idx] = 0


2D ๋น„๋ณ‘ํ•ฉ ์ปค๋„

fn embedding_kernel_2d[
    indices_layout: Layout,
    weights_layout: Layout,
    out_layout: Layout,
    batch_size: Int,
    seq_len: Int,
    vocab_size: Int,
    embed_dim: Int,
    dtype: DType = DType.float32,
](
    output: LayoutTensor[dtype, out_layout, MutAnyOrigin],
    indices: LayoutTensor[DType.int32, indices_layout, MutAnyOrigin],
    weights: LayoutTensor[dtype, weights_layout, MutAnyOrigin],
):
    """
    2D grid non-coalesced embedding kernel.

    Non-optimal approach for comparison:
    - 2D grid: (batch*seq, embed_dim)
    - More complex indexing
    - Potentially worse memory access patterns
    """

    # 2D grid indexing
    batch_seq_idx = Int(block_idx.x * block_dim.x + thread_idx.x)
    embed_idx = Int(block_idx.y * block_dim.y + thread_idx.y)

    total_positions = batch_size * seq_len

    # Bounds check
    if batch_seq_idx >= total_positions or embed_idx >= embed_dim:
        return

    # Convert to (batch, seq) coordinates
    batch_idx = batch_seq_idx // seq_len
    seq_idx = batch_seq_idx % seq_len

    # Get token index
    token_idx_val = Int(indices[batch_idx, seq_idx])

    # Assignment with 2D grid pattern
    if token_idx_val >= 0 and token_idx_val < vocab_size:
        output[batch_idx, seq_idx, embed_idx] = weights[
            token_idx_val, embed_idx
        ]
    else:
        output[batch_idx, seq_idx, embed_idx] = 0


๋‘ ํ’€์ด ๋ชจ๋‘ ๋™์ผํ•œ ์ž„๋ฒ ๋”ฉ ์กฐํšŒ ๋กœ์ง์„ ๊ตฌํ˜„ํ•˜์ง€๋งŒ ์Šค๋ ˆ๋“œ ๊ตฌ์„ฑ์ด ๋‹ค๋ฆ…๋‹ˆ๋‹ค:

์ฃผ์š” ์ฐจ์ด์ 

  1. ์Šค๋ ˆ๋“œ ๋งคํ•‘:

    • 1D ์ปค๋„: ์ถœ๋ ฅ ์š”์†Œ๋‹น ํ•˜๋‚˜์˜ ์Šค๋ ˆ๋“œ, ๋‹จ์ˆœํ•œ 1์ฐจ์› ์ธ๋ฑ์‹ฑ
    • 2D ์ปค๋„: (batchร—seq, embed_dim) ์ขŒํ‘œ์— ๋Œ€ํ•œ 2D ๊ทธ๋ฆฌ๋“œ ๋งคํ•‘
  2. ๋ฉ”๋ชจ๋ฆฌ ์ ‘๊ทผ ํŒจํ„ด:

    • 1D ์ปค๋„: ์—ฐ์†๋œ ์Šค๋ ˆ๋“œ๊ฐ€ ์—ฐ์†๋œ ์ž„๋ฒ ๋”ฉ ์ฐจ์›์— ์ ‘๊ทผ โ†’ ๋ณ‘ํ•ฉ๋จ
    • 2D ์ปค๋„: ์Šค๋ ˆ๋“œ ์ ‘๊ทผ ํŒจํ„ด์ด ๋ธ”๋ก ๊ตฌ์„ฑ์— ๋”ฐ๋ผ ๋‹ฌ๋ผ์ง โ†’ ๋ณ‘ํ•ฉ๋˜์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Œ
  3. ์ธ๋ฑ์‹ฑ ๋ณต์žก๋„:

    • 1D ์ปค๋„: ๋‹จ์ผ ๋‚˜๋ˆ—์…ˆ/๋‚˜๋จธ์ง€ ์ฒด์ธ์œผ๋กœ 3D ์ขŒํ‘œ ๊ณ„์‚ฐ
    • 2D ์ปค๋„: X/Y ์ขŒํ‘œ๋ฅผ ๋ณ„๋„๋กœ ๊ณ„์‚ฐ

์„ฑ๋Šฅ์— ๋ฏธ์น˜๋Š” ์˜ํ–ฅ

1D ์ปค๋„์ด ์ผ๋ฐ˜์ ์œผ๋กœ ๋” ๋†’์€ ์„ฑ๋Šฅ์„ ๋ณด์ด๋Š” ์ด์œ :

  • ๋ฉ”๋ชจ๋ฆฌ ๋ณ‘ํ•ฉ: ์—ฐ์†๋œ ์Šค๋ ˆ๋“œ๊ฐ€ ์—ฐ์†๋œ ๋ฉ”๋ชจ๋ฆฌ ์ฃผ์†Œ์— ์ ‘๊ทผ
  • ๋‹จ์ˆœํ•œ ์ธ๋ฑ์‹ฑ: ์ขŒํ‘œ ๊ณ„์‚ฐ์˜ ์—ฐ์‚ฐ ์˜ค๋ฒ„ํ—ค๋“œ๊ฐ€ ๋‚ฎ์Œ
  • ๋” ๋‚˜์€ ์บ์‹œ ํ™œ์šฉ: ์˜ˆ์ธก ๊ฐ€๋Šฅํ•œ ๋ฉ”๋ชจ๋ฆฌ ์ ‘๊ทผ ํŒจํ„ด

2D ์ปค๋„์˜ ์„ฑ๋Šฅ์ด ๋–จ์–ด์งˆ ์ˆ˜ ์žˆ๋Š” ์ด์œ :

  • ํฉ์–ด์ง„ ๋ฉ”๋ชจ๋ฆฌ ์ ‘๊ทผ: ์›Œํ”„ ๋‚ด ์Šค๋ ˆ๋“œ๋“ค์ด ์„œ๋กœ ๋‹ค๋ฅธ ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ์— ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ์Œ
  • ๋ณต์žกํ•œ ๊ทธ๋ฆฌ๋“œ ๊ตฌ์„ฑ: 16ร—16 ๋ธ”๋ก์ด ๋ฉ”๋ชจ๋ฆฌ ๋ ˆ์ด์•„์›ƒ๊ณผ ์ตœ์ ์œผ๋กœ ๋งž์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Œ
  • ์›Œํ”„ ๋ถ„๊ธฐ: ์„œ๋กœ ๋‹ค๋ฅธ ์Šค๋ ˆ๋“œ๊ฐ€ ์„œ๋กœ ๋‹ค๋ฅธ ์‹คํ–‰ ๊ฒฝ๋กœ๋ฅผ ๋”ฐ๋ฅผ ์ˆ˜ ์žˆ์Œ

ํ•ต์‹ฌ ๊ฐœ๋…

๊ฐœ๋…1D ๋ณ‘ํ•ฉ2D ๋น„๋ณ‘ํ•ฉ
์Šค๋ ˆ๋“œ ๊ตฌ์„ฑ1D 1์ฐจ์› ์ธ๋ฑ์‹ฑ2D ๊ทธ๋ฆฌ๋“œ (batchร—seq, embed)
๋ฉ”๋ชจ๋ฆฌ ์ ‘๊ทผ์—ฐ์†๋œ ์ฃผ์†Œํฉ์–ด์งˆ ์ˆ˜ ์žˆ์Œ
๊ทธ๋ฆฌ๋“œ ๊ตฌ์„ฑ๋‹จ์ˆœ: [total_elements // 256]๋ณต์žก: [batchร—seq // 16, embed // 16]
์„ฑ๋Šฅ๋ฉ”๋ชจ๋ฆฌ ๋Œ€์—ญํญ์— ์ตœ์ ํ™”์ตœ์ ํ™”๋˜์ง€ ์•Š์€ ๋ฉ”๋ชจ๋ฆฌ ํŒจํ„ด
์‚ฌ์šฉ ๋ชฉ์ ํ”„๋กœ๋•์…˜ ์ปค๋„๊ต์œก์šฉ ๋น„๊ต

ํ•ต์‹ฌ ๊ตํ›ˆ: ๋ฉ”๋ชจ๋ฆฌ ๋ณ‘ํ•ฉ์€ ์ž„๋ฒ ๋”ฉ๊ณผ ๊ฐ™์€ ๋ฉ”๋ชจ๋ฆฌ ๋ฐ”์šด๋“œ ์—ฐ์‚ฐ์—์„œ 2~3๋ฐฐ์˜ ์„ฑ๋Šฅ ์ฐจ์ด๋ฅผ ๊ฐ€์ ธ์˜ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.