Author: Bijan Naimi
According to the CDC, approximately 697,000 people per year die of heart disease in the United States, alone. That puts heart disease as being the cause of of 1 out of 5 deaths in the US. The fact of the matter is, however, that a change in lifestyle can very often combat the risk someone is at for heart disease. Also, if someone already has heart disease that is undiagnosed, lack of treatment can be unnecesarily fatal. The problem at hand is just to identify individuals with heart disease using basic health data.
Heart related illnesses are often referred to as the "silent killer." It is very often the case where people who suffer from them do not realize their condition unil it is too late. This is the main motivation behind choosing this topic as my final project: can we effectively classify patients with heart disease based off of basic data pertaining to their health. In addition, it would be a bonus if I can identify a leading factor that would suggest someone is likely to have heart disease.
I found the following dataset to use for this project: https://archive.ics.uci.edu/ml/datasets/heart+disease
I also found that this exact data was uploaded to kaggle.com at the following link (the original uploader mislabeled the dataset metadata / title): https://www.kaggle.com/datasets/rashikrahmanpritom/heart-attack-analysis-prediction-dataset
Instead of wrangling the data I from the former link, I used the CSV from the second one and read it into pandas below.
I first read the CSV into pandas and took a look at the data in a more organized form that just a raw CSV.
import pandas as pd
# Read into pandas
df = pd.read_csv('./heart.csv')
df
age | sex | cp | trtbps | chol | fbs | restecg | thalachh | exng | oldpeak | slp | caa | thall | output | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 63 | 1 | 3 | 145 | 233 | 1 | 0 | 150 | 0 | 2.3 | 0 | 0 | 1 | 1 |
1 | 37 | 1 | 2 | 130 | 250 | 0 | 1 | 187 | 0 | 3.5 | 0 | 0 | 2 | 1 |
2 | 41 | 0 | 1 | 130 | 204 | 0 | 0 | 172 | 0 | 1.4 | 2 | 0 | 2 | 1 |
3 | 56 | 1 | 1 | 120 | 236 | 0 | 1 | 178 | 0 | 0.8 | 2 | 0 | 2 | 1 |
4 | 57 | 0 | 0 | 120 | 354 | 0 | 1 | 163 | 1 | 0.6 | 2 | 0 | 2 | 1 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
298 | 57 | 0 | 0 | 140 | 241 | 0 | 1 | 123 | 1 | 0.2 | 1 | 0 | 3 | 0 |
299 | 45 | 1 | 3 | 110 | 264 | 0 | 1 | 132 | 0 | 1.2 | 1 | 0 | 3 | 0 |
300 | 68 | 1 | 0 | 144 | 193 | 1 | 1 | 141 | 0 | 3.4 | 1 | 2 | 3 | 0 |
301 | 57 | 1 | 0 | 130 | 131 | 0 | 1 | 115 | 1 | 1.2 | 1 | 1 | 3 | 0 |
302 | 57 | 0 | 1 | 130 | 236 | 0 | 0 | 174 | 0 | 0.0 | 1 | 1 | 2 | 0 |
303 rows × 14 columns
Here is a breakdown of the significance of each column
age: Age of observation
sex: Sex of observation (0 or 1)
cp: Chest pain type (0, 1, 2, 3)
trestbps: Resting blood pressure (mm Hg)
chol: Cholestrol (mg/dl)
fbs: Whether or not the observation has a high fasting blood sugar (0 or 1)
restecg: Resting electrocardiographic results (0, 1, 2)
thalach: Maximum heartrate achieved
exng: Exercise induced angina (0 or 1)
oldpeak: Previous peak
slope: Slope (0, 1 , 2)
ca: Number of major blood vessels (0, 1 , 2, 3)
thal: thall rate (1, 2, 3)
output: Heart disease diagnosis (0, 1)
I noticed how some data was converted from quantitative to qualitative by the creator of the dataset, for exapmle the fasting blood sugar attribute is 1 if it is higher than 120 and 0 otherwise. I felt like it is important to make this distinction. I confirmed there are no missing values in the data as seen below.
df = df.copy().dropna()
df
age | sex | cp | trtbps | chol | fbs | restecg | thalachh | exng | oldpeak | slp | caa | thall | output | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 63 | 1 | 3 | 145 | 233 | 1 | 0 | 150 | 0 | 2.3 | 0 | 0 | 1 | 1 |
1 | 37 | 1 | 2 | 130 | 250 | 0 | 1 | 187 | 0 | 3.5 | 0 | 0 | 2 | 1 |
2 | 41 | 0 | 1 | 130 | 204 | 0 | 0 | 172 | 0 | 1.4 | 2 | 0 | 2 | 1 |
3 | 56 | 1 | 1 | 120 | 236 | 0 | 1 | 178 | 0 | 0.8 | 2 | 0 | 2 | 1 |
4 | 57 | 0 | 0 | 120 | 354 | 0 | 1 | 163 | 1 | 0.6 | 2 | 0 | 2 | 1 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
298 | 57 | 0 | 0 | 140 | 241 | 0 | 1 | 123 | 1 | 0.2 | 1 | 0 | 3 | 0 |
299 | 45 | 1 | 3 | 110 | 264 | 0 | 1 | 132 | 0 | 1.2 | 1 | 0 | 3 | 0 |
300 | 68 | 1 | 0 | 144 | 193 | 1 | 1 | 141 | 0 | 3.4 | 1 | 2 | 3 | 0 |
301 | 57 | 1 | 0 | 130 | 131 | 0 | 1 | 115 | 1 | 1.2 | 1 | 1 | 3 | 0 |
302 | 57 | 0 | 1 | 130 | 236 | 0 | 0 | 174 | 0 | 0.0 | 1 | 1 | 2 | 0 |
303 rows × 14 columns
To get a better idea of how everything is distributed, I made a histogram of each attribute using matplotlib.
import numpy as np
import matplotlib.pyplot as plt
# Setting a universal figure size
plt.rcParams['figure.figsize'] = [5.0, 5.0]
plt.rcParams['figure.dpi'] = 140
# For each column make one plot
for col in df.columns:
# Create histogram with column data
plt.hist(df[col])
plt.title('Histogram for Column "' + col + '"')
plt.xlabel("Value")
plt.ylabel("Frequency")
plt.show()
I did not see anything in the distributions out of the norm. I was only really able to see the frequency of some of the categorical data that I was not familiar with before, which will be useful in the future. To get an idea of how the attributes interact with one another, I decided to make a heatmap as the next step in my EDA. I used seaborn, a data visualization library, to do so.
import seaborn as sns
# Get a correlation matrix from the dataframe and use it in seaborn
corr_matrix = df.corr()
ax = plt.axes()
# Use correlation matrix in heatmap
sns.heatmap(corr_matrix, annot=False , linewidths=.5, ax=ax)
ax.set_title('Correlation Heatmap for Data Attributes')
Text(0.5, 1.0, 'Correlation Heatmap for Data Attributes')
I was pretty suprised to see how some of the variables correlated to the output variable and how some did not.
I expected correlation from the maximum heartrate achieved, which there was, but I thought it was really interesting how the type of chest pain was one of the highest correlated variables to heart disease.
However, I intuitively thought that things like age and blood pressure would play more of a factor in heart disease, but these were only weakly correlated.
Another big correlator to output was maximum heartrate achieved (thalachh). To dig deeper, I took a look at how maximum heartrate achieved linearly correlates to the other quantitative variables.
from sklearn.linear_model import LinearRegression
# maximum heartrate achieved to be x axis
x_bloodpressure = df['thalachh']
# List of each variable for y
y_list = ['age', 'chol', 'trtbps', 'oldpeak']
for curr in y_list:
# Get column for that variable
y = df[curr]
# Reshape maximum heartrate achieved
x = np.array(x_bloodpressure).reshape((-1, 1))
y = np.array(y)
# Make linear resgression model
reg = LinearRegression().fit(x, y)
y = reg.intercept_ + (reg.coef_[0] * x)
# Plot x and y and also draw linear regression line
plt.plot(x, y, '-r')
plt.plot(x_bloodpressure, df[curr], '.')
plt.xlabel("Maximum Heartrate Achieved")
plt.ylabel(curr)
plt.title("Data in Column " + curr + " Against maximum heartrate achieved")
plt.show()
Up until this point, I have not directly distinguish between positive and negative diagnoses when looking into the variables, so I decided to use violin plots to do so, in order to see if there is anything anomolous in their distributions when seperated in this way.
# To seperate based off of classification
data_for_target = {
0: [],
1: []
}
# EXCLUDE THE OUTPUT VARIABLE
cols = list(df.columns)[:-1]
for c in cols:
for index, row in df.iterrows():
data_for_target[row['output']].append(row[c])
plt.violinplot([data_for_target[0], data_for_target[1]], [0,1])
plt.title("Distributions of " + c + " for Each Heart Disease Diagnosis")
plt.ylabel("Values")
plt.xlabel("Heart Disease Diagnosis (0 = Negative, 1 = Positive)")
plt.show()
From these graphs we can see that the distributions of these variables are almost identical when seperating between positive and negative diagnoses.
When performing the linear regression on the quantitative variables in the data, I noticed a correlation between age and maximum heartrate achieved. My thinking is that if I could find some correlations between quantitative values in the data with a variable that is extremely correlative to heart disease diagnoses, this information could be used to help predict the classification of heart disease. To further investigate this, I performed a hypothesis test on the two variable's linear correlation using the statsmodels library. I will be using the least squares method of evaluation.
$$H_{0} : \beta_1 = 0$$$$H_{a} : \beta_1 \ne 0$$The null hypothesis is essentially that the linear model does not work and the alternate hypothesis is that the slope (B1) is non-zero.
import statsmodels.api as sm
# List of variables to run hypothesis test with heartrate
y_list = ['age', 'chol', 'trtbps', 'oldpeak']
for curr in y_list:
X = list(df['thalachh'])
X = sm.add_constant(X)
y = list(df[curr])
# Run least squares hypothesis test
model = sm.OLS(y, X)
results = model.fit()
# Print out data we want from test
print(results.summary())
OLS Regression Results ============================================================================== Dep. Variable: y R-squared: 0.159 Model: OLS Adj. R-squared: 0.156 Method: Least Squares F-statistic: 56.83 Date: Fri, 16 Dec 2022 Prob (F-statistic): 5.63e-13 Time: 23:26:43 Log-Likelihood: -1071.7 No. Observations: 303 AIC: 2147. Df Residuals: 301 BIC: 2155. Df Model: 1 Covariance Type: nonrobust ============================================================================== coef std err t P>|t| [0.025 0.975] ------------------------------------------------------------------------------ const 78.0132 3.173 24.585 0.000 71.769 84.258 x1 -0.1580 0.021 -7.539 0.000 -0.199 -0.117 ============================================================================== Omnibus: 6.417 Durbin-Watson: 2.157 Prob(Omnibus): 0.040 Jarque-Bera (JB): 4.018 Skew: -0.091 Prob(JB): 0.134 Kurtosis: 2.466 Cond. No. 1.00e+03 ============================================================================== Notes: [1] Standard Errors assume that the covariance matrix of the errors is correctly specified. [2] The condition number is large, 1e+03. This might indicate that there are strong multicollinearity or other numerical problems. OLS Regression Results ============================================================================== Dep. Variable: y R-squared: 0.000 Model: OLS Adj. R-squared: -0.003 Method: Least Squares F-statistic: 0.02974 Date: Fri, 16 Dec 2022 Prob (F-statistic): 0.863 Time: 23:26:43 Log-Likelihood: -1625.7 No. Observations: 303 AIC: 3255. Df Residuals: 301 BIC: 3263. Df Model: 1 Covariance Type: nonrobust ============================================================================== coef std err t P>|t| [0.025 0.975] ------------------------------------------------------------------------------ const 249.6299 19.744 12.644 0.000 210.777 288.483 x1 -0.0225 0.130 -0.172 0.863 -0.279 0.234 ============================================================================== Omnibus: 83.725 Durbin-Watson: 2.032 Prob(Omnibus): 0.000 Jarque-Bera (JB): 313.571 Skew: 1.139 Prob(JB): 8.11e-69 Kurtosis: 7.433 Cond. No. 1.00e+03 ============================================================================== Notes: [1] Standard Errors assume that the covariance matrix of the errors is correctly specified. [2] The condition number is large, 1e+03. This might indicate that there are strong multicollinearity or other numerical problems. OLS Regression Results ============================================================================== Dep. Variable: y R-squared: 0.002 Model: OLS Adj. R-squared: -0.001 Method: Least Squares F-statistic: 0.6578 Date: Fri, 16 Dec 2022 Prob (F-statistic): 0.418 Time: 23:26:43 Log-Likelihood: -1297.0 No. Observations: 303 AIC: 2598. Df Residuals: 301 BIC: 2605. Df Model: 1 Covariance Type: nonrobust ============================================================================== coef std err t P>|t| [0.025 0.975] ------------------------------------------------------------------------------ const 136.9745 6.674 20.524 0.000 123.841 150.108 x1 -0.0358 0.044 -0.811 0.418 -0.123 0.051 ============================================================================== Omnibus: 28.268 Durbin-Watson: 1.795 Prob(Omnibus): 0.000 Jarque-Bera (JB): 35.244 Skew: 0.703 Prob(JB): 2.22e-08 Kurtosis: 3.904 Cond. No. 1.00e+03 ============================================================================== Notes: [1] Standard Errors assume that the covariance matrix of the errors is correctly specified. [2] The condition number is large, 1e+03. This might indicate that there are strong multicollinearity or other numerical problems. OLS Regression Results ============================================================================== Dep. Variable: y R-squared: 0.118 Model: OLS Adj. R-squared: 0.116 Method: Least Squares F-statistic: 40.45 Date: Fri, 16 Dec 2022 Prob (F-statistic): 7.48e-10 Time: 23:26:43 Log-Likelihood: -455.59 No. Observations: 303 AIC: 915.2 Df Residuals: 301 BIC: 922.6 Df Model: 1 Covariance Type: nonrobust ============================================================================== coef std err t P>|t| [0.025 0.975] ------------------------------------------------------------------------------ const 3.6505 0.415 8.790 0.000 2.833 4.468 x1 -0.0174 0.003 -6.360 0.000 -0.023 -0.012 ============================================================================== Omnibus: 74.222 Durbin-Watson: 1.672 Prob(Omnibus): 0.000 Jarque-Bera (JB): 138.482 Skew: 1.309 Prob(JB): 8.49e-31 Kurtosis: 5.027 Cond. No. 1.00e+03 ============================================================================== Notes: [1] Standard Errors assume that the covariance matrix of the errors is correctly specified. [2] The condition number is large, 1e+03. This might indicate that there are strong multicollinearity or other numerical problems.
We can reject the null hypothesis for age and oldpeak with a significance of 0.05, since the P values are less than 0.05 for these variables. This means there is a proven correlation between these variables and maximum heartrate achieved.
Finally, I am going to see if I can produce a model to accurately classify an observation's heart disease diagnosis based on the features I have analyzed thus far. I will be using 33 percent of the data at hand as testing data and the rest to train the models.
As for the models, I will be using Logistic Regression, Decision Trees, and Random Forest. To evaluate them, we will use both accuracy and false negatives, since classifying someone as not having heart disease when they do is pretty severe.
In the following code, I seperate the predicting column from everything else and split the training and testing data using built in functions to sklearn.
X = []
y = []
cols = list(df.columns)[:-1] # Cols is all of the non-target data
for index, row in df.iterrows():
to_append = []
for c in cols:
to_append.append(row[c])
X.append(to_append) # All of the data in cols
y.append(row['output']) # Our target variable
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42) # Training data and testing data split
Now that I have training and testing data, I will train a logistic regression model using sklearn. Logistic regression is perfect for this task, since it takes in both categorical and continuous quantiative values to produce a binary result, which is exactly what we want.
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
clf = LogisticRegression() # Using logistic regression
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train) # Transform training data using scalar
X_test = scaler.transform(X_test)
clf.fit(X_train,y_train) # Fit the model using trainnig data
result = clf.predict(X_test) # Get results with test data
acc = 0
false_neg = 0
for i in range(len(result)):
if (y_test[i] == 1 and result[i] == 0) == False: # A false negative
false_neg += 1
if y_test[i] == result[i]:
acc += 1
print("Accuracy", acc/len(result))
print("(Not a) False Negative Rate", false_neg/len(result))
Accuracy 0.81 (Not a) False Negative Rate 0.9
I also tried using a decision tree. Decision trees are an example of a classifying algorithm that is more organic. We are essentially seperating data points into different groups repeatedly using the different features we are supplying the model.
from sklearn import tree
clf = tree.DecisionTreeClassifier() # Using decision trees
clf.fit(X_train,y_train)
result = clf.predict(X_test)
acc = 0
false_neg = 0
for i in range(len(result)):
if (y_test[i] == 1 and result[i] == 0) == False:
false_neg += 1
if y_test[i] == result[i]:
acc += 1
print("Accuracy", acc/len(result))
print("(Not a) False Negative Rate", false_neg/len(result))
tree.plot_tree(clf)
Accuracy 0.74 (Not a) False Negative Rate 0.83
[Text(0.5492021276595744, 0.95, 'X[11] <= -0.148\ngini = 0.499\nsamples = 203\nvalue = [96, 107]'), Text(0.3696808510638298, 0.85, 'X[12] <= 0.284\ngini = 0.387\nsamples = 122\nvalue = [32, 90]'), Text(0.2393617021276596, 0.75, 'X[9] <= 0.607\ngini = 0.191\nsamples = 84\nvalue = [9, 75]'), Text(0.1595744680851064, 0.65, 'X[3] <= 2.125\ngini = 0.101\nsamples = 75\nvalue = [4, 71]'), Text(0.10638297872340426, 0.55, 'X[0] <= 0.65\ngini = 0.079\nsamples = 73\nvalue = [3, 70]'), Text(0.06382978723404255, 0.45, 'X[3] <= -1.281\ngini = 0.032\nsamples = 61\nvalue = [1, 60]'), Text(0.0425531914893617, 0.35, 'X[1] <= -0.338\ngini = 0.375\nsamples = 4\nvalue = [1, 3]'), Text(0.02127659574468085, 0.25, 'gini = 0.0\nsamples = 3\nvalue = [0, 3]'), Text(0.06382978723404255, 0.25, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.0851063829787234, 0.35, 'gini = 0.0\nsamples = 57\nvalue = [0, 57]'), Text(0.14893617021276595, 0.45, 'X[9] <= -0.758\ngini = 0.278\nsamples = 12\nvalue = [2, 10]'), Text(0.1276595744680851, 0.35, 'X[3] <= 0.099\ngini = 0.48\nsamples = 5\nvalue = [2, 3]'), Text(0.10638297872340426, 0.25, 'X[0] <= 0.761\ngini = 0.444\nsamples = 3\nvalue = [2, 1]'), Text(0.0851063829787234, 0.15, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.1276595744680851, 0.15, 'X[0] <= 0.872\ngini = 0.5\nsamples = 2\nvalue = [1, 1]'), Text(0.10638297872340426, 0.05, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.14893617021276595, 0.05, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.14893617021276595, 0.25, 'gini = 0.0\nsamples = 2\nvalue = [0, 2]'), Text(0.1702127659574468, 0.35, 'gini = 0.0\nsamples = 7\nvalue = [0, 7]'), Text(0.2127659574468085, 0.55, 'X[0] <= 0.761\ngini = 0.5\nsamples = 2\nvalue = [1, 1]'), Text(0.19148936170212766, 0.45, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.23404255319148937, 0.45, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.3191489361702128, 0.65, 'X[0] <= 0.761\ngini = 0.494\nsamples = 9\nvalue = [5, 4]'), Text(0.2978723404255319, 0.55, 'X[9] <= 2.199\ngini = 0.278\nsamples = 6\nvalue = [5, 1]'), Text(0.2765957446808511, 0.45, 'gini = 0.0\nsamples = 5\nvalue = [5, 0]'), Text(0.3191489361702128, 0.45, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.3404255319148936, 0.55, 'gini = 0.0\nsamples = 3\nvalue = [0, 3]'), Text(0.5, 0.75, 'X[8] <= 0.361\ngini = 0.478\nsamples = 38\nvalue = [23, 15]'), Text(0.425531914893617, 0.65, 'X[0] <= -0.517\ngini = 0.472\nsamples = 21\nvalue = [8, 13]'), Text(0.3829787234042553, 0.55, 'X[5] <= 1.113\ngini = 0.375\nsamples = 8\nvalue = [6, 2]'), Text(0.3617021276595745, 0.45, 'gini = 0.0\nsamples = 6\nvalue = [6, 0]'), Text(0.40425531914893614, 0.45, 'gini = 0.0\nsamples = 2\nvalue = [0, 2]'), Text(0.46808510638297873, 0.55, 'X[7] <= 0.148\ngini = 0.26\nsamples = 13\nvalue = [2, 11]'), Text(0.44680851063829785, 0.45, 'X[3] <= -0.518\ngini = 0.48\nsamples = 5\nvalue = [2, 3]'), Text(0.425531914893617, 0.35, 'gini = 0.0\nsamples = 2\nvalue = [0, 2]'), Text(0.46808510638297873, 0.35, 'X[4] <= 0.703\ngini = 0.444\nsamples = 3\nvalue = [2, 1]'), Text(0.44680851063829785, 0.25, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.48936170212765956, 0.25, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.48936170212765956, 0.45, 'gini = 0.0\nsamples = 8\nvalue = [0, 8]'), Text(0.574468085106383, 0.65, 'X[9] <= -0.849\ngini = 0.208\nsamples = 17\nvalue = [15, 2]'), Text(0.5531914893617021, 0.55, 'X[3] <= -0.106\ngini = 0.444\nsamples = 3\nvalue = [1, 2]'), Text(0.5319148936170213, 0.45, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.574468085106383, 0.45, 'gini = 0.0\nsamples = 2\nvalue = [0, 2]'), Text(0.5957446808510638, 0.55, 'gini = 0.0\nsamples = 14\nvalue = [14, 0]'), Text(0.7287234042553191, 0.85, 'X[2] <= -0.49\ngini = 0.332\nsamples = 81\nvalue = [64, 17]'), Text(0.6382978723404256, 0.75, 'X[4] <= 0.958\ngini = 0.041\nsamples = 48\nvalue = [47, 1]'), Text(0.6170212765957447, 0.65, 'gini = 0.0\nsamples = 40\nvalue = [40, 0]'), Text(0.6595744680851063, 0.65, 'X[4] <= 1.003\ngini = 0.219\nsamples = 8\nvalue = [7, 1]'), Text(0.6382978723404256, 0.55, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.6808510638297872, 0.55, 'gini = 0.0\nsamples = 7\nvalue = [7, 0]'), Text(0.8191489361702128, 0.75, 'X[10] <= 0.156\ngini = 0.5\nsamples = 33\nvalue = [17, 16]'), Text(0.7446808510638298, 0.65, 'X[0] <= 0.761\ngini = 0.32\nsamples = 15\nvalue = [12, 3]'), Text(0.723404255319149, 0.55, 'gini = 0.0\nsamples = 7\nvalue = [7, 0]'), Text(0.7659574468085106, 0.55, 'X[7] <= -1.068\ngini = 0.469\nsamples = 8\nvalue = [5, 3]'), Text(0.7446808510638298, 0.45, 'gini = 0.0\nsamples = 4\nvalue = [4, 0]'), Text(0.7872340425531915, 0.45, 'X[7] <= 0.59\ngini = 0.375\nsamples = 4\nvalue = [1, 3]'), Text(0.7659574468085106, 0.35, 'gini = 0.0\nsamples = 3\nvalue = [0, 3]'), Text(0.8085106382978723, 0.35, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.8936170212765957, 0.65, 'X[11] <= 0.942\ngini = 0.401\nsamples = 18\nvalue = [5, 13]'), Text(0.851063829787234, 0.55, 'X[3] <= 2.653\ngini = 0.165\nsamples = 11\nvalue = [1, 10]'), Text(0.8297872340425532, 0.45, 'gini = 0.0\nsamples = 10\nvalue = [0, 10]'), Text(0.8723404255319149, 0.45, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.9361702127659575, 0.55, 'X[7] <= 0.789\ngini = 0.49\nsamples = 7\nvalue = [4, 3]'), Text(0.9148936170212766, 0.45, 'gini = 0.0\nsamples = 3\nvalue = [3, 0]'), Text(0.9574468085106383, 0.45, 'X[12] <= 0.284\ngini = 0.375\nsamples = 4\nvalue = [1, 3]'), Text(0.9361702127659575, 0.35, 'gini = 0.0\nsamples = 3\nvalue = [0, 3]'), Text(0.9787234042553191, 0.35, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]')]
The issue with decision trees is that we would be forcing analysis using variables that we saw, from our Exploratory Data Analysis, are not necesarily correlated to the diagnosis. Random Forest Classifiers solve this problem by creating a bunch of decision trees with the features we supply and finds the best one after training them all.
from sklearn.ensemble import RandomForestClassifier
clf = RandomForestClassifier() # Using random forest
clf.fit(X_train,y_train)
result = clf.predict(X_test)
acc = 0
false_neg = 0
for i in range(len(result)):
if (y_test[i] == 1 and result[i] == 0) == False:
false_neg += 1
if y_test[i] == result[i]:
acc += 1
print("Accuracy", acc/len(result))
print("(Not a) False Negative Rate", false_neg/len(result))
tree.plot_tree(clf.estimators_[0], feature_names=df.columns, filled=True)
Accuracy 0.83 (Not a) False Negative Rate 0.92
[Text(0.44029850746268656, 0.9615384615384616, 'sex <= -0.338\ngini = 0.499\nsamples = 121\nvalue = [97, 106]'), Text(0.26865671641791045, 0.8846153846153846, 'thall <= 0.284\ngini = 0.357\nsamples = 43\nvalue = [17, 56]'), Text(0.1791044776119403, 0.8076923076923077, 'trtbps <= 2.125\ngini = 0.234\nsamples = 35\nvalue = [8, 51]'), Text(0.14925373134328357, 0.7307692307692307, 'chol <= 0.684\ngini = 0.188\nsamples = 33\nvalue = [6, 51]'), Text(0.05970149253731343, 0.6538461538461539, 'thalachh <= 1.054\ngini = 0.043\nsamples = 25\nvalue = [1, 45]'), Text(0.029850746268656716, 0.5769230769230769, 'gini = 0.0\nsamples = 21\nvalue = [0, 40]'), Text(0.08955223880597014, 0.5769230769230769, 'restecg <= -0.051\ngini = 0.278\nsamples = 4\nvalue = [1, 5]'), Text(0.05970149253731343, 0.5, 'oldpeak <= -0.667\ngini = 0.375\nsamples = 2\nvalue = [1, 3]'), Text(0.029850746268656716, 0.4230769230769231, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.08955223880597014, 0.4230769230769231, 'gini = 0.0\nsamples = 1\nvalue = [0, 3]'), Text(0.11940298507462686, 0.5, 'gini = 0.0\nsamples = 2\nvalue = [0, 2]'), Text(0.23880597014925373, 0.6538461538461539, 'chol <= 1.577\ngini = 0.496\nsamples = 8\nvalue = [5, 6]'), Text(0.208955223880597, 0.5769230769230769, 'age <= -0.017\ngini = 0.278\nsamples = 4\nvalue = [5, 1]'), Text(0.1791044776119403, 0.5, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.23880597014925373, 0.5, 'gini = 0.0\nsamples = 3\nvalue = [5, 0]'), Text(0.26865671641791045, 0.5769230769230769, 'gini = 0.0\nsamples = 4\nvalue = [0, 5]'), Text(0.208955223880597, 0.7307692307692307, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.3582089552238806, 0.8076923076923077, 'trtbps <= 0.833\ngini = 0.459\nsamples = 8\nvalue = [9, 5]'), Text(0.3283582089552239, 0.7307692307692307, 'chol <= 1.103\ngini = 0.494\nsamples = 5\nvalue = [4, 5]'), Text(0.29850746268656714, 0.6538461538461539, 'gini = 0.0\nsamples = 3\nvalue = [4, 0]'), Text(0.3582089552238806, 0.6538461538461539, 'gini = 0.0\nsamples = 2\nvalue = [0, 5]'), Text(0.3880597014925373, 0.7307692307692307, 'gini = 0.0\nsamples = 3\nvalue = [5, 0]'), Text(0.6119402985074627, 0.8846153846153846, 'cp <= -0.49\ngini = 0.473\nsamples = 78\nvalue = [80, 50]'), Text(0.47761194029850745, 0.8076923076923077, 'chol <= -0.865\ngini = 0.293\nsamples = 40\nvalue = [60, 13]'), Text(0.44776119402985076, 0.7307692307692307, 'gini = 0.0\nsamples = 6\nvalue = [13, 0]'), Text(0.5074626865671642, 0.7307692307692307, 'chol <= -0.71\ngini = 0.339\nsamples = 34\nvalue = [47, 13]'), Text(0.47761194029850745, 0.6538461538461539, 'gini = 0.0\nsamples = 2\nvalue = [0, 3]'), Text(0.5373134328358209, 0.6538461538461539, 'age <= 1.206\ngini = 0.289\nsamples = 32\nvalue = [47, 10]'), Text(0.4626865671641791, 0.5769230769230769, 'thalachh <= 0.922\ngini = 0.19\nsamples = 27\nvalue = [42, 5]'), Text(0.40298507462686567, 0.5, 'chol <= -0.218\ngini = 0.048\nsamples = 24\nvalue = [40, 1]'), Text(0.373134328358209, 0.4230769230769231, 'caa <= -0.148\ngini = 0.245\nsamples = 5\nvalue = [6, 1]'), Text(0.34328358208955223, 0.34615384615384615, 'age <= 0.261\ngini = 0.444\nsamples = 3\nvalue = [2, 1]'), Text(0.31343283582089554, 0.2692307692307692, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.373134328358209, 0.2692307692307692, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.40298507462686567, 0.34615384615384615, 'gini = 0.0\nsamples = 2\nvalue = [4, 0]'), Text(0.43283582089552236, 0.4230769230769231, 'gini = 0.0\nsamples = 19\nvalue = [34, 0]'), Text(0.5223880597014925, 0.5, 'trtbps <= 0.481\ngini = 0.444\nsamples = 3\nvalue = [2, 4]'), Text(0.4925373134328358, 0.4230769230769231, 'gini = 0.0\nsamples = 1\nvalue = [0, 2]'), Text(0.5522388059701493, 0.4230769230769231, 'exng <= 0.361\ngini = 0.5\nsamples = 2\nvalue = [2, 2]'), Text(0.5223880597014925, 0.34615384615384615, 'gini = 0.0\nsamples = 1\nvalue = [0, 2]'), Text(0.582089552238806, 0.34615384615384615, 'gini = 0.0\nsamples = 1\nvalue = [2, 0]'), Text(0.6119402985074627, 0.5769230769230769, 'caa <= -0.148\ngini = 0.5\nsamples = 5\nvalue = [5, 5]'), Text(0.582089552238806, 0.5, 'gini = 0.0\nsamples = 2\nvalue = [0, 5]'), Text(0.6417910447761194, 0.5, 'gini = 0.0\nsamples = 3\nvalue = [5, 0]'), Text(0.746268656716418, 0.8076923076923077, 'chol <= -1.056\ngini = 0.456\nsamples = 38\nvalue = [20, 37]'), Text(0.6865671641791045, 0.7307692307692307, 'thalachh <= 0.833\ngini = 0.408\nsamples = 4\nvalue = [5, 2]'), Text(0.6567164179104478, 0.6538461538461539, 'gini = 0.0\nsamples = 3\nvalue = [5, 0]'), Text(0.7164179104477612, 0.6538461538461539, 'gini = 0.0\nsamples = 1\nvalue = [0, 2]'), Text(0.8059701492537313, 0.7307692307692307, 'thalachh <= -1.201\ngini = 0.42\nsamples = 34\nvalue = [15, 35]'), Text(0.7761194029850746, 0.6538461538461539, 'gini = 0.0\nsamples = 3\nvalue = [4, 0]'), Text(0.835820895522388, 0.6538461538461539, 'chol <= 0.22\ngini = 0.364\nsamples = 31\nvalue = [11, 35]'), Text(0.7611940298507462, 0.5769230769230769, 'trtbps <= 1.185\ngini = 0.284\nsamples = 23\nvalue = [6, 29]'), Text(0.7014925373134329, 0.5, 'restecg <= -0.051\ngini = 0.219\nsamples = 21\nvalue = [4, 28]'), Text(0.6716417910447762, 0.4230769230769231, 'gini = 0.0\nsamples = 9\nvalue = [0, 14]'), Text(0.7313432835820896, 0.4230769230769231, 'oldpeak <= 2.381\ngini = 0.346\nsamples = 12\nvalue = [4, 14]'), Text(0.7014925373134329, 0.34615384615384615, 'caa <= 0.942\ngini = 0.291\nsamples = 11\nvalue = [3, 14]'), Text(0.6716417910447762, 0.2692307692307692, 'chol <= -0.291\ngini = 0.219\nsamples = 10\nvalue = [2, 14]'), Text(0.6417910447761194, 0.19230769230769232, 'cp <= 0.471\ngini = 0.444\nsamples = 4\nvalue = [2, 4]'), Text(0.6119402985074627, 0.11538461538461539, 'fbs <= 1.113\ngini = 0.444\nsamples = 2\nvalue = [2, 1]'), Text(0.582089552238806, 0.038461538461538464, 'gini = 0.0\nsamples = 1\nvalue = [2, 0]'), Text(0.6417910447761194, 0.038461538461538464, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.6716417910447762, 0.11538461538461539, 'gini = 0.0\nsamples = 2\nvalue = [0, 3]'), Text(0.7014925373134329, 0.19230769230769232, 'gini = 0.0\nsamples = 6\nvalue = [0, 10]'), Text(0.7313432835820896, 0.2692307692307692, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.7611940298507462, 0.34615384615384615, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.8208955223880597, 0.5, 'age <= 1.484\ngini = 0.444\nsamples = 2\nvalue = [2, 1]'), Text(0.7910447761194029, 0.4230769230769231, 'gini = 0.0\nsamples = 1\nvalue = [2, 0]'), Text(0.8507462686567164, 0.4230769230769231, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.9104477611940298, 0.5769230769230769, 'restecg <= -0.051\ngini = 0.496\nsamples = 8\nvalue = [5, 6]'), Text(0.8805970149253731, 0.5, 'gini = 0.0\nsamples = 2\nvalue = [4, 0]'), Text(0.9402985074626866, 0.5, 'trtbps <= -0.987\ngini = 0.245\nsamples = 6\nvalue = [1, 6]'), Text(0.9104477611940298, 0.4230769230769231, 'oldpeak <= -0.394\ngini = 0.5\nsamples = 2\nvalue = [1, 1]'), Text(0.8805970149253731, 0.34615384615384615, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.9402985074626866, 0.34615384615384615, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.9701492537313433, 0.4230769230769231, 'gini = 0.0\nsamples = 4\nvalue = [0, 5]')]
From the resulting evaluations, we can see that the Decision tree performs significantly more poorly than Random Forest and Logistic Regression. Random Forest and Logistic Regression are pretty close in their accuracy, with Random forest's false negative rate being slightly better.
I was able to create a model to predict heart disease in our observations with 80-90% effectiveness. I also recognized that when comparing simliarly effective models, a higher false negative rate should be taken more seriously than other forms of evaluation, since falsely giving someone a negative diagnosis is worse than the opposite.
I found some correlations between other variables provided in the dataset outside of the diagnosis variable. Concerning the diagnosis variable however, as suggested from the exploratory data analysis and the build of the random forest model, the leading factor that suggests heart disease is maximum heartrate achieved, which is the maximum heartrate produced by the body when under maximum stress.