์๋ฒ ๋ฉ ์ปค๋: ๋ณํฉ vs ๋น๋ณํฉ
์ด ํผ์ฆ์์๋ ๋์ผํ ๊ฒฐ๊ณผ๋ฅผ ์์ฑํ์ง๋ง ์๋ก ๋ค๋ฅธ ๋ฉ๋ชจ๋ฆฌ ์ ๊ทผ ํจํด์ ์ฌ์ฉํ๋ ๋ ๊ฐ์ง GPU ์๋ฒ ๋ฉ ์ปค๋์ ๊ตฌํํฉ๋๋ค. GPU ์ฑ๋ฅ์์ ๋ฉ๋ชจ๋ฆฌ ๋ณํฉ์ด ์ผ๋ง๋ ์ค์ํ์ง ์ง์ ์ฒดํํ ์ ์์ต๋๋ค.
1D ๋ณํฉ ์ปค๋ (์ต์ ํ๋ ์ ๊ทผ๋ฒ)
์ด ์ปค๋์ ๊ฐ ์ค๋ ๋๊ฐ ์ ํํ ํ๋์ ์ถ๋ ฅ ์์๋ฅผ ์ฒ๋ฆฌํ๋ ๋จ์ํ 1D ๊ทธ๋ฆฌ๋๋ฅผ ์ฌ์ฉํฉ๋๋ค. ํต์ฌ์ ์ฐ์๋ ์ค๋ ๋๊ฐ ์ฐ์๋ ๋ฉ๋ชจ๋ฆฌ ์์น์ ์ ๊ทผํ์ฌ ์ต์ ์ ๋ฉ๋ชจ๋ฆฌ ๋ณํฉ์ ๋ฌ์ฑํ๋ค๋ ์ ์ ๋๋ค.
์ค๋ ๋ ๊ตฌ์ฑ:
- ๊ทธ๋ฆฌ๋ ๊ตฌ์ฑ:
[total_elements // 256]๋ธ๋ก, ๋ธ๋ก๋น256์ค๋ ๋ - ์ค๋ ๋ ๋งคํ: ๊ฐ ์ค๋ ๋๊ฐ ํ๋์
(batch, seq, embed)์์น ์ฒ๋ฆฌ - ๋ฉ๋ชจ๋ฆฌ ํจํด: ์ฐ์๋ ์ค๋ ๋๊ฐ ์ฐ์๋ ์๋ฒ ๋ฉ ์ฐจ์ ์ ๊ทผ
๊ตฌํํ ๋ด์ฉ:
- ๋ธ๋ก ์ธ๋ฑ์ค์ ์ค๋ ๋ ์ธ๋ฑ์ค๋ก๋ถํฐ ์ ์ญ ์ค๋ ๋ ์ธ๋ฑ์ค ๊ณ์ฐ
- 1์ฐจ์ ์ธ๋ฑ์ค๋ฅผ 3D ์ขํ
(batch_idx, seq_idx, embed_idx)๋ก ๋ณํ - indices ํ ์์์ ํ ํฐ ์ธ๋ฑ์ค ์กฐํ
- ํด๋นํ๋ ์๋ฒ ๋ฉ ๋ฒกํฐ ์์๋ฅผ ์ถ๋ ฅ์ ๋ณต์ฌ
์์ฑํ ์ฝ๋
๋ ์๋ฒ ๋ฉ ์ปค๋์ ๋น ๋ถ๋ถ์ ์์ฑํด์ผ ํฉ๋๋ค:
comptime THREADS_PER_BLOCK = 256
fn embedding_kernel_coalesced[
indices_layout: Layout,
weights_layout: Layout,
out_layout: Layout,
batch_size: Int,
seq_len: Int,
vocab_size: Int,
embed_dim: Int,
dtype: DType = DType.float32,
](
output: LayoutTensor[dtype, out_layout, MutAnyOrigin],
indices: LayoutTensor[DType.int32, indices_layout, MutAnyOrigin],
weights: LayoutTensor[dtype, weights_layout, MutAnyOrigin],
):
"""
Memory-coalescing focused embedding kernel.
Key insight: The bottleneck is memory access patterns, not computation.
- Each thread handles one (batch, seq, embed) position
- Simple 1D grid for maximum simplicity and correctness
- Focus on getting memory access right first
"""
# Simple 1D indexing - each thread = one output element
global_idx = Int(block_idx.x * block_dim.x + thread_idx.x)
total_elements = batch_size * seq_len * embed_dim
if global_idx >= total_elements:
return
# Convert to (batch, seq, embed) coordinates
# FILL IN roughly 4 lines
# Get token index
# FILL IN 1 line
# Simple, correct assignment
# FILL IN 4 lines
์ ์ฒด ํ์ผ ๋ณด๊ธฐ: problems/p21/op/embedding.mojo
ํ
global_idx = block_idx.x * block_dim.x + thread_idx.x๋ก ์์ํ์ธ์- ๋๋์
๊ณผ ๋๋จธ์ง ์ฐ์ฐ์ผ๋ก 3D ์ขํ๋ฅผ ๊ตฌํฉ๋๋ค:
batch_idx = global_idx // (seq_len * embed_dim) remaining = global_idx % (seq_len * embed_dim)์ ์ฌ์ฉํ๋ฉด ์ดํ ๊ณ์ฐ์ด ๊ฐ๋จํด์ง๋๋ค- ํญ์ ๊ฒฝ๊ณ ๊ฒ์ฌ๋ฅผ ํ์ธ์:
if global_idx >= total_elements: return - ์ ํจํ์ง ์์ ํ ํฐ ์ธ๋ฑ์ค๋ ์ถ๋ ฅ์ 0์ผ๋ก ์ค์ ํ์ธ์
- ์๋ฒ ๋ฉ ์กฐํ:
output[batch_idx, seq_idx, embed_idx] = weights[token_idx, embed_idx]
2D ๋น๋ณํฉ ์ปค๋ (๋น๊ต์ฉ ์ ๊ทผ๋ฒ)
์ด ์ปค๋์ X ์ฐจ์์ด (batch ร seq) ์์น๋ฅผ, Y ์ฐจ์์ด ์๋ฒ ๋ฉ ์ฐจ์์ ๋ด๋นํ๋ 2D ๊ทธ๋ฆฌ๋๋ฅผ ์ฌ์ฉํฉ๋๋ค. ์ด ๋ฐฉ์์ ๋ฉ๋ชจ๋ฆฌ ์ ๊ทผ์ด ๋ณํฉ๋์ง ์์ ์ ์์ต๋๋ค.
์ค๋ ๋ ๊ตฌ์ฑ:
- ๊ทธ๋ฆฌ๋ ๊ตฌ์ฑ:
[batch x seq // 16, embed_dim // 16]๋ธ๋ก,16 x 16์ค๋ ๋ - ์ค๋ ๋ ๋งคํ:
thread_idx.x๋ batch/sequence์,thread_idx.y๋ ์๋ฒ ๋ฉ ์ฐจ์์ ๋งคํ - ๋ฉ๋ชจ๋ฆฌ ํจํด: ์ํ ๋ด ์ค๋ ๋๋ค์ด ํฉ์ด์ง ๋ฉ๋ชจ๋ฆฌ ์์น์ ์ ๊ทผํ ์ ์์
๊ตฌํํ ๋ด์ฉ:
- 2D ๊ทธ๋ฆฌ๋์์ X, Y ์ขํ ๊ณ์ฐ
- X ์ขํ๋ฅผ batch ์ธ๋ฑ์ค์ sequence ์ธ๋ฑ์ค๋ก ๋ถ๋ฆฌ
- Y ์ขํ๋ฅผ ์๋ฒ ๋ฉ ์ฐจ์์ผ๋ก ์ง์ ์ฌ์ฉ
- ๊ฒฝ๊ณ ๊ฒ์ฌ์ ํจ๊ป ๋์ผํ ์๋ฒ ๋ฉ ์กฐํ ์ํ
์์ฑํ ์ฝ๋
๋ ์๋ฒ ๋ฉ ์ปค๋์ ๋น ๋ถ๋ถ์ ์์ฑํด์ผ ํฉ๋๋ค:
fn embedding_kernel_2d[
indices_layout: Layout,
weights_layout: Layout,
out_layout: Layout,
batch_size: Int,
seq_len: Int,
vocab_size: Int,
embed_dim: Int,
dtype: DType = DType.float32,
](
output: LayoutTensor[dtype, out_layout, MutAnyOrigin],
indices: LayoutTensor[DType.int32, indices_layout, MutAnyOrigin],
weights: LayoutTensor[dtype, weights_layout, MutAnyOrigin],
):
"""
2D grid non-coalesced embedding kernel.
Non-optimal approach for comparison:
- 2D grid: (batch*seq, embed_dim)
- More complex indexing
- Potentially worse memory access patterns
"""
# 2D grid indexing
batch_seq_idx = Int(block_idx.x * block_dim.x + thread_idx.x)
embed_idx = Int(block_idx.y * block_dim.y + thread_idx.y)
total_positions = batch_size * seq_len
if batch_seq_idx >= total_positions or embed_idx >= embed_dim:
return
# Convert to (batch, seq) coordinates
# FILL IN 2 lines
# Get token index
# FILL IN 1 line
# Assignment with 2D grid pattern
# FILL IN 4 lines
์ ์ฒด ํ์ผ ๋ณด๊ธฐ: problems/p21/op/embedding.mojo
ํ
- X, Y ์ค๋ ๋ ์ขํ๋ฅผ ๋ชจ๋ ์ฌ์ฉํฉ๋๋ค:
batch_seq_idx = block_idx.x * block_dim.x + thread_idx.x - ๊ทธ๋ฆฌ๊ณ :
embed_idx = block_idx.y * block_dim.y + thread_idx.y batch_seq_idx๋ฅผ batch์ sequence ์ธ๋ฑ์ค๋ก ๋ถ๋ฆฌํฉ๋๋ค:batch_idx = batch_seq_idx // seq_len- ๋ ์ฐจ์ ๋ชจ๋ ๊ฒฝ๊ณ ๊ฒ์ฌ๋ฅผ ์์ง ๋ง์ธ์:
if batch_seq_idx >= total_positions or embed_idx >= embed_dim - ํ ํฐ ์กฐํ๋ 1D์ ๋์ผํ์ง๋ง, ์ค๋ ๋๋น ํ๋์ ์๋ฒ ๋ฉ ์ฐจ์๋ง ์ฒ๋ฆฌํฉ๋๋ค
- ์ด ์ปค๋์ ์ ์ฒด ๋ฒกํฐ๊ฐ ์๋ ์ค๋ ๋๋น ํ๋์ ์๋ฒ ๋ฉ ์ฐจ์์ ์ฒ๋ฆฌํฉ๋๋ค
์ปค์คํ op ๋ฑ๋ก
์ปค๋๋ค์ PyTorch์ ์ฝ๊ฒ ํตํฉํ ์ ์๋๋ก ์ปค์คํ ์ฐ์ฐ์ผ๋ก ๋ํ๋ฉ๋๋ค. ๋ฑ๋ก ํจํด์ MAX ๊ทธ๋ํ ์ปค์คํ op ์ดํดํ๊ธฐ์์ ์ค๋ช ํ MAX ์ปค์คํ op๊ณผ ๋์ผํฉ๋๋ค:
1D ๋ณํฉ ์ฐ์ฐ
์ด ์ฐ์ฐ์ ์ต์ ํ๋ 1D ์๋ฒ ๋ฉ ์ปค๋์ "embedding"์ผ๋ก ๋ฑ๋กํฉ๋๋ค:
import compiler
from runtime.asyncrt import DeviceContextPtr
from tensor import InputTensor, OutputTensor
from memory import UnsafePointer
from gpu.host import DeviceBuffer
@compiler.register("embedding")
struct EmbeddingCustomOp:
@staticmethod
fn execute[
target: StaticString,
batch_size: Int,
seq_len: Int,
vocab_size: Int,
embed_dim: Int,
](
output: OutputTensor[
dtype = DType.float32, rank=3
], # [batch_size, seq_len, embed_dim]
indices: InputTensor[
dtype = DType.int32, rank=2
], # [batch_size, seq_len]
weights: InputTensor[
dtype = output.dtype, rank=2
], # [vocab_size, embed_dim]
ctx: DeviceContextPtr,
) raises:
output_tensor = output.to_layout_tensor()
indices_tensor = indices.to_layout_tensor()
weights_tensor = weights.to_layout_tensor()
comptime indices_layout = indices_tensor.layout
comptime weights_layout = weights_tensor.layout
comptime out_layout = output_tensor.layout
@parameter
if target == "gpu":
gpu_ctx = ctx.get_device_context()
# Zero out output tensor
gpu_ctx.enqueue_memset(
DeviceBuffer[output.dtype](
gpu_ctx,
output_tensor.ptr,
batch_size * seq_len * embed_dim,
owning=False,
),
0,
)
# Calculate 1D grid dimensions (matching kernel's flat indexing)
total_elements = batch_size * seq_len * embed_dim
blocks = max(1, ceildiv(total_elements, THREADS_PER_BLOCK))
# Compile and launch optimized kernel
comptime kernel = embedding_kernel_coalesced[
indices_layout,
weights_layout,
out_layout,
batch_size,
seq_len,
vocab_size,
embed_dim,
output.dtype,
]
compiled_kernel = gpu_ctx.compile_function[kernel, kernel]()
gpu_ctx.enqueue_function(
compiled_kernel,
output_tensor,
indices_tensor,
weights_tensor,
grid_dim=(blocks,),
block_dim=(THREADS_PER_BLOCK,),
)
elif target == "cpu":
for batch in range(batch_size):
for seq in range(seq_len):
token_idx_val = Int(indices_tensor[batch, seq])
if token_idx_val >= 0 and token_idx_val < vocab_size:
for emb in range(embed_dim):
output_tensor[batch, seq, emb] = weights_tensor[
token_idx_val, emb
]
else:
raise Error("Unsupported target: " + target)
๋ฑ๋ก์ ํต์ฌ ์์:
- ๋จ์ํ ๊ทธ๋ฆฌ๋ ๊ตฌ์ฑ:
ceildiv(total_elements, THREADS_PER_BLOCK)๋ธ๋ก์ผ๋ก ์ง๊ด์ ์ธ 1D ๊ทธ๋ฆฌ๋ ์ฌ์ฉ - ๋ฉ๋ชจ๋ฆฌ ์ต์ ํ: ๋จ์ผ
enqueue_memsetํธ์ถ๋ก ์ถ๋ ฅ ๋ฒํผ๋ฅผ ํจ์จ์ ์ผ๋ก ์ด๊ธฐํ - ์ปดํ์ผ ํ์ ํ๋ผ๋ฏธํฐ: ๋ชจ๋ ํ ์ ์ฐจ์์ ์ปดํ์ผ ํ์ ํ๋ผ๋ฏธํฐ๋ก ์ ๋ฌํ์ฌ ์ต์ ์ฑ๋ฅ ๋ฌ์ฑ
- ๋๋ฐ์ด์ค ์ถ์ํ: GPU ์คํ๊ณผ CPU ํด๋ฐฑ์ ๋งค๋๋ฝ๊ฒ ์ฒ๋ฆฌ
2D ๋น๋ณํฉ ์ฐ์ฐ
์ด ์ฐ์ฐ์ ๋น๊ต์ฉ 2D ์๋ฒ ๋ฉ ์ปค๋์ "embedding_2d"๋ก ๋ฑ๋กํฉ๋๋ค:
@compiler.register("embedding_2d")
struct Embedding2DCustomOp:
@staticmethod
fn execute[
target: StaticString,
batch_size: Int,
seq_len: Int,
vocab_size: Int,
embed_dim: Int,
](
output: OutputTensor[
dtype = DType.float32, rank=3
], # [batch_size, seq_len, embed_dim]
indices: InputTensor[
dtype = DType.int32, rank=2
], # [batch_size, seq_len]
weights: InputTensor[
dtype = output.dtype, rank=2
], # [vocab_size, embed_dim]
ctx: DeviceContextPtr,
) raises:
output_tensor = output.to_layout_tensor()
indices_tensor = indices.to_layout_tensor()
weights_tensor = weights.to_layout_tensor()
comptime indices_layout = indices_tensor.layout
comptime weights_layout = weights_tensor.layout
comptime out_layout = output_tensor.layout
@parameter
if target == "gpu":
gpu_ctx = ctx.get_device_context()
# Zero out output tensor
gpu_ctx.enqueue_memset(
DeviceBuffer[output.dtype](
gpu_ctx,
output_tensor.ptr,
batch_size * seq_len * embed_dim,
owning=False,
),
0,
)
# Calculate 2D grid dimensions for non-coalesced access
total_positions = batch_size * seq_len
comptime BLOCK_X = 16 # batch*seq dimension
comptime BLOCK_Y = 16 # embed dimension
blocks_x = max(1, ceildiv(total_positions, BLOCK_X))
blocks_y = max(1, ceildiv(embed_dim, BLOCK_Y))
# Compile and launch 2D kernel
comptime kernel = embedding_kernel_2d[
indices_layout,
weights_layout,
out_layout,
batch_size,
seq_len,
vocab_size,
embed_dim,
output.dtype,
]
compiled_kernel = gpu_ctx.compile_function[kernel, kernel]()
gpu_ctx.enqueue_function(
compiled_kernel,
output_tensor,
indices_tensor,
weights_tensor,
grid_dim=(blocks_x, blocks_y),
block_dim=(BLOCK_X, BLOCK_Y),
)
elif target == "cpu":
# Same CPU fallback as 1D version
for batch in range(batch_size):
for seq in range(seq_len):
token_idx_val = Int(indices_tensor[batch, seq])
if token_idx_val >= 0 and token_idx_val < vocab_size:
for emb in range(embed_dim):
output_tensor[batch, seq, emb] = weights_tensor[
token_idx_val, emb
]
else:
raise Error("Unsupported target: " + target)
1D ์ฐ์ฐ๊ณผ์ ์ฃผ์ ์ฐจ์ด์ :
- ๋ณต์กํ ๊ทธ๋ฆฌ๋ ๊ตฌ์ฑ:
blocks_x์blocks_y๋ฅผ ๋ณ๋๋ก ๊ณ์ฐํ๋ 2D ๊ทธ๋ฆฌ๋ ์ฌ์ฉ - ๊ณ ์ ๋ธ๋ก ์ฐจ์: 2D ์ค๋ ๋ ๊ตฌ์ฑ์ ์ํด
BLOCK_X = 16,BLOCK_Y = 16์ผ๋ก ๊ณ ์ - ๋์ผํ ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ: ๋ฉ๋ชจ๋ฆฌ ์ด๊ธฐํ์ CPU ํด๋ฐฑ ๋ก์ง์ ๋์ผ
- ๋ค๋ฅธ ์ปค๋ ํธ์ถ ๋ฐฉ์: 2D ๊ทธ๋ฆฌ๋ ์ฐจ์
(blocks_x, blocks_y)๊ณผ ๋ธ๋ก ์ฐจ์(BLOCK_X, BLOCK_Y)์ ๋ฌ
๊ณตํต ๋ํผ ๊ธฐ๋ฅ
๋ ์ปค์คํ ์ฐ์ฐ์ ๋ค์๊ณผ ๊ฐ์ ํ์ ์ธํ๋ผ๋ฅผ ์ ๊ณตํฉ๋๋ค:
-
๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ:
enqueue_memset์ผ๋ก ์ถ๋ ฅ ํ ์ 0 ์ด๊ธฐํ- ์ ์ ํ ๋ฒํผ ์์ฑ๊ณผ ๋ฉ๋ชจ๋ฆฌ ๋ ์ด์์ ์ฒ๋ฆฌ
- ์๋ ์ ๋ฆฌ ๋ฐ ๋ฆฌ์์ค ๊ด๋ฆฌ
-
๋๋ฐ์ด์ค ์ถ์ํ:
- ์ต์ ํ๋ ์ปค๋๋ก GPU ์คํ
- ํธํ์ฑ๊ณผ ๋๋ฒ๊น ์ ์ํ CPU ํด๋ฐฑ
- ์คํ ๋์์ ๊ด๊ณ์์ด ์ผ๊ด๋ ์ธํฐํ์ด์ค
-
ํ๋ผ๋ฏธํฐ ์ ๋ฌ:
- ์ปค๋ ์ต์ ํ๋ฅผ ์ํ ์ปดํ์ผ ํ์ ํ ์ ์ฐจ์
- ๋ ์ด์์ ํ ์ ๋ณํ์ ํตํ ๋ฐํ์ ํ ์ ๋ฐ์ดํฐ
- ํ์ ์์ ํ ํ๋ผ๋ฏธํฐ ๊ฒ์ฆ
-
๊ทธ๋ฆฌ๋ ๊ตฌ์ฑ:
- ์ต์ ์ ๊ทธ๋ฆฌ๋ ์ฐจ์ ์๋ ๊ณ์ฐ
- ๊ฐ ์ปค๋์ ์ ๊ทผ ํจํด์ ์ต์ ํ๋ ์๋ก ๋ค๋ฅธ ์ ๋ต
- ์ ์ ํ ๋ธ๋ก ์ฐจ์ ๊ด๋ฆฌ
PyTorch ํตํฉ
๋ฑ๋ก๋ ์ฐ์ฐ์ CustomOpLibrary๋ฅผ ํตํด ํ์ด์ฌ์์ ํธ์ถํ ์ ์์ต๋๋ค:
# Load the custom operations
ops = CustomOpLibrary(mojo_kernels)
# Call the 1D coalesced version
result_1d = ops.embedding[{"batch_size": B, "seq_len": L, "vocab_size": V, "embed_dim": E}](
indices, weights
)
# Call the 2D non-coalesced version
result_2d = ops.embedding_2d[{"batch_size": B, "seq_len": L, "vocab_size": V, "embed_dim": E}](
indices, weights
)
์ด ์ ๊ทผ๋ฒ์ ์ฅ์ ์ ๋์ผํ ์ปค๋ ๊ตฌํ์ ๋ค์ํ ํ์ด์ฌ ํ๋ ์์ํฌ์์ ์ฌ์ฉํ๋ฉด์๋ ์ต์ ์ ์ฑ๋ฅ ํน์ฑ์ ์ ์งํ ์ ์๋ค๋ ๊ฒ์ ๋๋ค.
์ฝ๋ ์คํ
๋ค์ ๋ช ๋ น์ผ๋ก ํผ์ฆ์ ์คํํ ์ ์์ต๋๋ค:
pixi run p21
pixi run -e amd p21
uv run poe p21
์ฑ๊ณตํ๋ฉด ๋ค์๊ณผ ๋น์ทํ ์ถ๋ ฅ์ ๋ณผ ์ ์์ต๋๋ค:
Puzzle 21: Mojo Embedding Kernel Comparison
======================================================================
Configuration: B=8, L=512, V=10000, E=512
------------------------------------------------------------
Testing Correctness...
1D Coalesced - Max difference: 1.19e-07
2D Non-coalesced - Max difference: 1.19e-07
โ
Both implementations CORRECT
Benchmarking Mojo Kernels...
Performance Results:
1D Coalesced: 2.145 ms
2D Non-coalesced: 3.867 ms
1D is 1.80x faster than 2D
Key Learning Points:
โข Compare different GPU kernel implementations
โข 1D vs 2D grid patterns have different memory access
โข Coalesced memory access should be faster
โข Grid configuration affects GPU utilization
์๋ฃจ์
๋ ์ปค๋์ ์ขํ ๋ณํ๊ณผ ๋ฉ๋ชจ๋ฆฌ ์ฐ์ฐ์ ๊ตฌํํ๋ฉด ๋ฉ๋๋ค:
1D ๋ณํฉ ์ปค๋
fn embedding_kernel_coalesced[
indices_layout: Layout,
weights_layout: Layout,
out_layout: Layout,
batch_size: Int,
seq_len: Int,
vocab_size: Int,
embed_dim: Int,
dtype: DType = DType.float32,
](
output: LayoutTensor[dtype, out_layout, MutAnyOrigin],
indices: LayoutTensor[DType.int32, indices_layout, MutAnyOrigin],
weights: LayoutTensor[dtype, weights_layout, MutAnyOrigin],
):
"""
Memory-coalescing focused embedding kernel.
Key insight: The bottleneck is memory access patterns, not computation.
- Each thread handles one (batch, seq, embed) position
- Simple 1D grid for maximum simplicity and correctness
- Focus on getting memory access right first
"""
# Simple 1D indexing - each thread = one output element
global_idx = Int(block_idx.x * block_dim.x + thread_idx.x)
total_elements = batch_size * seq_len * embed_dim
if global_idx >= total_elements:
return
# Convert to (batch, seq, embed) coordinates
batch_idx = global_idx // (seq_len * embed_dim)
remaining = global_idx % (seq_len * embed_dim)
seq_idx = remaining // embed_dim
embed_idx = remaining % embed_dim
# Get token index
token_idx_val = Int(indices[batch_idx, seq_idx])
# Simple, correct assignment
if token_idx_val >= 0 and token_idx_val < vocab_size:
output[batch_idx, seq_idx, embed_idx] = weights[
token_idx_val, embed_idx
]
else:
output[batch_idx, seq_idx, embed_idx] = 0
2D ๋น๋ณํฉ ์ปค๋
fn embedding_kernel_2d[
indices_layout: Layout,
weights_layout: Layout,
out_layout: Layout,
batch_size: Int,
seq_len: Int,
vocab_size: Int,
embed_dim: Int,
dtype: DType = DType.float32,
](
output: LayoutTensor[dtype, out_layout, MutAnyOrigin],
indices: LayoutTensor[DType.int32, indices_layout, MutAnyOrigin],
weights: LayoutTensor[dtype, weights_layout, MutAnyOrigin],
):
"""
2D grid non-coalesced embedding kernel.
Non-optimal approach for comparison:
- 2D grid: (batch*seq, embed_dim)
- More complex indexing
- Potentially worse memory access patterns
"""
# 2D grid indexing
batch_seq_idx = Int(block_idx.x * block_dim.x + thread_idx.x)
embed_idx = Int(block_idx.y * block_dim.y + thread_idx.y)
total_positions = batch_size * seq_len
# Bounds check
if batch_seq_idx >= total_positions or embed_idx >= embed_dim:
return
# Convert to (batch, seq) coordinates
batch_idx = batch_seq_idx // seq_len
seq_idx = batch_seq_idx % seq_len
# Get token index
token_idx_val = Int(indices[batch_idx, seq_idx])
# Assignment with 2D grid pattern
if token_idx_val >= 0 and token_idx_val < vocab_size:
output[batch_idx, seq_idx, embed_idx] = weights[
token_idx_val, embed_idx
]
else:
output[batch_idx, seq_idx, embed_idx] = 0
๋ ํ์ด ๋ชจ๋ ๋์ผํ ์๋ฒ ๋ฉ ์กฐํ ๋ก์ง์ ๊ตฌํํ์ง๋ง ์ค๋ ๋ ๊ตฌ์ฑ์ด ๋ค๋ฆ ๋๋ค:
์ฃผ์ ์ฐจ์ด์
-
์ค๋ ๋ ๋งคํ:
- 1D ์ปค๋: ์ถ๋ ฅ ์์๋น ํ๋์ ์ค๋ ๋, ๋จ์ํ 1์ฐจ์ ์ธ๋ฑ์ฑ
- 2D ์ปค๋: (batchรseq, embed_dim) ์ขํ์ ๋ํ 2D ๊ทธ๋ฆฌ๋ ๋งคํ
-
๋ฉ๋ชจ๋ฆฌ ์ ๊ทผ ํจํด:
- 1D ์ปค๋: ์ฐ์๋ ์ค๋ ๋๊ฐ ์ฐ์๋ ์๋ฒ ๋ฉ ์ฐจ์์ ์ ๊ทผ โ ๋ณํฉ๋จ
- 2D ์ปค๋: ์ค๋ ๋ ์ ๊ทผ ํจํด์ด ๋ธ๋ก ๊ตฌ์ฑ์ ๋ฐ๋ผ ๋ฌ๋ผ์ง โ ๋ณํฉ๋์ง ์์ ์ ์์
-
์ธ๋ฑ์ฑ ๋ณต์ก๋:
- 1D ์ปค๋: ๋จ์ผ ๋๋์ /๋๋จธ์ง ์ฒด์ธ์ผ๋ก 3D ์ขํ ๊ณ์ฐ
- 2D ์ปค๋: X/Y ์ขํ๋ฅผ ๋ณ๋๋ก ๊ณ์ฐ
์ฑ๋ฅ์ ๋ฏธ์น๋ ์ํฅ
1D ์ปค๋์ด ์ผ๋ฐ์ ์ผ๋ก ๋ ๋์ ์ฑ๋ฅ์ ๋ณด์ด๋ ์ด์ :
- ๋ฉ๋ชจ๋ฆฌ ๋ณํฉ: ์ฐ์๋ ์ค๋ ๋๊ฐ ์ฐ์๋ ๋ฉ๋ชจ๋ฆฌ ์ฃผ์์ ์ ๊ทผ
- ๋จ์ํ ์ธ๋ฑ์ฑ: ์ขํ ๊ณ์ฐ์ ์ฐ์ฐ ์ค๋ฒํค๋๊ฐ ๋ฎ์
- ๋ ๋์ ์บ์ ํ์ฉ: ์์ธก ๊ฐ๋ฅํ ๋ฉ๋ชจ๋ฆฌ ์ ๊ทผ ํจํด
2D ์ปค๋์ ์ฑ๋ฅ์ด ๋จ์ด์ง ์ ์๋ ์ด์ :
- ํฉ์ด์ง ๋ฉ๋ชจ๋ฆฌ ์ ๊ทผ: ์ํ ๋ด ์ค๋ ๋๋ค์ด ์๋ก ๋ค๋ฅธ ์๋ฒ ๋ฉ ๋ฒกํฐ์ ์ ๊ทผํ ์ ์์
- ๋ณต์กํ ๊ทธ๋ฆฌ๋ ๊ตฌ์ฑ: 16ร16 ๋ธ๋ก์ด ๋ฉ๋ชจ๋ฆฌ ๋ ์ด์์๊ณผ ์ต์ ์ผ๋ก ๋ง์ง ์์ ์ ์์
- ์ํ ๋ถ๊ธฐ: ์๋ก ๋ค๋ฅธ ์ค๋ ๋๊ฐ ์๋ก ๋ค๋ฅธ ์คํ ๊ฒฝ๋ก๋ฅผ ๋ฐ๋ฅผ ์ ์์
ํต์ฌ ๊ฐ๋
| ๊ฐ๋ | 1D ๋ณํฉ | 2D ๋น๋ณํฉ |
|---|---|---|
| ์ค๋ ๋ ๊ตฌ์ฑ | 1D 1์ฐจ์ ์ธ๋ฑ์ฑ | 2D ๊ทธ๋ฆฌ๋ (batchรseq, embed) |
| ๋ฉ๋ชจ๋ฆฌ ์ ๊ทผ | ์ฐ์๋ ์ฃผ์ | ํฉ์ด์ง ์ ์์ |
| ๊ทธ๋ฆฌ๋ ๊ตฌ์ฑ | ๋จ์: [total_elements // 256] | ๋ณต์ก: [batchรseq // 16, embed // 16] |
| ์ฑ๋ฅ | ๋ฉ๋ชจ๋ฆฌ ๋์ญํญ์ ์ต์ ํ | ์ต์ ํ๋์ง ์์ ๋ฉ๋ชจ๋ฆฌ ํจํด |
| ์ฌ์ฉ ๋ชฉ์ | ํ๋ก๋์ ์ปค๋ | ๊ต์ก์ฉ ๋น๊ต |
ํต์ฌ ๊ตํ: ๋ฉ๋ชจ๋ฆฌ ๋ณํฉ์ ์๋ฒ ๋ฉ๊ณผ ๊ฐ์ ๋ฉ๋ชจ๋ฆฌ ๋ฐ์ด๋ ์ฐ์ฐ์์ 2~3๋ฐฐ์ ์ฑ๋ฅ ์ฐจ์ด๋ฅผ ๊ฐ์ ธ์ฌ ์ ์์ต๋๋ค.