Skip to yearly menu bar Skip to main content

Workshop: Causal Representation Learning

Curvature and Causal Inference in Network Data

Amirhossein Farzam · Allen Tannenbaum · Guillermo Sapiro

Keywords: [ Geometric Deep Learning ] [ Curvature ] [ graph neural networks ] [ causal representation learning ]


Learning causal mechanisms involving networked units of data is a notoriously challenging task with various applications. Graph Neural Networks (GNNs) have proven to be effective for learning representations that capture complex dependencies between data units. This effectiveness is largely due to the conduciveness of GNNs to tools that characterize the geometry of graphs. The potential of geometric deep learning for GNN-based causal representation learning, however, remains underexplored. This work makes three key contributions to bridge this gap. First, we establish a theoretical connection between graph curvature and causal inference, showing that negative curvatures pose challenges to learning the causal mechanisms underlying network data. Second, based on this theoretical insight, we present empirical results using the Ricci curvature to gauge the error in treatment effect estimates made from representations learned by GNNs. This empirically demonstrates that positive curvature regions yield more accurate results. Lastly, as an example of the potentials unleashed by this newfound connection between geometry and causal inference, we propose a method using Ricci flow to improve the treatment effect estimation on networked data. Our experiments confirm that this method reduces the error in treatment effect estimates by flattening the network, showcasing the utility of geometric methods for enhancing causal representation learning. Our findings open new avenues for leveraging discrete geometry in causal representation learning, offering insights and tools that enhance the performance of GNNs in learning robust structural relationships.

Chat is not available.