Puzzle 21: ์ž„๋ฒ ๋”ฉ Op

๋ฉ”๋ชจ๋ฆฌ ์ ‘๊ทผ ํŒจํ„ด๊ณผ ์„ฑ๋Šฅ

๋ฉ”๋ชจ๋ฆฌ ๋ฐ”์šด๋“œ ์—ฐ์‚ฐ๊ณผ GPU ๋ฉ”๋ชจ๋ฆฌ ์ ‘๊ทผ ์ตœ์ ํ™”์— ์ดˆ์ ์„ ๋งž์ถฐ Part V๋ฅผ ์ด์–ด๊ฐ‘๋‹ˆ๋‹ค.

Puzzle 20: 1D ํ•ฉ์„ฑ๊ณฑ Op์— ์ด์–ด, ๋™์ผํ•œ ์—ฐ์‚ฐ์˜ ์„œ๋กœ ๋‹ค๋ฅธ ์ปค๋„ ๊ตฌํ˜„์ด ์„ฑ๋Šฅ์— ์–ผ๋งˆ๋‚˜ ๊ทน์ ์ธ ์ฐจ์ด๋ฅผ ๊ฐ€์ ธ์˜ฌ ์ˆ˜ ์žˆ๋Š”์ง€ ์•Œ์•„๋ด…๋‹ˆ๋‹ค. ๋ฐฐ์šธ ๋‚ด์šฉ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:

  • GPU ๋ฉ”๋ชจ๋ฆฌ ๋ณ‘ํ•ฉ์ด ์„ฑ๋Šฅ์— ๋ฏธ์น˜๋Š” ์˜ํ–ฅ
  • ๋ฉ”๋ชจ๋ฆฌ ๋ฐ”์šด๋“œ ์—ฐ์‚ฐ์—์„œ ๊ทธ๋ฆฌ๋“œ ๊ตฌ์„ฑ์ด ์ค‘์š”ํ•œ ์ด์œ 
  • ์ตœ์ ์˜ ๋ฉ”๋ชจ๋ฆฌ ์ ‘๊ทผ ํŒจํ„ด์œผ๋กœ ์ปค๋„์„ ์„ค๊ณ„ํ•˜๋Š” ๋ฐฉ๋ฒ•
  • ์„œ๋กœ ๋‹ค๋ฅธ ์Šค๋ ˆ๋”ฉ ์ „๋žต์ด ๊ฐ€์ ธ์˜ค๋Š” ์„ฑ๋Šฅ ์ฐจ์ด

์ด ํผ์ฆ์€ ์–ด๋–ค ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•˜๋А๋ƒ๋ณด๋‹ค ๋ฉ”๋ชจ๋ฆฌ์— ์–ด๋–ป๊ฒŒ ์ ‘๊ทผํ•˜๋А๋ƒ๊ฐ€ ๋” ์ค‘์š”ํ•  ์ˆ˜ ์žˆ์Œ์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

๊ฐœ์š”

์ด ํผ์ฆ์—์„œ๋Š” ์‹ ๊ฒฝ๋ง์˜ ํ•ต์‹ฌ ๊ตฌ์„ฑ ์š”์†Œ์ธ ์ž„๋ฒ ๋”ฉ(embedding) ์—ฐ์‚ฐ์„ ์œ„ํ•œ ๋‘ ๊ฐ€์ง€ GPU ์ปค๋„์„ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค. ๋‘ ์ปค๋„ ๋ชจ๋‘ ๋™์ผํ•œ ๊ฒฐ๊ณผ๋ฅผ ์ƒ์„ฑํ•˜์ง€๋งŒ, ์„œ๋กœ ๋‹ค๋ฅธ ๋ฉ”๋ชจ๋ฆฌ ์ ‘๊ทผ ํŒจํ„ด์„ ์‚ฌ์šฉํ•˜์—ฌ ์ƒ๋‹นํ•œ ์„ฑ๋Šฅ ์ฐจ์ด๋ฅผ ๋ณด์ž…๋‹ˆ๋‹ค.

