We will find a way. We always have.

고려대학교에서 인공지능과 금융공학을 연구하고 있는 어느 대학원생의 블로그입니다.

딥러닝(Deep Learning)

파이썬으로 퍼셉트론 구현하기

MinsukSung 2020. 4. 25. 15:51

대학원 과제에서 퍼셉트론(Perceptron)을 구현하라는 과제를 받았다. 사실 파이썬으로 퍼셉트론을 구현해둔 코드는 많지만 의외로 결정 경계(Decision Boundary)를 표현해주는 코드는 많이 없었다. 그래서 직접 코딩해서 만들어보았다.

 

핵심 코드는 다음과 같다. 모든 코드는 깃헙에 올려두었다.

class Perceptron():
    # 초기화
    def __init__(self,example,thresholds=0.0,eta=0.01,n_iter=10):
        self.thresholds = thresholds
        self.eta = eta
        self.n_iter = n_iter
        self.example = example
        self.gif_path = './img/{}/'.format(example)

    # 학습
    def fit(self,X,y):
        self.w_ = np.random.normal(size=1+X.shape[1])
        self.errors_ = []

        for _iter in tqdm(range(self.n_iter)):
            errors = 0
            for xi,target in zip(X,y):
                update = self.eta * (self.predict(xi)-target)
                self.w_[1:] -= update * xi
                self.w_[0] -= update
                errors += int(update!=0.0)
            self.errors_.append(errors)
            self.show_decision_boundary(X,y,_iter)
        return self

    # 추론
    def predict(self,X):
        return np.sum(X*self.w_[1:])+self.w_[0]

    # 
    def show_weight(self):
        print(self.w_)

    def show_decision_boundary(self,X,y,_iter):
        markers = {1:'o',-1:'x'}
        colors = ['b','r']
        for xs,ys in zip(X,y):
            if self.predict(xs) > self.thresholds:
                plt.plot(xs[0],xs[1],markers[ys],c=colors[0])
            else:
                plt.plot(xs[0],xs[1],markers[ys],c=colors[1])
        matplotlib.pyplot.xlim([-0.25,1.25])
        matplotlib.pyplot.ylim([-0.25,1.25])
        plt.grid(True)
        _x = np.arange(-10,10)
        _y = -(self.w_[0] / self.w_[2]) - (self.w_[1]/self.w_[2])*_x
        plt.plot(_x, _y)
        matplotlib.pyplot.title('[{}] Decision Boundary: {} Gate'.format(_iter,self.example))
        matplotlib.pyplot.savefig(self.gif_path+'{}.jpg'.format(str(_iter).zfill(4)))
        matplotlib.pyplot.pause(0.000000001)
        plt.lines.pop(-1)

    def save_gif_decision_boundary(self):
        images = []
        for filename in tqdm(sorted(glob.glob(self.gif_path+'/*.jpg'))):
            images.append(matplotlib.pyplot.imread(filename))
        imageio.mimsave('./{}.gif'.format(self.example), images)