Our task today is to find all paths between two nodes in a graph, with a twist that nodes with uppercase letters can be visited more than once per path. Note that start
and end
are both lowercased! This kind of task is part of a group of graph traversal problems, and I've implemented a typical depth-first search (DFS) algorithm, leaving out adding the uppercased node names to the 'seen' list (the path of already visited nodes).
I also like to avoid using recursion. Recursion can usually be replaced by a loop with a stack or a queue, for DFS recursion is just a loop with the function call graph acting as the stack. As recursive DFS implementations pass along the 'seen' list as a parameter, here you need to put that set on the stack together with the nodes to visit.
from __future__ import annotations
from collections import deque
from typing import Iterable, TypeAlias
Graph: TypeAlias = dict[str, set[str]]
class CaveSystem:
def __init__(self, graph: Graph):
self.graph = graph
@classmethod
def from_lines(cls, lines: Iterable[str]) -> CaveSystem:
graph = {}
for line in lines:
left, right = line.split("-")
graph.setdefault(left, set()).add(right)
graph.setdefault(right, set()).add(left)
return cls(graph)
def _edges(self, node: str, visited: set[str]) -> Iterable[str]:
# traverse 'visited' just once for all edges leading from this node
seen = {node for node in visited if node.islower()}
yield from self.graph[node] - seen
def __iter__(self) -> Iterable[tuple[str]]:
# stack holds node-to-visit, lower-case nodes already visited
stack: deque[tuple[str, tuple[str]]] = deque([("start", ("start",))])
while stack:
node, visited = stack.pop()
if node == "end":
yield visited
continue
for next in self._edges(node, visited):
stack.append((next, (*visited, next)))
def __len__(self) -> int:
return sum(1 for _ in self)
tests: dict[str, int] = {
"start-A\nstart-b\nA-c\nA-b\nb-d\nA-end\nb-end": 10,
(
"dc-end\nHN-start\nstart-kj\ndc-start\n"
"dc-HN\nLN-dc\nHN-end\nkj-sa\nkj-HN\nkj-dc"
): 19,
(
"fs-end\nhe-DX\nfs-he\nstart-DX\npj-DX\nend-zg\nzg-sl\nzg-pj\npj-he\n"
"RW-he\nfs-DX\npj-RW\nzg-RW\nstart-pj\nhe-WI\nzg-he\npj-fs\nstart-RW\n"
): 226,
}
for test_map, expected in tests.items():
assert len(CaveSystem.from_lines(test_map.splitlines())) == expected
import aocd
cave_map = aocd.get_data(day=12, year=2021).splitlines()
print("Part 1:", len(CaveSystem.from_lines(cave_map)))
Part 1: 5756
We are now told that a path through the cave system can involve visiting one of the small caves twice. You can model this by using a multi-set instead of a set when finding what edges to consider. Instead of discounting all lower-cased node names that have already been seen on the current path-so-far, just ignore the visited nodes as long as no lower-case node name appears twice. To figure out if a node appears twice, use a multi-set instead of a regular set. In the Python standard library, the collections.Counter
class is such a multi-set.
To implement part 2, I factored out the generation of edges to follow into a separate method so I can reuse the rest of my DFS code.
from collections import Counter
class RevisitCaveSystem(CaveSystem):
def _edges(self, node: str, visited: set[str]) -> Iterable[str]:
seen = Counter(node for node in visited if node.islower())
if 2 not in seen.values():
# no small cave has been visited twice yet, only disallow "start"
seen = {"start"}
yield from self.graph[node].difference(seen)
revisit_tests: dict[str, int] = {
"start-A\nstart-b\nA-c\nA-b\nb-d\nA-end\nb-end": 36,
(
"dc-end\nHN-start\nstart-kj\ndc-start\n"
"dc-HN\nLN-dc\nHN-end\nkj-sa\nkj-HN\nkj-dc"
): 103,
(
"fs-end\nhe-DX\nfs-he\nstart-DX\npj-DX\nend-zg\nzg-sl\nzg-pj\npj-he\n"
"RW-he\nfs-DX\npj-RW\nzg-RW\nstart-pj\nhe-WI\nzg-he\npj-fs\nstart-RW\n"
): 3509,
}
for test_map, expected in revisit_tests.items():
assert len(RevisitCaveSystem.from_lines(test_map.splitlines())) == expected
print("Part 2:", len(RevisitCaveSystem.from_lines(cave_map)))
Part 2: 144603