Puzzle 20: 1D ํ•ฉ์„ฑ๊ณฑ Op

MAX ๊ทธ๋ž˜ํ”„์—์„œ PyTorch ์ปค์Šคํ…€ Op์œผ๋กœ

GPU ํผ์ฆ ์—ฌ์ •์˜ Part V์— ์ง„์ž…ํ–ˆ์Šต๋‹ˆ๋‹ค: PyTorch ์ปค์Šคํ…€ Op ํ†ตํ•ฉํ•˜๊ธฐ.

Puzzle 17: 1D ํ•ฉ์„ฑ๊ณฑ Op์—์„œ MAX ๊ทธ๋ž˜ํ”„๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ Mojo GPU ์ปค๋„์„ ํŒŒ์ด์ฌ๊ณผ ์—ฐ๋™ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ฐฐ์› ์Šต๋‹ˆ๋‹ค. ์ด์ œ๋ถ€ํ„ฐ๋Š” ๋‹ค์Œ์„ ์•Œ์•„๋ด…๋‹ˆ๋‹ค:

  • ๋™์ผํ•œ Mojo ์ปค๋„์„ PyTorch์˜ CustomOpLibrary๋กœ ์‚ฌ์šฉํ•˜๊ธฐ
  • PyTorch์˜ ํ…์„œ ์‹œ์Šคํ…œ ๋ฐ ์˜คํ† ๊ทธ๋ž˜๋“œ(autograd)์™€ ํ†ตํ•ฉํ•˜๊ธฐ
  • MAX ๊ทธ๋ž˜ํ”„์™€ PyTorch ๋ฐฉ์‹์˜ ์ปค์Šคํ…€ ์—ฐ์‚ฐ ๋น„๊ตํ•˜๊ธฐ
  • ๋ช…์‹œ์  ์ถœ๋ ฅ ํ…์„œ ํ• ๋‹น์ด๋ผ๋Š” ํ•ต์‹ฌ ํŒจํ„ด ์ดํ•ดํ•˜๊ธฐ

์ด ์ „ํ™˜์„ ํ†ตํ•ด ๋™์ผํ•œ ์ตœ์ ํ™”๋œ GPU ์ปค๋„์ด ์„œ๋กœ ๋‹ค๋ฅธ ํŒŒ์ด์ฌ ํ†ตํ•ฉ ๋ฐฉ์‹์—์„œ ์–ด๋–ป๊ฒŒ ๋™์ž‘ํ•˜๋Š”์ง€ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๊ฐœ์š”

์ด ํผ์ฆ์—์„œ๋Š” Puzzle 17: 1D ํ•ฉ์„ฑ๊ณฑ Op์˜ 1D ํ•ฉ์„ฑ๊ณฑ(convolution) ์ปค๋„์„ ๊ทธ๋Œ€๋กœ ๊ฐ€์ ธ์™€์„œ, MAX ๊ทธ๋ž˜ํ”„ ๋Œ€์‹  CustomOpLibrary๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ PyTorch์™€ ํ†ตํ•ฉํ•ฉ๋‹ˆ๋‹ค.

์—ฌ๊ธฐ์„œ ํ•ต์‹ฌ์€ ๋™์ผํ•œ Mojo ์ปค๋„์ด ์ˆ˜์ • ์—†์ด ๊ทธ๋Œ€๋กœ ๋™์ž‘ํ•œ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. MAX ๊ทธ๋ž˜ํ”„์™€ PyTorch ๋ฐฉ์‹ ์‚ฌ์ด์—์„œ ๋‹ฌ๋ผ์ง€๋Š” ๊ฒƒ์€ ํŒŒ์ด์ฌ ํ†ตํ•ฉ ๋ ˆ์ด์–ด๋ฟ์ž…๋‹ˆ๋‹ค.

์™„์„ฑํ•  ์ฝ”๋“œ

์ด ํผ์ฆ์„ ์™„์„ฑํ•˜๋ ค๋ฉด ์ปค์Šคํ…€ ์—ฐ์‚ฐ์„ ํ˜ธ์ถœํ•˜๋Š” ํ•œ ์ค„๋งŒ ์ฑ„์šฐ๋ฉด ๋ฉ๋‹ˆ๋‹ค:

