Introducing ikNN: An Interpretable k Nearest Neighbors Model
Introducing ikNN: An Interpretable k Nearest Neighbors Model
In machine learning, obtaining the most accurate model is often the primary objective. However, in some cases, interpretability is equally crucial. While models like XGBoost, CatBoost, and LGBM offer high accuracy, they function as black-box models, making it challenging to understand their predictions and behavior on unseen data. This article introduces a new interpretable model called ikNN, or interpretable k Nearest Neighbors, which offers a balance between accuracy and interpretability.
The Need for Interpretability
In many contexts, black-box models are sufficient as long as they are reasonably accurate. For instance, a model predicting which ads will generate sales on a website can tolerate occasional inaccuracies without significant consequences. However, in high-stakes environments such as medicine and security, understanding why models make specific predictions is essential. It is also critical in environments that require auditing to ensure there are no biases related to race, gender, or other protected classes.
Explainable AI (XAI)
Explainable AI (XAI) provides post-hoc analysis to explain the predictions of black-box models. Techniques such as proxy models, feature importances (e.g., SHAP), counterfactuals, and ALE plots are useful tools in XAI. However, these methods have limitations, and having an inherently interpretable model is preferable when possible.
Proxy Models
Proxy models involve training an interpretable model (e.g., a shallow decision tree) to mimic the behavior of a black-box model. While this approach provides some explanation, it only offers approximate explanations and may not always be accurate.
Feature Importances
Feature importances indicate the relevant features in a model but do not explain how these features relate to the prediction or how they interact. They also cannot predict the model’s performance on unseen data.
Interpretable Models
Interpretable models, such as decision trees, rule lists, GAMs (Generalized Additive Models), and linear/logistic regression, provide clear insights into their predictions. However, they often have lower accuracy compared to black-box models. The challenge is finding an interpretable model that offers sufficient accuracy for the given problem.
Introducing ikNN
ikNN is an interpretable model based on an ensemble of 2d kNN models. While not competitive with state-of-the-art models like CatBoost, ikNN offers accuracy that is often sufficient for many problems and is competitive with other interpretable models like decision trees. Interestingly, ikNN tends to outperform plain kNN models.
Key Features of ikNN
- Ensembling: Using an ensemble of 2d kNN models increases the reliability of predictions.
- Visualizability: 2d spaces are straightforward to visualize, providing high interpretability.
- Predictive Power: The model considers the most predictive 2d subspaces for each prediction, weighted by their accuracy on the training data.
Using ikNN
The ikNN model can be included in any project by copying the interpretable_knn.py
file and importing the iKNNClassifier
class. It provides an interface consistent with scikit-learn classifiers.
from sklearn.datasets import load_iris
from interpretable_knn import ikNNClassifier
from sklearn.model_selection import train_test_split
iris = load_iris()
X, y = iris.data, iris.target
clf = ikNNClassifier()
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
The ikNN Algorithm
ikNN creates a standard 2d kNN classifier for each pair of features in the dataset. For example, a table with 10 features results in 45 models, one for each unique pair of features. The ikNN model then assesses the accuracy of each 2d subspace and uses the most predictive subspaces for making predictions.
Visualization and Interpretability
ikNN provides tools for understanding the model through the graph_model() and graph_predictions() APIs. These tools allow users to visualize the data space and the predictions made by the model.
ikNN.graph_model(X.columns)
This function provides an overview of the dataspace by plotting five 2d spaces, showing the classes of the training data and the predictions made by the 2d kNN for each region.
ikNN.graph_predictions(X_test[0])
This function explains a specific row, showing where the row is located relative to the training data and the predictions made by the 2d kNN for this 2d space.
Example: Iris Dataset
Using the iris dataset, we can demonstrate the ikNN model:
from sklearn.datasets import load_iris
from interpretable_knn import ikNNClassifier
from sklearn.model_selection import train_test_split
iris = load_iris()
X, y = iris.data, iris.target
clf = ikNNClassifier()
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
For prediction, this is all that is required. The graph_model() and graph_predictions() APIs provide further insights into the model.
Accuracy Tests
Tests using 100 classification datasets from OpenML show that ikNN outperforms standard kNN models in terms of F1 (macro) scores for 58 of the 100 datasets. When performing grid search for hyperparameter tuning, ikNN performed the best in 76 of the 100 cases. ikNN models also tend to have smaller gaps between train and test scores, indicating more stable models.
Conclusion
While ikNN may not always be the strongest model in terms of accuracy, it is a valuable model to consider when interpretability is crucial. The ability to visualize and understand predictions makes ikNN an excellent choice for high-stakes environments where understanding model behavior is essential.