티스토리 뷰

개요

머신러닝을 배울 때 가장 처음 나오는 Decision Tree에 대해 공부 해봅니다. Decision Tree는 '지도학습'(Supervised Learning)의 대표적인 알고리즘 입니다. 연속되지 않은 값 즉 이산값을 가지고 모델을 만들 때 사용 합니다. '지도학습'은 정답이 있는 학습 입니다. 처음에 잘 모를때는 '지도학습'부터 공부 하는게 입문용으로는 좋은 것 같습니다.

 

어디에 쓸까요?

모델이 여러가지가 있습니다 무엇을 하고 싶은 것일까요?

{"age":35, "gender":"male"} 이런 데이터가 주어졌을 때 이 사람이 사고 싶을 것 같은 물건을 추천해 주고 싶은 걸까요?

 

아니면 데이터가 5천건 정도 있다고 했을 때 이 데이터를 학습시켜서 {"age":35, "gender":"male"} 이런 데이터를 넣었을 때 어떤 분류로 분류를 해줄지 이런 것을 해주고 싶은 것일까요?

 

데이터를 갔다가 sklearn에 넣고 뭐가 나오는지를 지켜보는 것이 과연 의미가 있을까요?

 

이 포스트에서는 '꽃받침의 길이', '꽃받침의 넓이', '꽃잎의 길이', '꽃잎의 넓이'를 넣으면 setosa, versicolor, virginica인지 예측을 하는 모델을 만드는 것이 목표 입니다.

 

iris데이터 해설

유명한 iris데이터 입니다. sklearn패키지 안에 포함 되어 있습니다. 데이터를 불러와서 어떻게 생겼는지 한번 확인 해보겠습니다. 이 데이터는 '꽃받침의 길이', '꽃받침의 넓이', '꽃잎의 길이', '꽃잎의 넓이'를 가지고 어떤 종인지를 예측하는 모델을 위한 데이터 입니다.

 

그런데 target이 종(specise)이기 때문에 연속값이 아니고 이산적인 값 입니다. 어떤 종인지를 분류해주는것입니다.

from sklearn.datasets import load_iris
iris = load_iris()

for key in iris:
    print(key)

결과

data
target
frame
target_names
DESCR
feature_names
filename

 

결과 해석

맴버의 개수는 7개 입니다.

 

그 내용은 아래와 같습니다.

