이번 포스팅에서는 선형회귀 알고리즘을 구현한다.
선형회귀 알고리즘의 틀을 아래와 같이 정리했다.
1. 가중치 초기화
for ( 에포크 반복 ) :
for ( 샘플 반복 ) :
2. 예측값 구하기
3. 가중치 업데이트
하나하나씩 살펴보고 구현해보자.
| 훈련 데이터 세팅
본격적으로 알고리즘을 구현하기 전에, 훈련데이터부터 세팅해보자.
사이킷런에서 제공하는 당뇨병 데이터를 로드해서 사용한다.
다만 특성을 2개 이상 사용하면 3차원 이상의 그래프를 그려야하기 때문에 x의 특성 중 하나만 선택했다.
from sklearn.datasets import load_diabetes
diabetes = load_diabetes()
x = diabetes.data[:,2]
y = diabetes.target
| STEP 1
1. 가중치 초기화
for ( 에포크 반복 ) :
for ( 샘플 반복 ) :
2. 예측값 구하기
3. 가중치 업데이트
선형회귀 함수의 초기 가중치와 절편을 지정하는 단계이다.
초기값을 선정하는 방법은 여러가지가 있지만,
여기서는 편의상 각각 1.0으로 설정했다.
w = 1.0
b = 1.0
그럼 초기 가중치와 절편일 때 어떤 그래프가 그려지는지 확인해보자.
import matplotlib.pyplot as plt
# 입력 x와 타깃 y
plt.scatter(x,y)
# 입력 x와 예측값 y_hat
y_hat = w * x + b
plt.plot(x,y_hat)
plt.show()
그래프가 훈련데이터의 경향을 전혀 설명하지 못하고 있다.
이제 선형회귀 알고리즘을 통해 가중치와 절편을 조절해주자.
| STEP 2
1. 가중치 초기화
for ( 에포크 반복 ) :
for ( 샘플 반복 ) :
2. 예측값 구하기
3. 가중치 업데이트
선형회귀함수의 예측값을 구하는 forpass() 함수를 정의했다.
def forpass(x) :
y_hat = x * w + b
return y_hat
| STEP 3
1. 가중치 초기화
for ( 에포크 반복 ) :
for ( 샘플 반복 ) :
2. 예측값 구하기
3. 가중치 업데이트
오차를 토대로 새로운 가중치와 절편을 구하는 함수 backprop()을 정의했다.
def backprop(x, err) :
w_grad = x * err
b_grad = 1* err
return w_grad, b_grad
위와 같은 식이 도출되는 과정에 대해서는 아래 더보기 포스팅에서 다뤘으므로 참고 바란다!
| STEP 4
1. 가중치 초기화
for ( 에포크 반복 ) :
for ( 샘플 반복 ) :
2. 예측값 구하기
3. 가중치 업데이트
모든 샘플과 에포크에 대해서 반복한다.
* 에포크란?!
경사하강법에서는 주어진 훈련데이터로 훈련을 여러번 반복한다.
이렇게 전체 데이터를 모두 이용하여 한 단위의 작업을 진행하는 것을 에포크라고 한다.
일반적으로 수십에서 수천번의 에포크를 반복한다.
여기서는 에포크 횟수를 100으로 설정하고,
fit() 이라는 함수로 정의했다.
def fit(x, y, epochs=100) :
global w
global b
for i in range(epochs) :
for x_i, y_i in zip(x,y) :
y_hat = forpass(x_i)
err = - (y_i - y_hat)
w_grad, b_grad = backprop(x_i, err)
w -= w_grad
b -= b_grad
| 알고리즘 실행 후 결과 그래프 확인
fit(x,y)
# 입력 x와 타깃 y
plt.scatter(x,y)
# 입력 x와 예측값 y_hat
y_hat = w * x + b
plt.plot(x,y_hat)
plt.show()
알고리즘 실행 후 그래프가 훈련데이터의 경향을 더 잘 설명하고 있다는 것을 확인할 수 있다.
전체 소스코드는 아래에 첨부한다.
'Deep Learning > [Books] Do it! 정직하게 코딩하며 배우는 딥러닝 입문' 카테고리의 다른 글
[모델 구축] 로지스틱 손실함수와 오류 역전파 이해하기 (0) | 2020.07.13 |
---|---|
[모델 선정] 이진분류 알고리즘 3가지 (퍼셉트론, 아달린, 로지스틱 회귀) (2) | 2020.07.09 |
[데이터 탐색] 데이터 탐색을 위한 파이썬 명령어 3가지 (1) | 2020.07.02 |
[모델 구축] 경사하강법을 구현하는 방법 - ② 손실함수 미분하기 (7) | 2020.06.29 |
[모델 구축] 경사하강법을 구현하는 방법 - ① 직접 변화율 계산하기 (9) | 2020.06.25 |