{
"cells": [
{
"cell_type": "markdown",
"id": "45278a2a",
"metadata": {},
"source": [
"(chapter8_part2)=\n",
"\n",
"\n",
"## Decision Trees in Classification\n",
"\n",
"- This is a supplement material for the [Machine Learning Simplified](https://themlsbook.com) book. It sheds light on Python implementations of the topics discussed while all detailed explanations can be found in the book. \n",
"- I also assume you know Python syntax and how it works. If you don't, I highly recommend you to take a break and get introduced to the language before going forward with my code. \n",
"- This material can be downloaded as a Jupyter notebook (Download button in the upper-right corner -> `.ipynb`) to reproduce the code and play around with it. \n",
"\n",
"\n",
"This notebook is a supplement for *Chapter 8. Decision Trees* of **Machine Learning For Everyone** book.\n",
"\n",
"## 1. Required Libraries, Data & Variables\n",
"\n",
"Let's import the data and have a look at it:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "e3bf2e2d",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" x1 \n",
" x2 \n",
" Color \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 0.25 \n",
" 1.41 \n",
" blue \n",
" \n",
" \n",
" 1 \n",
" 0.60 \n",
" 0.39 \n",
" blue \n",
" \n",
" \n",
" 2 \n",
" 0.71 \n",
" 1.29 \n",
" blue \n",
" \n",
" \n",
" 3 \n",
" 1.20 \n",
" 2.30 \n",
" blue \n",
" \n",
" \n",
" 4 \n",
" 1.75 \n",
" 0.59 \n",
" blue \n",
" \n",
" \n",
" 5 \n",
" 2.26 \n",
" 1.70 \n",
" green \n",
" \n",
" \n",
" 6 \n",
" 2.50 \n",
" 1.35 \n",
" green \n",
" \n",
" \n",
" 7 \n",
" 2.50 \n",
" 2.90 \n",
" green \n",
" \n",
" \n",
" 8 \n",
" 2.88 \n",
" 0.61 \n",
" green \n",
" \n",
" \n",
" 9 \n",
" 2.91 \n",
" 2.00 \n",
" green \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" x1 x2 Color\n",
"0 0.25 1.41 blue\n",
"1 0.60 0.39 blue\n",
"2 0.71 1.29 blue\n",
"3 1.20 2.30 blue\n",
"4 1.75 0.59 blue\n",
"5 2.26 1.70 green\n",
"6 2.50 1.35 green\n",
"7 2.50 2.90 green\n",
"8 2.88 0.61 green\n",
"9 2.91 2.00 green"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"%config InlineBackend.figure_format = 'retina' #to make sharper and prettier plots\n",
"\n",
"# Creating the DataFrame based on the provided data\n",
"data = {\n",
" 'x1': [0.25, 0.60, 0.71, 1.20, 1.75, 2.26, 2.50, 2.50, 2.88, 2.91],\n",
" 'x2': [1.41, 0.39, 1.29, 2.30, 0.59, 1.70, 1.35, 2.90, 0.61, 2.00],\n",
" 'Color': ['blue', 'blue', 'blue', 'blue', 'blue', 'green', 'green', 'green', 'green', 'green']\n",
"}\n",
"\n",
"# Convert the dictionary to a pandas DataFrame\n",
"df = pd.DataFrame(data)\n",
"\n",
"# Display the DataFrame\n",
"df"
]
},
{
"cell_type": "markdown",
"id": "2d0d5d26",
"metadata": {},
"source": [
"## 2. Visualizing Dataframe"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "018e4001",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"image/png": {
"height": 454,
"width": 571
}
},
"output_type": "display_data"
}
],
"source": [
"# Plotting\n",
"fig, ax = plt.subplots()\n",
"colors = {'blue': 'blue', 'green': 'green'}\n",
"\n",
"# Group by color and then plot each group\n",
"for key, group in df.groupby('Color'):\n",
" group.plot(ax=ax, kind='scatter', x='x1', y='x2', label=key, color=colors[key])\n",
"\n",
"# Setting plot labels and title\n",
"ax.set_xlabel('x1')\n",
"ax.set_ylabel('x2')\n",
"ax.set_title('Scatter Plot of Colors')\n",
"\n",
"# Display the legend\n",
"ax.legend(title='Point Color')\n",
"\n",
"# Show the plot\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "b107274b",
"metadata": {},
"source": [
"## 3. Preprocessing Dataframe"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a013cdab",
"metadata": {},
"outputs": [],
"source": [
"df['Color'] = df['Color'].map({'blue': 0, 'green': 1})"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2863845a",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" x1 \n",
" x2 \n",
" Color \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 0.25 \n",
" 1.41 \n",
" 0 \n",
" \n",
" \n",
" 1 \n",
" 0.60 \n",
" 0.39 \n",
" 0 \n",
" \n",
" \n",
" 2 \n",
" 0.71 \n",
" 1.29 \n",
" 0 \n",
" \n",
" \n",
" 3 \n",
" 1.20 \n",
" 2.30 \n",
" 0 \n",
" \n",
" \n",
" 4 \n",
" 1.75 \n",
" 0.59 \n",
" 0 \n",
" \n",
" \n",
" 5 \n",
" 2.26 \n",
" 1.70 \n",
" 1 \n",
" \n",
" \n",
" 6 \n",
" 2.50 \n",
" 1.35 \n",
" 1 \n",
" \n",
" \n",
" 7 \n",
" 2.50 \n",
" 2.90 \n",
" 1 \n",
" \n",
" \n",
" 8 \n",
" 2.88 \n",
" 0.61 \n",
" 1 \n",
" \n",
" \n",
" 9 \n",
" 2.91 \n",
" 2.00 \n",
" 1 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" x1 x2 Color\n",
"0 0.25 1.41 0\n",
"1 0.60 0.39 0\n",
"2 0.71 1.29 0\n",
"3 1.20 2.30 0\n",
"4 1.75 0.59 0\n",
"5 2.26 1.70 1\n",
"6 2.50 1.35 1\n",
"7 2.50 2.90 1\n",
"8 2.88 0.61 1\n",
"9 2.91 2.00 1"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df"
]
},
{
"cell_type": "markdown",
"id": "3dd74f94",
"metadata": {},
"source": [
"## 4. Training a Decision Tree with Gini"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "4111c236",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.tree import DecisionTreeClassifier\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import accuracy_score"
]
},
{
"cell_type": "markdown",
"id": "2bdfa1c7",
"metadata": {},
"source": [
"#### 4.1. Splitting into X and y"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "0964d844",
"metadata": {},
"outputs": [],
"source": [
"X = df[['x1', 'x2']]\n",
"y = df['Color']\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)"
]
},
{
"cell_type": "markdown",
"id": "b8f48f53",
"metadata": {},
"source": [
"#### 4.2. Building the Decision Tree Classifier"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a36a8169",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"DecisionTreeClassifier(random_state=42) In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
],
"text/plain": [
"DecisionTreeClassifier(random_state=42)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tree_classifier = DecisionTreeClassifier(criterion='gini', random_state=42)\n",
"tree_classifier.fit(X_train, y_train)"
]
},
{
"cell_type": "markdown",
"id": "5fe81c9b",
"metadata": {},
"source": [
"#### 4.3. Predict and Evaluate the Model"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "32549979",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy of the decision tree model: 1.0\n"
]
}
],
"source": [
"y_pred = tree_classifier.predict(X_test)\n",
"accuracy = accuracy_score(y_test, y_pred)\n",
"print(\"Accuracy of the decision tree model:\", accuracy)"
]
},