네이버 지식백과에서 다중 선형 회귀분석은 독립 변수가 두 개 이상인 경우에 사용하는 선형 회귀분석이라고 설명합니다. 다중 선형 회귀분석과 선형 회귀의 다른 부분은 무엇인지 알아보고 Tensorflow를 이용하여 구현까지 해보겠습니다. 선형 회귀분석과 다른 점에는 가설, 독립변수의 개수, 손실 함수 그리고 독립 변수 계산 방법이 있습니다. 다음은 선형 회귀분석과 다중 선형 회귀분석의 가설입니다.
- 선형 회귀분석: H(x) = Wx + b, 독립변수가 1개 일 때
- 다중 선형 회귀분석: H(x1, x2, x3) = w1x1 + w2x2 + w3x3 + b, 독립변수가 3개 일 때
다음은 n개의 독립변수에 대한 다중선형회귀분석의 식을 나타내면 다음과 같습니다.
<그림 1>과 같이 가설 식이 변했으므로 손실 함수도 그에 맞게 변하게 됩니다. 다음은 변화된 손실 함수입니다.
<그림 1>과 <그림 2>를 통해 선형회귀분석과 같이 독립변수 하나하나를 직접 수정해주는 방법은 무리가 있다는 것을 알 수 있습니다. 따라서 독립 변수의 개수가 아무리 많아져도 코드를 수정 없이 동작할 수 있는 방법을 찾아야 합니다. 그 방법은 바로 행렬의 곱셈을 사용하는 것입니다. 행렬의 곱셈 동작 방식은 다음과 같습니다.
<그림 3>가 같이 행렬의 곱셈을 사용한다면 다중선형회귀분석에서 아무리 독립변수의 개수가 많아져도 그에 맞게 대응할 수 있게 됩니다. 이처럼 선형 회귀분석과 다른 점들에 대해서 알아보았으니 코드로 구현해보겠습니다.
data = np.array([
[73., 80., 75., 152.],
[93., 88., 93., 185.],
[89., 91., 90., 180.],
[96., 98., 100., 196.],
[73., 66., 70., 142.]
], dtype=np.float32)
x = data[:,:-1]
y = data[:, [-1]]
x는 독립변수, y는 종속변수입니다. 데이터 슬라이싱을 이용하여 한번에 데이터를 담을 수 있습니다.
W = tf.Variable(tf.random.normal([3,1]))
b = tf.Variable(tf.random.normal([1]))
learning_rate = 0.00001
학습률은 0.00001, W는 랜덤값이 들어간 (3,1)의 모양의 행렬이고, b는 랜덤 값이 들어간 (1,1) 행렬입니다.
def hypothesis(X):
return tf.matmul(X, W) + b
hypothesis는 가설을 정의해주는 함수입니다. matmul함수를 사용하여 X, W를 곱해줍니다.
n_epochs = 2000
for i in range(n_epochs+1):
with tf.GradientTape() as tape:
cost = tf.reduce_mean(tf.square(hypothesis(x) - y))
W_grad, b_grad = tape.gradient(cost, [W, b])
W.assign_sub(learning_rate * W_grad)
b.assign_sub(learning_rate * b_grad)
if i % 100 == 0:
print("{:5} | {:10.4}".format(i, cost.numpy()))
2000번의 학습을 진행하고 100번째 학습마다 손실의 크기를 출력하도록 설정하였습니다.
0 | 6.92e+04
100 | 0.3907
200 | 0.3849
300 | 0.3793
...
1700 | 0.3235
1800 | 0.3207
1900 | 0.318
2000 | 0.3154
W.numpy()
array([[1.0143043],
[0.8079646],
[0.1996759]], dtype=float32)
b.numpy()
array([-0.46934485], dtype=float32)
손실이 점점 작아지는 방향으로 학습이 진행되는 것을 알 수 있습니다. 그리고 학습이 종료된 후의 W, b값도 출력해주었습니다.
참고자료
텐서플로우로 시작하는 딥러닝 기초 강의 - edwith
'Artificial Intelligence > Tensorflow' 카테고리의 다른 글
[Tensorflow] 전이학습을 이용한 음식 이미지 분류 모델(val_acc: 78.33) (0) | 2021.06.08 |
---|---|
[Tensorflow] shutil을 사용한 데이터 분류 (0) | 2021.05.19 |
[Tensorflow] 경사 하강법이 뭐야? (0) | 2021.05.07 |
[Tensorflow] 선형회귀분석이 뭐야? (0) | 2021.04.28 |
[Tensorflow] 김치 이미지 분류기 - 모델학습 (0) | 2021.04.23 |