GradTree: Learning Axis-Aligned Decision Trees with Gradient Descent
Published in AAAI Conference on Artificial Intelligence, 2024
Decision Trees (DTs) are commonly used for many machine learning tasks due to their high degree of interpretability. However, learning a DT from data is a difficult optimization problem, as it is non-convex and non-differentiable. Therefore, common approaches learn DTs using a greedy growth algorithm that minimizes the impurity locally at each internal node. Unfortunately, this greedy procedure can lead to inaccurate trees. In this paper, we present a novel approach for learning hard, axis-aligned DTs with gradient descent. The proposed method uses backpropagation with a straight-through operator on a dense DT representation, to jointly optimize all tree parameters. Our approach outperforms existing methods on binary classification benchmarks and achieves competitive results for multi-class tasks. The implementation is available under: https://github.com/s-marton/GradTree
Recommended citation: Marton, Sascha, et al. (2024). "GRANDE: Gradient-Based Decision Tree Ensembles for Tabular Data." Proceedings of the AAAI Conference on Artificial Intelligence. 38(13).
Download Paper | Download Slides