scikit-learn 을 통한 머신러닝 - 데이터셋 로딩, 학습, 그리고 예측
scikit-learn 을 통한 간단한 머신러닝에 대해 알아본다.
1. 예제 데이터셋 로딩
다음과 같이 iris
와 digits
데이터셋을 로딩한다.
1 2 3 | from sklearn import datasets iris = datasets.load_iris() digits = datasets.load_digits() | cs |
데이터셋은 데이터와 데이터에 대한 메타 데이터를 가지고 있는 일종의 딕셔너리같은 오브젝트이다. 데이터는 .data 로 저장되고, n_samples, n_features
배열을 가진다. 지도학습의 경우, 하나 또는 그 이상의 대응하는 변수가 .target
으로 저장된다.
,
digits.data를 통해
숫자 샘플을 분류하는데 사용할 수 있는 속성에 접근할 수 있다.1 2 3 4 5 6 7 8 9 | print(digits.data) [[ 0. 0. 5. ..., 0. 0. 0.] [ 0. 0. 0. ..., 10. 0. 0.] [ 0. 0. 0. ..., 16. 9. 0.] ..., [ 0. 0. 1. ..., 6. 0. 0.] [ 0. 0. 2. ..., 12. 0. 0.] [ 0. 0. 10. ..., 12. 1. 0.]] |
그리고
digits.target
은 우리가 학습하고자하는 각각의 숫자 이미지에 대응하는 숫자로 정확한 예측 여부를 판단하는 기준이 된다.
1 2 3 | digits.target array([0, 1, 2, ..., 8, 9, 8]) | cs |
데이터 배열의 구조
데이터는 항상 (n_samples, n_features)
의 구조를 가진 2D 배열이다. 숫자의 경우, 원본 샘플은 (8, 8)
의 구조를 가졌으며 아래와 같이 접근가능하다.
1 2 3 4 5 6 7 8 9 10 | digits.images[0] array([[ 0., 0., 5., 13., 9., 1., 0., 0.], [ 0., 0., 13., 15., 10., 15., 5., 0.], [ 0., 3., 15., 2., 0., 11., 8., 0.], [ 0., 4., 12., 0., 0., 8., 8., 0.], [ 0., 5., 8., 0., 0., 9., 8., 0.], [ 0., 4., 11., 0., 1., 12., 7., 0.], [ 0., 2., 14., 5., 10., 12., 0., 0.], [ 0., 0., 6., 13., 10., 0., 0., 0.]]) | cs |
2. 학습과 예측
숫자 데이터셋의 경우, 목적은 주어진 이미지를 가지고 어떤 숫자를 나타내는지를 예측하는 것이다. 0에서 9까지의 10개의 클래스가 주어지는데, 추정기로 하여금 보지 않았던 샘플이 어디에 속하는지 그 카테고리를 예측할 수 있도록 핏팅해야 한다.
scikit-learn에서 분류 추정기는 fit(X, y)
와
predict(T)
를 구현하는 파이썬 오브젝트이다. 추정기의 예로 들자면 support vector classification 을 구현하는 sklearn.svm.SVC 와 같은 클래스를 들 수 있다. 추정기의 생성자는 모델의 파라메터를 인수로 가진다. 하지만 당분간은 해당 추정기가 블랙박스라고 간주한다.
1 2 | from sklearn import svm clf = svm.SVC(gamma=0.001, C=100.) | cs |
모델의 파라메터 선택하기 : 본 예제에서는 우리는 감마 값을 인위적으로 설정할 것이다. grid search 또는 cross validation 같은 툴을 통해서 자동적으로 파라메터에 적합한 값을 고르는 것이 가능하다.
추정기 인스턴스를 clf 로 호출한다. 해당 모델에 fitted 되어 있어야 하는데, 즉, 그 모델로부터 학습을 해야한다는 것이다. 트레이닝 세트를 fit 메서드로 넘김으로써 가능하다. 트레이닝 셋트에서 데이터 셋의 마지막 하나를 제외한 모든 이미지를 사용할 수 있게 한다. digits.data의 가장 마지막 요소만 뺀 나머지 요소로 구성된 새로운 배열을 만들어내는
파이썬 구문 [:-1]
을 통해 트레이닝 세트를 선택했다.
1 2 3 4 5 6 | clf.fit(digits.data[:-1], digits.target[:-1]) SVC(C=100.0, cache_size=200, class_weight=None, coef0=0.0, decision_function_shape=None, degree=3, gamma=0.001, kernel='rbf', max_iter=-1, probability=False, random_state=None, shrinking=True, tol=0.001, verbose=False) | cs |
이제 새로운 값을 예측할 수 있는데, 특히 분류기로 하여금 트레이닝에 사용하지 않게했던 숫자 데이터셋의 마지막 이미지의 숫자가 무엇인지 분류기에게 물어볼 수 있다.
1 2 3 | clf.predict(digits.data[-1:]) array([8]) | cs |
이에 해당하는 이미지는 다음과 같다. 보는바와 같이 이미지의 해상도가 너무 낮다. 앞으로 이를 개선해나가는 방법을 알아보자.
'프로그래밍 Programming' 카테고리의 다른 글
Jupyter Notebook 셀 크기 조절하기 How to increase/decrease the cell width of the jupyter notebook in browser (0) | 2016.11.03 |
---|---|
scikit-learn (1) - 기계학습이란? (0) | 2016.11.01 |
텐서보드 사용법 (0) | 2016.10.31 |
텐서플로우 코드 에러 - ResourceExhaustedError (0) | 2016.10.31 |
Deep MNIST for Experts - 전체 코드 (2) (0) | 2016.10.31 |