
{
"cell_type": "markdown",
"id": "46c97eed",
"metadata": {},
"source": [
"#### 4.4. Visualize the Tree (optional)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "e8a7ce47",
"metadata": {},
"outputs": [],
"source": [
"from sklearn import tree\n",
"import graphviz\n",
"\n",
"dot_data = tree.export_graphviz(tree_classifier, out_file=None, \n",
" feature_names=['x1', 'x2'], \n",
" class_names=['blue', 'green'],\n",
" filled=True, rounded=True, \n",
" special_characters=True) \n",
"graph = graphviz.Source(dot_data) "
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "928cbc1c",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"Tree \n",
" \n",
"\n",
"\n",
"0 \n",
" \n",
"x1 ≤ 2.005 \n",
"gini = 0.5 \n",
"samples = 8 \n",
"value = [4, 4] \n",
"class = blue \n",
" \n",
"\n",
"\n",
"1 \n",
" \n",
"gini = 0.0 \n",
"samples = 4 \n",
"value = [4, 0] \n",
"class = blue \n",
" \n",
"\n",
"\n",
"0->1 \n",
" \n",
" \n",
"True \n",
" \n",
"\n",
"\n",
"2 \n",
" \n",
"gini = 0.0 \n",
"samples = 4 \n",
"value = [0, 4] \n",
"class = green \n",
" \n",
"\n",
"\n",
"0->2 \n",
" \n",
" \n",
"False \n",
" \n",
" \n",
" \n"
],
"text/plain": [
""
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Visualize the decision tree\n",
"graph"
]
},
{
"cell_type": "markdown",
"id": "5246d56a",
"metadata": {},
"source": [
"## 5. Training a Decision Tree with Entropy\n",
"\n",
"#### 4.1. Splitting into X and y"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "c33aa131",
"metadata": {},
"outputs": [],
"source": [
"X = df[['x1', 'x2']]\n",
"y = df['Color']\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)"
]
},
{
"cell_type": "markdown",
"id": "4f28e878",
"metadata": {},
"source": [
"#### 4.2. Building the Decision Tree Classifier"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "7e70ffdf",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"DecisionTreeClassifier(criterion='entropy', random_state=42) In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
],
"text/plain": [
"DecisionTreeClassifier(criterion='entropy', random_state=42)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tree_classifier = DecisionTreeClassifier(criterion='entropy', random_state=42)\n",
"tree_classifier.fit(X_train, y_train)"
]
},
{
"cell_type": "markdown",
"id": "ffbb9ce4",
"metadata": {},
"source": [
"#### 4.3. Predict and Evaluate the Model"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "e4b8c2f8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy of the decision tree model: 1.0\n"
]
}
],
"source": [
"y_pred = tree_classifier.predict(X_test)\n",
"accuracy = accuracy_score(y_test, y_pred)\n",
"print(\"Accuracy of the decision tree model:\", accuracy)"
]
},
{
"cell_type": "markdown",
"id": "0dfbb74c",
"metadata": {},
"source": [
"#### 4.4. Visualize the Tree (optional)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "b6781767",
"metadata": {},
"outputs": [],
"source": [
"from sklearn import tree\n",
"import graphviz\n",
"\n",
"dot_data = tree.export_graphviz(tree_classifier, out_file=None, \n",
" feature_names=['x1', 'x2'], \n",
" class_names=['blue', 'green'],\n",
" filled=True, rounded=True, \n",
" special_characters=True) \n",
"graph = graphviz.Source(dot_data) "
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "cf58ec4e",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"Tree \n",
" \n",
"\n",
"\n",
"0 \n",
" \n",
"x1 ≤ 2.005 \n",
"entropy = 1.0 \n",
"samples = 8 \n",
"value = [4, 4] \n",
"class = blue \n",
" \n",
"\n",
"\n",
"1 \n",
" \n",
"entropy = 0.0 \n",
"samples = 4 \n",
"value = [4, 0] \n",
"class = blue \n",
" \n",
"\n",
"\n",
"0->1 \n",
" \n",
" \n",
"True \n",
" \n",
"\n",
"\n",
"2 \n",
" \n",
"entropy = 0.0 \n",
"samples = 4 \n",
"value = [0, 4] \n",
"class = green \n",
" \n",
"\n",
"\n",
"0->2 \n",
" \n",
" \n",
"False \n",
" \n",
" \n",
" \n"
],
"text/plain": [
""
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Visualize the decision tree\n",
"graph"
]
}
],
"metadata": {
"jupytext": {
"formats": "md:myst",
"text_representation": {
"extension": ".md",
"format_name": "myst"
}
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.7"
},
"source_map": [
11,
30,
47,
52,
71,
76,
81,
83,
88,
92,
97,
101,
106,
109,
114,
118,
123,
136,
139,
146,
150,
155,
158,
163,
167,
172,
185
]
},
"nbformat": 4,
"nbformat_minor": 5
}