import torch
from max.torch import CustomOpLibrary


def conv1d_pytorch(
    input_tensor: torch.Tensor, kernel_tensor: torch.Tensor
) -> torch.Tensor:
    """
    1D convolution using our custom PyTorch operation.

    This demonstrates the transition from MAX Graph (p15) to PyTorch CustomOpLibrary.
    Uses the EXACT same Mojo kernel, but different Python integration!
    """
    # Load our custom operations
    mojo_kernels = Path(__file__).parent / "op"
    ops = CustomOpLibrary(mojo_kernels)

    # Create output tensor with same shape as input
    output_tensor = torch.empty_like(input_tensor)

    # Call our custom conv1d operation with explicit output tensor
    # The Mojo signature expects: (out, input, kernel)
    conv1d = ops.conv1d[
        {
            "input_size": input_tensor.shape[0],
            "conv_size": kernel_tensor.shape[0],
        }
    ]

    # FILL IN with 1 line of code

    return output_tensor


์ „์ฒด ํŒŒ์ผ ๋ณด๊ธฐ: problems/p20/p20.py

๋‹ค์Œ ๋ช…๋ น์œผ๋กœ ํผ์ฆ์„ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

pixi run p20
pixi run -e amd p20
uv run poe p20

์„ฑ๊ณตํ•˜๋ฉด ๋‹ค์Œ๊ณผ ๋น„์Šทํ•œ ์ถœ๋ ฅ์„ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

Puzzle 20: From MAX Graph to PyTorch Custom Ops
============================================================
Input array: [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14.]
Convolution kernel: [0. 1. 2. 3.]

NumPy reference result: [14. 20. 26. 32. 38. 44. 50. 56. 62. 68. 74. 80. 41. 14.  0.]

Testing PyTorch Custom Op (device: cuda)
----------------------------------------
PyTorch custom op result: [14. 20. 26. 32. 38. 44. 50. 56. 62. 68. 74. 80. 41. 14.  0.]
โœ… PyTorch custom op verification PASSED

Comparing with MAX Graph approach (like p15)
--------------------------------------------
MAX Graph result: [14. 20. 26. 32. 38. 44. 50. 56. 62. 68. 74. 80. 41. 14.  0.]
โœ… MAX Graph verification PASSED
โœ… PyTorch and MAX Graph results MATCH

์†”๋ฃจ์…˜

์ปดํŒŒ์ผ๋œ ์ปค์Šคํ…€ ์—ฐ์‚ฐ์„ ์ ์ ˆํ•œ ์ธ์ž์™€ ํ•จ๊ป˜ ํ˜ธ์ถœํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค:

    # Call our custom conv1d operation with explicit output tensor
    # The Mojo signature expects: (out, input, kernel)
    conv1d = ops.conv1d[
        {
            "input_size": input_tensor.shape[0],
            "conv_size": kernel_tensor.shape[0],
        }
    ]
    torch.compile(conv1d)(output_tensor, input_tensor, kernel_tensor)

์ด ํ’€์ด๋Š” ๋ช‡ ๊ฐ€์ง€ ํ•ต์‹ฌ ๊ฐœ๋…์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค:

1. torch.compile() ํ†ตํ•ฉ

torch.compile ํ†ตํ•ฉ ๋ฐฉ์‹์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:

torch.compile(conv1d)(output_tensor, input_tensor, kernel_tensor)

2. ๋ช…์‹œ์  ์ถœ๋ ฅ ํ…์„œ ํ• ๋‹น

output_tensor = torch.empty_like(input_tensor)
  • MAX ๊ทธ๋ž˜ํ”„๋Š” ์ถœ๋ ฅ ํ• ๋‹น์„ ์ž๋™์œผ๋กœ ์ฒ˜๋ฆฌํ•˜์ง€๋งŒ
  • PyTorch CustomOpLibrary๋Š” ๋ฏธ๋ฆฌ ํ• ๋‹น๋œ ์ถœ๋ ฅ ํ…์„œ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค
  • Mojo ์—ฐ์‚ฐ ์‹œ๊ทธ๋‹ˆ์ฒ˜๋Š” (out, input, kernel) ์ˆœ์„œ๋ฅผ ๊ธฐ๋Œ€ํ•ฉ๋‹ˆ๋‹ค

