K-최근접 이웃을 사용하여 2개의 종류를 분류하는 머신러닝 모델
- 머신러닝에서 여러 개의 종류(class) 중 하나를 구별해 내는 문제를 분류라고 함
- 2개의 클래스 중 하나를 고르는 문제를 '이진분류'라고 함
도미 데이터
bream_length = [25.4, 26.3, 26.5, 29.0, 29.0, 29.7, 29.7, 30.0, 30.0, 30.7, 31.0, 31.0, 31.5, 32.0, 32.0, 32.0, 33.0, 33.0, 33.5, 33.5, 34.0, 34.0, 34.5, 35.0, 35.0, 35.0, 35.0, 36.0, 36.0, 37.0, 38.5, 38.5, 39.5, 41.0, 41.0]
bream_weight = [242.0, 290.0, 340.0, 363.0, 430.0, 450.0, 500.0, 390.0, 450.0, 500.0, 475.0, 500.0, 500.0, 340.0, 600.0, 600.0, 700.0, 700.0, 610.0, 650.0, 575.0, 685.0, 620.0, 680.0, 700.0, 725.0, 720.0, 714.0, 850.0, 1000.0, 920.0, 955.0, 925.0, 975.0, 950.0]
import matplotlib.pyplot as plt
plt.scatter(bream_length, bream_weight)
plt.xlabel('length')
plt.ylabel('weight')
plt.show()
- 생선의 길이가 길면 무게가 많이 나감
- 산점도 그래프가 일직선에 가까운 형태 => 선형(linear)적이라고 함
빙어 데이터
smelt_length = [9.8, 10.5, 10.6, 11.0, 11.2, 11.3, 11.8, 11.8, 12.0, 12.2, 12.4, 13.0, 14.3, 15.0]
smelt_weight = [6.7, 7.5, 7.0, 9.7, 9.8, 8.7, 10.0, 9.9, 9.8, 12.2, 13.4, 12.2, 19.7, 19.9]
도미와 빙어 데이터 산점도
plt.scatter(bream_length, bream_weight)
plt.scatter(smelt_length, smelt_weight)
plt.xlabel('length')
plt.ylabel('weight')
plt.show()
- 빙어의 산점도도 선형적이기는 하지만 도미에 비해 무게가 길이에 영향을 덜 받음
-> 두 데이터를 스스로 구분하기 위한 머신러닝 프로그램 만들기
K-최근접 이웃(K-Nearest Neighbors)
length = bream_length + smelt_length
weight = bream_weight + smelt_weight
- 사이킷런이라는 머신러닝 패키지를 사용할 예정 -> 이 패키지 사용하려면 각 특성의 리스트를 세로방향으로 늘어뜨린 2차원 리스트 만들어야 함
fish_data = [[l, w] for l, w in zip(length, weight)]
- zip() 함수는 나열된 리스트에서 원소를 하나씩 꺼내주는 일 -> for문으로 반복
-> 생선 49마리의 길이와 무게 모두 준비
=> 정답 데이터 필요: 첫 번째 생선은 도미, 두 번째 생선도 도미라는 식으로 각각 어떤 생선인지 답을 만듦 why?
: 우리는 머신러닝 알고리즘이 생선의 길이와 무게를 보고 도미와 빙어를 구분하는 규칙을 찾기를 원함 -> 그러기 위해서는 어떤 생선이 도미고 빙어인지 알려주어야 함 like 스무고개 할 때 고개마다 정답 알려주듯(컴퓨터라 문자 이해 x, 0과 1로) -> 도미와 빙어 순서대로 나열해서 정답 리스트는 1이 35번, 0이 14번 등장
fish_target = [1]*35 + [0]*14
print(fish_target)
>> [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- 머신러닝에서는 2개를 구분하는 경우 찾으려는 대상을 1로 놓고 그 외에는 0으로 놓는다
클래스 임포트
from sklearn.neighbors import KNeighborsClassifier
kn = KNeighborsClassifier()
- 이 객체에 fish_target과 fish_data 전달해서 도미 찾기 위한 기준 학습 => 훈련: 사이킷런에서는 fit() 메서드가 이런 역할
kn.fit(fish_data, fish_target)
-> 훈련시키는 것
kn.score(fish_data, fish_target)
>> 1.0
-> 정확도가 1.0
- 얼마나 잘 훈련되었는지 평가
- 사이킷런에서 모델 평가하는 메서드는 score() -> 0~1 값 반환(1은 모든 데이터를 정확히 맞혔다는 것을 의미)
K-최근접 이웃 알고리즘
- 어떤 데이터에 대한 답을 구할 때 주위의 다른 데이터를 보고 다수를 차지하는 것을 정답으로 사용
- 저 초록 삼각형은 도미일까? 빙어일까?
kn.predict([[30, 600]])
>> array([1])
- 앞서 우리는 도미를 1, 빙어를 0으로 가정
- 삼각형은 도미
- k-최근접 이웃 알고리즘을 위해 준비해야 할 일은 데이터를 모두 가지고 있는 게 전부
- 새로운 데이터에 대해 예측할 때는 가장 가까운 직선거리에 어떤 데이터가 있는지 살피기만 하면 됨
- but, k-최근접 이웃 알고리즘의 이런 특징 때문에 데이터가 아주 많은 경우 사용하기 어려움 -> 메모리가 많이 필요 + 직선거리 계산하는데 많은 시간 필요
- k-최근접 이웃 알고리즘은 무언가 훈련되는 것은 없음 -> fit()메서드에 전달한 데이터를 모두 저장하고 있다가 새로운 데이터가 등장하면 가장 가까운 데이터를 참고하여 도미인지 빙어인지 구분 -> 그럼 몇개의 데이터를 참고할까? => 기본값은 5
kn49 = KNeighborsClassifier(n_neighbors=49)
kn49.fit(fish_data, fish_target)
kn49.score(fish_data, fish_target)
>> 0.7142857142857143
- 참고 데이터를 49개로 한 kn49 모델
- 그런데 49개 중에 도미가 35개로 다수를 차지하므로 어떤 데이터를 넣어도 무조건 도미로 예측
- 특성: 데이터를 표현하는 하나의 성질
- fit(): 훈련
- 정확도 = (정확하게 맞힌 개수) / (전체 데이터 수)
'STUDY > ML' 카테고리의 다른 글
[딥러닝] 심층 신경망 (0) | 2023.04.12 |
---|---|
[ML] fit_transform()과 transform() (0) | 2023.03.18 |
[ML] Binary Encoding과 One Hot Encoding (0) | 2023.03.18 |