๋น„๊ตํ•  ๋‘ ์ปค๋„:

  • 1D ๋ณ‘ํ•ฉ(coalesced) ์ปค๋„: ์—ฐ์†์ ์ธ ๋ฉ”๋ชจ๋ฆฌ ์ ‘๊ทผ์œผ๋กœ ๋ฉ”๋ชจ๋ฆฌ ๋Œ€์—ญํญ์— ์ตœ์ ํ™”
  • 2D ๋น„๋ณ‘ํ•ฉ(non-coalesced) ์ปค๋„: ๋น„๊ต๋ฅผ ์œ„ํ•œ ์ตœ์ ํ™”๋˜์ง€ ์•Š์€ ๋ฉ”๋ชจ๋ฆฌ ์ ‘๊ทผ ํŒจํ„ด

์ด ๋น„๊ต๋ฅผ ํ†ตํ•ด GPU ์ปค๋„ ์„ฑ๋Šฅ์—์„œ ๋ฉ”๋ชจ๋ฆฌ ๋ณ‘ํ•ฉ์ด ์–ผ๋งˆ๋‚˜ ์ค‘์š”ํ•œ์ง€ ์ฒด๊ฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋ฐฐ๊ฒฝ: ์ž„๋ฒ ๋”ฉ ์—ฐ์‚ฐ

์ž„๋ฒ ๋”ฉ ์—ฐ์‚ฐ์€ ์ด์‚ฐ์ ์ธ ํ† ํฐ ์ธ๋ฑ์Šค๋ฅผ ๋ฐ€์ง‘ ๋ฒกํ„ฐ ํ‘œํ˜„์œผ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค:

# Input: token indices
indices = [[1, 5, 2], [7, 1, 9]]           # Shape: [batch_size, seq_len]

# Embedding table (learned parameters)
embedding_table = [                        # Shape: [vocab_size, embed_dim]
    [0.1, 0.2, 0.3, 0.4],  # Token 0
    [0.5, 0.6, 0.7, 0.8],  # Token 1
    [0.9, 1.0, 1.1, 1.2],  # Token 2
    # ... more tokens
]

# Output: embedded vectors
output[0,0] = embedding_table[1]  # [0.5, 0.6, 0.7, 0.8]
output[0,1] = embedding_table[5]  # lookup token 5's embedding
output[0,2] = embedding_table[2]  # [0.9, 1.0, 1.1, 1.2]
# ... and so on

์ด ์—ฐ์‚ฐ์€ ๋ฉ”๋ชจ๋ฆฌ ๋ฐ”์šด๋“œ์ž…๋‹ˆ๋‹ค. ์„ฑ๋Šฅ์€ ์ž„๋ฒ ๋”ฉ ํ…Œ์ด๋ธ”์—์„œ ์–ผ๋งˆ๋‚˜ ํšจ์œจ์ ์œผ๋กœ ์ฝ๊ณ  ์ถœ๋ ฅ ํ…์„œ์— ์“ธ ์ˆ˜ ์žˆ๋А๋ƒ์— ๋‹ฌ๋ ค ์žˆ์Šต๋‹ˆ๋‹ค.

ํ•™์Šต ๊ฒฝ๋กœ

์ด ํผ์ฆ์€ ์ฒด๊ณ„์ ์ธ ์ดํ•ด๋ฅผ ์œ„ํ•ด ๋‘ ๋ถ€๋ถ„์œผ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค:

๋ณ‘ํ•ฉ vs ๋น„๋ณ‘ํ•ฉ ์ปค๋„

์—ฌ๊ธฐ์„œ๋ถ€ํ„ฐ ์‹œ์ž‘ํ•˜์—ฌ ์‹ค์ œ ํผ์ฆ ์ฝ”๋“œ๋ฅผ ๊ตฌํ˜„ํ•˜๊ณ  ์ปค๋„ ๊ตฌํ˜„์„ ์ดํ•ดํ•ฉ๋‹ˆ๋‹ค.

