Graph Neural Networks to learn an energy potential

May 20, 2024

Graph neural networks (GNNs) are an increasingly popular approach to learn from data structures that can be represented with graphs. This is a common scenario in computational science:

The training data, 500 particles uniformly placed in a square box.

The training data, 500 particles uniformly placed in a square box.

To test this method I generated particles in a square 2D domain, with random positions , . Each particle has an associated feature . These are the state of the system.

Now, we define an "energy" associated with each particle, that can be computed from the state of the system: where is the cutoff radius, the distance separating particles and and .

The training data, 500 particles placed randomly in a square box, is illustrated on the right image. The color shows the energy.

Let's suppose that we have data about the energy of particles in a given configuration, can we predict the energy of particles in a new situation, without knowing the form of the energy?

The GNN architecture is set as follows: where is the output at node and the edge features are the direction between particles and their distance, . The function is a multi layer perceptron. Note that in this case the direction between particles is not useful, but we assume that we don't know this information.

This is easily implemented using pytorch-geometric:

class EdgeConv(MessagePassing):
    def __init__(self, x_channels, e_channels, out_channels):
        super().__init__(aggr='add')
        self.mlp = nn.Sequential(nn.Linear(2 * x_channels + e_channels, 16),
                              nn.ReLU(),
                              nn.Linear(16, out_channels))

    def forward(self, x, edge_index, edge_attr):
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_i, x_j, edge_attr):
        z = torch.cat([x_i, x_j, edge_attr], dim=1)
        return self.mlp(z)


class GCN(torch.nn.Module):
    def __init__(self, x_channels, e_channels, out_channels):
        super().__init__()
        self.conv = EdgeConv(x_channels, e_channels, out_channels)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        return self.conv(x, edge_index, edge_attr)

We train for 20000 epochs using the Adam optimizer with a learning rate of 0.001. Here are the results, the model tested on a new random configuration:

The prediction of the energy is very close to the ground truth on the test data.