AI

AI_Basic) ์—ญ์ „ํŒŒ(Backpropagtaion) : ํ•™์Šต๋ถ€ํ„ฐ ์—ญ์ „ํŒŒ๊ฐ€ ์–ด๋””์— ์‚ฌ์šฉ๋˜๋Š”์ง€๊นŒ์ง€ ์ฝ”๋“œ์œ„์ฃผ

proggg 2024. 12. 2. 14:42
728x90

์„ ํ˜•์‹์˜ ๊ฒฝ์šฐ ์ดํ•ด๊ฐ€ ์ง๊ด€์ ์ด๋‚˜ ๊ทธ๊ฒƒ๋งŒ์„ ๊ฐ€์ง€๊ณ  ๊ฐ‘์ž๊ธฐ CNN ๋ชจ๋ธ์— ๋Œ€์ž…ํ•˜๋ฉด ๋จธ๋ฆฌ๊ฐ€ ๋ตํ•ด์ง€๋Š” ๊ฒฝํ—˜์„ ํ•œ์ ์ด ์žˆ์„๊ฒƒ์ด๋‹ค. ๊ทธ๋ž˜์„œ ์˜ค๋Š˜์€ CNN ๊ธฐ๋ฐ˜ ํ•™์Šต ๊ณผ์ •์—์„œ ์—ญ์ „ํŒŒ๊ฐ€ ์–ด๋–ป๊ฒŒ ๋™์ž‘ํ•˜๋Š”์ง€ ์•Œ์•„๋ณด๊ธฐ ์œ„ํ•ด ์ •๋ฆฌํ• ๊ฒธ ๊ธ€์„ ์ž‘์„ฑํ–ˆ๋‹ค.
์ธ๊ณต์ง€๋Šฅ๋ฐˆ

CNN ๊ธฐ๋ฐ˜ ๋จธ์‹ ๋Ÿฌ๋‹ ์‹œ์Šคํ…œ์˜ ํ•™์Šต ๊ณผ์ •๊ณผ ์—ญ์ „ํŒŒ์˜ ์ดํ•ด

๋จธ์‹ ๋Ÿฌ๋‹ ์‹œ์Šคํ…œ์˜ ํ•™์Šต ๊ณผ์ •์€ ๋ฐ์ดํ„ฐ ์ค€๋น„, ๋ชจ๋ธ ์ •์˜, ์ˆœ์ „ํŒŒ, ์†์‹ค ๊ณ„์‚ฐ, ์—ญ์ „ํŒŒ, ๊ทธ๋ฆฌ๊ณ  ํŒŒ๋ผ๋ฏธํ„ฐ ์—…๋ฐ์ดํŠธ์˜ ๋‹จ๊ณ„๋กœ ์ด๋ฃจ์–ด์ง„๋‹ค. ์ด๋ฒˆ์—๋Š” CNN(Convolutional Neural Network) ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜๊ณ , Cross Entropy ์†์‹ค ํ•จ์ˆ˜๋ฅผ ์ ์šฉํ•˜์—ฌ ๊ฐ ๋‹จ๊ณ„๋ฅผ ์‚ดํŽด๋ณด๋„๋ก ํ•˜์ž.

1. ๋ฐ์ดํ„ฐ ์ค€๋น„

CNN ๋ชจ๋ธ์€ ์ฃผ๋กœ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ๋ฅผ ๋‹ค๋ฃจ๋ฏ€๋กœ, ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ์…‹์„ ์ค€๋น„ํ•ด์•ผ ํ•œ๋‹ค. PyTorch์˜ torchvision ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ์‰ฝ๊ฒŒ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ๋ฅผ ๋กœ๋“œํ•˜๊ณ  ์ „์ฒ˜๋ฆฌํ•  ์ˆ˜ ์žˆ๋‹ค.

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

def prepare_data():
    """
    CIFAR-10 ๋ฐ์ดํ„ฐ์…‹์„ ๋กœ๋“œํ•˜๊ณ  ์ „์ฒ˜๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜

    reference: CIFAR-10 ๋ฐ์ดํ„ฐ์…‹์„ ๋‹ค์šด๋กœ๋“œํ•˜๊ณ , ๋ฐ์ดํ„ฐ ์ฆ๊ฐ• ๋ฐ ์ •๊ทœํ™”๋ฅผ ์ ์šฉํ•œ๋‹ค.
    return: train_loader, test_loader - ํ•™์Šต ๋ฐ ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ ๋กœ๋”
    """
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

    return train_loader, test_loader

