본문 바로가기

메이플의 개발 스토리

[머신러닝] 사이킷런 K-NN 알고리즘 본문

ML DL

[머신러닝] 사이킷런 K-NN 알고리즘

mapled 2021. 12. 19. 18:03

안녕하세요. 메이플입니다.

해당 포스팅은 혼자 공부하는 머신러닝+딥러닝 책을 교재로 스터디한 내용을 정리한 내용입니다.

보다 자세한 내용은 책을 참고해주시기 바랍니다.


아래 포스팅과 이어지는 글입니다. 

 - [머신러닝 스터디] 인공지능, 딥러닝, 머신러닝이란?

 - [머신러닝 스터디] 맷플롯립을 통해 데이터의 산점도 출력

 - [머신러닝 스터디] numpy 패키지


머신러닝 키워드

훈련 : 머신러닝 알고리즘이 데이터에서 규칙을 찾는 과정

* 모델(model) : 머신러닝 알고리즘을 구현한 프로그램이나 알고리즘을 구체화하여 표현한 것

정확도 : 정확한 답을 몇 개 맞혔는지 백분율로 나타낸 값. (정확히 맞힌 개수) / (전체 데이터 개수)


k-최근접 이웃(k-Nearest Neighbors) 알고리즘

k-최근접 이웃 알고리즘은 지도 학습의 한 종류로 거리기반 분류분석 모델입니다. 머신러닝에서 데이터를 가장 가까운 유사 속성에 따라 분류하고 라벨링한다고 보면 됩니다. 판별하고 싶은 데이터와 인접한 K 개수의 데이터를 찾아, 해당 데이터의 라벨이 다수인 범주로 데이터를 분류하는 방식입니다. 이런 특징의 알고리즘은 데이터가 아주 많은 경우 처리해야 할 계산도 많아지기 때문에 사용하기 어렵습니다. 이 알고리즘은 K-NN 알고리즘이라고도 불립니다.

 

k-최근접 이웃 알고리즘에 대한 자세한 내용은 아래 링크에 나와있습니다.

https://m.blog.naver.com/bestinall/221760380344

 

[머신러닝] K-NN 알고리즘 (K-최근접 이웃) 개념

K-NN 알고리즘 (K-최근접 이웃) 개념 편 될성부른 나무는 떡잎부터 알아본다 모든 기업들에게 'V...

blog.naver.com


사이킷런(scikit-learn) 패키지로 k-최근접 이웃 알고리즘

입력 데이터 변환

k-최근접 이웃 알고리즘을 구현하기 위해서 해당 알고리즘이 포함된 머신러닝 패키지인 사이킷런을 사용했습니다.

사이킷런은 2차원 리스트로 데이터를 입력받기 때문에 방어와 도미 데이터를 합쳐주고 2차원 배열로 바꿔줍니다.

fish_data는 방어와 도미가 합쳐진 데이터에서 하나씩 원소를 꺼낸(zip() 메서드) 다음 배열로 넣는 코드입니다.

fish_target는 해당 데이터가 방어인지 도미인지 알려주는 정답 리스트입니다.

* 리스트 컴프리헨션을 사용하면 for문, if문을 사용하는 것보다 코드 속도가 더 빨라집니다.

 

** 위의 코드를 아래처럼 numpy 패키지를 사용해서 표현할수도 있습니다.

import numpy as np

fish_data = np.column_stack((fish_length, fish_weight))
fish_target = np.concatenate((np.ones(35), np.zeros(14)))

k-최근접 이웃 모델을 생성해서 훈련

- KNeighborsClassifier() : k-최근접 이웃 모델을 만드는 사이킷런 클래스

- fit() : 사이킷런 모델을 훈련할 때 사용하는 메서드

사이언킷 패키지에서 k-최근접 이웃 알고리즘을 구현한 클래스를 임포트하고 객체를 만듭니다.

그리고 특성 데이터 fish_data와 정답 데이터 fish_target을 객체에 전달하여 도미를 찾기 위한 기준을 훈련(training)합니다.

정확도 측정

score() : 훈련된 사이킷런 모델의 성능을 측정하는 메서드

이제 이 모델이 잘 훈련되었는지 확인해봅니다. 1은 모든 데이터를 정확히 맞혔다는 것을 의미합니다.

새로운 데이터로 예측

-predict() : 사이킷런 모델을 훈련하고 예측할 때 사용하는 메서드

길이 30, 무게 600인 새로운 물고기가 있다고 가정하면 산점도 그래프로 확인했을 때, 해당 물고기는 파란색인 도미에 가까운 것을 볼 수 있습니다.

predict() 메서드를 사용하면 이 데이터의 정답을 예측합니다.

최근접 이웃 수 설정

- n_neighbors 매개변수 : 참고할 최근접 이웃 수를 설정할 수 있는 매개변수 (기본값은 5)

객체를 생성할 때 n_neighbors 매개변수를 통해 몇 개의 데이터를 참고할 것인지 설정할 수 있습니다.

가장 가까운 데이터를 49개를 사용하도록 설정하고 위와 똑같이 데이터를 넣고 잘 훈련이 됐는지 score() 메서드를 통해 확인해보면 이전과 다르게 0.71이란 값이 나옵니다.

이는 입력 데이터 49개 중에 도미가 35개로 다수를 차지하므로 어떤 데이터를 넣어도 도미로 예측되기 때문입니다.

(참고로 35/49 = 0.71)

지금처럼 데이터가 편향되어 있고 참고하는 최근접 이웃 수가 너무 많은 경우 예측이 빗나갈 경우가 커지기 때문에 참고하는 최근접 이웃의 수를 설정하는 것을 중요합니다.

 

적절한 n_neighbors 찾기

아래 for문을 통해서 n_neighbors가 18 이상인 경우부터 정확도가 1 아래로 떨어지는 것을 볼 수 있습니다.

Comments