On DeepMind’s recent paper on discovering new mathematics

Given below is a record of my failed attempt to replicate DeepMind’s recent mathematical breakthrough. This is mostly because I could not find the code for the funsearch module on their GitHub repository. If you’re someone who could help me with this, I’d surely love to hear from you.

Introduction

It is well known that determining whether a given mathematical argument is correct is much easier than producing a correct mathematical argument. Similarly, it is much more difficult to produce examples of mathematical concepts than to verify that a given object is a valid example of the mathematical concept under consideration. DeepMind, in its recent paper, leverages this observation to produce programs that may find examples of mathematical objects, and then a separate “evaluator” to determine whether those indeed are valid examples. Examinee and examiner all rolled into one.

Why LLMs?

Note that training a neural network to produce a correct example of a mathematical concept is also possible. However, acquiring the requisite hardware and re-training a neural network for every specific purpose can be expensive and time consuming. An LLM is perhaps a “generalist neural network” in some sense, in that it has been trained on a sizable fraction of the internet. It can now be fine tuned a little bit to achieve our purpose.

How can a program solve a mathematical problem?

Normally, you’d think we need a machine that can write mathematical proofs. Or perhaps a neural network that does something magical in its formidable black box and comes up with a mind-bending example that mathematicians have failed at conjuring for decades. However, a program can also come up with examples. Simplistically speaking, it can list all elements of a well-defined set, and then check whether any of them are examples of the object we are looking for. In some sense, it is their computational speed and accuracy that make programs good at coming up with examples.

So do we provide the program, and the LLM just runs it?

No, we don’t even know the correct program to run at the beginning. We provide the LLM with a very basic program, and then ask it to iteratively improve it through “evolution”. We discard the bad programs, and ask the LLM to further improve the good ones. In order to narrow the search space, we ask the LLM to only evolve the “priority function” part of the program, and leave everything else the same. The priority function is the function that makes the decision at every step.

An example would be: imagine a sailor lost at sea, on a lifeboat with paddles. Every day, based on the position of the stars, perhaps the direction of the wind, etc, they have to decide in which direction to paddle. This decision is carried out by the priority function of the sailor, which is perhaps their brain. The LLM is asked to iteratively give the sailor a better and better brain, until they can make the really decisions every day, and see landfall. Everything else that is inconsequential about the sailor, like their clothes, their gender, etc is left the same.

So is the output the example that we’re looking for?

No, it is a program that gives us the example. In this way, we can perhaps tweak the program, and generate even more complex examples. It’s akin to a mathematician stating that a conjecture is true, versus them publishing their proof, which can lead to other mathematicians modifying their proof to prove other impressive things.

Consider the following diagram:

First, we specify the task that the LLM is supposed to accomplish, like “find an example of the cap set problem”. Then a program, called “prompt” here, is fed into the LLM. The LLM provides three programs that are supposedly an improvement on the input program. However, only one of them is actually correct. This correct program is then stored in a programs database. The programs in the database are then passed into the LLM to further improve them. The LLM may take a correct program as an input, and incorrect evolve it to give us an incorrect output program. Hence, we again run tests to see which of them are correct, and so on.

Specification

So what is the code that is run?

I found the Jupyter notebook file from DeepMind’s GitHub page.

The first boilerplate input program is

"""Finds large cap sets."""
import itertools
import numpy as np


# @funsearch.run
def evaluate(n: int) -> int:
"""Returns the size of an `n`-dimensional cap set."""
capset = solve(n)
return len(capset)


def solve(n: int) -> np.ndarray:
"""Returns a large cap set in `n` dimensions."""
all_vectors = np.array(list(itertools.product((0, 1, 2), repeat=n)), dtype=np.int32)

# Powers in decreasing order for compatibility with `itertools.product`, so
# that the relationship `i = all_vectors[i] @ powers` holds for all `i`.
powers = 3 ** np.arange(n - 1, -1, -1)

# Precompute all priorities.
priorities = np.array([priority(tuple(vector), n) for vector in all_vectors])

# Build `capset` greedily, using priorities for prioritization.
capset = np.empty(shape=(0, n), dtype=np.int32)
while np.any(priorities != -np.inf):
# Add a vector with maximum priority to `capset`, and set priorities of
# invalidated vectors to `-inf`, so that they never get selected.
max_index = np.argmax(priorities)
vector = all_vectors[None, max_index] # [1, n]
blocking = np.einsum('cn,n->c', (- capset - vector) % 3, powers) # [C]
priorities[blocking] = -np.inf
priorities[max_index] = -np.inf
capset = np.concatenate([capset, vector], axis=0)

