How to display the graphical decision tree for this scikit-learn decision tree script?
from sklearn import tree
clf = tree.DecisionTreeClassifier()
# [height, weight, shoe_size]
X = [[181, 80, 44], [177, 70, 43], [160, 60, 38], [154, 54, 37], [166, 65, 40],
[190, 90, 47], [175, 64, 39],
[177, 70, 40], [159, 55, 37], [171, 75, 42], [181, 85, 43]]
Y = ['male', 'male', 'female', 'female', 'male', 'male', 'female', 'female',
'female', 'male', 'male']
clf = clf.fit(X, Y)
prediction = clf.predict([[160, 60, 22]])
print(prediction)
The above script works fine but does not display the graphical tree. How to do that?
To display graphical trees, we have to import graphviz and render an image of the tree we are trying to get. Below is an example code which will show the tree
from sklearn import tree
clf = tree.DecisionTreeClassifier()
# [height, weight, shoe_size]
X = [[181, 80, 44], [177, 70, 43], [160, 60, 38], [154, 54, 37], [166, 65, 40],
[190, 90, 47], [175, 64, 39],
[177, 70, 40], [159, 55, 37], [171, 75, 42], [181, 85, 43]]
Y = ['male', 'male', 'female', 'female', 'male', 'male', 'female', 'female',
'female', 'male', 'male']
clf = clf.fit(X, Y)
prediction = clf.predict([[160, 60, 22]])
print(prediction)
import graphviz
dot_data = tree.export_graphviz(clf, out_file=None)
graph = graphviz.Source(dot_data)
graph.render("gender")
The last line of the code will generate a pdf which will display the decision tree