src.core package¶
Submodules¶
src.core.memory module¶
Mario Kart DS (MKDS) Emulator I/O & Geometry Utilities¶
This module provides a high-level, vectorized interface for reading game state from a running DeSmuME emulator and performing common geometric operations used in visualization, control, and RL policy features for Mario Kart DS.
It wraps low-level memory reads (positions, directions, camera data, objects, checkpoints, clock) and exposes a compact API for tasks like projecting world-space points to screen-space, computing distances to checkpoints and obstacles, and deriving view matrices.
- The implementation favors:
Deterministic caching at frame and game lifetimes to minimize emulator I/O, and
Torch tensor computations (CPU / CUDA / MPS) for fast, batched math.
Quick Start¶
>>> from desmume.emulator import DeSmuME
>>> import torch
>>> from your_module import (
... read_position, read_direction, project_to_screen, read_forward_distance_checkpoint
... )
>>> emu = DeSmuME()
>>> # ... open ROM, load state, etc.
>>> device = torch.device("cpu") # or "cuda", "mps"
>>> pos = read_position(emu, device=device) # (3,)
>>> dir = read_direction(emu, device=device) # (3,)
>>> screen = project_to_screen(emu, pos.unsqueeze(0), device=device) # (1, 4)
>>> fwd_to_cp = read_forward_distance_checkpoint(emu, device=device) # scalar tensor
Key Concepts & Conventions¶
Coordinate Systems¶
World Space: Right-handed, with Y as up. Many functions assume a canonical “up” vector of (0, 1, 0) and construct a right-handed orthonormal basis [right, up, forward].
Camera Space / Clip Space / NDC / Screen Space: _compute_model_view builds a model-view matrix from camera position/target; _project_to_screen applies perspective and viewport transforms to return pixel coordinates for a 256×192 screen (Nintendo DS top display).
Screen Origin: (0, 0) is top-left. X grows to the right; Y grows downward. This follows the standard raster convention and matches the (1 - ndc_y) transform used in projection.
Units¶
Positions & Scalars returned from memory are derived from MKDS fixed-point formats (FX32, etc.) via helpers like read_vector_3d_fx32 and read_fx32, and are exposed as Python floats / Torch tensors.
Angles: - Camera FOV is read from a 16-bit fixed-point angle and converted to radians
using:
value * (2π / 0x10000)
.Time: - read_clock() returns centiseconds (10 ms units).
Memory Map & Assets¶
Addresses: The module uses static addresses for key pointers (racer, course, objects, checkpoints, camera, clock). See constants: RACER_PTR_ADDR, COURSE_ID_ADDR, OBJECTS_PTR_ADDR, CHECKPOINT_PTR_ADDR, CLOCK_DATA_PTR, CAMERA_PTR_ADDR.
Course Files: - courses.json maps course IDs to directory names. - KCL (course_collision.kcl) and NKM (course_map.nkm) are loaded via
KCLTensor.from_file(…) and NKMTensor.from_file(…).
Caching Model¶
Two decorators reduce emulator I/O:
@frame_cache
— Caches the function’s single return value per emulator tick (emu.get_ticks()). Recomputes only when the tick changes.@game_cache
— Caches the function’s single return value for the process lifetime (until interpreter exit).
⚠ Important: Both caches ignore argument values. If you call a cached function with different arguments within the same lifetime (same frame or same run), the first computed result is reused. In practice, pass stable arguments (e.g., a constant device) to avoid surprises.
Device Handling¶
Many functions accept a Torch device and return tensors allocated there. For best performance, use cuda (GPU) or mps (Apple Silicon) when available, and keep devices consistent across the call sites, especially for @game_cache results (KCL/NKM tensors are created on the device used at first call).
Public API Overview¶
Clock & Course¶
read_clock_ptr(emu) — Base pointer to clock data (cached for game lifetime).
read_clock(emu) — Current clock in 10 ms units (cached per frame).
get_current_course_id(emu) — Current course ID (byte).
get_course_path(id) — Course directory name from courses.json.
load_current_kcl(emu, device) — Parsed KCL collision mesh (game-cached).
load_current_nkm(emu, device) — Parsed NKM map data (game-cached).
Player & Objects¶
read_racer_ptr(emu) — Pointer to the player racer struct.
read_position(emu, device) — Player world position (3,).
read_direction(emu, device) — Player forward direction (3,).
read_objects(…), read_object_* helpers — Scans and queries object table. - safe_object decorator returns None for deleted objects.
Camera & Projection¶
read_camera_ptr(emu) — Pointer to camera struct.
read_camera_fov(emu) — FOV in radians.
read_camera_aspect(emu) — Aspect ratio (W/H).
read_camera_position(emu, device) — Camera world pos (3,) with elevation.
read_camera_target_position(emu, device) — Camera look-at (3,).
read_model_view(emu, device) — 4×4 model-view matrix.
project_to_screen(emu, points, device) — Projects (N,3) to (N,4): [x_px, y_px, clip_z, normalized_depth].
z_clip_mask(x) — Mask for points within Z-near/far bounds (camera space).
Checkpoints¶
read_checkpoint_ptr(emu) — Pointer to checkpoint manager.
read_current_checkpoint(emu), read_current_key_checkpoint(emu), read_current_lap(emu) — Indices for current progress.
read_ghost_checkpoint(emu), read_ghost_key_checkpoint(emu) — Ghost state.
read_checkpoint_positions(emu, device) — (C, 2, 3) segment endpoints.
read_next_checkpoint(emu, checkpoint_count) — Next index (wraps).
read_next_checkpoint_position(emu, device), read_current_checkpoint_position(emu, device) — (2,3) endpoints.
read_facing_point_checkpoint(emu, direction, device) — Intersection of a ray (from player, given direction) with next checkpoint line in XZ.
read_forward_distance_checkpoint(emu, device), read_left_distance_checkpoint(emu, device), read_direction_to_checkpoint(emu, device) — Distances/steering angle.
Obstacles (Walls / Offroad)¶
read_facing_point_obstacle(emu, position, direction, device) — Samples a cone of rays around the forward direction to find the nearest hit against wall/offroad triangles. Returns a point or None.
read_forward_distance_obstacle(emu, device), read_left_distance_obstacle(emu, device), read_right_distance_obstacle(emu, device) — Scalar distances to nearest obstacles along canonical forward/left/right rays. Return +inf when no hit.
Return Types & Shapes¶
Positions / Directions: torch.Tensor with shape (3,).
Batches of points: (N, 3).
Screen Projection: (N, 4) → [x_px, y_px, clip_z, normalized_depth].
Checkpoints: (C, 2, 3) → per checkpoint two endpoints [p1, p2].
Distances / Angles: 0-D or 1-D scalar torch.Tensor (depending on operation).
Errors & Edge Cases¶
Deleted / Ignored Objects: safe_object-wrapped functions return None when the object is deleted; callers must handle None.
No Geometry: When there are no wall/offroad triangles or raycasts miss, obstacle distance functions return +inf (as a tensor).
Empty Projections: _project_to_screen returns an empty tensor when given no points; invalid (behind-camera) points may still project with negative clip_w.
Performance Notes¶
Caching eliminates redundant memory reads across a frame / game run.
Geometry routines (ray casting, distances, projection) are vectorized in Torch; prefer GPU/MPS devices when available.
Keep devices consistent across calls that share cached state (e.g., KCL/NKM).
Project player and next checkpoint endpoints to screen:
>>> pts = torch.vstack([read_position(emu, device), # (1,3)
... read_next_checkpoint_position(emu, device)]).reshape(-1, 3)
>>> screen_pts = project_to_screen(emu, pts, device)
>>> screen_pts[:, :2] # pixel coordinates
Compute lateral vs forward distance to next checkpoint:
>>> d_left = read_left_distance_checkpoint(emu, device)
>>> d_front = read_forward_distance_checkpoint(emu, device)
Find nearest obstacle straight ahead:
>>> d_obs = read_forward_distance_obstacle(emu, device)
>>> float(d_obs) if torch.isfinite(d_obs) else float("inf")
Implementation Notes¶
_compute_orthonormal_basis builds a right-handed frame from a forward vector and an up-like reference (default (0,1,0)), normalizing each axis.
_compute_model_view constructs a 4×4 model-view matrix in row-major with basis rows [right, up, forward] and a translated origin.
_project_to_screen creates a simple perspective matrix using vertical FOV and aspect ratio; returns pixel coordinates using constants SCREEN_WIDTH = 256 and SCREEN_HEIGHT = 192.
Compatibility¶
Tested with DeSmuME Python bindings and Torch. Some ops may vary by backend (e.g., MPS lacks a few linear algebra kernels); this module sticks to widely supported APIs.
- src.core.memory.frame_cache(func: Callable[[Concatenate[desmume.emulator.DeSmuME, P]], R]) Callable[[Concatenate[desmume.emulator.DeSmuME, P]], R] [source]¶
Decorator that caches a function’s return value once per emulator tick.
The wrapped function will only be re-executed when emu.get_ticks() changes. Useful for expensive reads that don’t change within a single frame.
- Parameters:
func – A function whose first argument is a DeSmuME instance.
- Returns:
A wrapped function with identical signature that returns a cached result per tick.
- src.core.memory.game_cache(func: Callable[[Concatenate[desmume.emulator.DeSmuME, P]], R]) Callable[[Concatenate[desmume.emulator.DeSmuME, P]], R] [source]¶
Decorator that caches a function’s return value for the process lifetime.
The wrapped function executes once and its result is reused thereafter. Appropriate for data that remains constant across a run (e.g., course files).
- Parameters:
func – A function whose first argument is a DeSmuME instance.
- Returns:
A wrapped function with identical signature that returns a cached result.
- src.core.memory.z_clip_mask(x: torch.Tensor) torch.Tensor [source]¶
Compute a boolean mask for points within the view frustum Z range.
- Parameters:
x – Tensor of shape (N, 3+) where x[:, 2] is the camera-space Z.
- Returns:
A boolean tensor of shape (N,) where True indicates Z is between -Z_FAR and -Z_NEAR.
- src.core.memory.read_clock_ptr(emu: desmume.emulator.DeSmuME)[source]¶
Read the base pointer to the game’s clock data structure.
- Parameters:
emu – Emulator instance.
- Returns:
Integer address of the clock data struct.
- src.core.memory.read_clock(emu: desmume.emulator.DeSmuME)[source]¶
Read the current game clock value.
The value is read from the clock data structure and multiplied by 10, resulting in units of 10 ms (centiseconds).
- Parameters:
emu – Emulator instance.
- Returns:
Integer time in 10 ms units.
- src.core.memory.get_current_course_id(emu: desmume.emulator.DeSmuME)[source]¶
Read the current course ID from memory.
- Parameters:
emu – Emulator instance.
- Returns:
Integer course ID (byte).
- src.core.memory.get_course_path(id: int, lookup_path: str = './src/misc/courses.json')[source]¶
Resolve a course ID to the local filesystem path for its assets.
- Parameters:
id – Course ID.
- Returns:
String path relative to ./courses/ for the given course.
- Raises:
AssertionError – If the course ID is not present in the lookup table.
- src.core.memory.load_current_kcl(emu: desmume.emulator.DeSmuME, device)[source]¶
Load and parse the KCL collision file for the current course.
Cached for the lifetime of the process.
- Parameters:
emu – Emulator instance.
device – Torch device (e.g., ‘cpu’, ‘cuda’, ‘mps’) to store tensors on.
- Returns:
KCLTensor with triangle and prism data on the specified device.
- src.core.memory.load_current_nkm(emu: desmume.emulator.DeSmuME, device)[source]¶
Load and parse the NKM map file for the current course.
Cached for the lifetime of the process.
- Parameters:
emu – Emulator instance.
device – Torch device to store tensors on.
- Returns:
NKMTensor with NKM section tensors (e.g., checkpoints) on the specified device.
- src.core.memory.read_racer_ptr(emu: desmume.emulator.DeSmuME, addr: int = 35106040)[source]¶
Read the pointer to the player’s racer object.
- Parameters:
emu – Emulator instance.
addr – Memory address where the racer pointer is stored.
- Returns:
Integer address of the racer structure.
- src.core.memory.read_position(emu: desmume.emulator.DeSmuME, device)[source]¶
Read the player’s world-space position.
- Parameters:
emu – Emulator instance.
device – Torch device for the returned tensor.
- Returns:
torch.Tensor of shape (3,) representing (x, y, z) in world units.
- src.core.memory.read_direction(emu: desmume.emulator.DeSmuME, device)[source]¶
Read the player’s forward direction vector (world-space).
- Parameters:
emu – Emulator instance.
device – Torch device for the returned tensor.
- Returns:
torch.Tensor of shape (3,) representing the forward direction.
- src.core.memory.read_objects_array_max_count(emu: desmume.emulator.DeSmuME, addr: int = 35108232)[source]¶
Read the maximum number of objects in the global object array.
- Parameters:
emu – Emulator instance.
addr – Base address of the object array metadata.
- Returns:
Signed integer max count.
- src.core.memory.read_objects_array_ptr(emu: desmume.emulator.DeSmuME, addr: int = 35108232)[source]¶
Read the pointer to the global object pointer array.
- Parameters:
emu – Emulator instance.
addr – Base address of the object array metadata.
- Returns:
Signed integer address of the object pointer array.
- src.core.memory.read_object_offset(emu: desmume.emulator.DeSmuME, id: int)[source]¶
Compute the memory offset of an object entry within the array.
- Parameters:
emu – Emulator instance.
id – Object index.
- Returns:
Integer byte offset to the object’s metadata entry.
- src.core.memory.read_object_ptr(emu: desmume.emulator.DeSmuME, id: int)[source]¶
Read the object instance pointer for a given object ID.
- Parameters:
emu – Emulator instance.
id – Object index.
- Returns:
Integer address of the object struct (0 if null).
- src.core.memory.read_object_flags(emu: desmume.emulator.DeSmuME, id: int)[source]¶
Read the object’s flags (type/category bits, state, etc.).
- Parameters:
emu – Emulator instance.
id – Object index.
- Returns:
Unsigned short flags value.
- src.core.memory.read_object_position_ptr(emu: desmume.emulator.DeSmuME, id: int)[source]¶
Read the pointer to an object’s position vector in memory.
- Parameters:
emu – Emulator instance.
id – Object index.
- Returns:
Integer address for the object’s position struct (0 if deleted).
- src.core.memory.read_object_is_ignored(emu: desmume.emulator.DeSmuME, id: int)[source]¶
Determine if an object should be ignored (null or ignored-flag set).
- Parameters:
emu – Emulator instance.
id – Object index.
- Returns:
True if object ptr is 0 or ignored bit is set; False otherwise.
- src.core.memory.read_object_is_deleted(emu: desmume.emulator.DeSmuME, id: int)[source]¶
Check if the object has been deleted (position pointer is null).
- Parameters:
emu – Emulator instance.
id – Object index.
- Returns:
True if deleted; False otherwise.
- src.core.memory.safe_object(func)[source]¶
Decorator that skips object reads when the object appears deleted.
The wrapped function receives (emu, id, *args, **kwargs). If the object is deleted (null position pointer), the wrapper returns None.
- src.core.memory.read_object_position(emu: desmume.emulator.DeSmuME, id: int, *args, **kwargs)[source]¶
Internal wrapper used by safe_object to guard deleted objects.
- src.core.memory.read_map_object_type_id(emu: desmume.emulator.DeSmuME, id: int, *args, **kwargs)[source]¶
Internal wrapper used by safe_object to guard deleted objects.
- src.core.memory.read_map_object_is_coin_collected(emu: desmume.emulator.DeSmuME, id: int, *args, **kwargs)[source]¶
Internal wrapper used by safe_object to guard deleted objects.
- src.core.memory.read_racer_object_is_ghost(emu: desmume.emulator.DeSmuME, id: int, *args, **kwargs)[source]¶
Internal wrapper used by safe_object to guard deleted objects.
- src.core.memory.read_objects(emu: desmume.emulator.DeSmuME)[source]¶
Scan the global object table and group object indices by category.
- Categories:
‘map_objects’
‘racer_objects’
‘item_objects’
‘dynamic_objects’
- Returns:
Dict[str, list[int]] mapping category name to list of indices.
- src.core.memory.read_camera_ptr(emu: desmume.emulator.DeSmuME, addr: int = 35105356)[source]¶
Read the pointer to the active camera structure.
- Parameters:
emu – Emulator instance.
addr – Address where the camera pointer is stored.
- Returns:
Integer address of the camera struct.
- src.core.memory.read_camera_fov(emu: desmume.emulator.DeSmuME)[source]¶
Read the current camera field-of-view (radians).
The FOV value is stored as a 16-bit fixed-point angle; it is converted to radians.
- Parameters:
emu – Emulator instance.
- Returns:
Floating-point FOV in radians.
- src.core.memory.read_camera_aspect(emu: desmume.emulator.DeSmuME)[source]¶
Read the camera aspect ratio from memory.
- Parameters:
emu – Emulator instance.
- Returns:
Float aspect ratio (width/height).
- src.core.memory.read_camera_position(emu: desmume.emulator.DeSmuME, device)[source]¶
Read the camera world position, including elevation offset.
- Parameters:
emu – Emulator instance.
device – Torch device for the returned tensor.
- Returns:
torch.Tensor shape (3,) representing camera (x, y, z).
- src.core.memory.read_camera_target_position(emu: desmume.emulator.DeSmuME, device)[source]¶
Read the camera’s target/look-at position in world space.
- Parameters:
emu – Emulator instance.
device – Torch device for the returned tensor.
- Returns:
torch.Tensor shape (3,) target (x, y, z).
- src.core.memory.read_model_view(emu: desmume.emulator.DeSmuME, device)[source]¶
Compute and cache the camera model-view matrix for the current frame.
- Parameters:
emu – Emulator instance.
device – Torch device for returned matrix.
- Returns:
torch.Tensor shape (4,4) model-view matrix.
- src.core.memory.project_to_screen(emu: desmume.emulator.DeSmuME, points: torch.Tensor, device, screen_dim=(256, 192))[source]¶
Convenience wrapper to project points using the current camera state.
- Parameters:
emu – Emulator instance.
points – Tensor shape (N,3) of world-space points.
device – Torch device.
- Returns:
Tensor shape (N,4) in screen space (see _project_to_screen).
- src.core.memory.read_checkpoint_ptr(emu: desmume.emulator.DeSmuME, addr: int = 35083772)[source]¶
Read the pointer to the checkpoint manager/state.
- Parameters:
emu – Emulator instance.
addr – Address where the checkpoint pointer is stored.
- Returns:
Integer address for checkpoint data.
- src.core.memory.read_current_checkpoint(emu: desmume.emulator.DeSmuME)[source]¶
Read the index of the current checkpoint.
- Parameters:
emu – Emulator instance.
- Returns:
Unsigned byte checkpoint index.
- src.core.memory.read_current_key_checkpoint(emu: desmume.emulator.DeSmuME)[source]¶
Read the current key checkpoint index (special/lap-related).
- Parameters:
emu – Emulator instance.
- Returns:
Signed byte key checkpoint index.
- src.core.memory.read_ghost_checkpoint(emu: desmume.emulator.DeSmuME)[source]¶
Read the recorded ghost’s current checkpoint index.
- Parameters:
emu – Emulator instance.
- Returns:
Signed byte ghost checkpoint index.
- src.core.memory.read_ghost_key_checkpoint(emu: desmume.emulator.DeSmuME)[source]¶
Read the recorded ghost’s current key checkpoint index.
- Parameters:
emu – Emulator instance.
- Returns:
Signed byte ghost key checkpoint index.
- src.core.memory.read_current_lap(emu: desmume.emulator.DeSmuME)[source]¶
Read the current lap number.
- Parameters:
emu – Emulator instance.
- Returns:
Signed byte lap index (0-based).
- src.core.memory.read_next_checkpoint(emu: desmume.emulator.DeSmuME, checkpoint_count: int)[source]¶
Compute the next checkpoint index (wrapping to 0 at the end).
- Parameters:
emu – Emulator instance.
checkpoint_count – Total number of checkpoints.
- Returns:
Integer index of the next checkpoint.
- src.core.memory.read_checkpoint_positions(emu: desmume.emulator.DeSmuME, device)[source]¶
Build a tensor of checkpoint segment endpoints in 3D.
Reads NKM and KCL, extracts floor geometry, and converts checkpoint pairs from 2D to 3D using nearest floor elevation.
- Parameters:
emu – Emulator instance.
device – Torch device.
- Returns:
Tensor shape (C, 2, 3) where C is number of checkpoints, containing [p1, p2] endpoints per checkpoint.
- src.core.memory.read_next_checkpoint_position(emu: desmume.emulator.DeSmuME, device)[source]¶
Get the 3D endpoints of the next checkpoint segment.
- Parameters:
emu – Emulator instance.
device – Torch device.
- Returns:
Tensor shape (2,3) representing the next checkpoint’s [p1, p2].
- src.core.memory.read_current_checkpoint_position(emu: desmume.emulator.DeSmuME, device)[source]¶
Get the 3D endpoints of the current checkpoint segment.
- Parameters:
emu – Emulator instance.
device – Torch device.
- Returns:
Tensor shape (2,3) representing current checkpoint’s [p1, p2].
- src.core.memory.read_facing_point_checkpoint(emu: desmume.emulator.DeSmuME, direction: torch.Tensor, device)[source]¶
Raycast from the player along a direction to the next checkpoint line (XZ).
- Parameters:
emu – Emulator instance.
direction – Tensor shape (3,) direction vector.
device – Torch device.
- Returns:
Tensor shape (3,) point of intersection in world coordinates.
- src.core.memory.read_forward_distance_checkpoint(emu, device)[source]¶
Compute forward distance from player to the next checkpoint line.
- Parameters:
emu – Emulator instance.
device – Torch device.
- Returns:
Scalar torch.Tensor distance.
- src.core.memory.read_left_distance_checkpoint(emu, device)[source]¶
Compute leftward distance from player to the next checkpoint line.
- Parameters:
emu – Emulator instance.
device – Torch device.
- Returns:
Scalar torch.Tensor distance.
- src.core.memory.read_direction_to_checkpoint(emu: desmume.emulator.DeSmuME, device)[source]¶
Compute a steering angle toward the next checkpoint from forward/left distances.
Angle is computed as atan(forward / left).
- Parameters:
emu – Emulator instance.
device – Torch device.
- Returns:
Scalar torch.Tensor angle in radians.
- src.core.memory.read_facing_point_obstacle(emu: DeSmuME, position: torch.Tensor | None = None, direction: torch.Tensor | None = None, device=None)[source]¶
Raycast toward walls/offroad and return the nearest hit point.
Samples a cone of directions around the provided (or player) direction, and finds the nearest intersection against wall and offroad triangles.
- Parameters:
emu – Emulator instance.
position – Optional world position (3,). Defaults to player’s position.
direction – Optional direction (3,). Defaults to player’s forward vector.
device – Torch device.
- Returns:
torch.Tensor shape (3,) hit point, or None if no intersections.
- src.core.memory.read_forward_distance_obstacle(emu: desmume.emulator.DeSmuME, device) torch.Tensor [source]¶
Compute forward distance to the nearest wall/offroad obstacle.
- Parameters:
emu – Emulator instance.
device – Torch device.
- Returns:
Scalar torch.Tensor distance; +inf if no hit.
- src.core.memory.read_left_distance_obstacle(emu: desmume.emulator.DeSmuME, device) torch.Tensor [source]¶
Compute leftward distance to the nearest wall/offroad obstacle.
- Parameters:
emu – Emulator instance.
device – Torch device.
- Returns:
Scalar torch.Tensor distance; +inf if no hit.
- src.core.memory.read_right_distance_obstacle(emu: desmume.emulator.DeSmuME, device) torch.Tensor [source]¶
Compute rightward distance to the nearest wall/offroad obstacle.
- Parameters:
emu – Emulator instance.
device – Torch device.
- Returns:
Scalar torch.Tensor distance; +inf if no hit.
- src.core.memory.read_checkpoint_distance_altitude(emu: desmume.emulator.DeSmuME, device) torch.Tensor [source]¶
Compute the altitude (height) of the triangle formed by player and checkpoint endpoints.
Uses the two checkpoint endpoints and the player’s position to form sides a and b, then returns the triangle altitude via triangle_altitude(a, b).
- Parameters:
emu – Emulator instance.
device – Torch device.
- Returns:
Scalar torch.Tensor altitude value.
src.core.model module¶
src.core.train module¶
Parallel trainer for Mario Kart DS agents using DeSmuME, multiprocessing, shared memory frame streaming, and optional live GTK visualization.
This module orchestrates end-to-end evaluation and evolution of a population of neural network controllers (NEAT-style) for Mario Kart DS. It supports three execution modes per evaluated individual:
Headless (run_process) — fast evaluation with no display.
Display worker (run_window_process) — renders frames and writes them into a per-process shared-memory buffer; no GTK loop.
Display host (run_window_host_process) — renders frames, writes them into shared memory, and owns the GTK window that tiles and presents all display-enabled workers in real time.
Key concepts¶
Shared memory frames: Each display-enabled process writes an RGBX framebuffer of shape
(SCREEN_HEIGHT, SCREEN_WIDTH, 4)
(dtypenp.uint8
) to a POSIX shared-memory segment namedf"emu_frame_{id}"
. The host window process opens these buffers read-only for display tiling.Overlays: Optional per-frame overlays are computed off the main emulation loop by a single background thread fed via a queue. Overlays are composited in the worker before writing to shared memory using
src.visualization.window.on_draw_memoryview()
.Statistics / fitness: Each process records split times and distances at track checkpoints as a
dict[int, list[tuple[float, float]]]
mappingcheckpoint_id -> [(delta_time, distance_at_split), ...]
. A simple fitness function sums the recorded distances.Batching & evolution:
run_training_session()
evaluates a subset (batch) of the population in parallel (bounded bynum_proc
), aggregates stats, thentrain()
evolves the population.
Threading & processes¶
The DeSmuME emulator is created and used inside each process that runs it.
The GTK main loop must run in a single process. This module designates one display-enabled process per batch as the host that creates the window and drives GTK via GLib.timeout_add.
Overlays are computed by a single background thread (daemon) within each display-enabled process to keep the emulation loop responsive.
- class src.core.train.EmulatorProcessConfig[source]¶
Bases:
TypedDict
- id: int¶
- host: bool¶
- show: bool¶
- class src.core.train.EmulatorBatchConfig[source]¶
Bases:
TypedDict
- size: int¶
- display_shm_names: list[str]¶
- device: DeviceLikeType | None¶
- overlay_ids: list[int]¶
- class src.core.train.CheckpointRecord[source]¶
Bases:
object
- id: int¶
- times: list[float]¶
- dists: list[float]¶
Create or replace a named POSIX shared-memory segment.
This helper guarantees that a shared-memory block with the given name exists with the requested size. If a stale block exists (e.g., from an earlier crashed run), it is closed and unlinked before creating a fresh one.
- Parameters:
name – Symbolic name of the shared memory region (e.g.,
"emu_frame_0"
).size – Size in bytes to allocate for the region.
- Returns:
An opened handle to the new shared-memory block. The caller owns the handle and is responsible for closing it (and unlinking at teardown time).
- Return type:
multiprocessing.shared_memory.SharedMemory
- Raises:
ValueError – If size <= 0.
OSError – If the OS cannot allocate or map the segment.
- Side Effects:
May unlink an existing segment of the same name.
Creates a new segment in the system shared-memory namespace.
- src.core.train.initialize_emulator() desmume.emulator.DeSmuME [source]¶
Initialize and prime a DeSmuME emulator instance.
Loads the MKDS ROM, restores a savestate (slot 3), mutes audio, and cycles once to ensure memory is initialized. Then spins until the emulator reports it is running.
- Returns:
A ready-to-use emulator instance positioned at the savestate.
- Return type:
DeSmuME
Notes
This function blocks until
emu.is_running()
returns True.The ROM path
"mariokart_ds.nds"
and savestate index are hard-coded.
- src.core.train.initialize_window(emu, config: EmulatorProcessConfig, batch_config: EmulatorBatchConfig) SharedEmulatorWindow | None [source]¶
Create and initialize the tiled GTK window for live visualization.
Computes a near-square grid (
n_rows
×n_cols
) based on the number of display-enabled processes, instantiates a renderer bound to emu, and returns aSharedEmulatorWindow
configured to read from the provided shared-memory frame names.- Parameters:
emu – Active
DeSmuME
instance (used to build the renderer).display_count – Number of display-enabled workers to tile.
shm_names – List of shared-memory segment names (
"emu_frame_{id}"
).
- Returns:
GTK window object ready to be shown.
- Return type:
SharedEmulatorWindow
- Side Effects:
Initializes a GTK/Cairo renderer via
AbstractRenderer
.
- src.core.train.initialize_overlays(config: EmulatorProcessConfig, batch_config: EmulatorBatchConfig) Queue | None [source]¶
Start a background overlay thread and return its work queue.
Given a list of overlay IDs, looks them up in
AVAILABLE_OVERLAYS
, starts a single daemon thread that consumesDeSmuME
instances from a queue and applies the overlays. The queue is returned to the caller to submit per-frame overlay requests.- Parameters:
overlay_ids – List of overlay identifiers to enable (indexes into
AVAILABLE_OVERLAYS
).device – Torch device on which overlay computations (if any) should run.
- Returns:
If overlay_ids is non-empty, a
Queue
into which the caller shouldput(emu)
once per frame, andput(None)
on shutdown. ReturnsNone
when overlay_ids is empty.- Return type:
Queue | None
Notes
The overlay worker catches exceptions per overlay and propagates a summarized error message on failure via
safe_thread()
.Overlays are executed off the emulation thread to avoid jitter.
- src.core.train.handle_controls(emu: desmume.emulator.DeSmuME, logits: torch.Tensor)[source]¶
Apply model outputs to emulator controls with a simple threshold policy.
All values
>= 0.5
are considered pressed for the correspondingMODEL_KEY_MAP
entry. Additionally, when any action is pressed, the accelerator (mapped toMODEL_KEY_MAP[5]
) is also pressed to keep the kart moving.- Parameters:
emu – Active emulator instance whose keypad state will be updated.
logits – 1D tensor of action activations aligned with
MODEL_KEY_MAP
.
- Side Effects:
Calls
emu.input.keypad_update(0)
andemu.input.keypad_add_key(...)
multiple times.
- src.core.train.initialize_model(emu: desmume.emulator.DeSmuME, config: EmulatorProcessConfig, batch_config: EmulatorBatchConfig)[source]¶
- src.core.train.safe_thread(func, proc_id, thread_id=0)[source]¶
Wrap a function for background execution with nicer error reporting.
The returned wrapper calls func(*args, **kwargs) and converts any exception into a concise message identifying the logical process and thread of origin.
- Parameters:
func – Callable to wrap.
proc_id – Integer process identifier for error messages.
thread_id – Integer thread identifier for error messages.
- Returns:
A new callable with identical signature that raises a concise
Exception
on failure.- Return type:
Callable
- src.core.train.send_window_end_signal(id)[source]¶
Zero a per-process frame buffer to signal the host window to exit.
- Parameters:
id – Process index whose frame buffer should be cleared.
- Side Effects:
Writes zeros into the shared frame
emu_frame_{id}
, which is used by the host window’s polling logic to detect end-of-batch.
- src.core.train.get_forward_func(emu: desmume.emulator.DeSmuME, model: EvolvedNet, device)[source]¶
Build a closure that performs one model step and checkpoint bookkeeping.
The returned callable reads emulator memory for sensor inputs, constructs the model input vector, computes the control logits, and records per-checkpoint split times and distances. When a terminal condition is reached (e.g. clock > 10000), it returns the accumulated stats dict instead of logits.
- Parameters:
emu – Active emulator instance to read game state from.
model – Evolved network to evaluate (expects 6 inputs → action logits).
device – Torch device on which tensors are constructed and the model runs.
- Returns:
A no-argument function that returns either a 1D tensor of action logits or a stats dictionary signaling the end of this individual’s run.
- Return type:
Callable[[], torch.Tensor | dict[int, list[tuple[float, float]]]]
- Sensor model:
Distances: forward/left/right obstacle distances are read and mapped through
tanh(1 - d / s1)
withs1 = 60.0
to compress range.Angles: direction to the next checkpoint as (cos θ, sin θ, -sin θ).
Notes
Checkpoint bookkeeping appends tuples of
(delta_time, distance_at_split)
.This function reads directly from emulator memory via utility helpers.
- src.core.train.run_training_batch(batch_population: list[Genome], show_samples: list[bool], training_stats: DictProxy[int, dict[int, CheckpointRecord]], training_stats_lock, batch_config: EmulatorBatchConfig)[source]¶
Evaluate a batch of genomes concurrently, optionally with live display.
One process in the batch is promoted to the display host (first True in show_samples) and creates a tiled GTK window that reads from the per-process shared-memory frames listed in shm_names. Additional True entries run as display workers; False entries run headless.
- Parameters:
batch_pop – Slice of the population to evaluate in this batch.
show_samples – Per-individual flags controlling display mode; exactly one True is chosen as the host (the first True), additional Trues are workers; all False means fully headless batch.
overlay_ids – Enabled overlay identifiers.
lock – IPC lock for synchronized writes to pop_stats.
pop_stats – Manager dict where each process writes a stats dict under its local batch index.
- Side Effects:
Creates per-process shared-memory frame segments for display-enabled individuals.
Spawns processes with appropriate targets and joins them.
- src.core.train.run_training_session(pop: list[Genome], num_proc: int | None = None, show_samples: list[bool] = [True], overlay_ids: list[int] = [], device: DeviceLikeType | None = None) dict[int, list[CheckpointRecord]] [source]¶
Evaluate the full population in parallel batches and collect statistics.
The population is partitioned into batches of size
min(num_proc, remaining)
. Each batch is launched viarun_training_batch()
, returning when all processes in the batch complete and their stats have been merged.- Parameters:
pop – Full population of genomes to evaluate.
num_proc – Maximum number of concurrent processes. If
None
, usesos.cpu_count() - 1
.show_samples – List of booleans determining which individuals in each batch should display; a one-element list (e.g.,
[False]
) is broadcast to the batch size on each iteration.overlay_ids – Overlay identifiers to enable in display-enabled processes.
- Returns:
Mapping of global population index to that individual’s checkpoint stats dict.
- Return type:
dict[int, dict[int, list[tuple[float, float]]]]
Notes
This function uses a
multiprocessing.Manager
dict
so that per-process stats can be retrieved without explicit pipes or queues.Shared-memory frame buffers are currently not unlinked here; consider cleaning them in a higher-level teardown if needed.
- src.core.train.fitness(pop_stats: dict[int, list[CheckpointRecord]]) list[float] [source]¶
Compute scalar fitness from per-checkpoint stats.
The current fitness heuristic sums the recorded distances across all checkpoint splits for each individual.
- Parameters:
pop_stats – Mapping of population index to that individual’s stats dict (
checkpoint_id -> [(delta_time, distance_at_split), ...]
).- Returns:
Fitness values ordered by population index.
- Return type:
list[float]
- src.core.train.train(num_iters: int, pop_size: int, log_interval: int = 1, top_k: int | float = 0.1, device: DeviceLikeType | None = None, **simulation_kwargs)[source]¶
Main evolutionary training loop (selection + mutation).
- Repeats:
Evaluate the current population via
run_training_session()
.Rank by fitness.
Keep the best, and refill the population by mutating uniformly sampled parents from the top-k set.
- Parameters:
num_iters – Number of generations to run.
pop_size – Number of individuals per generation.
log_interval – Print progress every N generations.
top_k – Either the number of top individuals to sample parents from, or a fraction in
(0, 1]
interpreted as a proportion of the population.**simulation_kwargs – Passed through to
run_training_session()
.
- Side Effects:
Prints best fitness per log_interval.
Mutates and replaces the population in place each generation.