Introduction to TensorFlow GNN
For my first post in this blog, I’ve decided to write a beginner-level introduction to Graph Neural Networks with TensorFlow-GNN. An interactive version of this post, where you can run and modify the code, is available as a Kaggle notebook and also on GoogleColab. Please let me know if you find any bugs or errors, and also if you thought this was useful or you’d like more details about some particular point.
[July 1, 2022 UPDATE] The Kaggle notebook version of this post has been awarded the Google Open Source Expert Prize for the month of July.
[July 5, 2022 UPDATE] I’ve now updated this notebook to work with the latest TF-GNN release, v0.2.0; some typos were also corrected, and more resources have been added to §5.1.
Contents
- 1. Introduction
- 2. Data preparation
- 3. Vanilla MPNN models
- 4. Graph binary classification
- 5. Conclusions
1. Introduction
1.1 Graphs and Graph Neural Networks
A graph consists of a collection \(\mathcal{V}=\{1,\dots,n\}\) of nodes, sometimes also called vertices, and a set of edges \(\mathcal{E} \subseteq \mathcal{V} \times \mathcal{V}\). An edge \((i,j) \in \mathcal{E}\), also denoted \(i\to j\), represents a relation from node \(i\) to node \(j\) which in the general case is directed (i.e. does not necessarily imply the converse relation \(j\to i\)). Heterogeneous graphs can have different types of nodes and/or edges, and decorated graphs may have features associated to their nodes, \(h_i \in \mathbb{R}^{d_\mathcal{V}}\) for \(i=1,\dots,n\), and/or their edges, \(h_{i\to j} \in \mathbb{R}^{d_\mathcal{E}}\) for \((i,j)\in\mathcal{E}\).
Graphs are clearly a very general and powerful concept, which can be used to represent many different kinds of data. For example:
- A social network may be thought of as a graph where the nodes correspond to different users, and the edges correspond to their relationships (i.e. “friendship”, “follower”, etc).
- A country may be represented as a graph where cities are considered to be nodes, with roads connecting them playing the role of edges.
- …
Graph Neural Networks (GNN) are one way in which we can apply neural networks to graph-structured data in order to learn to make predictions about it. Their architecture typically involves stacking message passing layers, each of which updates the features of each node \(i\) in the graph by applying some function to the features of its neighbors, i.e. all the nodes \(j\) which have edges going from \(j\) to \(i\). More formally, a message passing layer implements the transformation
\[h_i' = f_{\theta_\mathcal{V}}\left(h_i, \tilde{\sum}_{(j,i) \in \mathcal{E}} g_{\theta_\mathcal{E}}\left(h_j, h_{j\to i}\right)\right) \qquad\text{for}\qquad i = 1, \dots, n,\]where \(f_{\theta_\mathcal{V}}\) is a neural network applied at each node with parameters \(\theta_\mathcal{V}\), and \(g_{\theta_\mathcal{E}}\) is a neural network applied at each edge with parameters \(\theta_\mathcal{E}\). Note that the latter may be trivial if edges do not have features. In the formula above \(\tilde\sum\) in fact represents any orderless aggregation function, which can be a conventional summation, an averaging operation, etc. The key point is that, being orderless, a GNN constructed in this way will be invariant under the permutation of the node labels \(i=1,\dots,n\), a property which of course should hold since the labels themselves are generally speaking arbitrary: the information is encoded in the graph structure itself, not the particular labels chosen to represent it.
1.2 TensorFlow GNN
TensorFlow GNN (TF-GNN) is a fairly recent addition to the TensorFlow ecosystem, having been first released in late 2021. At this time there are relatively few resources available explaining how to use TF-GNN in practice (see §5.1 for more); this notebook attempts to start filling the gap by providing a basic example of the application of TF-GNN to a real-world problem, namely the classification of molecules to detect cardiotoxicity.
TF-GNN is very powerful and flexible, and can deal with large, heterogeneous graphs. We will however keep things simple here, and work instead with rather small, homogeneous graphs having few features. The main focus will be on showcasing the basic concepts and design abstractions introduced in TF-GNN to facilitate and accelerate end-to-end training of GNN models, particularly:
-
GraphTensor
objects, encapsulating graph structure and features in a TF-compatible format. - Pre-built GraphUpdate layers implementing some of the message-passing protocols most commonly used in the literature.
- The orchestrator and related protocols, which can be used to skip most of the technical boilerplate that would otherwise be required for end-to-end training.
1.3 Graph classification
A prototypical example of application of GNNs are graph classification problems, whereupon we are tasked to predict the class of a given graph among a finite set of possible options. In the supervised setting, we have a set of labeled graphs from which we can learn the parameters \(\{\theta_\mathcal{V}, \theta_{\mathcal{E}}\}\) of the GNN by minimizing a categorical crossentropy loss for our predictions. These are obtained from the GNN after applying one final aggregation or pooling operation over all nodes, which again should be orderless to preserve the important property that the output is invariant under relabeling of the nodes.
In this notebook we will undertake the classification of molecules which have been labeled as being either toxic or non-toxic in the CardioTox dataset introduced by:
[1] K. Han, B. Lakshminarayanan and J. Liu, “Reliable Graph Neural Networks for Drug Discovery Under Distributional Shift,” (arxiv:2111.12951)
We will take the nodes of our graphs to be the atoms in a molecule, with edges representing atomic bonds between them. Thus, our problem is a binary classification one (because we have just two classes) on homogeneous graphs (because we have only one type of nodes and one type of edges) which are both node and edge decorated (because, as will be seen later, the data contains features for atoms and for the bonds between them).
The authors of [1] introduce various clever architectural choices to improve the performance of their model and reduce the risk of false-negative predictions, which are particularly undesirable in a toxicity-detection problem. For illustration purposes here we will not attempt to reproduce their results in full generality, but instead implement a very simple baseline based on Graph Attention Networks introduced in:
[2] P. Veličković, G. Cucurull, A. Casanova, A. Romero, P. Liò and Y. Bengio “Graph Attention Networks,” (arxiv:1710.10903)
[3] S. Brody, U. Alon and E. Yahav, “How Attentive are Graph Attention Networks?,” (arxiv:2105.14491)
The interested reader can later try different improvements to the code in this notebook, and thanks to TF-GNN these will be easy to introduce once the training pipeline has been established.
1.4 Dependencies and imports
Since TF-GNN is currently in an early alpha release stage, installation may have a few hiccups. I have found that the following combination works well in the Kaggle notebook environment:
For GoogleColab, you can instead use:
Either way, once everything is installed we can proceed:
Using TensorFlow v2.8.0 and TensorFlow-GNN v0.2.0 GPUs available: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
2. Data preparation
Unlike other types of data, there is no standard encoding for graphs. Indeed, depending on the intended application the graph structure can be represented by:
- An adjacency matrix \(A_{ij}\) specifying whether the edge \(i\to j\) is present or not in \(\mathcal{E}\).
- An adjacency list \(N_i\) for each node \(i\in\mathcal{V}\), specifying all the nodes \(j\in\mathcal{V}\) for which an edge \((i, j)\in\mathcal{E}\)
- A list of edges \((i,j)\in\mathcal{E}\)
To address this issue TF-GNN introduces the GraphTensor
object, which encapsulates both the graph structure and the features of the nodes, edges and the graph itself. These objects follow a graph schema, which specifies the types of nodes and edges as well as all the features that should be present in the graphs. The first step in any TF-GNN training pipeline should therefore be to convert the input data from whichever format is given into GraphTensor
format. These GraphTensor
objects can then be batched and consumed by our TF-GNN models just like a tf.Tensor
, hugely simplifying our lives in the process.
The goal of this section is to perform the task described above, resulting in DatasetProvider
objects for our data (see here and the following section for a description of the DatasetProvider
protocol). Let us then take a look at the CardioTox dataset, which we can automatically download from TensorFlow Datasets (TF-DS):
Drug Cardiotoxicity dataset [1-2] is a molecule classification task to detect cardiotoxicity caused by binding hERG target, a protein associated with heart beat rhythm. The data covers over 9000 molecules with hERG activity. Note: 1. The data is split into four splits: train, test-iid, test-ood1, test-ood2. 2. Each molecule in the dataset has 2D graph annotations which is designed to facilitate graph neural network modeling. Nodes are the atoms of the molecule and edges are the bonds. Each atom is represented as a vector encoding basic atom information such as atom type. Similar logic applies to bonds. 3. We include Tanimoto fingerprint distance (to training data) for each molecule in the test sets to facilitate research on distributional shift in graph domain. For each example, the features include: atoms: a 2D tensor with shape (60, 27) storing node features. Molecules with less than 60 atoms are padded with zeros. Each atom has 27 atom features. pairs: a 3D tensor with shape (60, 60, 12) storing edge features. Each edge has 12 edge features. atom_mask: a 1D tensor with shape (60, ) storing node masks. 1 indicates the corresponding atom is real, othewise a padded one. pair_mask: a 2D tensor with shape (60, 60) storing edge masks. 1 indicates the corresponding edge is real, othewise a padded one. active: a one-hot vector indicating if the molecule is toxic or not. [0, 1] indicates it's toxic, otherwise [1, 0] non-toxic. ## References [1]: V. B. Siramshetty et al. Critical Assessment of Artificial Intelligence Methods for Prediction of hERG Channel Inhibition in the Big Data Era. JCIM, 2020. https://pubs.acs.org/doi/10.1021/acs.jcim.0c00884 [2]: K. Han et al. Reliable Graph Neural Networks for Drug Discovery Under Distributional Shift. NeurIPS DistShift Workshop 2021. https://arxiv.org/abs/2111.12951
As mentioned before, we will have only one set of nodes, which we call 'atom'
, and one set of edges, which we call 'bond'
. Of course, all edges have 'atom'
-type nodes at both endpoints. Both nodes and edges have a single feature vector, the former being a 27-dimensional 'atom_features'
vector, and the latter being a 12-dimensional 'bond_features'
vector. Moreover, the graphs themselves have global features giving their context, in this case a toxicity class 'toxicity'
, which is in fact the label we want to predict, and a molecule id 'molecule_id'
which we will mostly ignore.
All of the above can be encoded in the following graph schema specifying the structure and contents of our graphs:
This schema is a textual protobuf, which we can parse to obtain a GraphTensorSpec
(think TensorSpec
for GraphTensor
objects):
We should then convert the input dataset into GraphTensor
objects complying to graph_spec
, which we will do with the following helper function:
We can now map this function over the datasets to have them stream GraphTensor
objects:
GraphTensor( context=Context(features={'toxicity': <tf.Tensor: shape=(1,), dtype=tf.int64>, 'molecule_id': <tf.Tensor: shape=(1,), dtype=tf.string>}, sizes=[1], shape=(), indices_dtype=tf.int32), node_set_names=['atom'], edge_set_names=['bond'])
And check that the GraphTensor
thus produced are compatible with the GraphTensorSpec
we defined before:
True
However, to avoid processing the data multiple times, which would slow down all of our input pipeline, it is convenient to first dump all the data into TFRecord files. Later on we can easily load these instead of the original TF-DS datasets over which we mapped the make_graph_tensor
function.
NOTE: The create_tfrecords
method below works nicely, is rather general and could be immediately reused for other small-scale applications. For large-scale datasets, however, alternative approaches using tf.data.Dataset.cache
or tf.data.Dataset.snapshot
would be preferable, as they would allow for more optimizations such as e.g. sharding.
creating data/cardiotox-train.tfrecord... 100%|██████████| 6523/6523 [00:46<00:00, 140.23it/s] creating data/cardiotox-validation.tfrecord... 100%|██████████| 1631/1631 [00:11<00:00, 138.38it/s] creating data/cardiotox-test.tfrecord... 100%|██████████| 839/839 [00:05<00:00, 142.62it/s] creating data/cardiotox-test2.tfrecord... 100%|██████████| 177/177 [00:01<00:00, 140.72it/s]
Finally, we can use the TFRecordDatasetProvider
class to create DatasetProvider
-compliant objects that read these TFRecord files and provide tf.data.Dataset
objects for us to use through their get_dataset
method:
2.1 The DatasetProvider protocol
Each of the DatasetProvider
defined above conventionally produces a dataset of serialized GraphTensor
objects, which we need to parse before we can inspect. This is mentioned here for reference purposes only: the orchestrator will transparently deal with this during actual training.
To obtain the dataset we need to provide an input context:
We then map tfgnn.parse_single_example
over this dataset, specifying the appropriate GraphTensorSpec
for our graphs:
And we can then stream GraphTensor
objects as before
GraphTensor( context=Context(features={'toxicity': <tf.Tensor: shape=(1,), dtype=tf.int64>, 'molecule_id': <tf.Tensor: shape=(1,), dtype=tf.string>}, sizes=[1], shape=(), indices_dtype=tf.int32), node_set_names=['atom'], edge_set_names=['bond'])
2.2 Data inspection
The node and edge features are not particularly illustrative, but we can nevertheless access them directly if necessary. First, note that this particular molecule has the following number \(V = \left\vert\mathcal{V}\right\vert\) of atoms:
<tf.Tensor: shape=(1,), dtype=int32, numpy=array([33], dtype=int32)>
Their features are collected in a tensor of shape (V, 27)
, which we can access like so:
<tf.Tensor: shape=(33, 27), dtype=float32, numpy= array([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0.], [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 2., 0., 0., 0., 1., 0., 0., 0., 1.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 2., 0., 0., 0., 1., 0., 0., 0., 1.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0.], [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0.]], dtype=float32)>
Similarly, the number \(E = \left\vert\mathcal{E}\right\vert\) of bonds in the molecule is:
<tf.Tensor: shape=(1,), dtype=int32, numpy=array([68], dtype=int32)>
Their features are collected in a tensor of shape (E, 12)
, which we can access like so:
<tf.Tensor: shape=(68, 12), dtype=float32, numpy= array([[1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.], [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.], [0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.]], dtype=float32)>
The ids of the edge endpoints are then stored in a couple of tensors of shape (E,)
<tf.Tensor: shape=(68,), dtype=int32, numpy= array([ 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 5, 5, 5, 6, 7, 7, 8, 8, 9, 9, 10, 10, 10, 11, 12, 12, 13, 13, 14, 14, 15, 15, 15, 16, 17, 17, 18, 18, 19, 19, 20, 20, 20, 21, 22, 23, 23, 23, 24, 25, 25, 25, 26, 26, 27, 27, 28, 28, 29, 29, 30, 30, 30, 31, 31, 31, 32], dtype=int32)>
<tf.Tensor: shape=(68,), dtype=int32, numpy= array([ 1, 0, 2, 31, 1, 3, 23, 2, 4, 3, 5, 4, 6, 7, 5, 5, 8, 7, 9, 8, 10, 9, 11, 12, 10, 10, 13, 12, 14, 13, 15, 14, 16, 17, 15, 15, 18, 17, 19, 18, 20, 19, 21, 22, 20, 20, 2, 24, 25, 23, 23, 26, 30, 25, 27, 26, 28, 27, 29, 28, 30, 25, 29, 31, 1, 30, 32, 31], dtype=int32)>
Finally, global information about the graph is provided by its context
<tf.Tensor: shape=(1,), dtype=int64, numpy=array([0])>
<tf.Tensor: shape=(1,), dtype=string, numpy= array([b'CC1=C(C/C=C(\\C)CCC[C@H](C)CCC[C@H](C)CCCC(C)C)C(=O)c2ccccc2C1=O'], dtype=object)>
With all of this we can write the following helper function to visualize the graphs:
2.3 GraphTensor batching
GraphTensor
datasets can be batched as usual, resulting in new datasets that produce higher-rank GraphTensor
objects:
1
The resulting GraphTensor
now contains features in the form of tf.RaggedTensor
, since different graphs can have different numbers of nodes and edges:
TensorShape([64, None, 27])
TensorShape([64, None, 12])
where the shapes now correspond to (batch_size, V, 27)
and (batch_size, E, 12)
.
However, all layers in TF-GNN expect scalar graphs as their inputs, so before actually using a batch of graphs we should always “merge” the different graphs in the batch into a single graph with multiple disconnected components (of which TF-GNN automatically keeps track):
0
Now the atom features again have shape (V', 27)
, where \(V' = \sum_{k=1}^{\rm batch\_size} V_k\)
TensorShape([1562, 27])
And the bond fetures have shape \((E', 12)\) with \(E' = \sum_{k=1}^{\rm batch\_size} E_k\)
TensorShape([3370, 12])
We should note, however, that once more the orchestrator will transparently take care of batching and merging components for us, so that we needn’t worry about this as long as we don’t customize the training routines.
3. Vanilla MPNN models
A common architecture for GNNs consists of an initial layer which preprocesses the graph features, typically producing hidden states for nodes and/or edges, which is followed by one or more layers of message-passing working as described in the Introduction. The goal of this section is to define a vanilla_mpnn_model
function which can be used to create such simple GNNs from:
- an initial layer performing the pre-processing
- a layer stacking multiple message-passing layers
3.1 Initial graph embedding
For the first of these tasks we will use a tfgnn.keras.layers.MapFeatures
layer to create hidden state vectors for atoms and bonds from their respective features, by passing these through a dense layer. The resulting hidden states will have dimension hidden_size
, corresponding in the notation of the Introduction to \(d_\mathcal{V}\) and \(d_\mathcal{E}\).
The following helper function will create an initial MapFeatures
layer for the given hyperparameters:
-
hidden_size
: the hidden dimensions \(d_\mathcal{V}\) and \(d_\mathcal{E}\) -
activation
: the activation for the dense layers
We can check that the resulting layer replaces the 'atom_features'
and 'bond_features'
with hidden states of the specified dimensions
{'hidden_state': <tf.Tensor: shape=(1562, 128), dtype=float32, numpy= array([[0.15272579, 0. , 0. , ..., 0. , 0. , 0.00244589], [0.20299977, 0.15906705, 0. , ..., 0.2211346 , 0. , 0.23549727], [0.20299977, 0.15906705, 0. , ..., 0.2211346 , 0. , 0.23549727], ..., [0.3452986 , 0. , 0.01703222, ..., 0. , 0. , 0. ], [0.23036027, 0. , 0. , ..., 0. , 0. , 0. ], [0.595324 , 0. , 0.3375612 , ..., 0.23380664, 0. , 0.11104017]], dtype=float32)>}
{'hidden_state': <tf.Tensor: shape=(3370, 128), dtype=float32, numpy= array([[0. , 0. , 0. , ..., 0. , 0. , 0. ], [0. , 0. , 0. , ..., 0. , 0. , 0. ], [0. , 0.19551635, 0.05461648, ..., 0. , 0.06973472, 0. ], ..., [0.01244138, 0. , 0. , ..., 0. , 0.02111641, 0. ], [0.01244138, 0. , 0. , ..., 0. , 0.02111641, 0. ], [0.01244138, 0. , 0. , ..., 0. , 0.02111641, 0. ]], dtype=float32)>}
Note that both the atom and bond features are now named 'hidden_state'
; we could of course have chosen a different name, but leaving the default tfgnn.HIDDEN_STATE
will save us from having to specify feature names in what follows.
3.2 Stacking message-passing layers
To illustrate how to build a stack of message-passing layers, we will use the pre-built Graph Attention (GAT) [2] layers provided in the models.gat_v2
module. We then define a message-passing neural network (MPNN) layer successively applying these layers with hyperparameters:
-
hidden_size
: the hidden dimensions \(d_\mathcal{V}\) and \(d_\mathcal{E}\) -
hops
: the number of layers in the stack
We can now check that this layer processes the embedded graphs that come out from the initial feature map:
{'hidden_state': <tf.Tensor: shape=(1562, 128), dtype=float32, numpy= array([[0. , 0. , 0.4671823 , ..., 0.18130356, 0. , 0. ], [0. , 0. , 0.4360276 , ..., 0.20995092, 0. , 0. ], [0. , 0. , 0.4278612 , ..., 0.20583078, 0. , 0. ], ..., [0. , 0.00174278, 0.51690316, ..., 0.20055819, 0. , 0. ], [0. , 0.01096402, 0.54201955, ..., 0.20254537, 0. , 0. ], [0. , 0.0008509 , 0.5069808 , ..., 0.19695306, 0. , 0. ]], dtype=float32)>}
{'hidden_state': <tf.Tensor: shape=(3370, 128), dtype=float32, numpy= array([[0. , 0. , 0. , ..., 0. , 0. , 0. ], [0. , 0. , 0. , ..., 0. , 0. , 0. ], [0. , 0.19551635, 0.05461648, ..., 0. , 0.06973472, 0. ], ..., [0.01244138, 0. , 0. , ..., 0. , 0.02111641, 0. ], [0.01244138, 0. , 0. , ..., 0. , 0.02111641, 0. ], [0.01244138, 0. , 0. , ..., 0. , 0.02111641, 0. ]], dtype=float32)>}
3.3 Model construction
We are now ready to combine both ingredients into a tf.keras.Model
that takes a GraphTensor
representing a molecule as input, and produces as its output another GraphTensor
with hidden states for all the atoms. We use Keras’ functional API to define a vanilla_mpnn_model
helper function returning the desired tf.keras.Model
:
Model: "model_2" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_3 (InputLayer) [()] 0 graph_embedding (MapFeature () 5248 s) gat_mpnn (MPNN) () 396288 ================================================================= Total params: 401,536 Trainable params: 401,536 Non-trainable params: 0 _________________________________________________________________
For later convenience, let us encapsulate all of this logic in a function we can use to get model constructors for fixed hyperparameters. The returned constructor by convention takes only the GraphTensorSpec
of the input graphs for the model, and for good measure our constructor will also add some \(L_2\) regularization through an l2_coefficient
hyperparameter:
Model: "model_3" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_4 (InputLayer) [()] 0 graph_embedding (MapFeature () 5248 s) gat_mpnn (MPNN) () 396288 ================================================================= Total params: 401,536 Trainable params: 401,536 Non-trainable params: 0 _________________________________________________________________
4. Graph binary classification
4.1 Task specification
Having a GNN model at our disposal, we are now ready to apply it to the task at hand, namely the binary classification of molecules predicting their toxicity. This involves:
- Adding a readout and prediction head, which computes logits for each class from the features computed by the GNN.
- Defining the loss function to be minimized, which in this case should be a categorical crossentropy loss.
- Defining the metrics we are interested in measuring during training and validation.
The orchestrator defines the Task
protocol to achieve these goals, and conveniently provides a pre-implemented GraphBinaryClassification
class which complies with this protocol. While we could use it as is, for illustration purposes here we will extend its basic implementation in two ways:
- We will include the AUROC metric, which is reported by the authors of [1] (other metrics reported there can be added in a similar fashion).
- We will generalize the readout and prediction head to include a hidden layer.
First, we define a simple wrapper around the tf.keras.metrics.AUC
class to adapt it to our conventions:
Next, we subclass the GraphBinaryClassification
task and override its adapt
and metrics
methods:
To create an instance of this class we need to specify the node set which will be used to aggregate hidden states for prediction (remember in our case there is only one, 'atom'
) and the number of classes (two, for toxic and non-toxic), as well as the new hyperparameter hidden_dim
:
This instance then provides everything we will need for training, namely:
- The loss function
(<keras.losses.SparseCategoricalCrossentropy at 0x7fe2bc861650>,)
- The metrics
(<keras.metrics.SparseCategoricalAccuracy at 0x7fe2bc795550>, <keras.metrics.SparseCategoricalCrossentropy at 0x7fe2bc861850>, <__main__.AUROC at 0x7fe2bc90a590>)
- An
adapt
method to place the readout and prediction head on top of the GNN
Model: "model_4" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_4 (InputLayer) [()] 0 [] graph_embedding (MapFeatures) () 5248 ['input_4[0][0]'] gat_mpnn (MPNN) () 396288 ['graph_embedding[0][0]'] input.node_sets (InstancePrope {'atom': ()} 0 ['gat_mpnn[0][0]'] rty) input.sizes (InstanceProperty) (1,) 0 ['input.node_sets[0][0]'] tf.math.cumsum (TFOpLambda) (1,) 0 ['input.sizes[0][0]'] tf.math.reduce_sum (TFOpLambda () 0 ['input.sizes[0][0]'] ) tf.ones_like (TFOpLambda) (1,) 0 ['tf.math.cumsum[0][0]'] tf.__operators__.add (TFOpLamb () 0 ['tf.math.reduce_sum[0][0]'] da) tf.math.unsorted_segment_sum ( (None,) 0 ['tf.ones_like[0][0]', TFOpLambda) 'tf.__operators__.add[0][0]', 'tf.math.cumsum[0][0]'] tf.math.cumsum_1 (TFOpLambda) (None,) 0 ['tf.math.unsorted_segment_sum[0] [0]'] input._get_features_ref_4 (Ins {'hidden_state': (N 0 ['input.node_sets[0][0]'] tanceProperty) one, 128)} tf.__operators__.getitem (Slic (None,) 0 ['tf.math.cumsum_1[0][0]', ingOpLambda) 'tf.math.reduce_sum[0][0]'] tf.math.unsorted_segment_mean (1, 128) 0 ['input._get_features_ref_4[0][0] (TFOpLambda) ', 'tf.__operators__.getitem[0][0]' ] hidden_layer (Dense) (1, 256) 33024 ['tf.math.unsorted_segment_mean[0 ][0]'] logits (Dense) (1, 2) 514 ['hidden_layer[0][0]'] ================================================================================================== Total params: 435,074 Trainable params: 435,074 Non-trainable params: 0 ________________________________________________________________________________________________
The resulting model then produces logits for each class, given a GraphTensor
as its input:
<tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[-0.05114917, -0.23557672]], dtype=float32)>
- A
preprocessors
method, which could be used to pre-process the graphs before reaching the GNN but will remain unused here
()
4.2 Training
We are now ready to train the model. First, we create a KerasTrainer
instance which implements the orchestrator’s Trainer
protocol making use of Keras’ fit
method:
Next, we define a simple function conforming to the GraphTensorProcessorFn
protocol, which extracts the labels from the GraphTensor
objects, for use during supervised training (this function will be mapped over the datasets that are then passed on to the tf.keras.Model.fit
method):
Lastly, we can put everything together and get a coffee while watching some progress bars move :-)
Epoch 1/50 /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/indexed_slices.py:446: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_7/node_set_update_7/gat_v2_conv/Reshape_3:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_7/node_set_update_7/gat_v2_conv/Reshape_2:0", shape=(None, 1, 1), dtype=float32), dense_shape=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_7/node_set_update_7/gat_v2_conv/Cast:0", shape=(3,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. "shape. This may consume a large amount of memory." % value) /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/indexed_slices.py:446: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_7/node_set_update_7/gat_v2_conv/Reshape_6:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_7/node_set_update_7/gat_v2_conv/Reshape_5:0", shape=(None, 1, 1), dtype=float32), dense_shape=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_7/node_set_update_7/gat_v2_conv/Cast_1:0", shape=(3,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. "shape. This may consume a large amount of memory." % value) /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/indexed_slices.py:446: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_6/node_set_update_6/gat_v2_conv/Reshape_3:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_6/node_set_update_6/gat_v2_conv/Reshape_2:0", shape=(None, 1, 1), dtype=float32), dense_shape=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_6/node_set_update_6/gat_v2_conv/Cast:0", shape=(3,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. "shape. This may consume a large amount of memory." % value) /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/indexed_slices.py:446: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_6/node_set_update_6/gat_v2_conv/Reshape_6:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_6/node_set_update_6/gat_v2_conv/Reshape_5:0", shape=(None, 1, 1), dtype=float32), dense_shape=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_6/node_set_update_6/gat_v2_conv/Cast_1:0", shape=(3,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. "shape. This may consume a large amount of memory." % value) /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/indexed_slices.py:446: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_5/node_set_update_5/gat_v2_conv/Reshape_3:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_5/node_set_update_5/gat_v2_conv/Reshape_2:0", shape=(None, 1, 1), dtype=float32), dense_shape=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_5/node_set_update_5/gat_v2_conv/Cast:0", shape=(3,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. "shape. This may consume a large amount of memory." % value) /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/indexed_slices.py:446: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_5/node_set_update_5/gat_v2_conv/Reshape_6:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_5/node_set_update_5/gat_v2_conv/Reshape_5:0", shape=(None, 1, 1), dtype=float32), dense_shape=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_5/node_set_update_5/gat_v2_conv/Cast_1:0", shape=(3,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. "shape. This may consume a large amount of memory." % value) /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/indexed_slices.py:446: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_4/node_set_update_4/gat_v2_conv/Reshape_3:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_4/node_set_update_4/gat_v2_conv/Reshape_2:0", shape=(None, 1, 1), dtype=float32), dense_shape=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_4/node_set_update_4/gat_v2_conv/Cast:0", shape=(3,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. "shape. This may consume a large amount of memory." % value) /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/indexed_slices.py:446: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_4/node_set_update_4/gat_v2_conv/Reshape_6:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_4/node_set_update_4/gat_v2_conv/Reshape_5:0", shape=(None, 1, 1), dtype=float32), dense_shape=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_4/node_set_update_4/gat_v2_conv/Cast_1:0", shape=(3,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. "shape. This may consume a large amount of memory." % value) /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/indexed_slices.py:446: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_3/node_set_update_3/gat_v2_conv/Reshape_3:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_3/node_set_update_3/gat_v2_conv/Reshape_2:0", shape=(None, 1, 1), dtype=float32), dense_shape=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_3/node_set_update_3/gat_v2_conv/Cast:0", shape=(3,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. "shape. This may consume a large amount of memory." % value) /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/indexed_slices.py:446: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_3/node_set_update_3/gat_v2_conv/Reshape_6:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_3/node_set_update_3/gat_v2_conv/Reshape_5:0", shape=(None, 1, 1), dtype=float32), dense_shape=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_3/node_set_update_3/gat_v2_conv/Cast_1:0", shape=(3,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. "shape. This may consume a large amount of memory." % value) /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/indexed_slices.py:446: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_2/node_set_update_2/gat_v2_conv/Reshape_3:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_2/node_set_update_2/gat_v2_conv/Reshape_2:0", shape=(None, 1, 1), dtype=float32), dense_shape=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_2/node_set_update_2/gat_v2_conv/Cast:0", shape=(3,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. "shape. This may consume a large amount of memory." % value) /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/indexed_slices.py:446: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_2/node_set_update_2/gat_v2_conv/Reshape_6:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_2/node_set_update_2/gat_v2_conv/Reshape_5:0", shape=(None, 1, 1), dtype=float32), dense_shape=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_2/node_set_update_2/gat_v2_conv/Cast_1:0", shape=(3,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. "shape. This may consume a large amount of memory." % value) /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/indexed_slices.py:446: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_1/node_set_update_1/gat_v2_conv/Reshape_3:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_1/node_set_update_1/gat_v2_conv/Reshape_2:0", shape=(None, 1, 1), dtype=float32), dense_shape=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_1/node_set_update_1/gat_v2_conv/Cast:0", shape=(3,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. "shape. This may consume a large amount of memory." % value) /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/indexed_slices.py:446: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_1/node_set_update_1/gat_v2_conv/Reshape_6:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_1/node_set_update_1/gat_v2_conv/Reshape_5:0", shape=(None, 1, 1), dtype=float32), dense_shape=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_1/node_set_update_1/gat_v2_conv/Cast_1:0", shape=(3,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. "shape. This may consume a large amount of memory." % value) /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/indexed_slices.py:446: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_0/node_set_update/gat_v2_conv/Reshape_3:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_0/node_set_update/gat_v2_conv/Reshape_2:0", shape=(None, 1, 1), dtype=float32), dense_shape=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_0/node_set_update/gat_v2_conv/Cast:0", shape=(3,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. "shape. This may consume a large amount of memory." % value) /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/indexed_slices.py:446: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_0/node_set_update/gat_v2_conv/Reshape_6:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_0/node_set_update/gat_v2_conv/Reshape_5:0", shape=(None, 1, 1), dtype=float32), dense_shape=Tensor("gradient_tape/model_7/gat_mpnn/message_passing_0/node_set_update/gat_v2_conv/Cast_1:0", shape=(3,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. "shape. This may consume a large amount of memory." % value) 50/50 [==============================] - 10s 89ms/step - loss: 0.5478 - sparse_categorical_accuracy: 0.7352 - sparse_categorical_crossentropy: 0.5478 - AUROC: 0.6414 - val_loss: 0.5037 - val_sparse_categorical_accuracy: 0.7311 - val_sparse_categorical_crossentropy: 0.5037 - val_AUROC: 0.7434 Epoch 2/50 50/50 [==============================] - 3s 50ms/step - loss: 0.4676 - sparse_categorical_accuracy: 0.7645 - sparse_categorical_crossentropy: 0.4676 - AUROC: 0.7896 - val_loss: 0.4758 - val_sparse_categorical_accuracy: 0.7520 - val_sparse_categorical_crossentropy: 0.4758 - val_AUROC: 0.7870 Epoch 3/50 50/50 [==============================] - 2s 50ms/step - loss: 0.4404 - sparse_categorical_accuracy: 0.7889 - sparse_categorical_crossentropy: 0.4404 - AUROC: 0.8219 - val_loss: 0.4469 - val_sparse_categorical_accuracy: 0.7884 - val_sparse_categorical_crossentropy: 0.4469 - val_AUROC: 0.8204 Epoch 4/50 50/50 [==============================] - 2s 49ms/step - loss: 0.4192 - sparse_categorical_accuracy: 0.8055 - sparse_categorical_crossentropy: 0.4192 - AUROC: 0.8423 - val_loss: 0.4276 - val_sparse_categorical_accuracy: 0.8027 - val_sparse_categorical_crossentropy: 0.4276 - val_AUROC: 0.8406 Epoch 5/50 50/50 [==============================] - 3s 58ms/step - loss: 0.4015 - sparse_categorical_accuracy: 0.8191 - sparse_categorical_crossentropy: 0.4015 - AUROC: 0.8577 - val_loss: 0.4257 - val_sparse_categorical_accuracy: 0.8079 - val_sparse_categorical_crossentropy: 0.4257 - val_AUROC: 0.8441 Epoch 6/50 50/50 [==============================] - 3s 50ms/step - loss: 0.3943 - sparse_categorical_accuracy: 0.8202 - sparse_categorical_crossentropy: 0.3943 - AUROC: 0.8631 - val_loss: 0.4186 - val_sparse_categorical_accuracy: 0.8099 - val_sparse_categorical_crossentropy: 0.4186 - val_AUROC: 0.8497 Epoch 7/50 50/50 [==============================] - 2s 50ms/step - loss: 0.3878 - sparse_categorical_accuracy: 0.8286 - sparse_categorical_crossentropy: 0.3878 - AUROC: 0.8675 - val_loss: 0.4138 - val_sparse_categorical_accuracy: 0.8132 - val_sparse_categorical_crossentropy: 0.4138 - val_AUROC: 0.8529 Epoch 8/50 50/50 [==============================] - 3s 54ms/step - loss: 0.3860 - sparse_categorical_accuracy: 0.8322 - sparse_categorical_crossentropy: 0.3860 - AUROC: 0.8687 - val_loss: 0.4108 - val_sparse_categorical_accuracy: 0.8145 - val_sparse_categorical_crossentropy: 0.4108 - val_AUROC: 0.8522 Epoch 9/50 50/50 [==============================] - 2s 50ms/step - loss: 0.3823 - sparse_categorical_accuracy: 0.8327 - sparse_categorical_crossentropy: 0.3823 - AUROC: 0.8717 - val_loss: 0.4095 - val_sparse_categorical_accuracy: 0.8125 - val_sparse_categorical_crossentropy: 0.4095 - val_AUROC: 0.8551 Epoch 10/50 50/50 [==============================] - 3s 52ms/step - loss: 0.3759 - sparse_categorical_accuracy: 0.8350 - sparse_categorical_crossentropy: 0.3759 - AUROC: 0.8762 - val_loss: 0.4053 - val_sparse_categorical_accuracy: 0.8171 - val_sparse_categorical_crossentropy: 0.4053 - val_AUROC: 0.8586 Epoch 11/50 50/50 [==============================] - 2s 47ms/step - loss: 0.3726 - sparse_categorical_accuracy: 0.8363 - sparse_categorical_crossentropy: 0.3726 - AUROC: 0.8780 - val_loss: 0.4177 - val_sparse_categorical_accuracy: 0.8145 - val_sparse_categorical_crossentropy: 0.4177 - val_AUROC: 0.8496 Epoch 12/50 50/50 [==============================] - 2s 49ms/step - loss: 0.3692 - sparse_categorical_accuracy: 0.8389 - sparse_categorical_crossentropy: 0.3692 - AUROC: 0.8800 - val_loss: 0.3957 - val_sparse_categorical_accuracy: 0.8184 - val_sparse_categorical_crossentropy: 0.3957 - val_AUROC: 0.8635 Epoch 13/50 50/50 [==============================] - 2s 48ms/step - loss: 0.3590 - sparse_categorical_accuracy: 0.8458 - sparse_categorical_crossentropy: 0.3590 - AUROC: 0.8871 - val_loss: 0.3963 - val_sparse_categorical_accuracy: 0.8216 - val_sparse_categorical_crossentropy: 0.3963 - val_AUROC: 0.8652 Epoch 14/50 50/50 [==============================] - 3s 54ms/step - loss: 0.3579 - sparse_categorical_accuracy: 0.8466 - sparse_categorical_crossentropy: 0.3579 - AUROC: 0.8877 - val_loss: 0.3913 - val_sparse_categorical_accuracy: 0.8294 - val_sparse_categorical_crossentropy: 0.3913 - val_AUROC: 0.8689 Epoch 15/50 50/50 [==============================] - 3s 54ms/step - loss: 0.3519 - sparse_categorical_accuracy: 0.8519 - sparse_categorical_crossentropy: 0.3519 - AUROC: 0.8916 - val_loss: 0.3798 - val_sparse_categorical_accuracy: 0.8359 - val_sparse_categorical_crossentropy: 0.3798 - val_AUROC: 0.8759 Epoch 16/50 50/50 [==============================] - 2s 48ms/step - loss: 0.3451 - sparse_categorical_accuracy: 0.8536 - sparse_categorical_crossentropy: 0.3451 - AUROC: 0.8961 - val_loss: 0.3811 - val_sparse_categorical_accuracy: 0.8372 - val_sparse_categorical_crossentropy: 0.3811 - val_AUROC: 0.8737 Epoch 17/50 50/50 [==============================] - 2s 49ms/step - loss: 0.3378 - sparse_categorical_accuracy: 0.8580 - sparse_categorical_crossentropy: 0.3378 - AUROC: 0.9010 - val_loss: 0.3823 - val_sparse_categorical_accuracy: 0.8294 - val_sparse_categorical_crossentropy: 0.3823 - val_AUROC: 0.8750 Epoch 18/50 50/50 [==============================] - 2s 48ms/step - loss: 0.3436 - sparse_categorical_accuracy: 0.8573 - sparse_categorical_crossentropy: 0.3436 - AUROC: 0.8964 - val_loss: 0.3918 - val_sparse_categorical_accuracy: 0.8294 - val_sparse_categorical_crossentropy: 0.3918 - val_AUROC: 0.8753 Epoch 19/50 50/50 [==============================] - 2s 48ms/step - loss: 0.3439 - sparse_categorical_accuracy: 0.8562 - sparse_categorical_crossentropy: 0.3439 - AUROC: 0.8969 - val_loss: 0.3861 - val_sparse_categorical_accuracy: 0.8294 - val_sparse_categorical_crossentropy: 0.3861 - val_AUROC: 0.8746 Epoch 20/50 50/50 [==============================] - 2s 47ms/step - loss: 0.3337 - sparse_categorical_accuracy: 0.8612 - sparse_categorical_crossentropy: 0.3337 - AUROC: 0.9032 - val_loss: 0.3891 - val_sparse_categorical_accuracy: 0.8210 - val_sparse_categorical_crossentropy: 0.3891 - val_AUROC: 0.8750 Epoch 21/50 50/50 [==============================] - 2s 49ms/step - loss: 0.3286 - sparse_categorical_accuracy: 0.8661 - sparse_categorical_crossentropy: 0.3286 - AUROC: 0.9056 - val_loss: 0.3782 - val_sparse_categorical_accuracy: 0.8464 - val_sparse_categorical_crossentropy: 0.3782 - val_AUROC: 0.8756 Epoch 22/50 50/50 [==============================] - 3s 51ms/step - loss: 0.3286 - sparse_categorical_accuracy: 0.8658 - sparse_categorical_crossentropy: 0.3286 - AUROC: 0.9066 - val_loss: 0.3853 - val_sparse_categorical_accuracy: 0.8294 - val_sparse_categorical_crossentropy: 0.3853 - val_AUROC: 0.8775 Epoch 23/50 50/50 [==============================] - 2s 47ms/step - loss: 0.3229 - sparse_categorical_accuracy: 0.8692 - sparse_categorical_crossentropy: 0.3229 - AUROC: 0.9090 - val_loss: 0.3849 - val_sparse_categorical_accuracy: 0.8346 - val_sparse_categorical_crossentropy: 0.3849 - val_AUROC: 0.8746 Epoch 24/50 50/50 [==============================] - 2s 48ms/step - loss: 0.3130 - sparse_categorical_accuracy: 0.8747 - sparse_categorical_crossentropy: 0.3130 - AUROC: 0.9151 - val_loss: 0.3758 - val_sparse_categorical_accuracy: 0.8398 - val_sparse_categorical_crossentropy: 0.3758 - val_AUROC: 0.8801 Epoch 25/50 50/50 [==============================] - 2s 50ms/step - loss: 0.3132 - sparse_categorical_accuracy: 0.8756 - sparse_categorical_crossentropy: 0.3132 - AUROC: 0.9146 - val_loss: 0.3712 - val_sparse_categorical_accuracy: 0.8431 - val_sparse_categorical_crossentropy: 0.3712 - val_AUROC: 0.8833 Epoch 26/50 50/50 [==============================] - 3s 66ms/step - loss: 0.3085 - sparse_categorical_accuracy: 0.8766 - sparse_categorical_crossentropy: 0.3085 - AUROC: 0.9181 - val_loss: 0.3681 - val_sparse_categorical_accuracy: 0.8470 - val_sparse_categorical_crossentropy: 0.3681 - val_AUROC: 0.8848 Epoch 27/50 50/50 [==============================] - 3s 51ms/step - loss: 0.2998 - sparse_categorical_accuracy: 0.8797 - sparse_categorical_crossentropy: 0.2998 - AUROC: 0.9225 - val_loss: 0.3674 - val_sparse_categorical_accuracy: 0.8496 - val_sparse_categorical_crossentropy: 0.3674 - val_AUROC: 0.8859 Epoch 28/50 50/50 [==============================] - 2s 49ms/step - loss: 0.3059 - sparse_categorical_accuracy: 0.8789 - sparse_categorical_crossentropy: 0.3059 - AUROC: 0.9190 - val_loss: 0.3640 - val_sparse_categorical_accuracy: 0.8548 - val_sparse_categorical_crossentropy: 0.3640 - val_AUROC: 0.8892 Epoch 29/50 50/50 [==============================] - 3s 52ms/step - loss: 0.2951 - sparse_categorical_accuracy: 0.8828 - sparse_categorical_crossentropy: 0.2951 - AUROC: 0.9240 - val_loss: 0.3736 - val_sparse_categorical_accuracy: 0.8483 - val_sparse_categorical_crossentropy: 0.3736 - val_AUROC: 0.8828 Epoch 30/50 50/50 [==============================] - 2s 49ms/step - loss: 0.2938 - sparse_categorical_accuracy: 0.8825 - sparse_categorical_crossentropy: 0.2938 - AUROC: 0.9254 - val_loss: 0.3759 - val_sparse_categorical_accuracy: 0.8470 - val_sparse_categorical_crossentropy: 0.3759 - val_AUROC: 0.8827 Epoch 31/50 50/50 [==============================] - 2s 49ms/step - loss: 0.2848 - sparse_categorical_accuracy: 0.8856 - sparse_categorical_crossentropy: 0.2848 - AUROC: 0.9301 - val_loss: 0.3668 - val_sparse_categorical_accuracy: 0.8451 - val_sparse_categorical_crossentropy: 0.3668 - val_AUROC: 0.8877 Epoch 32/50 50/50 [==============================] - 3s 58ms/step - loss: 0.2885 - sparse_categorical_accuracy: 0.8861 - sparse_categorical_crossentropy: 0.2885 - AUROC: 0.9275 - val_loss: 0.3867 - val_sparse_categorical_accuracy: 0.8294 - val_sparse_categorical_crossentropy: 0.3867 - val_AUROC: 0.8832 Epoch 33/50 50/50 [==============================] - 2s 48ms/step - loss: 0.2904 - sparse_categorical_accuracy: 0.8831 - sparse_categorical_crossentropy: 0.2904 - AUROC: 0.9269 - val_loss: 0.3913 - val_sparse_categorical_accuracy: 0.8372 - val_sparse_categorical_crossentropy: 0.3913 - val_AUROC: 0.8765 Epoch 34/50 50/50 [==============================] - 2s 47ms/step - loss: 0.2891 - sparse_categorical_accuracy: 0.8834 - sparse_categorical_crossentropy: 0.2891 - AUROC: 0.9276 - val_loss: 0.3857 - val_sparse_categorical_accuracy: 0.8444 - val_sparse_categorical_crossentropy: 0.3857 - val_AUROC: 0.8809 Epoch 35/50 50/50 [==============================] - 2s 50ms/step - loss: 0.2780 - sparse_categorical_accuracy: 0.8894 - sparse_categorical_crossentropy: 0.2780 - AUROC: 0.9333 - val_loss: 0.3794 - val_sparse_categorical_accuracy: 0.8470 - val_sparse_categorical_crossentropy: 0.3794 - val_AUROC: 0.8894 Epoch 36/50 50/50 [==============================] - 3s 57ms/step - loss: 0.2670 - sparse_categorical_accuracy: 0.8945 - sparse_categorical_crossentropy: 0.2670 - AUROC: 0.9382 - val_loss: 0.4007 - val_sparse_categorical_accuracy: 0.8197 - val_sparse_categorical_crossentropy: 0.4007 - val_AUROC: 0.8775 Epoch 37/50 50/50 [==============================] - 3s 52ms/step - loss: 0.2598 - sparse_categorical_accuracy: 0.8964 - sparse_categorical_crossentropy: 0.2598 - AUROC: 0.9422 - val_loss: 0.3599 - val_sparse_categorical_accuracy: 0.8542 - val_sparse_categorical_crossentropy: 0.3599 - val_AUROC: 0.8921 Epoch 38/50 50/50 [==============================] - 3s 55ms/step - loss: 0.2585 - sparse_categorical_accuracy: 0.8959 - sparse_categorical_crossentropy: 0.2585 - AUROC: 0.9428 - val_loss: 0.3738 - val_sparse_categorical_accuracy: 0.8529 - val_sparse_categorical_crossentropy: 0.3738 - val_AUROC: 0.8895 Epoch 39/50 50/50 [==============================] - 2s 47ms/step - loss: 0.2507 - sparse_categorical_accuracy: 0.9002 - sparse_categorical_crossentropy: 0.2507 - AUROC: 0.9458 - val_loss: 0.3730 - val_sparse_categorical_accuracy: 0.8561 - val_sparse_categorical_crossentropy: 0.3730 - val_AUROC: 0.8937 Epoch 40/50 50/50 [==============================] - 2s 49ms/step - loss: 0.2546 - sparse_categorical_accuracy: 0.8975 - sparse_categorical_crossentropy: 0.2546 - AUROC: 0.9446 - val_loss: 0.3717 - val_sparse_categorical_accuracy: 0.8535 - val_sparse_categorical_crossentropy: 0.3717 - val_AUROC: 0.8954 Epoch 41/50 50/50 [==============================] - 2s 50ms/step - loss: 0.2463 - sparse_categorical_accuracy: 0.9020 - sparse_categorical_crossentropy: 0.2463 - AUROC: 0.9484 - val_loss: 0.3772 - val_sparse_categorical_accuracy: 0.8470 - val_sparse_categorical_crossentropy: 0.3772 - val_AUROC: 0.8899 Epoch 42/50 50/50 [==============================] - 2s 47ms/step - loss: 0.2490 - sparse_categorical_accuracy: 0.9028 - sparse_categorical_crossentropy: 0.2490 - AUROC: 0.9459 - val_loss: 0.3963 - val_sparse_categorical_accuracy: 0.8477 - val_sparse_categorical_crossentropy: 0.3963 - val_AUROC: 0.8955 Epoch 43/50 50/50 [==============================] - 2s 48ms/step - loss: 0.2403 - sparse_categorical_accuracy: 0.9038 - sparse_categorical_crossentropy: 0.2403 - AUROC: 0.9498 - val_loss: 0.3779 - val_sparse_categorical_accuracy: 0.8522 - val_sparse_categorical_crossentropy: 0.3779 - val_AUROC: 0.9005 Epoch 44/50 50/50 [==============================] - 2s 48ms/step - loss: 0.2527 - sparse_categorical_accuracy: 0.8994 - sparse_categorical_crossentropy: 0.2527 - AUROC: 0.9447 - val_loss: 0.3837 - val_sparse_categorical_accuracy: 0.8405 - val_sparse_categorical_crossentropy: 0.3837 - val_AUROC: 0.8924 Epoch 45/50 50/50 [==============================] - 3s 53ms/step - loss: 0.2330 - sparse_categorical_accuracy: 0.9084 - sparse_categorical_crossentropy: 0.2330 - AUROC: 0.9539 - val_loss: 0.3765 - val_sparse_categorical_accuracy: 0.8652 - val_sparse_categorical_crossentropy: 0.3765 - val_AUROC: 0.8977 Epoch 46/50 50/50 [==============================] - 3s 52ms/step - loss: 0.2286 - sparse_categorical_accuracy: 0.9127 - sparse_categorical_crossentropy: 0.2286 - AUROC: 0.9541 - val_loss: 0.3828 - val_sparse_categorical_accuracy: 0.8535 - val_sparse_categorical_crossentropy: 0.3828 - val_AUROC: 0.8976 Epoch 47/50 50/50 [==============================] - 2s 48ms/step - loss: 0.2172 - sparse_categorical_accuracy: 0.9156 - sparse_categorical_crossentropy: 0.2172 - AUROC: 0.9594 - val_loss: 0.4157 - val_sparse_categorical_accuracy: 0.8457 - val_sparse_categorical_crossentropy: 0.4157 - val_AUROC: 0.8968 Epoch 48/50 50/50 [==============================] - 2s 48ms/step - loss: 0.2285 - sparse_categorical_accuracy: 0.9142 - sparse_categorical_crossentropy: 0.2285 - AUROC: 0.9549 - val_loss: 0.4125 - val_sparse_categorical_accuracy: 0.8509 - val_sparse_categorical_crossentropy: 0.4125 - val_AUROC: 0.8921 Epoch 49/50 50/50 [==============================] - 3s 52ms/step - loss: 0.2226 - sparse_categorical_accuracy: 0.9144 - sparse_categorical_crossentropy: 0.2226 - AUROC: 0.9574 - val_loss: 0.3726 - val_sparse_categorical_accuracy: 0.8620 - val_sparse_categorical_crossentropy: 0.3726 - val_AUROC: 0.9007 Epoch 50/50 50/50 [==============================] - 2s 48ms/step - loss: 0.2209 - sparse_categorical_accuracy: 0.9155 - sparse_categorical_crossentropy: 0.2209 - AUROC: 0.9583 - val_loss: 0.3729 - val_sparse_categorical_accuracy: 0.8600 - val_sparse_categorical_crossentropy: 0.3729 - val_AUROC: 0.9016 2022-06-23 19:33:36.568640: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. /opt/conda/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:524: UserWarning: Encoding a StructuredValue with type tensorflow_gnn.GraphTensorSpec; loading this StructuredValue will require that this type be imported and registered. "imported and registered." % type_spec_class_name) /opt/conda/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:524: UserWarning: Encoding a StructuredValue with type tensorflow_gnn.ContextSpec.v2; loading this StructuredValue will require that this type be imported and registered. "imported and registered." % type_spec_class_name) /opt/conda/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:524: UserWarning: Encoding a StructuredValue with type tensorflow_gnn.NodeSetSpec; loading this StructuredValue will require that this type be imported and registered. "imported and registered." % type_spec_class_name) /opt/conda/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:524: UserWarning: Encoding a StructuredValue with type tensorflow_gnn.EdgeSetSpec; loading this StructuredValue will require that this type be imported and registered. "imported and registered." % type_spec_class_name) /opt/conda/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:524: UserWarning: Encoding a StructuredValue with type tensorflow_gnn.AdjacencySpec; loading this StructuredValue will require that this type be imported and registered. "imported and registered." % type_spec_class_name)
4.3 Metric visualization
A straightforward way to visualize the various metrics collected during training and validation is to use TensorBoard. Ideally, the following magic should work in your system:
For reference, I have uploaded the logs of a local run to TensorBoard.dev. There you should see that the loss decreases until we start to overfit (orange line is training, blue line is validation):
The accuracy reaches approximately 87%
And the AUROC increases up to approximately 0.9
Conclusions
In this notebook we have seen how to train a GNN model for graph binary classification using TF-GNN in an end-to-end fashion. The final cell running the orchestrator brings together all the elements we introduced along the way, namely:
- The
DatasetProvider
-compliant objectstrain_dataset_provider
andvalid_dataset_provider
constructed in §2 to provide the data - The model constructor function
get_model_creation_fn
built in §3.3 with the components of §3.1 and §3.2 to assemble the GNN - The
GraphBinaryClassification
task defined in §4.1 to specify the readout and prediction head, as well as the loss and metrics. - The
KerasTrainer
and target feature extractor created in §4.2 for supervised training
While it may not be immediately obvious from such a small example, TF-GNN helped at each step by providing not just the underlying operations we need to perform on graphs, but also many useful protocols and helper functions taking care of much of the boilerplate code we would have otherwise required. Using these in tandem with the orchestrator means that all of the components are easily extendable and/or replaceable. Moreover, it allows us at least in principle to easily scale the various moving parts independently and without unnecessary pain. For example, introducing a non-trivial strategy in the trainer we could distribute our training across multiple GPUs or, eventually, TPUs, while also parallelizing our input pipeline through the InputContext
that is passed on to the DatasetProvider
.
The results we obtained for the drug cardiotoxicity dataset are good, but not impressive. This was to be expected given the very simple GAT-based model that we implemented and the fact that we did no hyperparameter optimization or principled architectural choices. For comparison, we partially quote here Table 1 from [1], where we see that our AUROC results are essentially consistent with the GNN baseline considered there:
The reader is encouraged to try various modifications of the GNN architecture to improve the performance of the model, as well as study its behavior on the out-of-distribution test sets test1_dataset_provider
and test2_dataset_provider
that we prepared in §2 but did not use here.
5.1 Resources and acknowledgements
There are a number of resources that have inspired or been used in this notebook, some of which are cited along the way. They are also collected here for reference purposes:
-
The paper [1] introducing the CardioTox dataset digs deeper in the data we have used, and builds better classification models for it
-
The papers [2] and [3] introduced and popularized Graph Attention Networks
-
A Gentle Introduction to Graph Neural Networks is a good reference to start learning more about GNNs and their applications
-
For those looking to dig deeper into GNNs, the Graph Representation Learning Book by W. L. Hamilton is fairly up-to-date and more comprehensive
-
The TensorFlow GNN guide is a good place to start learning how to use TF-GNN, and some of this notebook’s code was inspired by the examples provided there
-
[July 5, 2022 UPDATE] Some official notebooks with TF-GNN examples are now available, see e.g. Molecular Graph Classification, Solving OGBN-MAG end-to-end and Learning shortest paths with GraphNetworks