{'data': array([
	[5.1 3.5 1.4 0.2],[4.9 3.  1.4 0.2],[4.7 3.2 1.3 0.2],[4.6 3.1 1.5 0.2],[5.  3.6 1.4 0.2],
    [5.4 3.9 1.7 0.4],[4.6 3.4 1.4 0.3],[5.  3.4 1.5 0.2],[4.4 2.9 1.4 0.2],[4.9 3.1 1.5 0.1],
    [5.4 3.7 1.5 0.2],[4.8 3.4 1.6 0.2],[4.8 3.  1.4 0.1],[4.3 3.  1.1 0.1],[5.8 4.  1.2 0.2],
    [5.7 4.4 1.5 0.4],[5.4 3.9 1.3 0.4],[5.1 3.5 1.4 0.3],[5.7 3.8 1.7 0.3],[5.1 3.8 1.5 0.3],
    [5.4 3.4 1.7 0.2],[5.1 3.7 1.5 0.4],[4.6 3.6 1.  0.2],[5.1 3.3 1.7 0.5],[4.8 3.4 1.9 0.2],
    [5.  3.  1.6 0.2],[5.  3.4 1.6 0.4],[5.2 3.5 1.5 0.2],[5.2 3.4 1.4 0.2],[4.7 3.2 1.6 0.2],
    [4.8 3.1 1.6 0.2],[5.4 3.4 1.5 0.4],[5.2 4.1 1.5 0.1],[5.5 4.2 1.4 0.2],[4.9 3.1 1.5 0.2],
    [5.  3.2 1.2 0.2],[5.5 3.5 1.3 0.2],[4.9 3.6 1.4 0.1],[4.4 3.  1.3 0.2],[5.1 3.4 1.5 0.2],
    [5.  3.5 1.3 0.3],[4.5 2.3 1.3 0.3],[4.4 3.2 1.3 0.2],[5.  3.5 1.6 0.6],[5.1 3.8 1.9 0.4],
    [4.8 3.  1.4 0.3],[5.1 3.8 1.6 0.2],[4.6 3.2 1.4 0.2],[5.3 3.7 1.5 0.2],[5.  3.3 1.4 0.2],
    [7.  3.2 4.7 1.4],[6.4 3.2 4.5 1.5],[6.9 3.1 4.9 1.5],[5.5 2.3 4.  1.3],[6.5 2.8 4.6 1.5],
    [5.7 2.8 4.5 1.3],[6.3 3.3 4.7 1.6],[4.9 2.4 3.3 1. ],[6.6 2.9 4.6 1.3],[5.2 2.7 3.9 1.4],
    [5.  2.  3.5 1. ],[5.9 3.  4.2 1.5],[6.  2.2 4.  1. ],[6.1 2.9 4.7 1.4],[5.6 2.9 3.6 1.3],
    [6.7 3.1 4.4 1.4],[5.6 3.  4.5 1.5],[5.8 2.7 4.1 1. ],[6.2 2.2 4.5 1.5],[5.6 2.5 3.9 1.1],
    [5.9 3.2 4.8 1.8],[6.1 2.8 4.  1.3],[6.3 2.5 4.9 1.5],[6.1 2.8 4.7 1.2],[6.4 2.9 4.3 1.3],
    [6.6 3.  4.4 1.4],[6.8 2.8 4.8 1.4],[6.7 3.  5.  1.7],[6.  2.9 4.5 1.5],[5.7 2.6 3.5 1. ],
    [5.5 2.4 3.8 1.1],[5.5 2.4 3.7 1. ],[5.8 2.7 3.9 1.2],[6.  2.7 5.1 1.6],[5.4 3.  4.5 1.5],
    [6.  3.4 4.5 1.6],[6.7 3.1 4.7 1.5],[6.3 2.3 4.4 1.3],[5.6 3.  4.1 1.3],[5.5 2.5 4.  1.3],
    [5.5 2.6 4.4 1.2],[6.1 3.  4.6 1.4],[5.8 2.6 4.  1.2],[5.  2.3 3.3 1. ],[5.6 2.7 4.2 1.3],
    [5.7 3.  4.2 1.2],[5.7 2.9 4.2 1.3],[6.2 2.9 4.3 1.3],[5.1 2.5 3.  1.1],[5.7 2.8 4.1 1.3],
    [6.3 3.3 6.  2.5],[5.8 2.7 5.1 1.9],[7.1 3.  5.9 2.1],[6.3 2.9 5.6 1.8],[6.5 3.  5.8 2.2],
    [7.6 3.  6.6 2.1],[4.9 2.5 4.5 1.7],[7.3 2.9 6.3 1.8],[6.7 2.5 5.8 1.8],[7.2 3.6 6.1 2.5],
    [6.5 3.2 5.1 2. ],[6.4 2.7 5.3 1.9],[6.8 3.  5.5 2.1],[5.7 2.5 5.  2. ],[5.8 2.8 5.1 2.4],
    [6.4 3.2 5.3 2.3],[6.5 3.  5.5 1.8],[7.7 3.8 6.7 2.2],[7.7 2.6 6.9 2.3],[6.  2.2 5.  1.5],[6.9 3.2 5.7 2.3],[5.6 2.8 4.9 2. ],[7.7 2.8 6.7 2. ],[6.3 2.7 4.9 1.8],[6.7 3.3 5.7 2.1],[7.2 3.2 6.  1.8],[6.2 2.8 4.8 1.8],[6.1 3.  4.9 1.8],[6.4 2.8 5.6 2.1],[7.2 3.  5.8 1.6],[7.4 2.8 6.1 1.9],[7.9 3.8 6.4 2. ],[6.4 2.8 5.6 2.2],[6.3 2.8 5.1 1.5],[6.1 2.6 5.6 1.4],[7.7 3.  6.1 2.3],[6.3 3.4 5.6 2.4],[6.4 3.1 5.5 1.8],[6.  3.  4.8 1.8],[6.9 3.1 5.4 2.1],[6.7 3.1 5.6 2.4],[6.9 3.1 5.1 2.3],[5.8 2.7 5.1 1.9],[6.8 3.2 5.9 2.3],[6.7 3.3 5.7 2.5],[6.7 3.  5.2 2.3],[6.3 2.5 5.  1.9],[6.5 3.  5.2 2. ],[6.2 3.4 5.4 2.3],[5.9 3.  5.1 1.8]
    ]),
'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]),
'frame': None, 
'target_names': array(['setosa', 'versicolor', 'virginica'], dtype='<U10'), 
'DESCR': '''
Iris plants dataset
--------------------

**Data Set Characteristics:**

    :Number of Instances: 150 (50 in each of three classes)
    :Number of Attributes: 4 numeric, predictive attributes and the class
    :Attribute Information:
        - sepal length in cm
        - sepal width in cm
        - petal length in cm
        - petal width in cm
        - class:
                - Iris-Setosa
                - Iris-Versicolour
                - Iris-Virginica
                
    :Summary Statistics:

    ============== ==== ==== ======= ===== ====================
                    Min  Max   Mean    SD   Class Correlation
    ============== ==== ==== ======= ===== ====================
    sepal length:   4.3  7.9   5.84   0.83    0.7826
    sepal width:    2.0  4.4   3.05   0.43   -0.4194
    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)
    petal width:    0.1  2.5   1.20   0.76    0.9565  (high!)
    ============== ==== ==== ======= ===== ====================

    :Missing Attribute Values: None
    :Class Distribution: 33.3% for each of 3 classes.
    :Creator: R.A. Fisher
    :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)
    :Date: July, 1988

The famous Iris database, first used by Sir R.A. Fisher. The dataset is taken
from Fisher's paper. Note that it's the same as in R, but not as in the UCI
Machine Learning Repository, which has two wrong data points.

This is perhaps the best known database to be found in the
pattern recognition literature.  Fisher's paper is a classic in the field and
is referenced frequently to this day.  (See Duda & Hart, for example.)  The
data set contains 3 classes of 50 instances each, where each class refers to a
type of iris plant.  One class is linearly separable from the other 2; the
latter are NOT linearly separable from each other.

.. topic:: References

   - Fisher, R.A. "The use of multiple measurements in taxonomic problems"
     Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to
     Mathematical Statistics" (John Wiley, NY, 1950).
   - Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.
     (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.
   - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System
     Structure and Classification Rule for Recognition in Partially Exposed
     Environments".  IEEE Transactions on Pattern Analysis and Machine
     Intelligence, Vol. PAMI-2, No. 1, 67-71.
   - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE Transactions
     on Information Theory, May 1972, 431-433.
   - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II
     conceptual clustering system finds 3 classes in the data.
   - Many, many more ...
'''
'feature_names': ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'],
'filename': 'c:\\git\\python\\aip-triton\\venv\\lib\\site-packages\\sklearn\\datasets\\data\\iris.csv'}

