Skip to yearly menu bar Skip to main content

Workshop: Table Representation Learning Workshop

GradTree: Learning Axis-Aligned Decision Trees with Gradient Descent

Sascha Marton · Stefan L├╝dtke · Christian Bartelt · Heiner Stuckenschmidt

Keywords: [ decision trees ] [ Gradient Descent ]


Decision Trees (DTs) are commonly used for many machine learning tasks due to their high degree of interpretability. However, learning a DT from data is a difficult optimization problem, as it is non-convex and non-differentiable. Therefore, common approaches learn DTs using a greedy growth algorithm that minimizes the impurity locally at each internal node. Unfortunately, this greedy procedure can lead to inaccurate trees.In this paper, we present a novel approach for learning hard, axis-aligned DTs with gradient descent. The proposed method uses backpropagation with a straight-through operator on a dense DT representation, to jointly optimize all tree parameters.Our approach outperforms existing methods on a wide range of binary classification benchmarks and is available under:

Chat is not available.