๋ฌด์—‡์„ ํ•˜๊ฒŒ ๋ ๊นŒ์š”:

  • ๋‘ ๊ฐ€์ง€ GPU ์ž„๋ฒ ๋”ฉ ์ปค๋„ ์™„์„ฑ (1D ๋ณ‘ํ•ฉ vs 2D ๋น„๋ณ‘ํ•ฉ)
  • GPU ํ”„๋กœ๊ทธ๋ž˜๋ฐ์˜ ๊ธฐ๋ณธ ๋ฉ”๋ชจ๋ฆฌ ์ ‘๊ทผ ํŒจํ„ด ํ•™์Šต
  • ๋™์ผํ•œ ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ์„œ๋กœ ๋‹ค๋ฅธ ์Šค๋ ˆ๋”ฉ ์ „๋žต์œผ๋กœ ๊ตฌํ˜„ํ•˜๋Š” ์‚ฌ๋ก€ ํ™•์ธ
  • Mojo์—์„œ์˜ ์ปค์Šคํ…€ ์—ฐ์‚ฐ ๋“ฑ๋ก ์ดํ•ด

์„ฑ๋Šฅ ๋น„๊ต

์ปค๋„ ์„ฑ๋Šฅ์ด ์™œ ๋‹ค๋ฅธ์ง€, ๋ฉ”๋ชจ๋ฆฌ ๋ณ‘ํ•ฉ์˜ ์ด๋ก ์„ ๊นŠ์ด ํŒŒ๊ณ ๋“ญ๋‹ˆ๋‹ค.

๋ฌด์—‡์„ ๋ฐฐ์šธ๊นŒ์š”:

  • GPU ์„ฑ๋Šฅ์—์„œ ๋ฉ”๋ชจ๋ฆฌ ๋ณ‘ํ•ฉ์ด ์ค‘์š”ํ•œ ์ด์œ 
  • ์Šค๋ ˆ๋“œ ๊ตฌ์„ฑ์ด ๋ฉ”๋ชจ๋ฆฌ ๋Œ€์—ญํญ ํ™œ์šฉ์— ๋ฏธ์น˜๋Š” ์˜ํ–ฅ
  • ์‹ ๊ฒฝ๋ง ์ตœ์ ํ™”์— ๋Œ€ํ•œ ์‹ค์ œ ์‹œ์‚ฌ์ 
  • ๋ฉ”๋ชจ๋ฆฌ ๋ฐ”์šด๋“œ ์—ฐ์‚ฐ์„ ์œ„ํ•œ ์ตœ์ ํ™” ์ „๋žต

์‹œ์ž‘ํ•˜๊ธฐ

GPU ๋ฉ”๋ชจ๋ฆฌ ์ตœ์ ํ™”๋ฅผ ํƒ๊ตฌํ•  ์ค€๋น„๊ฐ€ ๋˜์…จ๋‚˜์š”? ๋ณ‘ํ•ฉ vs ๋น„๋ณ‘ํ•ฉ ์ปค๋„ ์—์„œ ์ฝ”๋“œ๋ฅผ ๊ตฌํ˜„ํ•œ ํ›„, ์„ฑ๋Šฅ ๋น„๊ต ๋กœ ๋„˜์–ด๊ฐ€ ์„ฑ๋Šฅ ์ฐจ์ด์˜ ์›์ธ์„ ์ดํ•ดํ•ด ๋ณด์„ธ์š”.

๐Ÿ’ก ์„ฑ๊ณต ํŒ: ์„œ๋กœ ๋‹ค๋ฅธ ๊ทธ๋ฆฌ๋“œ ๊ตฌ์„ฑ(1D vs 2D)์ด ๋ฉ”๋ชจ๋ฆฌ ์ ‘๊ทผ ํŒจํ„ด์— ์–ด๋–ค ์˜ํ–ฅ์„ ๋ฏธ์น˜๋Š”์ง€ ์ฃผ์˜ ๊นŠ๊ฒŒ ์‚ดํŽด๋ณด์„ธ์š”. ์ด ํ†ต์ฐฐ์€ ์ž„๋ฒ ๋”ฉ์„ ๋„˜์–ด ๋‹ค์–‘ํ•œ GPU ํ”„๋กœ๊ทธ๋ž˜๋ฐ ์‹œ๋‚˜๋ฆฌ์˜ค์— ์ ์šฉ๋ฉ๋‹ˆ๋‹ค.