Domain-Invariant Feature Learning for Patient-Level Phenotype Prediction from Single-Cell Data
Abstract
Accurate prediction of patient-level disease status from single-cell RNA sequencing (scRNA-seq) data is critical to enabling precision diagnostics. However, study-specific artifacts induce spurious correlations that limit generalization and interpretability. We studied this problem in the context of Multiple Instance Learning (MIL), a framework where each patient is modeled as a set of single-cell profiles. To improve robustness to domain shifts, we propose an adversarial and metric-based approach that learns domain-invariant representations while preserving task-relevant biological variation. We benchmarked our method on a systemic lupus erythematosus (SLE) dataset with synthetically added spurious features and evaluated its performance on two real-world scRNA-seq atlases: a cross-tissue immune dataset and a COVID-19 severity atlas. Across all settings, we observed consistent improvements in out-of-domain accuracy and more biologically faithful model attributions. Our findings establish a new standard for robust, interpretable patient-level prediction under domain shifts using scRNA-seq.