train_loader, test_loader = prepare_data()

์ด ์ฝ”๋“œ์—์„œ๋Š” CIFAR-10 ๋ฐ์ดํ„ฐ์…‹์„ ์‚ฌ์šฉํ•œ๋‹ค. ๋ฐ์ดํ„ฐ ์ฆ๊ฐ• ๊ธฐ๋ฒ•(RandomHorizontalFlip, RandomCrop)๊ณผ ์ •๊ทœํ™”๋ฅผ ์ ์šฉํ•˜์—ฌ ๋ชจ๋ธ์˜ ์ผ๋ฐ˜ํ™” ์„ฑ๋Šฅ์„ ํ–ฅ์ƒ์‹œํ‚จ๋‹ค.

2. CNN ๋ชจ๋ธ ์ •์˜

์ด์ œ CNN ๋ชจ๋ธ์„ ์ •์˜ํ•ด์•ผ ํ•œ๋‹ค. PyTorch์˜ nn.Module ํด๋ž˜์Šค๋ฅผ ์ƒ์†๋ฐ›์•„ CNN ๋ชจ๋ธ์„ ๊ตฌํ˜„ํ•  ์ˆ˜ ์žˆ๋‹ค.

import torch.nn as nn
import torch.nn.functional as F

class CNN(nn.Module):
    """
    CNN ๋ชจ๋ธ ํด๋ž˜์Šค

    reference: 3๊ฐœ์˜ ์ปจ๋ณผ๋ฃจ์…˜ ์ธต๊ณผ 2๊ฐœ์˜ ์™„์ „์—ฐ๊ฒฐ์ธต์œผ๋กœ ๊ตฌ์„ฑ๋œ CNN ๋ชจ๋ธ์„ ์ •์˜ํ•œ๋‹ค.
    """
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 64 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = CNN().to(device)

์ด CNN ๋ชจ๋ธ์€ 3๊ฐœ์˜ ์ปจ๋ณผ๋ฃจ์…˜ ์ธต๊ณผ 2๊ฐœ์˜ ์™„์ „์—ฐ๊ฒฐ์ธต์œผ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ๋‹ค. ๊ฐ ์ปจ๋ณผ๋ฃจ์…˜ ์ธต ๋’ค์—๋Š” ReLU ํ™œ์„ฑํ™” ํ•จ์ˆ˜์™€ Max Pooling์ด ์ ์šฉ๋œ๋‹ค.

3. ์ˆœ์ „ํŒŒ

์ˆœ์ „ํŒŒ ๊ณผ์ •์€ ์ž…๋ ฅ ์ด๋ฏธ์ง€๋ฅผ ๋ชจ๋ธ์— ํ†ต๊ณผ์‹œ์ผœ ์˜ˆ์ธก๊ฐ’์„ ์–ป๋Š” ๊ณผ์ •์ด๋‹ค. CNN ๋ชจ๋ธ์—์„œ๋Š” ์ด๋ฏธ์ง€๊ฐ€ ์ปจ๋ณผ๋ฃจ์…˜ ์ธต๊ณผ ํ’€๋ง ์ธต์„ ๊ฑฐ์ณ ํŠน์ง•์„ ์ถ”์ถœํ•˜๊ณ , ์ตœ์ข…์ ์œผ๋กœ ์™„์ „์—ฐ๊ฒฐ์ธต์„ ํ†ต๊ณผํ•˜์—ฌ ํด๋ž˜์Šค ์˜ˆ์ธก๊ฐ’์„ ์ถœ๋ ฅํ•œ๋‹ค.

def forward_pass(model, inputs):
    """
    ์ˆœ์ „ํŒŒ ์ˆ˜ํ–‰ ํ•จ์ˆ˜

    reference: CNN ๋ชจ๋ธ์— ์ž…๋ ฅ ์ด๋ฏธ์ง€๋ฅผ ํ†ต๊ณผ์‹œ์ผœ ์˜ˆ์ธก๊ฐ’์„ ์–ป๋Š”๋‹ค.
    argument: model - CNN ๋ชจ๋ธ, inputs - ์ž…๋ ฅ ์ด๋ฏธ์ง€ ๋ฐฐ์น˜
    return: outputs - ๋ชจ๋ธ์˜ ์˜ˆ์ธก๊ฐ’
    """
    outputs = model(inputs)
    return outputs

