Transcript
  • End to End Learning and Optimization on Graphs

    Bryan Wilder1, Eric Ewing2, Bistra Dilkina2, Milind Tambe11Harvard University

    2University of Southern California

    To appear at NeurIPS 2019

    110/27/2019 Bryan Wilder (Harvard)

  • 10/27/2019 Bryan Wilder (Harvard) 2

    Data Decisions???

  • 10/27/2019 Bryan Wilder (Harvard) 3

    Data Decisionsarg max๐‘ฅ๐‘ฅโˆˆ๐‘‹๐‘‹ ๐‘“๐‘“ ๐‘ฅ๐‘ฅ,๐œƒ๐œƒ

    Standard two stage: predict then optimize

    Training: maximize accuracy

  • 10/27/2019 Bryan Wilder (Harvard) 4

    Data Decisionsarg max๐‘ฅ๐‘ฅโˆˆ๐‘‹๐‘‹ ๐‘“๐‘“ ๐‘ฅ๐‘ฅ,๐œƒ๐œƒ

    Standard two stage: predict then optimize

    Challenge: misalignment between โ€œaccuracyโ€ and decision quality

    Training: maximize accuracy

  • 10/27/2019 Bryan Wilder (Harvard) 5

    Data Decisions

    Pure end to end

    Training: maximize decision quality

  • 10/27/2019 Bryan Wilder (Harvard) 6

    Data Decisions

    Pure end to end

    Challenge: optimization is hard

    Training: maximize decision quality

  • 10/27/2019 Bryan Wilder (Harvard) 7

    Data Decisions

    Decision-focused learning: differential optimization during training

    arg max๐‘ฅ๐‘ฅโˆˆ๐‘‹๐‘‹

    ๐‘“๐‘“ ๐‘ฅ๐‘ฅ,๐œƒ๐œƒ

    Training: maximize decision quality

  • 10/27/2019 Bryan Wilder (Harvard) 8

    Data Decisions

    Challenge: how to make optimization differentiable?

    arg max๐‘ฅ๐‘ฅโˆˆ๐‘‹๐‘‹

    ๐‘“๐‘“ ๐‘ฅ๐‘ฅ,๐œƒ๐œƒ

    Training: maximize decision quality

    Decision-focused learning: differential optimization during training

  • Relax + differentiate

    10/27/2019 Bryan Wilder (Harvard) 9

    Forward pass: run a solver

    Backward pass: sensitivity analysis via KKT conditions

  • Relax + differentiateโ€ข Convex QPs [Amos and Kolter 2018, Donti et al 2018]โ€ข Linear and submodular programs [Wilder, Dilkina, Tambe 2019]โ€ข MAXSAT (via SDP relaxation) [Wang, Donti, Wilder, Kolter 2019]โ€ข MIPs [Ferber, Wilder, Dilkina, Tambe 2019]

    โ€ข Monday @ 11am, Room 612

    10/27/2019 Bryan Wilder (Harvard) 10

  • Whatโ€™s wrong with relaxations?โ€ข Some problems donโ€™t have good onesโ€ข Slow to solve continuous optimization problemโ€ข Slower to backprop through โ€“ ๐‘‚๐‘‚(๐‘›๐‘›3)

    10/27/2019 Bryan Wilder (Harvard) 11

  • This workโ€ข Alternative: include solver for a simpler proxy problemโ€ข Learn a representation that maps hard problem to simple oneโ€ข Instantiate this idea for a class of graph optimization problems

    10/27/2019 Bryan Wilder (Harvard) 12

  • Graph learning + optimization

    10/27/2019 Bryan Wilder (Harvard) 13

  • Problem classesโ€ข Partition the nodes into K disjoint groups

    โ€ข Community detection, maxcut, โ€ฆโ€ข Select a subset of K nodes

    โ€ข Facility location, influence maximization, vaccination, โ€ฆโ€ข Methods of choice are often combinatorial/discrete

    10/27/2019 Bryan Wilder (Harvard) 14

  • Approachโ€ข Observation: grouping nodes into communities is a good heuristic

    โ€ข Partitioning: correspond to well-connected subgroupsโ€ข Facility location: put one facility in each community

    โ€ข Observation: graph learning approaches already embed into ๐‘…๐‘…๐‘›๐‘›

    10/27/2019 Bryan Wilder (Harvard) 15

  • Approach1. Start with clustering algorithm (in ๐‘…๐‘…๐‘›๐‘›)

    โ€ข Can (approximately) differentiate very quickly 2. Train embeddings (representation) to solve the particular problem

    โ€ข Automatically learning a good continuous relaxation!

    10/27/2019 Bryan Wilder (Harvard) 16

  • ClusterNet Approach

    10/27/2019 Bryan Wilder (Harvard) 17

    Node embedding (GCN)

    K-means clustering Locate 1 facility in

    each community

  • Differentiable K-means

    10/27/2019 Bryan Wilder (Harvard) 18

    Update cluster centers

    Softmax update to node assignments

    Forward pass

  • Differentiable K-means

    10/27/2019 Bryan Wilder (Harvard) 19

    Backward pass

    โ€ข Option 1: differentiate through the fixed-point condition ๐œ‡๐œ‡๐‘ก๐‘ก = ๐œ‡๐œ‡๐‘ก๐‘ก+1

    โ€ข Prohibitively slow, memory-intensive

  • Differentiable K-means

    10/27/2019 Bryan Wilder (Harvard) 20

    Backward pass

    โ€ข Option 1: differentiate through the fixed-point condition ๐œ‡๐œ‡๐‘ก๐‘ก = ๐œ‡๐œ‡๐‘ก๐‘ก+1

    โ€ข Prohibitively slow, memory-intensive โ€ข Option 2: unroll the entire series of updates

    โ€ข Cost scales with # iterationsโ€ข Have to stick to differentiable operations

  • Differentiable K-means

    10/27/2019 Bryan Wilder (Harvard) 21

    Backward pass

    โ€ข Option 1: differentiate through the fixed-point condition ๐œ‡๐œ‡๐‘ก๐‘ก = ๐œ‡๐œ‡๐‘ก๐‘ก+1

    โ€ข Prohibitively slow, memory-intensive โ€ข Option 2: unroll the entire series of updates

    โ€ข Cost scales with # iterationsโ€ข Have to stick to differentiable operations

    โ€ข Option 3: get the solution, then unroll one updateโ€ข Do anything to solve the forward passโ€ข Linear time/memory, implemented in vanilla pytorch

  • Differentiable K-meansTheorem [informal]: provided the clusters are sufficiently balanced and well-separated, the Option 3 approximate gradients converge exponentially quickly to the true ones.

    Idea: show that this corresponds to approximating a particular term in the analytical fixed-point gradients.

    10/27/2019 Bryan Wilder (Harvard) 22

  • ClusterNet Approach

    10/27/2019 Bryan Wilder (Harvard) 23

    GCN node embeddings

    K-means clustering Locate 1 facility in

    each community

  • ClusterNet Approach

    10/27/2019 Bryan Wilder (Harvard) 24

    GCN node embeddings

    K-means clustering Locate 1 facility in

    each community

    Loss: quality of facility assignment

  • ClusterNet Approach

    10/27/2019 Bryan Wilder (Harvard) 25

    GCN node embeddings

    K-means clustering Locate 1 facility in

    each community

    Loss: quality of facility assignment

    Differentiate through K-means

  • ClusterNet Approach

    10/27/2019 Bryan Wilder (Harvard) 26

    GCN node embeddings

    K-means clustering Locate 1 facility in

    each community

    Loss: quality of facility assignment

    Differentiate through K-means

    Update GCN params

  • Experimentsโ€ข Learning problem: link predictionโ€ข Optimization: community detection and facility location problemsโ€ข Train GCNs as predictive component

    10/27/2019 Bryan Wilder (Harvard) 27

  • Example: community detection

    10/27/2019 Bryan Wilder (Harvard) 28

    Observe partial graph Predict unseen edges Find communities

    max๐‘Ÿ๐‘Ÿ

    12๐‘š๐‘š

    ๏ฟฝ๐‘ข๐‘ข,๐‘ฃ๐‘ฃโˆˆ๐‘‰๐‘‰

    ๏ฟฝ๐‘˜๐‘˜=1

    ๐พ๐พ

    ๐ด๐ด๐‘ข๐‘ข,๐‘ฃ๐‘ฃ โˆ’๐‘‘๐‘‘๐‘ข๐‘ข๐‘‘๐‘‘๐‘ฃ๐‘ฃ2๐‘š๐‘š

    ๐‘Ÿ๐‘Ÿ๐‘ข๐‘ข๐‘˜๐‘˜๐‘Ÿ๐‘Ÿ๐‘ฃ๐‘ฃ๐‘˜๐‘˜

    ๐‘Ÿ๐‘Ÿ๐‘ข๐‘ข๐‘˜๐‘˜ โˆˆ 0,1 โˆ€๐‘ข๐‘ข โˆˆ ๐‘‰๐‘‰, ๐‘˜๐‘˜ = 1 โ€ฆ๐พ๐พ

    ๏ฟฝ๐‘˜๐‘˜=1

    ๐พ๐พ

    ๐‘Ÿ๐‘Ÿ๐‘ข๐‘ข๐‘˜๐‘˜ = 1 โˆ€๐‘ข๐‘ข โˆˆ ๐‘‰๐‘‰

  • Example: community detection

    10/27/2019 Bryan Wilder (Harvard) 29

    max๐‘Ÿ๐‘Ÿ

    12๐‘š๐‘š

    ๏ฟฝ๐‘ข๐‘ข,๐‘ฃ๐‘ฃโˆˆ๐‘‰๐‘‰

    ๏ฟฝ๐‘˜๐‘˜=1

    ๐พ๐พ

    ๐ด๐ด๐‘ข๐‘ข,๐‘ฃ๐‘ฃ โˆ’๐‘‘๐‘‘๐‘ข๐‘ข๐‘‘๐‘‘๐‘ฃ๐‘ฃ2๐‘š๐‘š

    ๐‘Ÿ๐‘Ÿ๐‘ข๐‘ข๐‘˜๐‘˜๐‘Ÿ๐‘Ÿ๐‘ฃ๐‘ฃ๐‘˜๐‘˜

    ๐‘Ÿ๐‘Ÿ๐‘ข๐‘ข๐‘˜๐‘˜ โˆˆ 0,1 โˆ€๐‘ข๐‘ข โˆˆ ๐‘‰๐‘‰, ๐‘˜๐‘˜ = 1 โ€ฆ๐พ๐พ

    ๏ฟฝ๐‘˜๐‘˜=1

    ๐พ๐พ

    ๐‘Ÿ๐‘Ÿ๐‘ข๐‘ข๐‘˜๐‘˜ = 1 โˆ€๐‘ข๐‘ข โˆˆ ๐‘‰๐‘‰

    โ€ข Useful in scientific discovery (social groups, functional modules in biological networks)โ€ข In applications, two-stage approach is common

    [Yan & Gegory โ€™12, Burgess et al โ€˜16, Berlusconi et al โ€˜16, Tan et al โ€˜16, Bahulker et al โ€™18โ€ฆ]

    Observe partial graph Predict unseen edges Find communities

  • Experimentsโ€ข Learning problem: link predictionโ€ข Optimization: community detection and facility location problemsโ€ข Train GCNs as predictive componentโ€ข Comparison

    โ€ข Two stage: GCN + expert-designed algorithm (2Stage)โ€ข Pure end to end: Deep GCN to predict optimal solution (e2e)

    10/27/2019 Bryan Wilder (Harvard) 30

  • Results: single-graph link prediction

    10/27/2019 Bryan Wilder (Harvard) 31

    0

    0.2

    0.4

    0.6

    Mod

    ular

    ityCommunity detection

    (higher is better)

    ClusterNet 2stage e2e5

    7

    9

    11

    Max

    dist

    ance

    Facility location (lower is better)

    ClusterNet 2stage e2e

    Representative example from cora, citeseer, protein interaction, facebook, adolescent health networks

  • Results: generalization across graphs

    10/27/2019 Bryan Wilder (Harvard) 32

    0

    0.2

    0.4

    0.6

    Mod

    ular

    ityCommunity detection

    (higher is better)

    ClusterNet 2stage e2e5

    7

    9

    Max

    dist

    ance

    Facility location (lower is better)

    ClusterNet 2stage e2e

    ClusterNet learns generalizable strategies for optimization!

  • Takeawaysโ€ข Good decisions require integrating learning and optimizationโ€ข Pure end-to-end methods miss out on useful structureโ€ข Even simple optimization primitives provide good inductive bias

    NeurIPSโ€™19 paper, see bryanwilder.github.io Code available at https://github.com/bwilder0/clusternet

    10/27/2019 Bryan Wilder (Harvard) 33

    https://github.com/bwilder0/clusternet

    End to End Learning and Optimization on GraphsSlide Number 2Slide Number 3Slide Number 4Slide Number 5Slide Number 6Slide Number 7Slide Number 8Relax + differentiateRelax + differentiateWhatโ€™s wrong with relaxations?This workGraph learning + optimizationProblem classesApproachApproachClusterNet ApproachDifferentiable K-meansDifferentiable K-meansDifferentiable K-meansDifferentiable K-meansDifferentiable K-meansClusterNet ApproachClusterNet ApproachClusterNet ApproachClusterNet ApproachExperimentsExample: community detectionExample: community detectionExperimentsResults: single-graph link predictionResults: generalization across graphsTakeaways


Recommended