Embed Size (px)
Inference in Probabilistic Graphical Models by Graph Neural Networks
Author: KiJung Yoon, Renjie Liao, Yuwen Xiong, Lisa Zhang, Ethan Fetaya, Raquel Urtasun, Richard Zemel, Xaq Pitkow
Presenter: Shihao Niu, Zhe Qu, Siqi Liu, Jules Ahmar
TL;DR: Use Graph Neural Networks (GNNs) to learn a message-passing algorithm that solves inference tasks in probabilistic graphical models.
Motivation● Inference is difficult for probabilistic graphical models. ● Message passing algorithms, such as belief propagation, struggles when the
graph contains loops○ Loopy belief propagation: convergence are not guaranteed.
Why GNNs● Essentially an extension of recurrent neural networks (RNN) on the graph
inputs. ● Central idea is to update hidden states at each node iteratively, by
aggregating incoming messages. ● Have a similar structure as a message passing algorithm.
● Recall that the distribution of a factor graph is○
● Recall the formulas of a belief propagation algorithm○ ○
Factor graph and belief propagation
BP to GNNs: mapping the messages
● BP is recursive and graph-based. Naturally, we could map the messages to GNN nodes, and use Neural Networks to describe the nonlinear updates.
BP to GNNs: mapping the variable nodes
BP to GNNs: mapping the variable nodesMarginal probability of in MRF:
Marginal joint probability of in factor graph:
● All of the messages depend only on one variable node at a time● The nonlinear functions between GNN nodes can account for AFTER
equilibrium is reached.
Preliminaries for model● Binary MRF, aka Ising models.● and are specified randomly, and are provided as input for GNN inference. ● ●
Update the state embedding of based on
- the feature of - the feature of the edges of- the state embeddings of the neighbors of - the feature of the neighbor of
Local output function:
GNN Recap (Cont.)Scarselli, Franco, et al. "The graph neural network model."
Decompose the state update function to be a sum of per-edge terms
Message Passing Neural Networks
Define Message from i to j at time t+1 as:
Step 1: Aggregate all incoming message into a single message at the destination node
Step 2: Update hidden state based on the current hidden state and the aggregated message
An abstraction of several GNN variants
Phase 1Message Passing
Message Passing Neural Networks (Cont.)
Phase 2: Readout Phase
The message function, node update function, and readout function could have different settings.
MPNN could generalize several different models.
GG-NN (Gated Graph Neural Network)
Source: Zhou, Jie, et al. "Graph neural networks: A review of methods and applications."
Gate Recurrent Units (GRU)
Gate Recurrent Units (GRU)
Gate Recurrent Units (GRU)
Two mappings between Factor graph and GNN
message-GNN and node-GNN perform similarly, and much better than belief propagation
Mapping I: Message-GNN (graphical model) (GNN) Message 𝜇ij between node i and j Node v Message nodes are ij and jk Node v and w connected
Conforms closely to the structure of conventional belief propagation, and reflects how messages depend on each other:Motivation:
Mapping I: Message-GNN1. If connected, message from node to :
2. Then update its hidden state by:
3. Readout function to extract marginal or MAP:
a. First aggregates all GNN nodes with same target by summation
b. Then apply a shared readout function
neural network (GRU)
Multi-layer Perceptron with ReLU activation function
another MLP with sigmoid activation function
(nodes in graphical model)
Mapping II: Node-GNN
● Mapping: (graphical model) (GNN) Variable nodes Node
1. Message function:
2. Aggregate Messages:
3. Node update function:
4. Readout is generated directly from hidden states:
Message-GNN and Node-GNN● Objective: backpropagation to minimize total cross-entropy loss function
--- ground truth, --- estimated result
● Receives external inputs about couplings between edges● Depends on the hidden states of source and destination nodes at the
previous time step.
Message Passing Function (General):
Experiments● In each experiment, two types of GNNs are tested:
○ Variable nodes (node-GNN)○ Message nodes (msg-GNN)
● Examine generalization of the model when...○ Testing on unseen graphs of the same structure○ Testing on completely random graphs○ Testing on graphs with the same size○ Testing on graphs with larger size
● Analyze performance in estimating both marginal probabilities and MAP state
Larger, Novel Test Graphs
Marginal Inference Accuracy
Generalization Performance on Random Graphs
Convergence of Inference Dynamics
Conclusion● Experiments showed that GNNs provide a flexible learning method for
inference in probabilistic graphical models
● Proved that learned representations and nonlinear transformations on edges generalize to larger graphs with different structures
● Examined two possible representations of graphical models within GNNs: variable nodes and message nodes
● Experimental results support GNNs as a great framework for solving hard inference problems
● Future work: train and test on larger and more diverse graphs, as well as broader classes of graphical models
References1. Zhou, Jie, et al. "Graph neural networks: A review of methods and applications." arXiv preprint
2. Gilmer, Justin, et al. "Neural message passing for quantum chemistry." Proceedings of the 34th
International Conference on Machine Learning-Volume 70. JMLR. org, 2017.
3. Scarselli, Franco, et al. "The graph neural network model." IEEE Transactions on Neural Networks
20.1 (2008): 61-80.
4. Li, Yujia, et al. "Gated graph sequence neural networks." arXiv preprint arXiv:1511.05493 (2015).
5. Wu, Zonghan, et al. "A comprehensive survey on graph neural networks." arXiv preprint
Homework1. Where do GNNs outperform belief propagation? Where does belief
propagation outperform GNNs?2. Given the following factor graph, draw the GNN using Message-GNN