CNN 이란?
- 데이터로부터 직접 학습하는 딥러닝의 신경망 아키텍쳐
- 영상, 객체, 클래스, 범주인식을 위한 패턴을 찾을 때 유용함
- 오디오, 시계열 및 신호 데이터를 분류하는 데도 매우 효과적임.
- 수십, 수백개의 계층을 가질 수 있음.
- 각 계층은 영상 및 데이터의 서로 다른 특징을 검출함.
- 각 훈련영성에 서로 다른 해상도의 필터가 적용되며, convolution된 각 영상은 다음 계층의 입력으로 사용됨.
- 필터는 밝기, 경계와 같이 매우 간단한 특징으로 시작하여 객체를 고유하게 정의하는 특징으로 복잡도를 늘려나감.
CNN의 작동방식
가중치 및 편향
- CNN에는 주어진 계층의 모든 은닉 뉴런에 대해 동일하게 공유된 가중치(weight)와 편향값(bias)이 있음
- 가중치 (Weight)
- 가중치는 입력 데이터의 특성에 곱해져서 다음 레이어로 전달되는 값
- 각각의 뉴런은 입력값에 대한 가중치와 활성화 함수의 출력을 곱한 값을 합산하여 처리
- 가중치는 모델이 데이터의 패턴을 학습하고 표현하는 데 중요한 역할을
- 편향 (bias)
- 각 뉴런의 출력에 더해지는 상수 값
- 가중치와 함께 활성화 함수로 전달
- 편향은 모델이 데이터를 올바르게 표현하고 예측하기 위해 데이터와 모델 사이의 차이를 조절하는 역할
- 가중치가 데이터의 패턴을 조정하는 반면, 편향은 모델이 얼마나 잘 데이터를 표현하는지에 영향을 미침
- 가중치와 편향은 모델의 학습 과정에서 최적화되는 매개변수임
- 간단히 말해, 가중치는 입력 데이터의 각 특성에 대한 중요도를 조절하고, 편향은 모델의 전체적인 편향을 조절하는 역할
- 가중치 (Weight)
간단한 예제
- mnist 손글씨 분류
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
# 데이터 로드 및 전처리
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255
test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
# CNN 모델 정의
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
# 모델 컴파일
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
# 모델 학습
model.fit(train_images, train_labels, epochs=5, batch_size=64, validation_split=0.2)
# 모델 평가
test_loss, test_acc = model.evaluate(test_images, test_labels)
print('Test accuracy:', test_acc)
참고자료
https://kr.mathworks.com/discovery/convolutional-neural-network-matlab.html