Source code for src.core.model

from __future__ import annotations
from torch._prims_common import DeviceLikeType
import torchlens as tl
import torch, torch.nn as nn, torch.nn.functional as F
import random, copy
from src.utils.vector import get_mps_device

device = None


[docs] class NodeGene: def __init__(self, nid, ntype): self.id = nid self.type = ntype # input | hidden | output
[docs] class ConnGene: def __init__(self, in_id, out_id, w, enabled=True): self.in_id = in_id self.out_id = out_id self.w = w self.enabled = enabled
[docs] class Genome: def __init__(self, n_inputs, n_outputs, device: DeviceLikeType | None = None): self.nodes = [NodeGene(i, "input") for i in range(n_inputs)] + \ [NodeGene(n_inputs+i, "output") for i in range(n_outputs)] self.conns = [] self.next_node_id = n_inputs + n_outputs self.device = device
[docs] def mutate_weight(self): if self.conns: c = random.choice(self.conns) c.w += torch.randn(1, device=self.device).item() * 0.1
[docs] def mutate_add_conn(self): a, b = random.sample(self.nodes, 2) if a.type == "output" or b.type == "input": return self.conns.append(ConnGene(a.id, b.id, random.uniform(-1, 1)))
[docs] def mutate_add_node(self): if not self.conns: return conn = random.choice(self.conns) if not conn.enabled: return conn.enabled = False new_id = self.next_node_id; self.next_node_id += 1 new_node = NodeGene(new_id, "hidden") self.nodes.append(new_node) self.conns.append(ConnGene(conn.in_id, new_id, 1.0)) self.conns.append(ConnGene(new_id, conn.out_id, conn.w))
[docs] class EvolvedNet(nn.Module): def __init__(self, genome, device: DeviceLikeType | None = None): super().__init__() self.g = genome self.device = device self.params = nn.ParameterDict() for c in genome.conns: if c.enabled: self.params[f"w_{c.in_id}_{c.out_id}"] = nn.Parameter(torch.tensor(c.w, device=device)) self.inputs = [n.id for n in genome.nodes if n.type=="input"] self.outputs = [n.id for n in genome.nodes if n.type=="output"]
[docs] def forward(self, x): # map node IDs to torch tensors vals = {nid: torch.tensor(0.0, device=self.device) for nid in [n.id for n in self.g.nodes]} # assign input activations for i, nid in enumerate(self.inputs): vals[nid] = x[i] # propagate through connections iteratively for _ in range(len(self.g.nodes)): for (k, v) in self.params.items(): i, o = map(int, k[2:].split("_")) vals[o] = vals[o] + torch.tanh(vals[i] * v) # collect output activations as tensor return torch.stack([torch.tanh(vals[o]).to(device) for o in self.outputs])
if __name__ == "__main__": device = get_mps_device() def fitness(genome) -> float: net = EvolvedNet(genome) data = [([0,0],0),([0,1],1),([1,0],1),([1,1],0)] err = 0 for x,y in data: y_hat = net(torch.tensor(x, dtype=torch.float32)) err += (y_hat - y)**2 return -err.item() # higher is better # ------------------------- # Evolution loop # ------------------------- POP: list[Genome] = [Genome(2,1) for _ in range(30)] for g in POP: g.mutate_add_conn() # start with random links for gen in range(50): scores = [(fitness(g), g) for g in POP] scores.sort(reverse=True, key=lambda x:x[0]) #print(f"Gen {gen}: best fitness {scores[0][0]:.3f}") tl.show_model_graph(EvolvedNet(scores[0][1]), torch.tensor([0, 0], dtype=torch.float32), vis_opt="rolled") print(len(scores[0][1].conns)) if scores[0][0] > -0.05: print("Solved!") break # selection and mutation newpop = [copy.deepcopy(scores[0][1])] for _ in range(len(POP)-1): g = copy.deepcopy(random.choice(scores[:10])[1]) random.choice([g.mutate_weight, g.mutate_add_conn, g.mutate_add_node])() newpop.append(g) POP = newpop