The snailfish numbers are, in essence, binary trees, and because each node in the tree always has exactly two children that are either other nodes or numbers (leaves), it is also a full binary tree.
The two operations, exploding and splitting, are very similar to the kinds of operations that self-balancing binary search trees perform whenever you insert or remove a node.
To 'explode' a node, there is a requirement to find the preceding and succeeding leaf nodes in in-order traversal order. I choose to implement the binary tree using nodes and references, with a recursive __iter__
method to handle the traversal. While processing explosions and splits, I track the preceding and following nodes as well, so we don't need parent pointers in the nodes; from those previous or next nodes you can find the parent in the subtree and replace it with a 0 leaf.
from __future__ import annotations
from copy import deepcopy
from enum import IntEnum
from functools import reduce
from itertools import chain, islice
from operator import add
from typing import ClassVar, Final, Iterator, TypeAlias
class Dir(IntEnum):
left = 0
right = 1
def __invert__(self) -> Dir:
return Dir(1 - self)
LEFT: Final[Dir] = Dir.left
RIGHT: Final[Dir] = Dir.right
Depth: TypeAlias = int
class SnailfishNumber:
is_leaf: ClassVar[bool] = False
def __init__(self, left: SnailfishNumber, right: SnailfishNumber) -> None:
self.left = left
self.right = right
@classmethod
def from_line(cls, line: str) -> SnailfishNumber:
return cls._parse(line.encode())[0]
@classmethod
def _parse(cls, line: bytes, i: int = 0) -> tuple[SnailfishNumber, int]:
if 0x30 <= line[i] <= 0x39: # digits
return Leaf(line[i] - 0x30), i + 1
left, i = cls._parse(line, i + 1)
right, i = cls._parse(line, i + 1)
return SnailfishNumber(left, right), i + 1
def __str__(self) -> str:
return f"[{self.left},{self.right}]"
def __getitem__(self, index: Dir) -> SnailfishNumber:
"""Dynamic access to the node left and right children"""
return (self.left, self.right)[index]
def __setitem__(self, index: Dir, value: SnailfishNumber) -> None:
if index is LEFT:
self.left = value
else:
self.right = value
@property
def magnitude(self) -> int:
return 3 * self.left.magnitude + 2 * self.right.magnitude
def explode(
self, prev: SnailfishNumber | None, next: SnailfishNumber | None
) -> None:
assert self.left.is_leaf, self.right.is_leaf
parent = pdir = None
for node, ldir in (prev, LEFT), (next, RIGHT):
if node is None:
continue
n = node[ldir]
while not n.is_leaf:
n = n[~ldir]
n.value += self[ldir].value # type: ignore
if parent is None:
parent, pdir = node[~ldir], ldir
if parent is self:
parent, pdir = next if prev is None else prev, ~pdir
assert parent is not None
while parent[pdir] is not self:
parent = parent[pdir]
parent[pdir] = Leaf()
def split(self) -> bool:
for d in (LEFT, RIGHT):
if self[d].is_leaf and (v := self[d].value) >= 10: # type: ignore
self[d] = SnailfishNumber(Leaf(v // 2), Leaf((v + 1) // 2))
return True
return False
def __iter__(self) -> Iterator[tuple[Depth, SnailfishNumber]]:
"""In-order traversal of nodes only, as (depth, node) tuples"""
if not self.left.is_leaf:
yield from ((depth + 1, n) for depth, n in self.left)
yield 0, self
if not self.right.is_leaf:
yield from ((depth + 1, n) for depth, n in self.right)
def __add__(self, other: SnailfishNumber) -> SnailfishNumber:
new = SnailfishNumber(deepcopy(self), deepcopy(other))
while True:
# explodes
prev, lookahead = None, chain(islice(new, 1, None), [(None, None)])
for (depth, node), (_, next) in zip(new, lookahead):
if depth == 4:
node.explode(prev, next)
else:
prev = node
# splits
for _, node in new:
if node.split():
break
else:
break
return new
class Leaf(SnailfishNumber):
is_leaf = True
def __init__(self, value: int = 0) -> None:
self.value = value
def __str__(self) -> str:
return str(self.value)
@property
def magnitude(self) -> int:
return self.value
testlines = """\
[1,2]
[[1,2],3]
[9,[8,7]]
[[1,9],[8,5]]
[[[[1,2],[3,4]],[[5,6],[7,8]]],9]
[[[9,[3,8]],[[0,9],6]],[[[3,7],[4,9]],3]]
[[[[1,3],[5,3]],[[1,3],[8,7]]],[[[4,9],[6,9]],[[8,2],[7,3]]]]
""".splitlines()
for line in testlines:
assert str(SnailfishNumber.from_line(line)) == line
testsums = {
"[[[[4,3],4],4],[7,[[8,4],9]]]\n[1,1]": "[[[[0,7],4],[[7,8],[6,0]]],[8,1]]",
"[1,1]\n[2,2]\n[3,3]\n[4,4]": "[[[[1,1],[2,2]],[3,3]],[4,4]]",
"[1,1]\n[2,2]\n[3,3]\n[4,4]\n[5,5]": "[[[[3,0],[5,3]],[4,4]],[5,5]]",
"[1,1]\n[2,2]\n[3,3]\n[4,4]\n[5,5]\n[6,6]": "[[[[5,0],[7,4]],[5,5]],[6,6]]",
(
"[[[0,[4,5]],[0,0]],[[[4,5],[2,6]],[9,5]]]\n"
"[7,[[[3,7],[4,3]],[[6,3],[8,8]]]]\n"
"[[2,[[0,8],[3,4]]],[[[6,7],1],[7,[1,6]]]]\n"
"[[[[2,4],7],[6,[0,5]]],[[[6,8],[2,8]],[[2,1],[4,5]]]]\n"
"[7,[5,[[3,8],[1,4]]]]\n[[2,[2,2]],[8,[8,1]]]\n"
"[2,9]\n[1,[[[9,3],9],[[9,0],[0,7]]]]\n[[[5,[7,4]],7],1]\n"
"[[[[4,2],2],6],[8,7]]"
): "[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]",
}
for lines, expected_str in testsums.items():
nodes = map(SnailfishNumber.from_line, lines.splitlines())
assert str(reduce(add, nodes)) == expected_str
testmagnitudes = {
"[[1,2],[[3,4],5]]": 143,
"[[[[0,7],4],[[7,8],[6,0]]],[8,1]]": 1384,
"[[[[1,1],[2,2]],[3,3]],[4,4]]": 445,
"[[[[3,0],[5,3]],[4,4]],[5,5]]": 791,
"[[[[5,0],[7,4]],[5,5]],[6,6]]": 1137,
"[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]": 3488,
}
for testnum, expected_mag in testmagnitudes.items():
assert SnailfishNumber.from_line(testnum).magnitude == expected_mag
testhomework_lines = """\
[[[0,[5,8]],[[1,7],[9,6]]],[[4,[1,2]],[[1,4],2]]]
[[[5,[2,8]],4],[5,[[9,9],0]]]
[6,[[[6,2],[5,6]],[[7,6],[4,7]]]]
[[[6,[0,7]],[0,9]],[4,[9,[9,0]]]]
[[[7,[6,4]],[3,[1,3]]],[[[5,5],1],9]]
[[6,[[7,3],[3,2]]],[[[3,8],[5,7]],4]]
[[[[5,4],[7,7]],8],[[8,3],8]]
[[9,3],[[9,9],[6,[4,9]]]]
[[2,[[7,7],7]],[[5,8],[[9,3],[0,2]]]]
[[[[5,2],5],[8,[3,7]]],[[5,[7,5]],[4,4]]]
""".splitlines()
testhomework = list(map(SnailfishNumber.from_line, testhomework_lines))
testsum = reduce(add, testhomework)
assert str(testsum) == "[[[[6,6],[7,6]],[[7,7],[7,0]]],[[[7,7],[7,7]],[[7,8],[9,9]]]]"
assert testsum.magnitude == 4140
import aocd
homework = [
SnailfishNumber.from_line(line)
for line in aocd.get_data(day=18, year=2021).splitlines()
]
print("Part 1:", reduce(add, homework).magnitude)
Part 1: 4132
All we have to do for part 2 is loop over the permutations of the input snailfish numbers, to find the highest magnitude.
This is taking a bit more time for the puzzle input as each add operation can involve a large number of explosions and splits, which in turn require a lot of traversals.
from itertools import permutations
def maximize(numbers: list[SnailfishNumber]) -> int:
return max((a + b).magnitude for a, b in permutations(numbers, 2))
assert maximize(testhomework) == 3993
print("Part 2:", maximize(homework))
Part 2: 4685
My implementation for part 2 takes about 10 seconds to run, primarily because the number of explosions, splits and iterations are very large.
We can avoid most of the iteration, however, by implementing the binary tree as a fixed-sized array of integers instead of a series of linked nodes (in Python, use a list), similar to what I did in Day 3. The tree has a maximum depth of 5, so the array size is bounded to $2^5 = 32$ elements. Most nodes are set to None
, only leaf nodes have an integer value.
The advantages are that you can trivially find nodes that need exploding; those are all at the same level, stored in the second half of the list. If there are only None
values there, we look at the first half for integer values greater than 9 to split. Going up or down the tree to find the preceding and next leaf nodes is a question of halving or doubling the index (plus or minus 1, depending on traversal direction).
Just as on Day 3, I've elected to leave the first two elements of the array None
to represent the root node at index 1; it makes the node indices easier to work with (going up is an integer division by 2, going down is doubling the index then adding 0 or 1 for left and right, directions I already defined as an IntEnum
named Dir
).
Exploding and splitting are easier than I thought in this model; finding the preceding or succeeding leaf node is handled by a single method, and as I was coding the loops I realised that after the initial set of explosions clearing the 5th level of the tree, any subsequent explosions are the direct result of a split on the 4th level. If you keep a priority queue of nodes to recheck after an explosion, you can avoid most of the traversals altogether. The priority queue needs to keep nodes in the correct order, which is what the VISIT_ORDER
list is for; it used to provide the priority value in the queue.
All this work paid off; the time for part 2 now completes in under a second as opposed to more than 10.
from heapq import heappop, heappush
from itertools import count, repeat
# In an array-based btree, the bits in the node index encode the path to the
# leaf node. To calculate the tree magnitude, all leaf values are multiplied by
# 3s and 2s based on the bits in their node index, except for the most
# significant bit; e.g. node 4 is 100, each 0 is a left node in the tree so
# multiplied by 3, while each 1 (past the first 1) would be 2x. The following
# code pre-computes those factors.
def _factor(n: int) -> int:
v = 1
while n > 1:
v *= 3 - (n % 2)
n >>= 1
return v
FACTORS: Final[list[int]] = [None, None, *map(_factor, range(2, 32))]
# offsets to move a tree to a sub-tree
OFFSETS: Final[list[int]] = [
0,
*chain.from_iterable(repeat(2**i, 2**i) for i in range(5)),
]
# array visiting order for a 32-element tree, used for maintaining a heapq
# of nodes potentially needing splitting.
VISIT_ORDER: list[int | None] = [None] * 32
_node = 1
for _i in count():
VISIT_ORDER[_node] = _i
if _node < 16:
_node *= 2
continue
while _node % 2 == RIGHT:
_node //= 2
if not _node:
break
_node += 1
del _i, _node, _factor
class ArraySnailfishNumber:
def __init__(self, btree: list[int | None]) -> None:
self.btree = btree
@classmethod
def from_line(cls, line: str) -> ArraySnailfishNumber:
"""Parse the btree from a line"""
btree, node = [None] * 32, 1
for b in line.encode():
match b:
case 0x5B: # [
node *= 2
case 0x5D: # ]
node //= 2
case 0x2C: # ,
node += 1
case d: # digits
btree[node] = d - 0x30
return cls(btree)
def __iter__(self) -> Iterator[int]:
"""Pre-order traversal iteration over indices to btree nodes"""
btree, node = self.btree, 1
while True:
yield node
value = btree[node]
if value is None:
node *= 2
continue
while node % 2 == RIGHT:
node //= 2
if not node:
return
node += 1
def __str__(self) -> str:
btree, chars = self.btree, []
for node in self:
value = btree[node]
if value is None:
chars.append("[")
continue
chars.append(str(value))
while node % 2 == RIGHT and node > 1:
node //= 2
chars.append("]")
if node % 2 == LEFT:
chars.append(",")
return "".join(chars)
@property
def magnitude(self) -> int:
btree = self.btree
return sum(v * FACTORS[n] for n in self if (v := btree[n]) is not None)
def _find_sibling(self, node: int, dir: Dir) -> int | None:
"""Find sibling; if dir is LEFT, preceding, otherwise succeeding"""
while node % 2 == dir:
node //= 2
if node <= 1: # at or before the root, no sibling
return None
# move to opposite sibling node at same depth, then go in opposite
# direction to next leaf
node, dir, btree = node // 2 * 2 + dir, ~dir, self.btree
while node < 32 and btree[node] is None:
node = node * 2 + dir
return node
def __add__(self, other: ArraySnailfishNumber) -> ArraySnailfishNumber:
btree, sb, ob = [None] * 64, self.btree, other.btree
new = ArraySnailfishNumber(btree)
pqueue, overflow = [], 0
def explode(node: int, l1: int, l2: int) -> None:
# two new leaves outside the tree, update preceding, succeeding, and
# parent; this yields the updated sibling nodes (if < 32).
# overflow only applies to the initial copy phase, when exploding
# can push values to a successor at an index > 32
nonlocal overflow
btree[node // 2] = 0
if prev := new._find_sibling(node, LEFT):
btree[prev] += l1 + overflow
overflow = 0
if btree[prev] > 9:
heappush(pqueue, (VISIT_ORDER[prev], prev))
if next := new._find_sibling(node + RIGHT, RIGHT):
if next > 32:
overflow = l2
return
btree[next] += l2
if btree[next] > 9:
heappush(pqueue, (VISIT_ORDER[next], next))
# copy the first 3 levels of the two trees into subtrees on the result
# and add nodes to our priority queue for splitting.
for i in range(16):
ns, vs = OFFSETS[i] + i, sb[i]
no, vo = ns + OFFSETS[i], ob[i]
btree[ns], btree[no] = vs, vo
# prime the priority queue for values needing splitting
if None is not vs > 9:
heappush(pqueue, (VISIT_ORDER[ns], ns))
if None is not vo > 9:
heappush(pqueue, (VISIT_ORDER[no], no))
# explode the bottom levels of the two trees; first sb, then ob
for i in range(16, 32, 2):
if sb[i] is not None:
explode(OFFSETS[i] + i, *sb[i : i + 2])
for i in range(16, 32, 2):
if ob[i] is not None:
explode(OFFSETS[i] * 2 + i, *ob[i : i + 2])
# for each entry on the queue:
# - re-verify it is still over 9, then split
# - if split added leaves at level 5, explode (which will queue up as needed)
# - if split values are large enough to need splitting again, add
# their nodes to the queue for further checks
while pqueue:
_, node = heappop(pqueue)
if (value := btree[node]) is None or value < 10:
continue
btree[node] = None
leaves = value // 2, (value + 1) // 2
if node >= 16: # new leaves at level 5 to explode
explode(node * 2, *leaves)
continue
for n, v in zip((node * 2, node * 2 + 1), leaves):
btree[n] = v
if value > 9:
heappush(pqueue, (VISIT_ORDER[n], n))
return new
for line in testlines:
assert str(ArraySnailfishNumber.from_line(line)) == line
for lines, expected_str in testsums.items():
nodes = map(ArraySnailfishNumber.from_line, lines.splitlines())
assert str(reduce(add, nodes)) == expected_str
for testnum, expected_mag in testmagnitudes.items():
assert ArraySnailfishNumber.from_line(testnum).magnitude == expected_mag
testhomework = list(map(ArraySnailfishNumber.from_line, testhomework_lines))
testsum = reduce(add, testhomework)
assert str(testsum) == "[[[[6,6],[7,6]],[[7,7],[7,0]]],[[[7,7],[7,7]],[[7,8],[9,9]]]]"
assert testsum.magnitude == 4140
assert maximize(testhomework) == 3993
homework = [
ArraySnailfishNumber.from_line(line)
for line in aocd.get_data(day=18, year=2021).splitlines()
]
print("Part 1:", reduce(add, homework).magnitude)
print("Part 2:", maximize(homework))
Part 1: 4132 Part 2: 4685