Upload
others
View
0
Download
0
Embed Size (px)
Citation preview
End to End Learning and Optimization on Graphs
Bryan Wilder1, Eric Ewing2, Bistra Dilkina2, Milind Tambe11Harvard University
2University of Southern California
To appear at NeurIPS 2019
110/27/2019 Bryan Wilder (Harvard)
10/27/2019 Bryan Wilder (Harvard) 2
Data Decisions???
10/27/2019 Bryan Wilder (Harvard) 3
Data Decisionsarg maxπ₯π₯βππ ππ π₯π₯,ππ
Standard two stage: predict then optimize
Training: maximize accuracy
10/27/2019 Bryan Wilder (Harvard) 4
Data Decisionsarg maxπ₯π₯βππ ππ π₯π₯,ππ
Standard two stage: predict then optimize
Challenge: misalignment between βaccuracyβ and decision quality
Training: maximize accuracy
10/27/2019 Bryan Wilder (Harvard) 5
Data Decisions
Pure end to end
Training: maximize decision quality
10/27/2019 Bryan Wilder (Harvard) 6
Data Decisions
Pure end to end
Challenge: optimization is hard
Training: maximize decision quality
10/27/2019 Bryan Wilder (Harvard) 7
Data Decisions
Decision-focused learning: differential optimization during training
arg maxπ₯π₯βππ
ππ π₯π₯,ππ
Training: maximize decision quality
10/27/2019 Bryan Wilder (Harvard) 8
Data Decisions
Challenge: how to make optimization differentiable?
arg maxπ₯π₯βππ
ππ π₯π₯,ππ
Training: maximize decision quality
Decision-focused learning: differential optimization during training
Relax + differentiate
10/27/2019 Bryan Wilder (Harvard) 9
Forward pass: run a solver
Backward pass: sensitivity analysis via KKT conditions
Relax + differentiateβ’ Convex QPs [Amos and Kolter 2018, Donti et al 2018]β’ Linear and submodular programs [Wilder, Dilkina, Tambe 2019]β’ MAXSAT (via SDP relaxation) [Wang, Donti, Wilder, Kolter 2019]β’ MIPs [Ferber, Wilder, Dilkina, Tambe 2019]
β’ Monday @ 11am, Room 612
10/27/2019 Bryan Wilder (Harvard) 10
Whatβs wrong with relaxations?β’ Some problems donβt have good onesβ’ Slow to solve continuous optimization problemβ’ Slower to backprop through β ππ(ππ3)
10/27/2019 Bryan Wilder (Harvard) 11
This workβ’ Alternative: include solver for a simpler proxy problemβ’ Learn a representation that maps hard problem to simple oneβ’ Instantiate this idea for a class of graph optimization problems
10/27/2019 Bryan Wilder (Harvard) 12
Graph learning + optimization
10/27/2019 Bryan Wilder (Harvard) 13
Problem classesβ’ Partition the nodes into K disjoint groups
β’ Community detection, maxcut, β¦β’ Select a subset of K nodes
β’ Facility location, influence maximization, vaccination, β¦β’ Methods of choice are often combinatorial/discrete
10/27/2019 Bryan Wilder (Harvard) 14
Approachβ’ Observation: grouping nodes into communities is a good heuristic
β’ Partitioning: correspond to well-connected subgroupsβ’ Facility location: put one facility in each community
β’ Observation: graph learning approaches already embed into π π ππ
10/27/2019 Bryan Wilder (Harvard) 15
Approach1. Start with clustering algorithm (in π π ππ)
β’ Can (approximately) differentiate very quickly 2. Train embeddings (representation) to solve the particular problem
β’ Automatically learning a good continuous relaxation!
10/27/2019 Bryan Wilder (Harvard) 16
ClusterNet Approach
10/27/2019 Bryan Wilder (Harvard) 17
Node embedding (GCN)
K-means clustering Locate 1 facility in
each community
Differentiable K-means
10/27/2019 Bryan Wilder (Harvard) 18
Update cluster centers
Softmax update to node assignments
Forward pass
Differentiable K-means
10/27/2019 Bryan Wilder (Harvard) 19
Backward pass
β’ Option 1: differentiate through the fixed-point condition πππ‘π‘ = πππ‘π‘+1
β’ Prohibitively slow, memory-intensive
Differentiable K-means
10/27/2019 Bryan Wilder (Harvard) 20
Backward pass
β’ Option 1: differentiate through the fixed-point condition πππ‘π‘ = πππ‘π‘+1
β’ Prohibitively slow, memory-intensive β’ Option 2: unroll the entire series of updates
β’ Cost scales with # iterationsβ’ Have to stick to differentiable operations
Differentiable K-means
10/27/2019 Bryan Wilder (Harvard) 21
Backward pass
β’ Option 1: differentiate through the fixed-point condition πππ‘π‘ = πππ‘π‘+1
β’ Prohibitively slow, memory-intensive β’ Option 2: unroll the entire series of updates
β’ Cost scales with # iterationsβ’ Have to stick to differentiable operations
β’ Option 3: get the solution, then unroll one updateβ’ Do anything to solve the forward passβ’ Linear time/memory, implemented in vanilla pytorch
Differentiable K-meansTheorem [informal]: provided the clusters are sufficiently balanced and well-separated, the Option 3 approximate gradients converge exponentially quickly to the true ones.
Idea: show that this corresponds to approximating a particular term in the analytical fixed-point gradients.
10/27/2019 Bryan Wilder (Harvard) 22
ClusterNet Approach
10/27/2019 Bryan Wilder (Harvard) 23
GCN node embeddings
K-means clustering Locate 1 facility in
each community
ClusterNet Approach
10/27/2019 Bryan Wilder (Harvard) 24
GCN node embeddings
K-means clustering Locate 1 facility in
each community
Loss: quality of facility assignment
ClusterNet Approach
10/27/2019 Bryan Wilder (Harvard) 25
GCN node embeddings
K-means clustering Locate 1 facility in
each community
Loss: quality of facility assignment
Differentiate through K-means
ClusterNet Approach
10/27/2019 Bryan Wilder (Harvard) 26
GCN node embeddings
K-means clustering Locate 1 facility in
each community
Loss: quality of facility assignment
Differentiate through K-means
Update GCN params
Experimentsβ’ Learning problem: link predictionβ’ Optimization: community detection and facility location problemsβ’ Train GCNs as predictive component
10/27/2019 Bryan Wilder (Harvard) 27
Example: community detection
10/27/2019 Bryan Wilder (Harvard) 28
Observe partial graph Predict unseen edges Find communities
maxππ
12ππ
οΏ½π’π’,π£π£βππ
οΏ½ππ=1
πΎπΎ
π΄π΄π’π’,π£π£ βπππ’π’πππ£π£2ππ
πππ’π’πππππ£π£ππ
πππ’π’ππ β 0,1 βπ’π’ β ππ, ππ = 1 β¦πΎπΎ
οΏ½ππ=1
πΎπΎ
πππ’π’ππ = 1 βπ’π’ β ππ
Example: community detection
10/27/2019 Bryan Wilder (Harvard) 29
maxππ
12ππ
οΏ½π’π’,π£π£βππ
οΏ½ππ=1
πΎπΎ
π΄π΄π’π’,π£π£ βπππ’π’πππ£π£2ππ
πππ’π’πππππ£π£ππ
πππ’π’ππ β 0,1 βπ’π’ β ππ, ππ = 1 β¦πΎπΎ
οΏ½ππ=1
πΎπΎ
πππ’π’ππ = 1 βπ’π’ β ππ
β’ Useful in scientific discovery (social groups, functional modules in biological networks)β’ In applications, two-stage approach is common
[Yan & Gegory β12, Burgess et al β16, Berlusconi et al β16, Tan et al β16, Bahulker et al β18β¦]
Observe partial graph Predict unseen edges Find communities
Experimentsβ’ Learning problem: link predictionβ’ Optimization: community detection and facility location problemsβ’ Train GCNs as predictive componentβ’ Comparison
β’ Two stage: GCN + expert-designed algorithm (2Stage)β’ Pure end to end: Deep GCN to predict optimal solution (e2e)
10/27/2019 Bryan Wilder (Harvard) 30
Results: single-graph link prediction
10/27/2019 Bryan Wilder (Harvard) 31
0
0.2
0.4
0.6
Mod
ular
ityCommunity detection
(higher is better)
ClusterNet 2stage e2e5
7
9
11
Max
dist
ance
Facility location (lower is better)
ClusterNet 2stage e2e
Representative example from cora, citeseer, protein interaction, facebook, adolescent health networks
Results: generalization across graphs
10/27/2019 Bryan Wilder (Harvard) 32
0
0.2
0.4
0.6
Mod
ular
ityCommunity detection
(higher is better)
ClusterNet 2stage e2e5
7
9
Max
dist
ance
Facility location (lower is better)
ClusterNet 2stage e2e
ClusterNet learns generalizable strategies for optimization!
Takeawaysβ’ Good decisions require integrating learning and optimizationβ’ Pure end-to-end methods miss out on useful structureβ’ Even simple optimization primitives provide good inductive bias
NeurIPSβ19 paper, see bryanwilder.github.io Code available at https://github.com/bwilder0/clusternet
10/27/2019 Bryan Wilder (Harvard) 33
https://github.com/bwilder0/clusternet
End to End Learning and Optimization on GraphsSlide Number 2Slide Number 3Slide Number 4Slide Number 5Slide Number 6Slide Number 7Slide Number 8Relax + differentiateRelax + differentiateWhatβs wrong with relaxations?This workGraph learning + optimizationProblem classesApproachApproachClusterNet ApproachDifferentiable K-meansDifferentiable K-meansDifferentiable K-meansDifferentiable K-meansDifferentiable K-meansClusterNet ApproachClusterNet ApproachClusterNet ApproachClusterNet ApproachExperimentsExample: community detectionExample: community detectionExperimentsResults: single-graph link predictionResults: generalization across graphsTakeaways