여기에서 중요한 것은 data와 target입니다. data의 개수를 세어보면 150개이고 target의 개수도 150개 입니다.

 

data의 한개 element(ex [5.1 3.5 1.4 0.2])는 4개의 항목으로 구성 되어 있습니다. 그 각각이 의미하는 것은 'feature_names'에 있습니다. 내용은 다음과 같습니다.

['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']

sepal은 꽃받침 입니다. petal은 꽃잎입니다. 그러면 위 4개의 항목은 꽃받침의 길이, 꽃받침의 넓이, 꽃잎의 길이, 꽃잎의 넓이가 되겠습니다.

 

이 데이터는 target이 있으므로 Supervised 데이터 입니다. Supervised는 지도학습 입니다.

 

target_names

'target_names': array(['setosa', 'versicolor', 'virginica'], dtype='<U10'), 

target_names는 위와 같이 되어 있습니다. 그리고 target은 아래와 같이 되어 있습니다.

'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]),

target은 0, 1, 2 3가지의 값이고 target_names로 봤을 때 0번은 setosa 1번은 versicolor, 2번은 virginica 입니다.

 

collections.Counter를 이용해 개수를 세어보겠습니다.

from collections import Counter
print(iris_data.target)
cnt = Counter(iris_data.target)
print(cnt)

결과

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2]
Counter({0: 50, 1: 50, 2: 50})

 

0번, 1번, 2번 각각 50개씩 150개가 나옵니다.

 

train_test_split() 사용법

train_test_split()은 데이터를 나누어 주는 기능을 합니다.

random_state는 난수의 seed로 같은 값을 넘기면 같은 데이터가 뽑힙니다. random이라도 추출을 재현하기 위해 씁니다.

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

iris_data=load_iris()

X_train, X_test, y_train, y_test = train_test_split(iris_data.data, iris_data.target, test_size=0.3, random_state=24)

print('iris_data.data:',len(iris_data.data))
print('X_train:',len(X_train))

결과

iris_data.data: 150

X_train: 105

 

결과 해석

test_size를 0.3으로 했기 때문에 X_train의 개수가 105개가 뽑혔습니다. 0.2로 하면 120개가 뽑힙니다.

 

 

DecisionTree로 학습시켜 모델 만들고 예측하기

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

iris_data=load_iris()
dt = DecisionTreeClassifier()

dt.fit(iris_data.data, iris_data.target)
print(dt)

pred = dt.predict([[6.5, 3.,  5.2, 2. ]])
print(pred)

결과

DecisionTreeClassifier()
[2]

 

결과 해석

iris_data.data와 iris_data.target으로 학습을 시킵니다.

 

그리고 iris_data.data에 있던 데이터 중 하나인 [6.5, 3.,  5.2, 2. ] 를 넣고 예측을 하면 2가 나온 것을 볼 수 있습니다.

 

dt는 모델 입니다. 이 모델에 4가지 값 꽃받침의 길이, 꽃받침의 넓이, 꽃잎의 길이, 꽃잎의 넓이를 넣으면 setosa, versicolor, virginica인지 예측을 해줍니다.

 

end.

 

 

728x90
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2025/01   »
1 2 3 4
5 6 7 8 9 10 11
12 13 14 15 16 17 18
19 20 21 22 23 24 25
26 27 28 29 30 31
글 보관함