This post is a summary of lectures 6 to 8 (Videos 17 to 25) of Stanford CS224W, ‘Machine Learning with Graphs’ course. (youtube)

Introduction to Graph Neural Networks

The goal is to find an encoding of vector $v$ based graph structure

$$ ENC(v) = \text{Multiple layers of non-linear transformation} \ \text{based on graph structure} $$

Deep Graph Encoder
Deep Graph Encoder

Tasks

  • Node classification
  • Link prediction
  • Community detection
  • Network similarity

😐 Modern deep learning toolbox is designed for simple sequences and grids

But graphs:

  • Arbitrary size and complex topological structure
  • No Fixed ordering or reference point
  • Often dynamic and have multimodal features

Basic of Deep Learning

I skip this lecture.

Deep Learning for Graphs

NotionDescription
$G$Graph
$V$Vertex set
$A$Adjacency matrix
$X\in \mathbb{R}^{m\times|V|}$Matrix of node features
$v\in V$A node
$N(v)$Set of neighbors of $v$

Naive approach

Join adjacency matrix and feature and then feed them to the deep neural net:

naive approach
Naive Approach

  • $O(|V|)$ parameters
  • Not applicable to graphs with different size
  • Depends on node ordering

CNNs similar idea

💡Transform information at the neighbors and combine it.

  • Transform messages $h_i$ from neighbors: $W_ih_i$
  • Add them up: $\sum_iW_ih_i$

Graph Convolutional Neural Network (GCN)

💡Network neighborhood defines a computation graph

Computation graph based on graph structure
Computation graph based on graph structure

The math

  1. Initial 0th layer embeddings are equal to node feature

$$h_v^0=x_v$$

  1. For each layer

$$ h_v^{(l+1)} = \textcolor{orange}{\sigma}( \textcolor{pink}{W_l} \textcolor{skyblue}{\sum_{u\in N(v)}\frac{h_u^{(l)}}{|N(v)|}} + \textcolor{pink}{B_l}h_v ), \quad \forall l \in \lbrace 0, \dots, \textcolor{purple}{L}-1 \rbrace $$

  • Where:
    • $\textcolor{orange}{\sigma}$: None-linearity (e.g. ReLU)
    • $\textcolor{pink}{W_l}$, $\textcolor{pink}{B_l}$: Trainable weight matrix
    • $\textcolor{skyblue}{\sum_{u\in N(v)}\frac{h_u^{(l)}}{|N(v)|}}$: Averages of neighbor’s previous layer embeddings
    • $\textcolor{purple}{L}$: Total number of layers
  1. Embedding after L layers of neighborhood aggregation $$z_v=h_v^{(L)}$$

Matrix formulation

  • $H^{(l)}=[h_1^{(l)}\dots h_{|V|}^{(l)}]^T$

Matrix form of hidden embedding
Matrix form of hidden embedding

  • $\sum_{u\in N(v)} h_u^{(l)} = A_{v,:}H^{(l)}$
    • $A_{v,:}$ means row $v$ of adjacency matrix $A$
  • $D$ is diagonal matrix where: $D_{v,v}=\text{Deg}(v) + \epsilon $
    • $D_{v,v}^{-1}=1/(|N(v)|+\epsilon)$
    • $\epsilon$ prevents dividing by zero.
  • $\tilde{A}=D^{-1}A$

Matrix form

$$ H^{(l+1)}=\sigma(\textcolor{tomato}{\tilde{A}H^{(l)}W_l^T } + \textcolor{aqua}{H^{(l)}B_l^T}) $$

  • Neighborhood aggregation
  • Self transformation

🙃 Not all GNNs can be expressed in matrix form, when aggregation function is complex.

How to train a GNN

Supervised

Minimize loss between GNN output and label

$$ \min_{\theta}\mathcal{L}(y,f(z_v)) $$

Unsupervised

Because no label is available, use the graph structure as the supervision (e.g. node similarity based random walks, matrix factorization, etc.).

Inductive capacity

The same aggregation parameters are shared for all nodes

  • The number of model parameters is sublinear in $|V|$
  • Generalize to unseen nodes

A General GNN Framework

Each topic will be discussed in next sections

general framework of GNNs
General Framework of GNNs

GNN Layer = Message + Aggregation

  • Different instantiation under this perspective
  • GCN, GraphSAGE, GAT, …

Layer Connectivity

  • Stack layers sequentially
  • Ways of adding skip connection

Graph Augmentation

  • Graph feature augmentation
  • Graph structure augmentation

Learning Objective

  • Supervised/Unsupervised objectives
  • Node/Edge/Graph level objectives

A Single Layer of GNN

Message + Aggregation Framework

Message

$$ m_u^{(l)} = \text{MSG}^{(l)}(h_u^{(l-1)}), \quad u\in \lbrace N(v) \cup v \rbrace $$

Example: A linear layer $m_u^{(l)}=W^{(l)}h_u^{(l-1)}$

✍️ Note: Usually a different message computation is for neighbors and node $v$ itself. example:

  • $m_u^{(l)}=\textcolor{tomato}{W^{(l)}}h_u^{(l)}$
  • $m_v^{(l)}=\textcolor{dodgerblue}{B^{(l)}}h_v^{(l)}$

Aggregation

$$ h_v^{(l)} = \text{AGG}^{(l)}(\lbrace m_u^{l}, u \in N(v) \rbrace, m_v^{(l)}) $$

⚠️ $\text{AGG}^{(l)}$ should be an order invariant function (works on sets/multi-sets, not sequences)

Example: Sum(.), Mean(.), Max(.), etc.

  • Pooling ones (like Max) are coordinate-wise.

✍️ Note: We can (should!) Add expressiveness using Nonlinearity

  • $\sigma(\cdot)$: $\text{Sigmoid}(\cdot)$, $\text{ReLU}(\cdot)$, etc.
  • Can be added to message or aggregation

⬇️ In the followings deepppink is message and orange is aggregation.

GCN

$$ h_v^{(l)} = \textcolor{orange}{\sigma( \sum_{u\in N(v)}} \textcolor{deeppink}{W^{(l)}\frac{h_u^{(l)}}{|N(v)|}} \textcolor{orange}{)} $$

✍️ In the GCN original paper they used different normalization rather than $1/|N(v)|$, link.

GraphSAGE

$$ h_v^{(l)} = \textcolor{orange}{\sigma( W^{(l)}\cdot \text{CONCAT(}} \textcolor{deeppink}{h_v^{(l-1)}}, \textcolor{orange}{\text{AGG}(} \lbrace\textcolor{deeppink}{h_u^{(l-1)}}, \forall u \in N(v)\rbrace \textcolor{orange}{)))} $$

✍️ Notes:

  • Two-stage aggregation is used:
    • Aggregation from node neighbors
      • the output of this stage is a message itself
    • Further aggregation from the node itself
  • $\text{AGG}$ can be Mean, Pool, or apply LSTM to reshuffled of neighbors.
  • (Optional) $l_2$ Normalization
    • $h_v^{(l)} \leftarrow h_v^{(l)}/||h_v^{(l)} ||_2$
    • In some cases results in performance improvement

GAT (Graph ATtention)

$$ h_v^{(l)} = \textcolor{orange}{\sigma( \sum_{u\in N(v)}} \textcolor{deeppink}{\alpha_{vu}W_V^{(l)}h_u^{(l-1)}} \textcolor{orange}{)} $$

Where:

  1. Attention coefficient $e_{vu}$ (indicates importance of $u$’s message to node $v$)

$$ e_{vu} = a(W_Q^{(l)}h_v^{(l-1)},W_K^{(l)}h_u^{(l-1)}) $$

  1. Normalize $e_{vu}$ to get attention weight $\alpha_{vu}$ (using softmax)

$$ \alpha_{vu} = \frac{\text{exp}(e_{vu})}{\sum_{k\in N(v)}\text{exp}{(e_{vk})}} $$

✍️ Notes:

  • Attention is inspired by cognitive attention
  • In GCN/GraphSAGE: $\alpha_{vu}=1/|N(v)|$
  • $a$ can be a simple single-layer neural network
  • In attention terminology, $K$ is key, $Q$ is query, and $V$ is value.

Multi-head attention

  • Create multiple attention score
    • $h_v^{(l)}[1] = \sigma(\sum_{u\in N(v)}\textcolor{red}{\alpha_{vu}^1}W^{(l)}h_u^{(l-1)})$
    • $h_v^{(l)}[2] = \sigma(\sum_{u\in N(v)}\textcolor{green}{\alpha_{vu}^2}W^{(l)}h_u^{(l-1)})$
    • $h_v^{(l)}[3] = \sigma(\sum_{u\in N(v)}\textcolor{blue}{\alpha_{vu}^3}W^{(l)}h_u^{(l-1)})$
  • Aggregate (e.g. concatenation)
    • $h_v^{(l)} = \text{AGG}(h_v^{(l)}[1],h_v^{(l)}[2],h_v^{(l)}[3])$

Benefits of attention mechanism

  • Computationally efficient
  • Storage efficient
  • Localized
  • Inductive capability

GNN Layer in Practice

Many deep learning techniques can be used here, I just mention a few of them:

  • Batch Normalization
    • Stabilize neural network training
  • Dropout
    • Prevent over-fitting
  • Attention/Gating
    • Control the importance of a message

Stacking Multiple Layers of GNNs

Over-smoothing problem

Link

After Adding multiple layers of GNNs sequentially, all nodes converge to similar embedding which makes them hard to differentiate. It’s because of the Receptive field overlap in the depth layer.

solution

  • Be cautious when adding layers: link
    • Make aggregation/transformation become a deep neural network. link
    • Add layers that don’t pass a message (MLP layers for pre-processing and post-processing). link
  • Add skip connection in GNNs: link
    • How to apply: link
    • Other methods of skip connection: link

