Source code for src.core.train

"""Parallel trainer for Mario Kart DS agents using DeSmuME, multiprocessing,
shared memory frame streaming, and optional live GTK visualization.

This module orchestrates end-to-end evaluation and evolution of a population of
neural network controllers (NEAT-style) for Mario Kart DS. It supports three
execution modes per evaluated individual:

  1) **Headless** (`run_process`) — fast evaluation with no display.
  2) **Display worker** (`run_window_process`) — renders frames and writes them
     into a per-process shared-memory buffer; no GTK loop.
  3) **Display host** (`run_window_host_process`) — renders frames, writes them
     into shared memory, and owns the GTK window that tiles and presents all
     *display-enabled* workers in real time.

Key concepts
------------
- **Shared memory frames:** Each display-enabled process writes an RGBX
  framebuffer of shape ``(SCREEN_HEIGHT, SCREEN_WIDTH, 4)`` (dtype
  ``np.uint8``) to a POSIX shared-memory segment named ``f"emu_frame_{id}"``.
  The host window process opens these buffers read-only for display tiling.
- **Overlays:** Optional per-frame overlays are computed off the main emulation
  loop by a single background thread fed via a queue. Overlays are composited in
  the worker before writing to shared memory using
  :func:`src.visualization.window.on_draw_memoryview`.
- **Statistics / fitness:** Each process records split times and distances at
  track checkpoints as a ``dict[int, list[tuple[float, float]]]`` mapping
  ``checkpoint_id -> [(delta_time, distance_at_split), ...]``. A simple fitness
  function sums the recorded distances.
- **Batching & evolution:** :func:`run_training_session` evaluates a subset
  (batch) of the population in parallel (bounded by ``num_proc``), aggregates
  stats, then :func:`train` evolves the population.

Threading & processes
---------------------
- The DeSmuME emulator is created and used **inside each process** that runs it.
- The GTK main loop **must** run in a single process. This module designates one
  display-enabled process per batch as the *host* that creates the window and
  drives GTK via `GLib.timeout_add`.
- Overlays are computed by a **single background thread** (daemon) within each
  display-enabled process to keep the emulation loop responsive.

Shared-memory lifetime
----------------------
- Creation: Call :func:`safe_shared_memory` to create (or replace) a named
  shared-memory segment.
- Ownership: Workers **open** their frame buffer (``emu_frame_{id}``) by name
  and keep a persistent ``SharedMemory`` handle as long as they render frames.
- Teardown: After processes finish, the parent should **close and unlink**
  per-process frame segments to avoid resource-tracker warnings.

Examples
--------
Run 10 generations with a population of 32 where only one sample is displayed
each batch and overlays with IDs 0, 3, and 4 are enabled:

    >>> if __name__ == "__main__":
    ...     train(
    ...         num_iters=10,
    ...         pop_size=32,
    ...         show_samples=[False],   # broadcast later per batch
    ...         overlay_ids=[0, 3, 4],
    ...     )

Notes
-----
- This module expects the `mariokart_ds.nds` ROM to be available in the working
  directory and a valid savestate at index 3.
- ``on_draw_memoryview`` expects an emulator-provided RGBX memory buffer
  (4 bytes/pixel), and returns a premultiplied ARGB32 array suitable for Cairo.
- ``MODEL_KEY_MAP`` defines a simple thresholded policy:
  values >= 0.5 are pressed, and the accelerator button is always pressed when
  any action is taken.

"""

# Builtin dependencies
from __future__ import annotations
import random, math, os, sys, copy
from multiprocessing.managers import DictProxy
from multiprocessing import Process, Manager, Lock
from multiprocessing.shared_memory import SharedMemory
from queue import Queue
from threading import Thread
from typing import TypedDict, TypeAlias, cast

# External dependencies
from desmume.emulator import DeSmuME, SCREEN_HEIGHT, SCREEN_WIDTH, SCREEN_PIXEL_SIZE
from desmume.frontend.gtk_drawing_area_desmume import AbstractRenderer
from desmume.controls import Keys, keymask
from torch._prims_common import DeviceLikeType
import numpy as np
import gi

from src.main import SAVE_STATE_ID

