Neural Classification Trees Disentangle Latent Subgroups for Robust ML
Machine learning models often exploit spurious correlations, leading to high average accuracy but poor performance on underrepresented subgroups. Existing mitigation strategies typically adjust network parameters using subgroup annotations or inferred pseudo-labels. However, these methods generally output only a class prediction at inference time, lacking insight into a sample's latent subgroup structure. To address this, the authors propose Neural Classification Trees (NCT), a framework that encodes subgroup structure within its tree-shaped architecture. NCT routes each sample to an easy or hard node based on prediction correctness and reuses these routes as pseudo-labels for subsequent iterations. This process disentangles conflicting subgroups without requiring explicit subgroup supervision. The approach was evaluated on five benchmarks spanning binary and multi-class spurious correlations. Experiments demonstrate that the learned tree topology isolates minority subgroups, providing strong interpretability and competitive robustness compared to state-of-the-art methods.