Graph Augmentation for GNNs

Our assumption so far

Raw input graph = Computation graph

Reasons for breaking this assumption

  • Features
    • The input graph lacks features
      • $\rightarrow$ Feature augmentation
  • Graph structure
    • The graph is too sparse
      • $\rightarrow$ Add virtual nodes / edges
    • The graph is too dense
      • $\rightarrow$ Sample neighbors when doing message passing
    • The graph is too large
      • $\rightarrow$ Sample subgraphs to compute embedding

Feature augmentation

Used when graph does not have node features (common when we only have adjacency matrix)

Standard approaches

  • Constant value
  • Unique ID to nodes
    • should be converted to one-hot vectors

Comparison

constant vs one-hot node features
constant vs one-hot node features

⚠️ Some structures are hard to learn by GNNs

Example: Cycle count feature

cycle count feature
cycle count feature

GNN can’t learn the length of a cycle that $v_1$ resides in

Therefore, we should augment node features with cycle count manually by adding a vector of cycle counts like this:

$$ \text{cycle count feature: } [0, 0, 0, 1, 0, 0] $$

where index 0 indicates that $v_1$ resides in cycle of length $0$, index 1 indicates cycle of length $1$, and so on.

Commonly used augmented features

  • Node Degree
  • PageRank
  • Clustering coefficient
  • Node centrality
    • Eigenvector
    • Betweenness
    • Closeness
  • Graphlet

Add virtual nodes / edges

Virtual edge

Common approach: connect 2-hop neighbors

  • instead of using adj. matrix $A$ for GNN computation, use $A^2 + A$

Use cases:

  • Bipartite graphs

Virtual node

Add a virtual node and connect it to other nodes (all or some of them)

Benefits:

  • Greatly improves message passing in sparse graphs

Node neighborhood sampling

Instead of using all nodes for message passing, (randomly) sample a node’s neighborhood. for example, if a node has 5 neighbors, sample 2 of them in the message passing phase.

Benefits:

  • Reduce computational cost
    • Allows for scaling to large graphs
  • in practice it works great 👌

Prediction with GNNs

GNN training pipeline

GNN training pipeline
GNN training pipeline

GNN output

GNN output is a set of node embeddings:

$$ \lbrace h_v^{(L)}, \forall v \in G \rbrace $$

Prediction head

  • Node-level tasks
  • Edge-level tasks
  • Graph-level tasks

Node-level

We can directly use node embeddings or transform them into label space ($y_v$ is ground truth label and $\widehat{y}_v$ is model output):

$$ \textcolor{Chartreuse}{\widehat{y}_v = h_v^{(L)}} $$

$$ \text{or} $$

$$ \textcolor{Chartreuse}{\widehat{y}_v = Head_{node}(h_v^{(L)}) = W^{(H)}h_v^{(L)}} $$

Edge-level

$$ \color{Chartreuse} \widehat{y}_{uv} = Head_{edge}(h_u^{(L)},h_v^{(L)}) $$

Options for $Head_{edge}$:

  • Concatenation + Linear: $\text{Linear}(\text{Concat}(h_u^{(L)},h_v^{(L)}))$
    • Similar to graph attention
  • Dot product: $(h_u^{(L)})^T h_v^{(L)}$
    • This approach only applies to 1-way prediction
    • Applying to $k$-way prediction
      • Similar to multi-head attention: $W^{(1)}, …, W^{(K)}$ is trainable

$$ \widehat{y}_{uv}^{(1)} = (h_u^{(L)})^T \textcolor{red}{W^{(1)}}h_v^{(L)} $$

$$ \dots $$

$$ \widehat{y}_{uv}^{(K)} = (h_u^{(L)})^T \textcolor{red}{W^{(K)}}h_v^{(L)} $$

$$ \widehat{y}_{uv} = \text{Concat}(\widehat{y}_{uv}^{(1)}, \dots, \widehat{y}_{uv}^{(K)}) \in \mathbb{R}^k $$

Graph-level

$$ \color{Chartreuse} \widehat{y}_G = Head_{graph}(\lbrace h_v^{(L)} \in \mathbb{R}^d, \forall v \in G \rbrace) $$

Options for $Head_{graph}$:

  • Global mean pooling
  • Global max pooling
  • Global sum pooling

These global pooling methods will work great over small graphs but

⚠️ Global pooling over a (large) graph will lose information

Hierarchical global pooling

Hierarchically pool section of nodes.

Which section?

Diffpool idea:

Diffpool
Diffpool

  • Leverage 2 independent GNNs at each level
    • GNN A: Compute node embeddings
    • GNN B: Compute the cluster that a node belongs to
  • GNNs A and B at each level can be executed in parallel
  • GNNs A and B are trained jointly

Training GNNs

Setting-up GNN Prediction Tasks

When Things Don’t Go As Planned

The three above topics are almost the same in other deep learning areas so I will put them aside for now.