# ์ˆœ์ „ํŒŒ ์ˆ˜ํ–‰ ์˜ˆ์‹œ
inputs, _ = next(iter(train_loader))
inputs = inputs.to(device)
outputs = forward_pass(model, inputs)

์ด ์ฝ”๋“œ์—์„œ๋Š” forward_pass ํ•จ์ˆ˜๋ฅผ ์ •์˜ํ•˜์—ฌ CNN ๋ชจ๋ธ์— ์ž…๋ ฅ ์ด๋ฏธ์ง€๋ฅผ ํ†ต๊ณผ์‹œํ‚ค๊ณ  ์˜ˆ์ธก๊ฐ’์„ ์–ป๋Š”๋‹ค. PyTorch๋Š” ์ด ๊ณผ์ •์—์„œ ์ž๋™์œผ๋กœ ๊ณ„์‚ฐ ๊ทธ๋ž˜ํ”„๋ฅผ ์ƒ์„ฑํ•œ๋‹ค.

4. Cross Entropy ์†์‹ค ๊ณ„์‚ฐ

๋ถ„๋ฅ˜ ๋ฌธ์ œ์—์„œ๋Š” ์ฃผ๋กœ Cross Entropy ์†์‹ค ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค. ์ด ํ•จ์ˆ˜๋Š” ๋ชจ๋ธ์˜ ์˜ˆ์ธก ํ™•๋ฅ  ๋ถ„ํฌ์™€ ์‹ค์ œ ๋ ˆ์ด๋ธ” ๋ถ„ํฌ ์‚ฌ์ด์˜ ์ฐจ์ด๋ฅผ ์ธก์ •ํ•œ๋‹ค.

def compute_loss(outputs, targets, criterion):
    """
    Cross Entropy ์†์‹ค ๊ณ„์‚ฐ ํ•จ์ˆ˜

    reference: ๋ชจ๋ธ์˜ ์˜ˆ์ธก๊ฐ’๊ณผ ์‹ค์ œ ๋ ˆ์ด๋ธ” ์‚ฌ์ด์˜ Cross Entropy ์†์‹ค์„ ๊ณ„์‚ฐํ•œ๋‹ค.
    argument: outputs - ๋ชจ๋ธ์˜ ์˜ˆ์ธก๊ฐ’, targets - ์‹ค์ œ ๋ ˆ์ด๋ธ”, criterion - ์†์‹ค ํ•จ์ˆ˜
    return: loss - ๊ณ„์‚ฐ๋œ ์†์‹ค๊ฐ’
    """
    loss = criterion(outputs, targets)
    return loss

# ์†์‹ค ํ•จ์ˆ˜ ์ •์˜ ๋ฐ ์†์‹ค ๊ณ„์‚ฐ ์˜ˆ์‹œ
criterion = nn.CrossEntropyLoss()
_, targets = next(iter(train_loader))
targets = targets.to(device)
loss = compute_loss(outputs, targets, criterion)

์—ฌ๊ธฐ์„œ๋Š” nn.CrossEntropyLoss()๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์†์‹ค ํ•จ์ˆ˜๋ฅผ ์ •์˜ํ•œ๋‹ค. ์ด ํ•จ์ˆ˜๋Š” ์†Œํ”„ํŠธ๋งฅ์Šค ํ•จ์ˆ˜์™€ ์Œ์˜ ๋กœ๊ทธ ์šฐ๋„ ์†์‹ค์„ ๊ฒฐํ•ฉํ•œ ํ˜•ํƒœ๋กœ, ๋‹ค์ค‘ ํด๋ž˜์Šค ๋ถ„๋ฅ˜ ๋ฌธ์ œ์— ์ ํ•ฉํ•˜๋‹ค.

5. ์—ญ์ „ํŒŒ

