from __future__ import annotations
from torch._prims_common import DeviceLikeType
import torchlens as tl
import torch, torch.nn as nn, torch.nn.functional as F
import random, copy
from src.utils.vector import get_mps_device
import dataclasses, json
from pathlib import Path
device = None
[docs]
class NodeGene:
def __init__(self, nid, ntype):
self.id = nid
self.type = ntype # input | hidden | output
def __eq__(self, other):
if not isinstance(other, NodeGene):
return False
return self.id == other.id and self.type == other.type
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return hash((self.id, self.type))
[docs]
class ConnGene:
def __init__(self, in_id, out_id, w, enabled=True):
self.in_id = in_id
self.out_id = out_id
self.w = w
self.enabled = enabled
def __eq__(self, other):
if not isinstance(other, ConnGene):
return False
return self.in_id == other.in_id and self.out_id == other.out_id and self.w == other.w and self.enabled == other.enabled
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return hash((self.in_id, self.out_id, self.w, self.enabled))
[docs]
class Genome:
def __init__(self, n_inputs, n_outputs, device: DeviceLikeType | None = None):
self.n_inputs = n_inputs
self.n_outputs = n_outputs
self.nodes = [NodeGene(i, "input") for i in range(n_inputs)] + \
[NodeGene(n_inputs+i, "output") for i in range(n_outputs)]
self.conns = []
self.next_node_id = n_inputs + n_outputs
self.device = device
[docs]
def mutate_weight(self):
if self.conns:
c = random.choice(self.conns)
c.w += torch.randn(1, device=self.device).item() * 0.1
[docs]
def mutate_add_conn(self):
a, b = random.sample(self.nodes, 2)
if a.type == "output" or b.type == "input": return
self.conns.append(ConnGene(a.id, b.id, random.uniform(-1, 1)))
[docs]
def mutate_add_node(self):
if not self.conns: return
conn = random.choice(self.conns)
if not conn.enabled: return
conn.enabled = False
new_id = self.next_node_id; self.next_node_id += 1
new_node = NodeGene(new_id, "hidden")
self.nodes.append(new_node)
self.conns.append(ConnGene(conn.in_id, new_id, 1.0))
self.conns.append(ConnGene(new_id, conn.out_id, conn.w))
def __eq__(self, other):
return set(self.nodes) == set(other.nodes) and set(self.conns) == set(other.conns) and self.next_node_id == other.next_node_id and self.n_inputs == other.n_inputs and self.n_outputs == other.n_outputs
def __hash__(self):
if len(self.conns) == 0:
return hash((len(self.nodes), len(self.conns), self.next_node_id, self.n_inputs, self.n_outputs))
return hash((self.conns[0], len(self.nodes), len(self.conns), self.next_node_id, self.n_inputs, self.n_outputs))
[docs]
class EvolvedNet(nn.Module):
def __init__(self, genome, device: DeviceLikeType | None = None):
super().__init__()
self.g = genome
self.device = device
self.params = nn.ParameterDict()
for c in genome.conns:
if c.enabled:
self.params[f"w_{c.in_id}_{c.out_id}"] = nn.Parameter(torch.tensor(c.w, device=device))
self.inputs = [n.id for n in genome.nodes if n.type=="input"]
self.outputs = [n.id for n in genome.nodes if n.type=="output"]
[docs]
def forward(self, x):
# map node IDs to torch tensors
vals = {nid: torch.tensor(0.0, device=self.device) for nid in [n.id for n in self.g.nodes]}
# assign input activations
for i, nid in enumerate(self.inputs):
vals[nid] = x[i]
# propagate through connections iteratively
for _ in range(len(self.g.nodes)):
for (k, v) in self.params.items():
i, o = map(int, k[2:].split("_"))
vals[o] = vals[o] + torch.tanh(vals[i] * v)
# collect output activations as tensor
return torch.stack([torch.tanh(vals[o]).to(device) for o in self.outputs])
[docs]
class JSONEncoder(json.JSONEncoder):
[docs]
def default(self, o):
if dataclasses.is_dataclass(o):
return dataclasses.asdict(o)
if isinstance(o, Genome):
return {
"n_inputs": o.n_inputs,
"n_outputs": o.n_outputs,
"nodes": o.nodes,
"conns": o.conns
}
if isinstance(o, NodeGene):
return {
"id": o.id,
"type": o.type
}
if isinstance(o, ConnGene):
return {
"enabled": o.enabled,
"in_id": o.in_id,
"out_id": o.out_id,
"w": o.w
}
return super().default(o)
[docs]
def load_genome(file_path: str | Path):
tmp_conns: list[ConnGene] = []
tmp_nodes: list[NodeGene] = []
tmp_object: None | Genome = None
def as_node(dct):
nonlocal tmp_object, tmp_nodes
if 'id' in dct and 'type' in dct:
node = NodeGene(dct['id'], dct['type'])
tmp_nodes.append(node)
return
def as_conn(dct):
nonlocal tmp_object, tmp_conns
if 'enabled' in dct and 'in_id' in dct and 'out_id' in dct and 'w' in dct:
conn = ConnGene(dct['in_id'], dct['out_id'], dct['w'], dct['enabled'])
tmp_conns.append(conn)
return
def as_genome(dct):
nonlocal tmp_object
if 'n_outputs' in dct and 'n_inputs' in dct and 'conns' in dct and 'nodes' in dct:
tmp_object = Genome(dct['n_inputs'], dct['n_outputs'])
tmp_object.nodes = tmp_nodes
tmp_object.conns = tmp_conns
if 'id' in dct and 'type' in dct:
as_node(dct)
if tmp_object is not None:
tmp_object.nodes = tmp_nodes
if 'enabled' in dct and 'in_id' in dct and 'out_id' in dct and 'w' in dct:
as_conn(dct)
if tmp_object is not None:
tmp_object.conns = tmp_conns
return tmp_object
with open("test_genome.json", 'r') as f:
data = f.read()
genome = json.loads(data, object_hook=as_genome)
return genome
[docs]
def save_genome(genome: Genome, file_path: str | Path):
with open(file_path, 'w') as f:
json.dump(genome, f, indent=4, cls=JSONEncoder)
if __name__ == "__main__":
device = get_mps_device()
POP: list[Genome] = [Genome(6,3) for _ in range(30)]
for g in POP: g.mutate_add_conn() # start with random links
save_genome(POP[0], "test_genome.json")
loaded_genome = load_genome("test_genome.json")
loaded_genome.mutate_add_conn()
print(hash(loaded_genome))