return capset


# @funsearch.evolve
def priority(el: tuple[int, ...], n: int) -> float:
"""Returns the priority with which we want to add `element` to the cap set."""
return 0.0

This program was clearly not good. For example, in nine dimensions, the largest cap set that the program found contains 512 elements, when it is known that the largest such cap set actually contains 1082 elements.

FunSearch then discovered the following program:

def priority(el: tuple[int, ...], n: int) -> float:
score = n
in_el = 0
el_count = el.count(0)

if el_count == 0:
score += n ** 2
if el[1] == el[-1]:
score *= 1.5
if el[2] == el[-2]:
score *= 1.5
if el[3] == el[-3]:
score *= 1.5
else:
if el[1] == el[-1]:
score *= 0.5
if el[2] == el[-2]:
score *= 0.5

for e in el:
if e == 0:
if in_el == 0:
score *= n * 0.5
elif in_el == el_count - 1:
score *= 0.5
else:
score *= n * 0.5 ** in_el
in_el += 1
else:
score += 1

if el[1] == el[-1]:
score *= 1.5
if el[2] == el[-2]:
score *= 1.5

return score


# We call the `solve` function instead of `evaluate` so that we get access to
# the cap set itself (rather than just its size), for verification and
# inspection purposes.
cap_set_n8 = solve(8)
assert cap_set_n8.shape == (512, 8)

In 8 dimensions, it discovered a cap set containing 512 elements. Before this, the largest known cap set in 8 dimensions contained 496 elements. But how does one verify that the discovered cap set is indeed correct? The following function was used to verify this:

def is_cap_set(vectors: np.ndarray) -> bool:
"""Returns whether `vectors` form a valid cap set.

Checking the cap set property naively takes O(c^3 n) time, where c is the size
of the cap set. This function implements a faster check that runs in O(c^2 n).

Args:
vectors: [c, n] array containing c n-dimensional vectors over {0, 1, 2}.
"""
_, n = vectors.shape

# Convert `vectors` elements into raveled indices (numbers in [0, 3^n) ).
powers = np.array([3 ** j for j in range(n - 1, -1, -1)], dtype=int) # [n]
raveled = np.einsum('in,n->i', vectors, powers) # [c]

# Starting from the empty set, we iterate through `vectors` one by one and at
# each step check that the vector can be inserted into the set without
# violating the defining property of cap set. To make this check fast we
# maintain a vector `is_blocked` indicating for each element of Z_3^n whether
# that element can be inserted into the growing set without violating the cap
# set property.
is_blocked = np.full(shape=3 ** n, fill_value=False, dtype=bool)
for i, (new_vector, new_index) in enumerate(zip(vectors, raveled)):
if is_blocked[new_index]:
return False # Inserting the i-th element violated the cap set property.
if i >= 1:
# Update which elements are blocked after the insertion of `new_vector`.
blocking = np.einsum(
'nk,k->n',
(- vectors[:i, :] - new_vector[None, :]) % 3, powers)
is_blocked[blocking] = True
is_blocked[new_index] = True # In case `vectors` contains duplicates.
return True # All elements inserted without violating the cap set property.


assert is_cap_set(cap_set_n8)

But how does FunSearch evolve the algorithm? What is the code for that?

It is given in the same GitHub profile. The details of the LLM are not given. We can perhaps try and run it in ChatGPT.

What is the evolutionary method?

Choose a bunch of initial programs, and evolve them separately. Evaluate each evolved program, and assign them a score (perhaps +1 for giving the correct answer for each input?). Trash the programs with the lowest scores, and place the best programs within the newly emptied islands. Keep repeating until termination. I suppose that on average, the islands with the best programs will survive.

This is like evolution in which God swoops in to strike dead all the unfit mutations, instead of a slow death that nature wreaks upon unfit species.

What are the program clusters and islands here?

In the process of evolution, each program gives rise to multiple mutated programs. Classify these progeny into clusters within an island. Sample clusters to see which clusters contain the better programs. Now within each “good” cluster, choose the shorter programs. Use this program as an input into the LLM to produce an output program. If correct, this output program should be added to an existing cluster. Note that in this process, clusters are not emptied, but only evolved.

Published by -

Graduate student

Leave a comment