์—ญ์ „ํŒŒ๋Š” ๊ณ„์‚ฐ๋œ ์†์‹ค์„ ๊ธฐ๋ฐ˜์œผ๋กœ ๋ชจ๋ธ์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๋“ค์— ๋Œ€ํ•œ ๊ทธ๋ž˜๋””์–ธํŠธ๋ฅผ ๊ณ„์‚ฐํ•˜๋Š” ๊ณผ์ •์ด๋‹ค. CNN ๋ชจ๋ธ์—์„œ๋Š” ์ปจ๋ณผ๋ฃจ์…˜ ์ธต, ํ’€๋ง ์ธต, ์™„์ „์—ฐ๊ฒฐ์ธต์˜ ๊ฐ€์ค‘์น˜์™€ ํŽธํ–ฅ์— ๋Œ€ํ•œ ๊ทธ๋ž˜๋””์–ธํŠธ๊ฐ€ ๊ณ„์‚ฐ๋œ๋‹ค.

์—ญ์ „ํŒŒ์˜ ์ˆ˜ํ•™์  ๊ณผ์ •์€ ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค:

  1. ์ถœ๋ ฅ์ธต์—์„œ์˜ ์˜ค์ฐจ: $$\delta^L = \nabla_a C \odot \sigma'(z^L)$$
    ์—ฌ๊ธฐ์„œ $C$๋Š” ๋น„์šฉ ํ•จ์ˆ˜, $a^L$์€ ์ถœ๋ ฅ์ธต์˜ ํ™œ์„ฑํ™”, $z^L$์€ ์ถœ๋ ฅ์ธต์˜ ๊ฐ€์ค‘์น˜ ํ•ฉ, $\sigma$๋Š” ํ™œ์„ฑํ™” ํ•จ์ˆ˜์ด๋‹ค.

  2. ์ด์ „ ์ธต์œผ๋กœ์˜ ์˜ค์ฐจ ์ „ํŒŒ: $$\delta^l = ((w^{l+1})^T \delta^{l+1}) \odot \sigma'(z^l)$$
    ์—ฌ๊ธฐ์„œ $w^{l+1}$์€ $l+1$ ์ธต์˜ ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ์ด๋‹ค.

  3. ๊ฐ€์ค‘์น˜์™€ ํŽธํ–ฅ์— ๋Œ€ํ•œ ๊ทธ๋ž˜๋””์–ธํŠธ:
    $$\frac{\partial C}{\partial w^l_{jk}} = a^{l-1}_k \delta^l_j$$
    $$\frac{\partial C}{\partial b^l_j} = \delta^l_j$$

CNN์—์„œ๋Š” ์ด ๊ณผ์ •์ด ์ปจ๋ณผ๋ฃจ์…˜ ์—ฐ์‚ฐ์˜ ํŠน์„ฑ์„ ๊ณ ๋ คํ•˜์—ฌ ์ˆ˜์ •๋œ๋‹ค.

def backward_pass(loss):
    """
    ์—ญ์ „ํŒŒ ์ˆ˜ํ–‰ ํ•จ์ˆ˜

    reference: ๊ณ„์‚ฐ๋œ ์†์‹ค์„ ๊ธฐ๋ฐ˜์œผ๋กœ ๋ชจ๋ธ์˜ ๋ชจ๋“  ํŒŒ๋ผ๋ฏธํ„ฐ์— ๋Œ€ํ•œ ๊ทธ๋ž˜๋””์–ธํŠธ๋ฅผ ๊ณ„์‚ฐํ•œ๋‹ค.
    argument: loss - ๊ณ„์‚ฐ๋œ ์†์‹ค๊ฐ’
    """
    loss.backward()

# ์—ญ์ „ํŒŒ ์ˆ˜ํ–‰ ์˜ˆ์‹œ
backward_pass(loss)

PyTorch์˜ autograd ์—”์ง„์€ loss.backward() ํ˜ธ์ถœ์„ ํ†ตํ•ด ์ž๋™์œผ๋กœ ๋ชจ๋“  ํŒŒ๋ผ๋ฏธํ„ฐ์— ๋Œ€ํ•œ ๊ทธ๋ž˜๋””์–ธํŠธ๋ฅผ ๊ณ„์‚ฐํ•œ๋‹ค. ์ด ๊ณผ์ •์—์„œ ์ปจ๋ณผ๋ฃจ์…˜ ์ธต์˜ ํ•„ํ„ฐ์™€ ์™„์ „์—ฐ๊ฒฐ์ธต์˜ ๊ฐ€์ค‘์น˜, ๊ทธ๋ฆฌ๊ณ  ๊ฐ ์ธต์˜ ํŽธํ–ฅ์— ๋Œ€ํ•œ ๊ทธ๋ž˜๋””์–ธํŠธ๊ฐ€ ๊ณ„์‚ฐ๋œ๋‹ค.