3. ํŒŒ๋ผ๋ฏธํ„ฐ ๋”•์…”๋„ˆ๋ฆฌ

ops.conv1d[{"input_size": input_tensor.shape[0], "conv_size": kernel_tensor.shape[0]}]
  • ํŒŒ๋ผ๋ฏธํ„ฐ๋Š” ๋”•์…”๋„ˆ๋ฆฌ ํ˜•ํƒœ๋กœ ์—ฐ์‚ฐ์— ์ „๋‹ฌ๋ฉ๋‹ˆ๋‹ค
  • ์ด ๊ฐ’๋“ค์€ Mojo ์ปค๋„์˜ ์ปดํŒŒ์ผ ํƒ€์ž„ ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ๋ฉ๋‹ˆ๋‹ค
  • Mojo @staticmethod fn execute ์‹œ๊ทธ๋‹ˆ์ฒ˜์˜ ํŒŒ๋ผ๋ฏธํ„ฐ ์ด๋ฆ„๊ณผ ์ผ์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค

4. ๊ฐ™์€ ์ปค๋„, ๋‹ค๋ฅธ ํ†ตํ•ฉ ๋ฐฉ์‹

๋‚ด๋ถ€์˜ Mojo ์ปค๋„(conv1d_kernel)์€ Puzzle 17๊ณผ ๋™์ผํ•ฉ๋‹ˆ๋‹ค:

  • ๋™์ผํ•œ GPU ์ปค๋„ ์ฝ”๋“œ
  • ๋™์ผํ•œ ๋ฉ”๋ชจ๋ฆฌ ์ ‘๊ทผ ํŒจํ„ด
  • ๋™์ผํ•œ ์—ฐ์‚ฐ ๋กœ์ง
  • ํŒŒ์ด์ฌ ๋ž˜ํผ ๋ ˆ์ด์–ด๋งŒ ๋‹ฌ๋ผ์ง

ํ•ต์‹ฌ ๊ฐœ๋…

์ด ํผ์ฆ์€ PyTorch ์ปค์Šคํ…€ ์—ฐ์‚ฐ์˜ ์ฃผ์š” ํŒจํ„ด์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค:

๊ฐœ๋…MAX ๊ทธ๋ž˜ํ”„ (p15)PyTorch CustomOpLibrary (p18)
์ถœ๋ ฅ ํ• ๋‹น์ž๋™์ˆ˜๋™ (torch.empty_like())
์—ฐ์‚ฐ ํ˜ธ์ถœops.custom(...)torch.compile(op)(...)
ํŒŒ๋ผ๋ฏธํ„ฐ ์ „๋‹ฌparameters={...}op[{...}]
๋””๋ฐ”์ด์Šค ๊ด€๋ฆฌ๋ช…์‹œ์  device contextPyTorch ํ…์„œ์˜ device
๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌMAX ๊ทธ๋ž˜ํ”„ ํ…์„œPyTorch ํ…์„œ

ํ•ต์‹ฌ ํŒจํ„ด: ๋ช…์‹œ์  ์ถœ๋ ฅ ํ…์„œ ํ• ๋‹น

๊ฐ€์žฅ ์ค‘์š”ํ•œ ์ฐจ์ด์ ์€ PyTorch CustomOpLibrary๊ฐ€ ๋ช…์‹œ์  ์ถœ๋ ฅ ํ…์„œ ํ• ๋‹น์„ ์š”๊ตฌํ•œ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค:

# โŒ ๋™์ž‘ํ•˜์ง€ ์•Š์Œ - ์ถœ๋ ฅ ํ…์„œ ์—†์Œ
result = torch.compile(conv1d)(input_tensor, kernel_tensor)

# โœ… ๋™์ž‘ํ•จ - ๋ฏธ๋ฆฌ ํ• ๋‹น๋œ ์ถœ๋ ฅ ํ…์„œ
output_tensor = torch.empty_like(input_tensor)
torch.compile(conv1d)(output_tensor, input_tensor, kernel_tensor)

