from dataclasses import dataclass, field
@dataclass
class Program(object):
weight: int
dependents: set = field(default_factory=set)
_total_weight: int = field(init=False, default=None)
def get_total_weight(self, programs):
if self._total_weight is None:
self._total_weight = self.weight + sum(
programs[d].get_total_weight(programs) for d in self.dependents
)
return self._total_weight
def read_programs(lines):
programs = {}
for line in lines:
if not line.strip():
continue
name_weight, arrow, dependents = line.partition(" -> ")
name, weight = name_weight.split(" (")
programs[name] = Program(
int(weight.rstrip(")\n ")),
{d.strip() for d in dependents.split(",") if d.strip()},
)
return programs
import aocd
data = aocd.get_data(day=7, year=2017)
programs = read_programs(data.splitlines())
def find_root(program):
roots = program.keys() - set.union(*(v.dependents for v in program.values()))
return roots.pop()
test = read_programs(
"""\
pbga (66)
xhth (57)
ebii (61)
havc (66)
ktlj (57)
fwft (72) -> ktlj, cntj, xhth
qoyq (66)
padx (45) -> pbga, havc, qoyq
tknk (41) -> ugml, padx, fwft
jptl (61)
ugml (68) -> gyxo, ebii, jptl
gyxo (61)
cntj (57)
""".splitlines()
)
assert find_root(test) == "tknk"
tests = {"ugml": 251, "padx": 243, "fwft": 243}
for name, expected in tests.items():
assert test[name].get_total_weight(test) == expected
def correct_balance(name, programs):
program = programs[name]
weights = {}
for d in program.dependents:
weight = programs[d].get_total_weight(programs)
weights.setdefault(weight, []).append(d)
if len(weights) == 1:
# balanced, no adjustment needed
return 0
# imbalanced
imbalanced, target = weights
if len(weights[imbalanced]) != 1:
imbalanced, target = target, imbalanced
# Check for balance in the child nodes
sub_program_name = weights[imbalanced][0]
sub_correction = correct_balance(sub_program_name, programs)
if sub_correction:
return sub_correction
return programs[sub_program_name].weight + (target - imbalanced)
assert correct_balance("tknk", test) == 60
root = find_root(programs)
print("Part 1:", root)
Part 1: gynfwly
print("Part 2:", correct_balance(root, programs))
Part 2: 1526