반응형

linear regression을 학습하는데, Normal Equation 방정식으로 풀기.

+Computing parameters analytically

Gradient descent 말고 다른 방법은 많다. 그 중에 하나가  Normal Equation

Alpha (학습률 learning rate)가 필요없다. 반복(iteratoin)도 필요없다.

단,    구하는 비용이 크다.  N x N 매트릭스가 필요.

n이 크면 느림. 100정도? 10000을 초과하면 gradient descent 방식이 낫다.


example)
m=4 (training data count)
n=4 (attributes/features)

=house price=
house size, # rooms , # floor, age of home , price
 x1,            x2,           x3,         x4,                y
data1..
data2.
data3..
data4..
가장 좌측에 x0 컬럼을 추가하여 모두 1로 설정. (x0=1)  (for bias)

X=matrix [ x0, x1, ... x4] ; 4x5 dim ; m x (n+1)
Y=[y] ; 4x1 dim  ; m x 1

Y = X S
X^-1 Y = S
S = (X^-1) Y
그러나 X의 inverse 벡터(X^-1)를 구하려면... 정사각행렬이어야 한다.
그래서 정사각 행렬로 먼저 변환해줘야 한다.

Y = X S
X^t Y = X^t X S     ; 왼쪽에 X^t(X의 transpose)를 곱함

(X^t X )^-1 X^t Y = (X^t X )^-1 (X^t X )  S    ; 왼쪽에 (X^t X)^-1 를 곱한다.
(X^t X )^-1 X^t Y = S

따라서 아래와 같이된다.


Octave에서는 :   pinv(X' * X)* X' * Y

R에서는 : W = solve( t(X) %*% X ) %*% t(X) %*% Y


example을 만들어 보자


# deep learning test

# linear regression ; test.. y=2x+1 learning 

#


# training , X1=1 (bias)

X=matrix( c(1,1,1 , 1,2,3), nrow=3 ) 


# training result Y

Y=matrix( c(3,5,7) ) 


# searching parameter, w1=bias 

W=c(2,2)



# H(X,W)=w1+ w2*x2

H = function (X, W) {

  H = X %*% W

  return (H)

}


Cost =function (X, W, Y) {

  m = nrow(X)

  return (sum((H(X,W)-Y)^2) / m)

}


NormalEquation = function (X, W, Y) {

  # no need alpha

  # no iteration

  # S =  inv( t(X) X ) t(X) Y

  W = solve( t(X) %*% X ) %*% t(X) %*% Y

  return (W)

}


print( Cost(X, W, Y) )


#learning

W = NormalEquation(X, W, Y)

print(paste(" Cost=", Cost(X,W,Y), " W1(b)=", W[1,1], " W2=", W[2,1]) )


# predict

qx = c(7,8,9)

xmat = cbind( rep(1, length(qx)), qx)

qy = H( xmat, W )

print (qx)

print (qy)



[1] 1

[1] " Cost= 3.15544362088405e-30  W1(b)= 1  W2= 2"

[1] 7 8 9

     [,1]

[1,]   15

[2,]   17

[3,]   19



'AI(DeepLearning)' 카테고리의 다른 글

[tf] XOR manual solve  (0) 2017.05.23
[R] multinomial classification. 다중분류  (0) 2017.05.19
[R] binary classification  (0) 2017.05.19
[R] linear regression (multi variable) 더하기 학습  (0) 2017.05.11
[R] linear regression  (0) 2017.05.11

+ Recent posts