6. ํŒŒ๋ผ๋ฏธํ„ฐ ์—…๋ฐ์ดํŠธ

๋งˆ์ง€๋ง‰์œผ๋กœ, ๊ณ„์‚ฐ๋œ ๊ทธ๋ž˜๋””์–ธํŠธ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ CNN ๋ชจ๋ธ์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์—…๋ฐ์ดํŠธํ•œ๋‹ค. ์ด ๊ณผ์ •์€ ์˜ตํ‹ฐ๋งˆ์ด์ €๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ˆ˜ํ–‰๋œ๋‹ค.

import torch.optim as optim

def update_parameters(model, optimizer):
    """
    ํŒŒ๋ผ๋ฏธํ„ฐ ์—…๋ฐ์ดํŠธ ํ•จ์ˆ˜

    reference: ๊ณ„์‚ฐ๋œ ๊ทธ๋ž˜๋””์–ธํŠธ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ CNN ๋ชจ๋ธ์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์—…๋ฐ์ดํŠธํ•œ๋‹ค.
    argument: model - CNN ๋ชจ๋ธ, optimizer - ์ตœ์ ํ™” ์•Œ๊ณ ๋ฆฌ์ฆ˜
    """
    optimizer.step()
    optimizer.zero_grad()

# ์˜ตํ‹ฐ๋งˆ์ด์ € ์ •์˜ ๋ฐ ํŒŒ๋ผ๋ฏธํ„ฐ ์—…๋ฐ์ดํŠธ ์˜ˆ์‹œ
optimizer = optim.Adam(model.parameters(), lr=0.001)
update_parameters(model, optimizer)

์—ฌ๊ธฐ์„œ๋Š” Adam ์˜ตํ‹ฐ๋งˆ์ด์ €๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์—…๋ฐ์ดํŠธํ•œ๋‹ค. optimizer.step()์€ ๊ณ„์‚ฐ๋œ ๊ทธ๋ž˜๋””์–ธํŠธ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์กฐ์ •ํ•˜๊ณ , optimizer.zero_grad()๋Š” ๋‹ค์Œ ๋ฐ˜๋ณต์„ ์œ„ํ•ด ๊ทธ๋ž˜๋””์–ธํŠธ๋ฅผ ์ดˆ๊ธฐํ™”ํ•œ๋‹ค.

๊ฒฐ๋ก 

์ด์ƒ์œผ๋กœ CNN ๋ชจ๋ธ๊ณผ Cross Entropy ์†์‹ค ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•œ ๋จธ์‹ ๋Ÿฌ๋‹ ์‹œ์Šคํ…œ์˜ ํ•™์Šต ๊ณผ์ •๊ณผ ์—ญ์ „ํŒŒ์— ๋Œ€ํ•ด ์‚ดํŽด๋ณด์•˜๋‹ค. CNN์€ ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ์— ํŠนํ™”๋œ ๊ตฌ์กฐ๋กœ, ์ปจ๋ณผ๋ฃจ์…˜ ์—ฐ์‚ฐ์„ ํ†ตํ•ด ํšจ๊ณผ์ ์œผ๋กœ ํŠน์ง•์„ ์ถ”์ถœํ•  ์ˆ˜ ์žˆ๋‹ค. Cross Entropy ์†์‹ค ํ•จ์ˆ˜๋Š” ๋‹ค์ค‘ ํด๋ž˜์Šค ๋ถ„๋ฅ˜ ๋ฌธ์ œ์— ์ ํ•ฉํ•˜๋ฉฐ, ๋ชจ๋ธ์˜ ์˜ˆ์ธก๊ณผ ์‹ค์ œ ๋ ˆ์ด๋ธ” ์‚ฌ์ด์˜ ์ฐจ์ด๋ฅผ ์ž˜ ์ธก์ •ํ•œ๋‹ค. ๊ฐ ๋‹จ๊ณ„๋ฅผ ์ดํ•ดํ•˜๊ณ  ์ ์ ˆํžˆ ๊ตฌํ˜„ํ•จ์œผ๋กœ์จ ํšจ๊ณผ์ ์ธ ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ๋ชจ๋ธ์„ ๊ฐœ๋ฐœํ•  ์ˆ˜ ์žˆ๋‹ค.

728x90