XOR 학습
단층 퍼셉트론으로는 비선형이 학습이 안된다.
따라서 멀티 퍼셉트론을 사용. 입력 레이어를 제외하고 Two-Layer 구성.
Sigmoid를 사용. 0/1 binary구별로 함.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue May 23 14:10:10 2017
@author: crazyj
"""
import numpy as np
import os
# xor simple network.
# X(2) - 2 - Y(1)
# sigmoid activation function use.
# manual gradient
#
# if fail?, try again!
# local minima problem exists...
# make deep and wide network.
#
X = np.array( [[0,0], [0,1], [1,0], [1,1]])
T = np.array( [[0], [1], [1], [0]] )
np.random.seed(int(os.times()[4]))
W1 = np.random.randn(2,2)
b1 = np.zeros([2])
W2 = np.random.randn(2,1)
b2 = np.zeros([1])
def Sigmoid(X):
return 1/(1+np.exp(-X))
def Predict(X, W1, b1, W2, b2):
Z1 = np.dot(X, W1)+b1
A1 = Sigmoid(Z1)
Z2 = np.dot(A1, W2)+b2
A2 = Sigmoid(Z2)
Y = A2
return Y
def Cost(X, W1, b1, W2, b2, T):
epsil = 1e-5
Z1 = np.dot(X, W1)+b1
A1 = Sigmoid(Z1)
Z2 = np.dot(A1, W2)+b2
A2 = Sigmoid(Z2)
Y = A2
return np.mean(-T*np.log(Y+epsil)-(1-T)*np.log(1-Y+epsil))
def Gradient(learning_rate, X, W1, b1, W2, b2, T):
Z1 = np.dot(X, W1)+b1
A1 = Sigmoid(Z1)
Z2 = np.dot(A1, W2)+b2
A2 = Sigmoid(Z2)
deltaY = A2-T
deltaA1 = np.dot(deltaY, W2.T) * (A1*(1-A1))
m = len(X)
gradW2 = np.dot(A1.T, deltaY)
gradW1 = np.dot(X.T, deltaA1)
W2 = W2-(learning_rate/m)*gradW2
b2 = b2-(learning_rate/m)*np.sum(deltaY)
W1 = W1-(learning_rate/m)*gradW1
b1 = b1-(learning_rate/m)*np.sum(deltaA1)
return (W1, b1, W2, b2)
for i in range(3000):
J= Cost(X,W1,b1,W2,b2,T)
W1,b1,W2,b2 = Gradient(1.0, X, W1, b1, W2, b2, T)
print ("Cost=",J)
Y = Predict(X, W1, b1, W2, b2)
print("predict=", Y)
결과
Cost= 0.351125685078
predict= [[ 0.50057071]
[ 0.49643107]
[ 0.99648031]
[ 0.00640712]]
실패?
다시 실행을 반복하다 보니 성공할때도 있다??? local minima 문제가 있음.
이를 해결하기 위해서는 여러번 시도해서 코스트가 낮아질 때까지 처음부터 반복(initialize 가 중요).하던가 network을 deep & wide하게 설계한다.
Cost= 0.00403719259697
predict= [[ 0.00475473]
[ 0.99634993]
[ 0.99634975]
[ 0.00409427]]
이건 성공 결과.
'AI(DeepLearning)' 카테고리의 다른 글
[tf] unknown math polynomial function modeling (0) | 2017.06.01 |
---|---|
[tf] XOR tensorflow로 학습구현 (0) | 2017.05.23 |
[R] multinomial classification. 다중분류 (0) | 2017.05.19 |
[R] binary classification (0) | 2017.05.19 |
[R] linear regression Normal Equation (0) | 2017.05.19 |