์ด ํŒจํ„ด์ด ๋ณด์žฅํ•˜๋Š” ๊ฒƒ๋“ค:

  • ์˜ฌ๋ฐ”๋ฅธ ๋””๋ฐ”์ด์Šค์— ๋ฉ”๋ชจ๋ฆฌ ํ• ๋‹น
  • ์ถœ๋ ฅ ํ…์„œ์˜ shape๊ณผ dtype์ด ์ •ํ™•
  • Mojo ์ปค๋„์ด ์ถœ๋ ฅ ๋ฒ„ํผ์— ์ง์ ‘ ์“ฐ๊ธฐ ๊ฐ€๋Šฅ

torch.compile() ํ†ตํ•ฉ

torch.compile()์ด ํ•„์ˆ˜์ ์ธ ์ด์œ :

  • PyTorch์™€ Mojo ์‚ฌ์ด์˜ ๋ฉ”๋ชจ๋ฆฌ ๋ ˆ์ด์•„์›ƒ ๋ณ€ํ™˜ ์ฒ˜๋ฆฌ
  • ๋””๋ฐ”์ด์Šค ๋™๊ธฐํ™” ๊ด€๋ฆฌ (CPU โ†” GPU)
  • ํ…์„œ ํฌ๋งท ๋ณ€ํ™˜ ์ตœ์ ํ™”
  • ๋ฉ”๋ชจ๋ฆฌ ์—ฐ์‚ฐ์— ๋Œ€ํ•œ ์ ์ ˆํ•œ ์˜ค๋ฅ˜ ์ฒ˜๋ฆฌ ์ œ๊ณต

์ฐธ๊ณ : torch.compile() ์—†์ด ์‚ฌ์šฉํ•˜๋ฉด std::bad_alloc ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Š” ์›์‹œ ์—ฐ์‚ฐ์ด PyTorch์˜ ํ…์„œ ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ๋ฅผ ์ฒ˜๋ฆฌํ•˜์ง€ ๋ชปํ•˜๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.

์ปค์Šคํ…€ ์—ฐ์‚ฐ ๋””๋ฒ„๊น…

์ž์ฃผ ๋ฐœ์ƒํ•˜๋Š” ๋ฌธ์ œ์™€ ํ•ด๊ฒฐ ๋ฐฉ๋ฒ•:

  1. ๋ฉ”๋ชจ๋ฆฌ ํ• ๋‹น ์˜ค๋ฅ˜: ํ•ญ์ƒ torch.compile()์„ ์‚ฌ์šฉํ•˜์„ธ์š”
  2. ์ž˜๋ชป๋œ ์ถœ๋ ฅ ํ˜•์ƒ: ์ถœ๋ ฅ ํ…์„œ๊ฐ€ ๊ธฐ๋Œ€ํ•˜๋Š” ์ฐจ์›๊ณผ ์ผ์น˜ํ•˜๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”
  3. ๋””๋ฐ”์ด์Šค ๋ถˆ์ผ์น˜: ๋ชจ๋“  ํ…์„œ๊ฐ€ ๊ฐ™์€ ๋””๋ฐ”์ด์Šค์— ์žˆ์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค
  4. ํŒŒ๋ผ๋ฏธํ„ฐ ์˜ค๋ฅ˜: ํŒŒ๋ผ๋ฏธํ„ฐ ์ด๋ฆ„์ด Mojo ์—ฐ์‚ฐ ์‹œ๊ทธ๋‹ˆ์ฒ˜์™€ ์ผ์น˜ํ•˜๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”

๋””๋ฒ„๊น… ์ ‘๊ทผ๋ฒ•: PyTorch ๊ฒฐ๊ณผ๋ฅผ ๋™์ผํ•œ ์ปค๋„์„ ์‹คํ–‰ํ•˜๋Š” MAX ๊ทธ๋ž˜ํ”„ ๋ ˆํผ๋Ÿฐ์Šค ๊ตฌํ˜„๊ณผ ๋น„๊ตํ•ด ๋ณด์„ธ์š”.