gi.require_version("Gtk", "3.0")
gi.require_version("Gdk", "3.0")
from gi.repository import Gtk, Gdk, GLib

# Local dependencies
from src.core.memory import *
from src.core.memory import read_clock
from src.core.model import Genome, EvolvedNet
from src.utils.vector import get_mps_device
from src.visualization.window import SharedEmulatorWindow, on_draw_memoryview
from src.visualization.overlay import AVAILABLE_OVERLAYS

[docs] class EmulatorProcessConfig(TypedDict): id: int sample: Genome host: bool show: bool
[docs] class EmulatorBatchConfig(TypedDict): size: int display_shm_names: list[str] device: DeviceLikeType | None overlay_ids: list[int]
[docs] class CheckpointRecord: id: int times: list[float] dists: list[float]
MODEL_KEY_MAP = { 3: Keys.KEY_UP, 2: Keys.KEY_DOWN, 1: Keys.KEY_LEFT, 0: Keys.KEY_RIGHT, 4: Keys.KEY_B, 5: Keys.KEY_A, 6: Keys.KEY_X, 7: Keys.KEY_Y, 8: Keys.KEY_L, 9: Keys.KEY_R, 10: Keys.KEY_START, 11: Keys.KEY_LEFT, 12: Keys.KEY_RIGHT, 13: Keys.KEY_UP, 14: Keys.KEY_DOWN, }
[docs] def safe_shared_memory(name: str, size: int): """Create or replace a named POSIX shared-memory segment. This helper guarantees that a shared-memory block with the given `name` exists with the requested `size`. If a stale block exists (e.g., from an earlier crashed run), it is closed and unlinked before creating a fresh one. Args: name: Symbolic name of the shared memory region (e.g., ``"emu_frame_0"``). size: Size in **bytes** to allocate for the region. Returns: multiprocessing.shared_memory.SharedMemory: An opened handle to the new shared-memory block. The caller owns the handle and is responsible for closing it (and unlinking at teardown time). Raises: ValueError: If `size <= 0`. OSError: If the OS cannot allocate or map the segment. Side Effects: - May unlink an existing segment of the same name. - Creates a new segment in the system shared-memory namespace. """ from multiprocessing import shared_memory if size <= 0: raise ValueError("safe_shared_memory: size must be > 0") try: shm = shared_memory.SharedMemory(name=name, create=True, size=size) except FileExistsError: old = shared_memory.SharedMemory(name=name) old.close() old.unlink() shm = shared_memory.SharedMemory(name=name, create=True, size=size) return shm
[docs] def initialize_emulator() -> DeSmuME: """Initialize and prime a DeSmuME emulator instance. Loads the MKDS ROM, restores a savestate (slot 3), mutes audio, and cycles once to ensure memory is initialized. Then spins until the emulator reports it is running. Returns: DeSmuME: A ready-to-use emulator instance positioned at the savestate. Notes: - This function blocks until ``emu.is_running()`` returns True. - The ROM path ``"mariokart_ds.nds"`` and savestate index are hard-coded. """ emu = DeSmuME() emu.open("mariokart_ds.nds") emu.savestate.load(SAVE_STATE_ID) emu.volume_set(0) emu.cycle() while not emu.is_running(): print("Waiting for emulator...") return emu
[docs] def initialize_window(emu, config: EmulatorProcessConfig, batch_config: EmulatorBatchConfig) -> SharedEmulatorWindow | None: """Create and initialize the tiled GTK window for live visualization. Computes a near-square grid (``n_rows`` × ``n_cols``) based on the number of display-enabled processes, instantiates a renderer bound to `emu`, and returns a :class:`SharedEmulatorWindow` configured to read from the provided shared-memory frame names. Args: emu: Active :class:`DeSmuME` instance (used to build the renderer). display_count: Number of display-enabled workers to tile. shm_names: List of shared-memory segment names (``"emu_frame_{id}"``). Returns: SharedEmulatorWindow: GTK window object ready to be shown. Side Effects: - Initializes a GTK/Cairo renderer via :class:`AbstractRenderer`. """ if not config['host']: return None shm_names = batch_config['display_shm_names'] display_count = len(shm_names) width = 1000 height = math.floor(width * (SCREEN_HEIGHT / SCREEN_WIDTH)) n_cols = math.ceil(math.sqrt(display_count)) n_rows = math.ceil(display_count / n_cols) renderer = AbstractRenderer.impl(emu) renderer.init() window = SharedEmulatorWindow( width=width, height=height, n_cols=n_cols, n_rows=n_rows, renderer=renderer, shm_names=shm_names, ) return window
[docs] def initialize_overlays( config: EmulatorProcessConfig, batch_config: EmulatorBatchConfig ) -> Queue | None: """Start a background overlay thread and return its work queue. Given a list of overlay IDs, looks them up in :data:`AVAILABLE_OVERLAYS`, starts a single daemon thread that consumes :class:`DeSmuME` instances from a queue and applies the overlays. The queue is returned to the caller to submit per-frame overlay requests. Args: overlay_ids: List of overlay identifiers to enable (indexes into :data:`AVAILABLE_OVERLAYS`). device: Torch device on which overlay computations (if any) should run. Returns: Queue | None: If `overlay_ids` is non-empty, a ``Queue`` into which the caller should ``put(emu)`` once per frame, and ``put(None)`` on shutdown. Returns ``None`` when `overlay_ids` is empty. Notes: - The overlay worker catches exceptions per overlay and propagates a summarized error message on failure via :func:`safe_thread`. - Overlays are executed off the emulation thread to avoid jitter. """ overlay_ids = batch_config['overlay_ids'] if not config['show'] or not len(overlay_ids) > 0: return None device = batch_config['device'] overlays = [] for id in overlay_ids: overlays.append(AVAILABLE_OVERLAYS[id]) emu_queue = Queue() def worker(): nonlocal overlays, emu_queue, id assert emu_queue is not None while True: emu_instance = emu_queue.get() if emu_instance is None: break for overlay in overlays: safe_overlay = safe_thread(overlay, proc_id=id) safe_overlay(emu_instance, device=device) emu_queue.task_done() thread = Thread(target=worker, daemon=True) thread.start() return emu_queue
[docs] def handle_controls(emu: DeSmuME, logits: torch.Tensor): """Apply model outputs to emulator controls with a simple threshold policy. All values ``>= 0.5`` are considered pressed for the corresponding ``MODEL_KEY_MAP`` entry. Additionally, when any action is pressed, the accelerator (mapped to ``MODEL_KEY_MAP[5]``) is also pressed to keep the kart moving. Args: emu: Active emulator instance whose keypad state will be updated. logits: 1D tensor of action activations aligned with ``MODEL_KEY_MAP``. Side Effects: - Calls ``emu.input.keypad_update(0)`` and ``emu.input.keypad_add_key(...)`` multiple times. """ logits_list = logits.tolist() emu.input.keypad_update(0) for i, v in enumerate(logits_list): if v < 0.5: continue emu.input.keypad_add_key(keymask(MODEL_KEY_MAP[i])) emu.input.keypad_add_key(keymask(MODEL_KEY_MAP[5]))
[docs] def initialize_model(emu: DeSmuME, config: EmulatorProcessConfig, batch_config: EmulatorBatchConfig): device = batch_config['device'] sample = config['sample'] model = EvolvedNet(sample, device=device) forward = get_forward_func(emu, model, device) return forward
def _run_process(training_stats: dict[int, dict[int, CheckpointRecord]], training_stats_lock, config: EmulatorProcessConfig, batch_config: EmulatorBatchConfig): assert config['show'] if config['host'] else True, "Host processes must have display enabled" # Initialize emulator emu = initialize_emulator() # Initialize model forward = initialize_model(emu, config, batch_config) # Initialize display shared memory buffer as numpy array frame = None if config["show"]: id = config['id'] shm_frame = SharedMemory(name=f"emu_frame_{id}") frame = np.ndarray( shape=(SCREEN_HEIGHT, SCREEN_WIDTH, 4), dtype=np.uint8, buffer=shm_frame.buf, ) # Initialize window window = initialize_window(emu, config, batch_config) # Set overlay overlay thread emu_queue = initialize_overlays(config, batch_config) stats: dict[int, CheckpointRecord] | None = None def tick(): nonlocal stats, emu_queue, frame, emu, window emu.cycle() if frame is not None: # Copy display data to shared memory buffer buf = emu.display_buffer_as_rgbx()[: SCREEN_PIXEL_SIZE * 4] new_frame = on_draw_memoryview(buf, SCREEN_WIDTH, SCREEN_HEIGHT, 1.0, 1.0) np.copyto(frame, new_frame) # Inference / Game Update logits = forward() if isinstance(logits, dict): id = config['id'] send_window_end_signal(config['id']) stats = logits return False handle_controls(emu, logits) if emu_queue is not None: # Queue Overlay Request emu_queue.put(emu) return True while not config['host']: val = tick() if val == False: break if window is not None: # Will incrementally check if the population has died def check_end(): """Quit GTK when all visible frames are cleared to zeros.""" for name in batch_config['display_shm_names']: shm = SharedMemory(name=name) arr = np.ndarray( (SCREEN_WIDTH, SCREEN_HEIGHT, 4), dtype=np.uint8, buffer=shm.buf ) if arr.sum() != 0: return True Gtk.main_quit() return False GLib.timeout_add(200, check_end) # non-blocking GLib.timeout_add(33, tick) # non-blocking window.show_all() Gtk.main() # blocking # Safe thread shutdown for overlay if emu_queue is not None: emu_queue.put(None) # Log results with training_stats_lock: id = config['id'] if stats is None: stats = {} training_stats[id] = stats
[docs] def safe_thread(func, proc_id, thread_id=0): """Wrap a function for background execution with nicer error reporting. The returned wrapper calls `func(*args, **kwargs)` and converts any exception into a concise message identifying the logical process and thread of origin. Args: func: Callable to wrap. proc_id: Integer process identifier for error messages. thread_id: Integer thread identifier for error messages. Returns: Callable: A new callable with identical signature that raises a concise :class:`Exception` on failure. """ def wrapper(*args, **kwargs): try: func(*args, **kwargs) except Exception as e: raise Exception(f"Error on thread {thread_id} of process {proc_id}") return wrapper
[docs] def send_window_end_signal(id): """Zero a per-process frame buffer to signal the host window to exit. Args: id: Process index whose frame buffer should be cleared. Side Effects: - Writes zeros into the shared frame ``emu_frame_{id}``, which is used by the host window's polling logic to detect end-of-batch. """ shm = SharedMemory(name=f"emu_frame_{id}") arr = np.ndarray((SCREEN_HEIGHT, SCREEN_WIDTH, 4), dtype=np.uint8, buffer=shm.buf) arr[:] = 0.0
[docs] def get_forward_func(emu: DeSmuME, model: EvolvedNet, device): """Build a closure that performs one model step and checkpoint bookkeeping. The returned callable reads emulator memory for sensor inputs, constructs the model input vector, computes the control logits, and records per-checkpoint split times and distances. When a terminal condition is reached (e.g. clock > 10000), it returns the accumulated stats dict instead of logits. Args: emu: Active emulator instance to read game state from. model: Evolved network to evaluate (expects 6 inputs → action logits). device: Torch device on which tensors are constructed and the model runs. Returns: Callable[[], torch.Tensor | dict[int, list[tuple[float, float]]]]: A no-argument function that returns either a 1D tensor of action logits or a stats dictionary signaling the end of this individual's run. Sensor model: - Distances: forward/left/right obstacle distances are read and mapped through ``tanh(1 - d / s1)`` with ``s1 = 60.0`` to compress range. - Angles: direction to the next checkpoint as (cos θ, sin θ, -sin θ). Notes: - Checkpoint bookkeeping appends tuples of ``(delta_time, distance_at_split)``. - This function reads directly from emulator memory via utility helpers. """ current_time = read_clock(emu) prev_time = current_time current_id = read_current_checkpoint(emu) prev_id = current_id prev_dist = read_checkpoint_distance_altitude(emu, device=device).item() times: dict[int, CheckpointRecord] = {} def forward() -> torch.Tensor | dict[int, CheckpointRecord]: nonlocal current_id, prev_id, current_time, prev_time, prev_dist, model, emu, times clock = read_clock(emu) if clock > 10000: return times emu = emu prev_id = read_current_checkpoint(emu) clock = read_clock(emu) if current_id != prev_id: assert isinstance(current_id, int) current_time = clock if current_id not in times: entry = { "id": current_id, "times": [], "dists": [] } times[current_id] = cast(CheckpointRecord, entry) cast(dict, times[current_id])['times'].append(current_time - prev_time) cast(dict, times[current_id])['dists'].append(prev_dist) prev_time = current_time prev_dist = read_checkpoint_distance_altitude(emu, device=device).item() s1 = 60.0 # Sensor inputs (obstacle distances) forward_d = read_forward_distance_obstacle(emu, device=device) left_d = read_left_distance_obstacle(emu, device=device) right_d = read_right_distance_obstacle(emu, device=device) inputs_dist1 = torch.tensor([forward_d, left_d, right_d], device=device) inputs_dist1 = torch.tanh(1 - inputs_dist1 / s1) # Angular relationship to next checkpoint angle = read_direction_to_checkpoint(emu, device=device) forward_a = torch.cos(angle) left_a = torch.sin(angle) right_a = -torch.sin(angle) inputs_dist2 = torch.tensor([forward_a, left_a, right_a], device=device) # Model inference inputs = torch.cat([inputs_dist1, inputs_dist2]) logits = model(inputs) return logits return forward
[docs] def run_training_batch( batch_population: list[Genome], show_samples: list[bool], training_stats: DictProxy[int, dict[int, CheckpointRecord]], training_stats_lock, batch_config: EmulatorBatchConfig ): """Evaluate a batch of genomes concurrently, optionally with live display. One process in the batch is promoted to the *display host* (first True in `show_samples`) and creates a tiled GTK window that reads from the per-process shared-memory frames listed in `shm_names`. Additional True entries run as display workers; False entries run headless. Args: batch_pop: Slice of the population to evaluate in this batch. show_samples: Per-individual flags controlling display mode; exactly one True is chosen as the host (the first True), additional Trues are workers; all False means fully headless batch. overlay_ids: Enabled overlay identifiers. lock: IPC lock for synchronized writes to `pop_stats`. pop_stats: Manager dict where each process writes a stats dict under its local batch index. Side Effects: - Creates per-process shared-memory frame segments for display-enabled individuals. - Spawns processes with appropriate targets and joins them. """ batch_size = batch_config['size'] processes = [] host_proc_found = False for id, sample, show in zip(range(batch_size), batch_population, show_samples): config: EmulatorProcessConfig = { "id": id, "sample": sample, "host": False, "show": show } if not host_proc_found: host_proc_found = True config["host"] = True if show: shm_frame = SharedMemory(name=f"emu_frame_{id}") frame = np.ndarray( shape=(SCREEN_HEIGHT, SCREEN_WIDTH, 4), dtype=np.uint8, buffer=shm_frame.buf, ) frame[:] = 1.0 process = Process( target=_run_process, args=(training_stats, training_stats_lock, config, batch_config), daemon=True, ) processes.append(process) # Start processes for p in processes: p.start() # Join processes for p in processes: p.join()
[docs] def run_training_session( pop: list[Genome], num_proc: int | None = None, show_samples: list[bool] = [True], overlay_ids: list[int] = [], device: DeviceLikeType | None = None ) -> dict[int, list[CheckpointRecord]]: """Evaluate the full population in parallel batches and collect statistics. The population is partitioned into batches of size ``min(num_proc, remaining)``. Each batch is launched via :func:`run_training_batch`, returning when all processes in the batch complete and their stats have been merged. Args: pop: Full population of genomes to evaluate. num_proc: Maximum number of concurrent processes. If ``None``, uses ``os.cpu_count() - 1``. show_samples: List of booleans determining which individuals in each batch should display; a one-element list (e.g., ``[False]``) is **broadcast** to the batch size on each iteration. overlay_ids: Overlay identifiers to enable in display-enabled processes. Returns: dict[int, dict[int, list[tuple[float, float]]]]: Mapping of *global* population index to that individual's checkpoint stats dict. Notes: - This function uses a :class:`multiprocessing.Manager` ``dict`` so that per-process stats can be retrieved without explicit pipes or queues. - Shared-memory frame buffers are currently **not** unlinked here; consider cleaning them in a higher-level teardown if needed. """ global shm_names # If no number of processes is specified, then we'll use the max minus one if num_proc is None: num_proc = os.cpu_count() assert num_proc is not None num_proc -= 1 if len(show_samples) == 1: show_samples *= num_proc pop_stats: dict[int, list[CheckpointRecord]] = {} pop_size = len(pop) count = 0 shm_names = [] for i in range(num_proc): if not show_samples[i]: continue name = f"emu_frame_{i}" size = SCREEN_HEIGHT * SCREEN_WIDTH * 4 shm_frame = safe_shared_memory(name=name, size=size) frame = np.ndarray( shape=(SCREEN_HEIGHT, SCREEN_WIDTH, 4), dtype=np.uint8, buffer=shm_frame.buf, ) shm_names.append(name) while count < pop_size: batch_size = min(pop_size - count, num_proc) batch_config: EmulatorBatchConfig = { "overlay_ids": overlay_ids, "display_shm_names": shm_names, "size": batch_size, "device": device } show_samples = show_samples[:batch_size] with Manager() as manager: # Create shared list for stats (locking) shared_pop_stats: DictProxy[int, dict[int, CheckpointRecord]] = manager.dict() lock = Lock() run_training_batch( pop[count : count + batch_size], show_samples=show_samples, training_stats=shared_pop_stats, training_stats_lock=lock, batch_config=batch_config ) for k, s in shared_pop_stats.items(): pop_stats[count + k] = list(s.values()) count += batch_size # TODO: Cleanup all shared memory buffers here return pop_stats
[docs] def fitness(pop_stats: dict[int, list[CheckpointRecord]]) -> list[float]: """Compute scalar fitness from per-checkpoint stats. The current fitness heuristic sums the recorded distances across all checkpoint splits for each individual. Args: pop_stats: Mapping of population index to that individual's stats dict (``checkpoint_id -> [(delta_time, distance_at_split), ...]``). Returns: list[float]: Fitness values *ordered by population index*. """ def total_dist(v: list[CheckpointRecord]): return sum([sum(cast(dict, r)['dists']) for r in v]) pop_stats_list = [(k, total_dist(s)) for k, s in pop_stats.items()] pop_stats_list.sort(key=lambda x: x[0]) return [x[1] for x in pop_stats_list]
[docs] def train( num_iters: int, pop_size: int, log_interval: int = 1, top_k: int | float = 0.1, device: DeviceLikeType | None = None, **simulation_kwargs, ): """Main evolutionary training loop (selection + mutation). Repeats: 1) Evaluate the current population via :func:`run_training_session`. 2) Rank by fitness. 3) Keep the best, and refill the population by mutating uniformly sampled parents from the top-k set. Args: num_iters: Number of generations to run. pop_size: Number of individuals per generation. log_interval: Print progress every N generations. top_k: Either the number of top individuals to sample parents from, or a fraction in ``(0, 1]`` interpreted as a proportion of the population. **simulation_kwargs: Passed through to :func:`run_training_session`. Side Effects: - Prints best fitness per `log_interval`. - Mutates and replaces the population in place each generation. """ pop = [Genome(6, 2, device=device) for _ in range(pop_size)] for g in pop: g.mutate_add_conn() # start with random links if top_k <= 1: top_k = int(round(len(pop) * top_k)) assert isinstance(top_k, int), "top_k must be an integer or a float less than 1" for n in range(num_iters): stats = run_training_session(pop, device=device, **simulation_kwargs) scores = fitness(stats) scores = [(p, s) for p, s in zip(pop, scores)] scores.sort(reverse=True, key=lambda x: x[1]) if n % log_interval == 0: os.system("clear") print(f"Best Fitness: {scores[0][1]}") newpop = [copy.deepcopy(scores[0][0])] for _ in range(len(pop) - 1): g = copy.deepcopy(random.choice(scores[:top_k])[0]) random.choice([g.mutate_weight, g.mutate_add_conn, g.mutate_add_node])() newpop.append(g) pop = newpop
if __name__ == "__main__": device = get_mps_device() train( num_iters=1000, pop_size=13, device=device, top_k=5, show_samples=[True], overlay_ids=[0, 3, 4], num_proc=13, )