Source code for src.core.memory

"""
Mario Kart DS (MKDS) Emulator I/O & Geometry Utilities
======================================================

This module provides a high-level, vectorized interface for reading game state
from a running DeSmuME emulator and performing common geometric operations
used in visualization, control, and RL policy features for Mario Kart DS.

It wraps low-level memory reads (positions, directions, camera data, objects,
checkpoints, clock) and exposes a compact API for tasks like projecting
world-space points to screen-space, computing distances to checkpoints and
obstacles, and deriving view matrices.

The implementation favors:
  * **Deterministic caching** at frame and game lifetimes to minimize emulator
    I/O, and
  * **Torch tensor** computations (CPU / CUDA / MPS) for fast, batched math.

-------------------------------------------------------------------------------
Quick Start
-------------------------------------------------------------------------------

>>> from desmume.emulator import DeSmuME
>>> import torch
>>> from your_module import (
...     read_position, read_direction, project_to_screen, read_forward_distance_checkpoint
... )
>>> emu = DeSmuME()
>>> # ... open ROM, load state, etc.
>>> device = torch.device("cpu")  # or "cuda", "mps"
>>> pos = read_position(emu, device=device)               # (3,)
>>> dir = read_direction(emu, device=device)              # (3,)
>>> screen = project_to_screen(emu, pos.unsqueeze(0), device=device)  # (1, 4)
>>> fwd_to_cp = read_forward_distance_checkpoint(emu, device=device)  # scalar tensor

-------------------------------------------------------------------------------
Key Concepts & Conventions
-------------------------------------------------------------------------------

Coordinate Systems
------------------
* **World Space**: Right-handed, with **Y as up**. Many functions assume a
  canonical "up" vector of (0, 1, 0) and construct a right-handed orthonormal
  basis `[right, up, forward]`.
* **Camera Space / Clip Space / NDC / Screen Space**: `_compute_model_view`
  builds a model-view matrix from camera position/target; `_project_to_screen`
  applies perspective and viewport transforms to return pixel coordinates for a
  256×192 screen (Nintendo DS top display).
* **Screen Origin**: (0, 0) is **top-left**. X grows to the right; Y grows
  downward. This follows the standard raster convention and matches the
  `(1 - ndc_y)` transform used in projection.

Units
-----
* **Positions & Scalars** returned from memory are derived from MKDS fixed-point
  formats (FX32, etc.) via helpers like `read_vector_3d_fx32` and `read_fx32`,
  and are exposed as Python floats / Torch tensors.
* **Angles**:
  - Camera FOV is read from a 16-bit fixed-point angle and converted to radians
    using: ``value * (2π / 0x10000)``.
* **Time**:
  - `read_clock()` returns **centiseconds** (10 ms units).

Memory Map & Assets
-------------------
* **Addresses**: The module uses static addresses for key pointers (racer,
  course, objects, checkpoints, camera, clock). See constants:
  `RACER_PTR_ADDR`, `COURSE_ID_ADDR`, `OBJECTS_PTR_ADDR`, `CHECKPOINT_PTR_ADDR`,
  `CLOCK_DATA_PTR`, `CAMERA_PTR_ADDR`.
* **Course Files**:
  - `courses.json` maps course IDs to directory names.
  - KCL (`course_collision.kcl`) and NKM (`course_map.nkm`) are loaded via
    `KCLTensor.from_file(...)` and `NKMTensor.from_file(...)`.

Caching Model
-------------
Two decorators reduce emulator I/O:

* ``@frame_cache`` — Caches the function's **single** return value per
  emulator tick (`emu.get_ticks()`). Recomputes only when the tick changes.

* ``@game_cache`` — Caches the function's **single** return value for the
  process lifetime (until interpreter exit).

⚠ **Important**: Both caches **ignore argument values**. If you call a cached
function with different arguments within the same lifetime (same frame or same
run), the **first computed result is reused**. In practice, pass stable
arguments (e.g., a constant `device`) to avoid surprises.

Device Handling
---------------
Many functions accept a Torch `device` and return tensors allocated there. For
best performance, use `cuda` (GPU) or `mps` (Apple Silicon) when available, and
keep devices **consistent** across the call sites, especially for `@game_cache`
results (KCL/NKM tensors are created on the device used at first call).

-------------------------------------------------------------------------------
Public API Overview
-------------------------------------------------------------------------------

Clock & Course
--------------
* `read_clock_ptr(emu)` — Base pointer to clock data (cached for game lifetime).
* `read_clock(emu)` — Current clock in 10 ms units (cached per frame).
* `get_current_course_id(emu)` — Current course ID (byte).
* `get_course_path(id)` — Course directory name from `courses.json`.
* `load_current_kcl(emu, device)` — Parsed KCL collision mesh (game-cached).
* `load_current_nkm(emu, device)` — Parsed NKM map data (game-cached).

Player & Objects
----------------
* `read_racer_ptr(emu)` — Pointer to the player racer struct.
* `read_position(emu, device)` — Player world position `(3,)`.
* `read_direction(emu, device)` — Player forward direction `(3,)`.
* `read_objects(...)`, `read_object_*` helpers — Scans and queries object table.
  - `safe_object` decorator returns `None` for deleted objects.

Camera & Projection
-------------------
* `read_camera_ptr(emu)` — Pointer to camera struct.
* `read_camera_fov(emu)` — FOV in radians.
* `read_camera_aspect(emu)` — Aspect ratio (W/H).
* `read_camera_position(emu, device)` — Camera world pos `(3,)` with elevation.
* `read_camera_target_position(emu, device)` — Camera look-at `(3,)`.
* `read_model_view(emu, device)` — 4×4 model-view matrix.
* `project_to_screen(emu, points, device)` — Projects `(N,3)` to `(N,4)`:
  `[x_px, y_px, clip_z, normalized_depth]`.
* `z_clip_mask(x)` — Mask for points within Z-near/far bounds (camera space).

Checkpoints
~~~~~~~~~~~
* `read_checkpoint_ptr(emu)` — Pointer to checkpoint manager.
* `read_current_checkpoint(emu)`, `read_current_key_checkpoint(emu)`,
  `read_current_lap(emu)` — Indices for current progress.
* `read_ghost_checkpoint(emu)`, `read_ghost_key_checkpoint(emu)` — Ghost state.
* `read_checkpoint_positions(emu, device)` — `(C, 2, 3)` segment endpoints.
* `read_next_checkpoint(emu, checkpoint_count)` — Next index (wraps).
* `read_next_checkpoint_position(emu, device)`,
  `read_current_checkpoint_position(emu, device)` — `(2,3)` endpoints.
* `read_facing_point_checkpoint(emu, direction, device)` — Intersection of a
  ray (from player, given direction) with next checkpoint line in XZ.
* `read_forward_distance_checkpoint(emu, device)`,
  `read_left_distance_checkpoint(emu, device)`,
  `read_direction_to_checkpoint(emu, device)` — Distances/steering angle.

Obstacles (Walls / Offroad)
~~~~~~~~~~~~~~~~~~~~~~~~~~~
* `read_facing_point_obstacle(emu, position, direction, device)` — Samples a
  cone of rays around the forward direction to find the **nearest** hit against
  wall/offroad triangles. Returns a point or `None`.
* `read_forward_distance_obstacle(emu, device)`,
  `read_left_distance_obstacle(emu, device)`,
  `read_right_distance_obstacle(emu, device)` — Scalar distances to nearest
  obstacles along canonical forward/left/right rays. Return `+inf` when no hit.

-------------------------------------------------------------------------------
Return Types & Shapes
-------------------------------------------------------------------------------

* Positions / Directions: `torch.Tensor` with shape `(3,)`.
* Batches of points: `(N, 3)`.
* Screen Projection: `(N, 4)` → `[x_px, y_px, clip_z, normalized_depth]`.
* Checkpoints: `(C, 2, 3)` → per checkpoint two endpoints `[p1, p2]`.
* Distances / Angles: 0-D or 1-D scalar `torch.Tensor` (depending on operation).

-------------------------------------------------------------------------------
Errors & Edge Cases
-------------------------------------------------------------------------------

* **Deleted / Ignored Objects**: `safe_object`-wrapped functions return `None`
  when the object is deleted; callers must handle `None`.
* **No Geometry**: When there are no wall/offroad triangles or raycasts miss,
  obstacle distance functions return `+inf` (as a tensor).
* **Empty Projections**: `_project_to_screen` returns an empty tensor when
  given no points; invalid (behind-camera) points may still project with
  negative `clip_w`.

-------------------------------------------------------------------------------
Performance Notes
-------------------------------------------------------------------------------

* Caching eliminates redundant memory reads across a frame / game run.
* Geometry routines (ray casting, distances, projection) are vectorized in
  Torch; prefer GPU/MPS devices when available.
* Keep devices consistent across calls that share cached state (e.g., KCL/NKM).

-------------------------------------------------------------------------------
Examples
-------------------------------------------------------------------------------

Project player and next checkpoint endpoints to screen:

>>> pts = torch.vstack([read_position(emu, device),    # (1,3)
...                     read_next_checkpoint_position(emu, device)]).reshape(-1, 3)
>>> screen_pts = project_to_screen(emu, pts, device)
>>> screen_pts[:, :2]  # pixel coordinates

Compute lateral vs forward distance to next checkpoint:

>>> d_left  = read_left_distance_checkpoint(emu, device)
>>> d_front = read_forward_distance_checkpoint(emu, device)

Find nearest obstacle straight ahead:

>>> d_obs = read_forward_distance_obstacle(emu, device)
>>> float(d_obs) if torch.isfinite(d_obs) else float("inf")

-------------------------------------------------------------------------------
Implementation Notes
-------------------------------------------------------------------------------

* `_compute_orthonormal_basis` builds a right-handed frame from a forward
  vector and an up-like reference (default `(0,1,0)`), normalizing each axis.
* `_compute_model_view` constructs a 4×4 model-view matrix in row-major with
  basis rows `[right, up, forward]` and a translated origin.
* `_project_to_screen` creates a simple perspective matrix using vertical FOV
  and aspect ratio; returns pixel coordinates using constants
  `SCREEN_WIDTH = 256` and `SCREEN_HEIGHT = 192`.

-------------------------------------------------------------------------------
Compatibility
-------------------------------------------------------------------------------

* Tested with DeSmuME Python bindings and Torch. Some ops may vary by backend
  (e.g., MPS lacks a few linear algebra kernels); this module sticks to widely
  supported APIs.

"""
from __future__ import annotations
import sys, os
from desmume.emulator import SCREEN_WIDTH, DeSmuME
from mkds.kcl import read_fx32
from src.mkds_extensions.kcl_torch import KCLTensor
from src.mkds_extensions.nkm_torch import NKMTensor
from mkds.utils import read_vector_3d_fx32
import torch
import json
import math
from src.utils.vector import (
    pairwise_distances_cross,
    intersect_ray_line_2d,
    triangle_raycast_batch,
    sample_cone,
    triangle_altitude,
)
from typing import Callable, Concatenate, TypeVar, ParamSpec
from functools import wraps

P = ParamSpec("P")
R = TypeVar("R")

SCREEN_WIDTH, SCREEN_HEIGHT = 256, 192
Z_FAR = 1000.0
Z_NEAR = 0.0
Z_SCALE = 10.0

RACER_PTR_ADDR = 0x0217ACF8
COURSE_ID_ADDR = 0x23CDCD8
OBJECTS_PTR_ADDR = 0x0217B588
CHECKPOINT_PTR_ADDR = 0x021755FC
CLOCK_DATA_PTR = 0x0217AA34
CAMERA_PTR_ADDR = 0x0217AA4C

# Object flags
FLAG_DYNAMIC = 0x1000
FLAG_MAPOBJ = 0x2000
FLAG_ITEM = 0x4000
FLAG_RACER = 0x8000

[docs] def frame_cache( func: Callable[Concatenate[DeSmuME, P], R], ) -> Callable[Concatenate[DeSmuME, P], R]: """Decorator that caches a function's return value once per emulator tick. The wrapped function will only be re-executed when `emu.get_ticks()` changes. Useful for expensive reads that don't change within a single frame. Args: func: A function whose first argument is a `DeSmuME` instance. Returns: A wrapped function with identical signature that returns a cached result per tick. """ val = None frame_count = 0 @wraps(func) def wrapper(emu: DeSmuME, *args: P.args, **kwargs: P.kwargs) -> R: wrapper.__doc__ = func.__doc__ nonlocal frame_count, val if emu.get_ticks() != frame_count or val is None: frame_count = emu.get_ticks() val = func(emu, *args, **kwargs) return val return wrapper
[docs] def game_cache( func: Callable[Concatenate[DeSmuME, P], R], ) -> Callable[Concatenate[DeSmuME, P], R]: """Decorator that caches a function's return value for the process lifetime. The wrapped function executes once and its result is reused thereafter. Appropriate for data that remains constant across a run (e.g., course files). Args: func: A function whose first argument is a `DeSmuME` instance. Returns: A wrapped function with identical signature that returns a cached result. """ val = None @wraps(func) def wrapper(emu: DeSmuME, *args: P.args, **kwargs: P.kwargs) -> R: wrapper.__doc__ = func.__doc__ nonlocal val if val is None: val = func(emu, *args, **kwargs) return val return wrapper
[docs] def z_clip_mask(x: torch.Tensor) -> torch.Tensor: """Compute a boolean mask for points within the view frustum Z range. Args: x: Tensor of shape (N, 3+) where x[:, 2] is the camera-space Z. Returns: A boolean tensor of shape (N,) where True indicates Z is between -Z_FAR and -Z_NEAR. """ return (x[:, 2] < -Z_NEAR) & (x[:, 2] > -Z_FAR)
[docs] @game_cache def read_clock_ptr(emu: DeSmuME): """Read the base pointer to the game's clock data structure. Args: emu: Emulator instance. Returns: Integer address of the clock data struct. """ return emu.memory.unsigned.read_long(CLOCK_DATA_PTR)
[docs] @frame_cache def read_clock(emu: DeSmuME): """Read the current game clock value. The value is read from the clock data structure and multiplied by 10, resulting in units of 10 ms (centiseconds). Args: emu: Emulator instance. Returns: Integer time in 10 ms units. """ addr = read_clock_ptr(emu) return emu.memory.signed.read_long(addr + 0x08) * 10
[docs] def get_current_course_id(emu: DeSmuME): """Read the current course ID from memory. Args: emu: Emulator instance. Returns: Integer course ID (byte). """ return emu.memory.unsigned.read_byte(COURSE_ID_ADDR)
[docs] def get_course_path(id: int, lookup_path: str = "./src/misc/courses.json"): """ Resolve a course ID to the local filesystem path for its assets. Args: id: Course ID. Returns: String path relative to ./courses/ for the given course. Raises: AssertionError: If the course ID is not present in the lookup table. """ course_id_lookup = None with open(lookup_path, "r") as f: course_id_lookup = json.load(f) assert course_id_lookup is not None assert str(id) in course_id_lookup return course_id_lookup[str(id)]
[docs] @game_cache def load_current_kcl(emu: DeSmuME, device): """Load and parse the KCL collision file for the current course. Cached for the lifetime of the process. Args: emu: Emulator instance. device: Torch device (e.g., 'cpu', 'cuda', 'mps') to store tensors on. Returns: `KCLTensor` with triangle and prism data on the specified device. """ assert device is not None id = get_current_course_id(emu) path = get_course_path(id) path = f"./courses/{path}/course_collision.kcl" kcl = KCLTensor.from_file(path, device=device) return kcl
[docs] @game_cache def load_current_nkm(emu: DeSmuME, device): """Load and parse the NKM map file for the current course. Cached for the lifetime of the process. Args: emu: Emulator instance. device: Torch device to store tensors on. Returns: `NKMTensor` with NKM section tensors (e.g., checkpoints) on the specified device. """ id = get_current_course_id(emu) path = get_course_path(id) path = f"./courses/{path}/course_map.nkm" nkm = NKMTensor.from_file(path, device=device) return nkm
[docs] def read_racer_ptr(emu: DeSmuME, addr: int = RACER_PTR_ADDR): """Read the pointer to the player's racer object. Args: emu: Emulator instance. addr: Memory address where the racer pointer is stored. Returns: Integer address of the racer structure. """ return emu.memory.unsigned.read_long(addr)
[docs] @frame_cache def read_position(emu: DeSmuME, device): """Read the player's world-space position. Args: emu: Emulator instance. device: Torch device for the returned tensor. Returns: torch.Tensor of shape (3,) representing (x, y, z) in world units. """ data = emu.memory.unsigned addr = read_racer_ptr(emu) pos = read_vector_3d_fx32(data, addr + 0x80) return torch.tensor(pos, dtype=torch.float32, device=device)
[docs] @frame_cache def read_direction(emu: DeSmuME, device): """Read the player's forward direction vector (world-space). Args: emu: Emulator instance. device: Torch device for the returned tensor. Returns: torch.Tensor of shape (3,) representing the forward direction. """ data = emu.memory.unsigned addr = read_racer_ptr(emu) pos = read_vector_3d_fx32(data, addr + 0x68) return torch.tensor(pos, dtype=torch.float32, device=device)
[docs] def read_objects_array_max_count(emu: DeSmuME, addr: int = OBJECTS_PTR_ADDR): """Read the maximum number of objects in the global object array. Args: emu: Emulator instance. addr: Base address of the object array metadata. Returns: Signed integer max count. """ return emu.memory.signed.read_long(addr + 0x08)
[docs] def read_objects_array_ptr(emu: DeSmuME, addr: int = OBJECTS_PTR_ADDR): """Read the pointer to the global object pointer array. Args: emu: Emulator instance. addr: Base address of the object array metadata. Returns: Signed integer address of the object pointer array. """ return emu.memory.signed.read_long(addr + 0x10)
[docs] def read_object_offset(emu: DeSmuME, id: int): """Compute the memory offset of an object entry within the array. Args: emu: Emulator instance. id: Object index. Returns: Integer byte offset to the object's metadata entry. """ return read_objects_array_ptr(emu) + id * 0x1C
[docs] def read_object_ptr(emu: DeSmuME, id: int): """Read the object instance pointer for a given object ID. Args: emu: Emulator instance. id: Object index. Returns: Integer address of the object struct (0 if null). """ offset = read_object_offset(emu, id) return emu.memory.unsigned.read_long(offset + 0x18)
[docs] def read_object_flags(emu: DeSmuME, id: int): """Read the object's flags (type/category bits, state, etc.). Args: emu: Emulator instance. id: Object index. Returns: Unsigned short flags value. """ offset = read_object_offset(emu, id) return emu.memory.unsigned.read_short(offset + 0x14)
[docs] def read_object_position_ptr(emu: DeSmuME, id: int): """Read the pointer to an object's position vector in memory. Args: emu: Emulator instance. id: Object index. Returns: Integer address for the object's position struct (0 if deleted). """ offset = read_object_offset(emu, id) return emu.memory.unsigned.read_long(offset + 0x0C)
[docs] def read_object_is_ignored(emu: DeSmuME, id: int): """Determine if an object should be ignored (null or ignored-flag set). Args: emu: Emulator instance. id: Object index. Returns: True if object ptr is 0 or ignored bit is set; False otherwise. """ obj_ptr = read_object_ptr(emu, id) flags = read_object_flags(emu, id) return obj_ptr == 0 or flags & 0x200
[docs] def read_object_is_deleted(emu: DeSmuME, id: int): """Check if the object has been deleted (position pointer is null). Args: emu: Emulator instance. id: Object index. Returns: True if deleted; False otherwise. """ pos_ptr = read_object_position_ptr(emu, id) return pos_ptr == 0
[docs] def safe_object(func): """Decorator that skips object reads when the object appears deleted. The wrapped function receives `(emu, id, *args, **kwargs)`. If the object is deleted (null position pointer), the wrapper returns `None`. """ def wrapper(emu: DeSmuME, id: int, *args, **kwargs): """Internal wrapper used by `safe_object` to guard deleted objects.""" if read_object_is_deleted(emu, id): return None return func(emu, id, *args, **kwargs) return wrapper
[docs] @frame_cache @safe_object def read_object_position(emu: DeSmuME, id: int, device): """Read an object's world-space position. Args: emu: Emulator instance. id: Object index. device: Torch device for the returned tensor. Returns: torch.Tensor of shape (3,) in world coordinates, or None if deleted. """ pos_ptr = read_object_position_ptr(emu, id) pos = read_vector_3d_fx32(emu.memory.unsigned, pos_ptr) return torch.tensor(pos, device=device)
[docs] @frame_cache @safe_object def read_map_object_type_id(emu: DeSmuME, id: int): """Read a map object's type ID (e.g., coin, tree, etc.). Args: emu: Emulator instance. id: Object index. Returns: Signed short type ID, or None if object is deleted. """ obj_ptr = read_object_ptr(emu, id) return emu.memory.signed.read_short(obj_ptr)
[docs] @frame_cache @safe_object def read_map_object_is_coin_collected(emu: DeSmuME, id: int): """Check if a coin-type map object has been collected. Args: emu: Emulator instance. id: Object index. Returns: True if collected; False otherwise; or None if object is deleted. """ obj_ptr = read_object_ptr(emu, id) return emu.memory.unsigned.read_short(obj_ptr + 0x02) & 0x01 != 0
[docs] @frame_cache @safe_object def read_racer_object_is_ghost(emu: DeSmuME, id: int): """Check if a racer object is currently in ghost state. Args: emu: Emulator instance. id: Object index. Returns: True if ghosted; False otherwise; or None if object is deleted. """ obj_ptr = read_object_ptr(emu, id) ghost_flag = emu.memory.unsigned.read_byte(obj_ptr + 0x7C) return ghost_flag & 0x04 != 0
[docs] @frame_cache def read_objects(emu: DeSmuME): """Scan the global object table and group object indices by category. Categories: - 'map_objects' - 'racer_objects' - 'item_objects' - 'dynamic_objects' Returns: Dict[str, list[int]] mapping category name to list of indices. """ obj_ids: dict[str, list[int]] = { "map_objects": [], "racer_objects": [], "item_objects": [], "dynamic_objects": [], } max_count = read_objects_array_max_count(emu) count = 0 idx = 0 while idx < 255 and count != max_count: if read_object_is_deleted(emu, idx): continue else: count += 1 if read_object_is_ignored(emu, idx): continue flags = read_object_flags(emu, idx) if flags & FLAG_MAPOBJ != 0: obj_ids["map_objects"].append(idx) elif flags & FLAG_RACER != 0: obj_ids["racer_objects"].append(idx) elif flags & FLAG_ITEM != 0: obj_ids["item_objects"].append(idx) elif flags & FLAG_DYNAMIC == 0: obj_ids["dynamic_objects"].append(idx) idx += 1 return obj_ids
[docs] @frame_cache def read_camera_ptr(emu: DeSmuME, addr: int = CAMERA_PTR_ADDR): """Read the pointer to the active camera structure. Args: emu: Emulator instance. addr: Address where the camera pointer is stored. Returns: Integer address of the camera struct. """ return emu.memory.unsigned.read_long(addr)
[docs] @frame_cache def read_camera_fov(emu: DeSmuME): """Read the current camera field-of-view (radians). The FOV value is stored as a 16-bit fixed-point angle; it is converted to radians. Args: emu: Emulator instance. Returns: Floating-point FOV in radians. """ addr = read_camera_ptr(emu) return emu.memory.unsigned.read_short(addr + 0x60) * (2 * math.pi / 0x10000)
[docs] @frame_cache def read_camera_aspect(emu: DeSmuME): """Read the camera aspect ratio from memory. Args: emu: Emulator instance. Returns: Float aspect ratio (width/height). """ addr = read_camera_ptr(emu) return read_fx32(emu.memory.unsigned, addr + 0x6C)
[docs] @frame_cache def read_camera_position(emu: DeSmuME, device): """Read the camera world position, including elevation offset. Args: emu: Emulator instance. device: Torch device for the returned tensor. Returns: torch.Tensor shape (3,) representing camera (x, y, z). """ addr = read_camera_ptr(emu) pos = read_vector_3d_fx32(emu.memory.unsigned, addr + 0x24) elevation = read_fx32(emu.memory.unsigned, addr + 0x178) pos = (pos[0], pos[1] + elevation, pos[2]) return torch.tensor(pos, device=device)
[docs] def read_camera_target_position(emu: DeSmuME, device): """Read the camera's target/look-at position in world space. Args: emu: Emulator instance. device: Torch device for the returned tensor. Returns: torch.Tensor shape (3,) target (x, y, z). """ addr = read_camera_ptr(emu) pos = read_vector_3d_fx32(emu.memory.unsigned, addr + 0x18) return torch.tensor(pos, device=device)
def _compute_orthonormal_basis( forward_vector_3d: torch.Tensor, reference_vector_3d: torch.Tensor | None = None, device=None, ): """Compute a right-handed orthonormal basis given a forward vector. Args: forward_vector_3d: Tensor shape (3,) forward direction. reference_vector_3d: Optional up-like reference; defaults to (0,1,0). device: Unused (kept for signature parity). Returns: torch.Tensor shape (3,3) with rows [right, up, forward]. """ if reference_vector_3d is None: reference_vector_3d = torch.tensor( [0.0, 1.0, 0.0], dtype=forward_vector_3d.dtype, device=forward_vector_3d.device, ) right_vector_3d = torch.cross(forward_vector_3d, reference_vector_3d, dim=0) right_vector_3d /= right_vector_3d.norm() up_vector_3d = torch.cross(right_vector_3d, forward_vector_3d, dim=0) up_vector_3d /= up_vector_3d.norm() basis = torch.stack( [ right_vector_3d, up_vector_3d, forward_vector_3d, ], dim=0, ) return basis def _compute_model_view( camera_pos: torch.Tensor, camera_target_pos: torch.Tensor, device ): """Build a 4x4 model-view matrix from camera position and target. Args: camera_pos: Tensor shape (3,) camera world position. camera_target_pos: Tensor shape (3,) target look-at position. device: Torch device for the returned matrix. Returns: torch.Tensor shape (4,4) model-view matrix. """ forward = camera_target_pos - camera_pos forward /= torch.norm(forward, dim=-1) rot = _compute_orthonormal_basis(forward, device=device) pos_proj = rot @ camera_pos.unsqueeze(-2).transpose(-1, -2) model_view = torch.eye(4, dtype=rot.dtype, device=device) model_view[:3, :3] = rot model_view[:3, 3] = -pos_proj.squeeze(-1) return model_view
[docs] @frame_cache def read_model_view(emu: DeSmuME, device): """Compute and cache the camera model-view matrix for the current frame. Args: emu: Emulator instance. device: Torch device for returned matrix. Returns: torch.Tensor shape (4,4) model-view matrix. """ camera_pos = read_camera_position(emu, device=device) camera_target_pos = read_camera_target_position(emu, device=device) return _compute_model_view(camera_pos, camera_target_pos, device=device)
def _project_to_screen(world_points, model_view, fov, aspect, screen_dim: tuple[int, int], device=None): """Project world-space points to screen coordinates using perspective projection. Args: world_points: Tensor shape (N,3) of world-space points. model_view: Tensor shape (4,4) model-view matrix. fov: Field-of-view in radians (vertical half-angle usage within projection). aspect: Aspect ratio (width/height). device: Torch device for intermediate/return tensors. Returns: Tensor shape (N,4): [x_px, y_px, clip_z, normalized_depth], where x/y are in pixel coordinates for a SCREEN_WIDTH x SCREEN_HEIGHT viewport. """ N = world_points.shape[0] # Homogenize points ones = torch.ones((N, 1), device=device) world_points = torch.cat([world_points, ones], dim=-1) cam_space = (model_view @ world_points.T).T # Perspective projection f = torch.tan(torch.tensor(fov, device=device) / 2) if cam_space.shape[0] == 0: return torch.empty((0, 2), device=device) fov_h = math.tan(fov) fov_w = math.tan(fov) * aspect projection_matrix = torch.zeros((4, 4), device=device) projection_matrix[0, 0] = 1 / fov_w projection_matrix[1, 1] = 1 / fov_h projection_matrix[2, 2] = (Z_FAR + Z_NEAR) / (Z_NEAR - Z_FAR) projection_matrix[2, 3] = -(2 * Z_FAR * Z_NEAR) / (Z_NEAR - Z_FAR) projection_matrix[3, 2] = 1 clip_space = (projection_matrix @ cam_space.T).T ndc = clip_space[:, :3] / clip_space[:, 3, None] screen_width, screen_height = screen_dim screen_x = (ndc[:, 0] + 1) / 2 * screen_width screen_y = (1 - ndc[:, 1]) / 2 * screen_height screen_depth = clip_space[:, 2] screen_depth_norm = -Z_FAR / (-Z_FAR + Z_SCALE * clip_space[:, 2]) return torch.stack([screen_x, screen_y, screen_depth, screen_depth_norm], dim=-1)
[docs] def project_to_screen(emu: DeSmuME, points: torch.Tensor, device, screen_dim=(SCREEN_WIDTH, SCREEN_HEIGHT)): """Convenience wrapper to project points using the current camera state. Args: emu: Emulator instance. points: Tensor shape (N,3) of world-space points. device: Torch device. Returns: Tensor shape (N,4) in screen space (see `_project_to_screen`). """ model_view = read_model_view(emu, device=device) fov = read_camera_fov(emu) aspect = read_camera_aspect(emu) return _project_to_screen(points, model_view, fov, aspect, screen_dim, device=device)
# CHECKPOINT INFO #
[docs] @game_cache def read_checkpoint_ptr(emu: DeSmuME, addr: int = CHECKPOINT_PTR_ADDR): """Read the pointer to the checkpoint manager/state. Args: emu: Emulator instance. addr: Address where the checkpoint pointer is stored. Returns: Integer address for checkpoint data. """ return emu.memory.unsigned.read_long(addr)
[docs] @frame_cache def read_current_checkpoint(emu: DeSmuME): """Read the index of the current checkpoint. Args: emu: Emulator instance. Returns: Unsigned byte checkpoint index. """ addr = read_checkpoint_ptr(emu) return emu.memory.unsigned.read_byte(addr + 0x46)
[docs] @frame_cache def read_current_key_checkpoint(emu: DeSmuME): """Read the current key checkpoint index (special/lap-related). Args: emu: Emulator instance. Returns: Signed byte key checkpoint index. """ addr = read_checkpoint_ptr(emu) return emu.memory.signed.read_byte(addr + 0x48)
[docs] @frame_cache def read_ghost_checkpoint(emu: DeSmuME): """Read the recorded ghost's current checkpoint index. Args: emu: Emulator instance. Returns: Signed byte ghost checkpoint index. """ addr = read_checkpoint_ptr(emu) return emu.memory.signed.read_byte(addr + 0xD2)
[docs] @frame_cache def read_ghost_key_checkpoint(emu: DeSmuME): """Read the recorded ghost's current key checkpoint index. Args: emu: Emulator instance. Returns: Signed byte ghost key checkpoint index. """ addr = read_checkpoint_ptr(emu) return emu.memory.signed.read_byte(addr + 0xD4)
[docs] @frame_cache def read_current_lap(emu: DeSmuME): """Read the current lap number. Args: emu: Emulator instance. Returns: Signed byte lap index (0-based). """ addr = read_checkpoint_ptr(emu) return emu.memory.signed.read_byte(addr + 0x38)
[docs] @frame_cache def read_next_checkpoint(emu: DeSmuME, checkpoint_count: int): """Compute the next checkpoint index (wrapping to 0 at the end). Args: emu: Emulator instance. checkpoint_count: Total number of checkpoints. Returns: Integer index of the next checkpoint. """ current_checkpoint = read_current_checkpoint(emu) next_checkpoint = current_checkpoint + 1 if next_checkpoint != checkpoint_count: return next_checkpoint else: return 0
def _convert_2d_checkpoints(P: torch.Tensor, source: torch.Tensor, dim=0): """Lift 2D checkpoint endpoints into 3D by sampling the missing dimension. Given 2D endpoints (e.g., XZ) and a set of floor points, fills in the missing axis by nearest-neighbor on that axis. Args: P: Tensor shape (N,2) endpoints (with one missing dimension). source: Tensor shape (M,3) of reference points (e.g., floor vertices). dim: Dimension index (0/1/2) to fill in. Returns: Tensor shape (N,3) with the missing axis populated. """ dim_mask = torch.range(0, source.shape[1] - 1, 1) != dim D = pairwise_distances_cross(P, source[:, dim_mask]) min_idx = D.argmin(dim=1) result = torch.ones(P.shape[0], P.shape[1] + 1, device=P.device) result[:, dim_mask] = P result[:, dim] = source[min_idx, dim] return result
[docs] @game_cache def read_checkpoint_positions(emu: DeSmuME, device): """Build a tensor of checkpoint segment endpoints in 3D. Reads NKM and KCL, extracts floor geometry, and converts checkpoint pairs from 2D to 3D using nearest floor elevation. Args: emu: Emulator instance. device: Torch device. Returns: Tensor shape (C, 2, 3) where C is number of checkpoints, containing [p1, p2] endpoints per checkpoint. """ nkm = load_current_nkm(emu, device=device) kcl = load_current_kcl(emu, device=device) floor_mask = kcl.prisms.is_floor == 1 floor_points = kcl.triangles[floor_mask] floor_points = floor_points.reshape(floor_points.shape[0] * 3, 3) return torch.stack( [ _convert_2d_checkpoints(nkm._CPOI.position1, floor_points, dim=1), _convert_2d_checkpoints(nkm._CPOI.position2, floor_points, dim=1), ], dim=1, )
[docs] @frame_cache def read_next_checkpoint_position(emu: DeSmuME, device): """Get the 3D endpoints of the next checkpoint segment. Args: emu: Emulator instance. device: Torch device. Returns: Tensor shape (2,3) representing the next checkpoint's [p1, p2]. """ nkm = load_current_nkm(emu, device=device) checkpoints = read_checkpoint_positions(emu, device) checkpoint_count = nkm._CPOI.entry_count next_checkpoint = read_next_checkpoint(emu, checkpoint_count) return checkpoints[next_checkpoint]
[docs] @frame_cache def read_current_checkpoint_position(emu: DeSmuME, device): """Get the 3D endpoints of the current checkpoint segment. Args: emu: Emulator instance. device: Torch device. Returns: Tensor shape (2,3) representing current checkpoint's [p1, p2]. """ checkpoints = read_checkpoint_positions(emu, device=device) current_checkpoint = read_current_checkpoint(emu) return checkpoints[current_checkpoint]
[docs] @frame_cache def read_facing_point_checkpoint(emu: DeSmuME, direction: torch.Tensor, device): """Raycast from the player along a direction to the next checkpoint line (XZ). Args: emu: Emulator instance. direction: Tensor shape (3,) direction vector. device: Torch device. Returns: Tensor shape (3,) point of intersection in world coordinates. """ position = read_position(emu, device=device) checkpoint = read_next_checkpoint_position(emu, device) mask_xz = torch.tensor([0, 2], dtype=torch.int32, device=device) pos_xz = position[mask_xz] dir_xz = direction[mask_xz] pxz_1, pxz_2 = checkpoint[:, mask_xz].chunk(2, dim=0) pxz_1 = pxz_1.squeeze(0) pxz_2 = pxz_2.squeeze(0) intersect, _ = intersect_ray_line_2d(pos_xz, dir_xz, pxz_1, pxz_2) intersect = torch.tensor([intersect[0], position[1], intersect[1]], device=device) return intersect
[docs] @frame_cache def read_forward_distance_checkpoint(emu, device): """Compute forward distance from player to the next checkpoint line. Args: emu: Emulator instance. device: Torch device. Returns: Scalar torch.Tensor distance. """ direction = read_direction(emu, device=device) position = read_position(emu, device=device) ray_point = read_facing_point_checkpoint(emu, direction, device=device) return torch.norm(ray_point - position)
[docs] @frame_cache def read_left_distance_checkpoint(emu, device): """Compute leftward distance from player to the next checkpoint line. Args: emu: Emulator instance. device: Torch device. Returns: Scalar torch.Tensor distance. """ direction = read_direction(emu, device=device) position = read_position(emu, device=device) up_basis = -torch.tensor([0, 1.0, 0], device=device, dtype=torch.float32) left_basis = torch.cross(direction, up_basis) ray_point = read_facing_point_checkpoint(emu, left_basis, device=device) return torch.norm(ray_point - position)
[docs] @frame_cache def read_direction_to_checkpoint(emu: DeSmuME, device): """Compute a steering angle toward the next checkpoint from forward/left distances. Angle is computed as atan(forward / left). Args: emu: Emulator instance. device: Torch device. Returns: Scalar torch.Tensor angle in radians. """ f = read_forward_distance_checkpoint(emu, device=device) l = read_left_distance_checkpoint(emu, device=device) angle = torch.atan(f / l) return angle
# OBSTACLE INFO #
[docs] @frame_cache def read_facing_point_obstacle( emu: DeSmuME, position: torch.Tensor | None = None, direction: torch.Tensor | None = None, device=None, ): """Raycast toward walls/offroad and return the nearest hit point. Samples a cone of directions around the provided (or player) direction, and finds the nearest intersection against wall and offroad triangles. Args: emu: Emulator instance. position: Optional world position (3,). Defaults to player's position. direction: Optional direction (3,). Defaults to player's forward vector. device: Torch device. Returns: torch.Tensor shape (3,) hit point, or None if no intersections. """ assert device is not None kcl = load_current_kcl(emu, device=device) triangles = kcl.triangles wall_mask = kcl.prisms.is_wall == 1 offroad_mask = ( (kcl.prisms.collision_type == 5) | (kcl.prisms.collision_type == 3) | (kcl.prisms.collision_type == 2) ) triangles = triangles[wall_mask | offroad_mask] B = triangles.shape[0] if B == 0: return None v1, v2, v3 = triangles.chunk(3, dim=1) v1 = v1.squeeze(1) v2 = v2.squeeze(1) v3 = v3.squeeze(1) if position is None: position = read_position(emu, device=device) if direction is None: direction = read_direction(emu, device=device) ray_dir = direction / torch.norm(direction, keepdim=True) angle = torch.tensor(torch.pi / 24, device=device) ray_samples = sample_cone(ray_dir, angle, 50) ray_dir = ray_dir.reshape(1, 3) ray_samples = torch.cat([ray_dir, ray_samples], dim=0) ray_origin = position ray_origin = ray_origin.unsqueeze(0) ray_origin = ray_origin.reshape(1, 3) ray_origin_samples = ray_origin.repeat(ray_samples.shape[0], 1) points = triangle_raycast_batch(ray_origin_samples, ray_samples, v1, v2, v3) N, M, C = points.shape points = points.reshape(N * M, C) if points.shape[0] == 0: return None dist = torch.cdist(points, ray_origin) min_id = torch.argmin(dist) current_point_min = points[min_id] return current_point_min
[docs] @frame_cache def read_forward_distance_obstacle(emu: DeSmuME, device) -> torch.Tensor: """Compute forward distance to the nearest wall/offroad obstacle. Args: emu: Emulator instance. device: Torch device. Returns: Scalar torch.Tensor distance; +inf if no hit. """ position = read_position(emu, device=device) ray_point = read_facing_point_obstacle(emu, device=device) if ray_point is None: return torch.tensor([float("inf")], device=device) dist = torch.sqrt(torch.sum((position - ray_point) ** 2, dim=0)) return dist
[docs] @frame_cache def read_left_distance_obstacle(emu: DeSmuME, device) -> torch.Tensor: """Compute leftward distance to the nearest wall/offroad obstacle. Args: emu: Emulator instance. device: Torch device. Returns: Scalar torch.Tensor distance; +inf if no hit. """ position = read_position(emu, device=device) direction = read_direction(emu, device=device) up_basis = -torch.tensor([0, 1.0, 0], device=device, dtype=torch.float32) left_basis = torch.cross(direction, up_basis) ray_point = read_facing_point_obstacle(emu, direction=left_basis, device=device) if ray_point is None: return torch.tensor([float("inf")], device=device) dist = torch.sqrt(torch.sum((position - ray_point) ** 2, dim=0)) return dist
[docs] @frame_cache def read_right_distance_obstacle(emu: DeSmuME, device) -> torch.Tensor: """Compute rightward distance to the nearest wall/offroad obstacle. Args: emu: Emulator instance. device: Torch device. Returns: Scalar torch.Tensor distance; +inf if no hit. """ position = read_position(emu, device=device) direction = read_direction(emu, device=device) up_basis = torch.tensor([0, 1.0, 0], device=device, dtype=torch.float32) right_basis = torch.cross(direction, up_basis) ray_point = read_facing_point_obstacle(emu, direction=right_basis, device=device) if ray_point is None: return torch.tensor([float("inf")], device=device) dist = torch.sqrt(torch.sum((position - ray_point) ** 2, dim=0)) return dist
[docs] @frame_cache def read_checkpoint_distance_altitude(emu: DeSmuME, device) -> torch.Tensor: """Compute the altitude (height) of the triangle formed by player and checkpoint endpoints. Uses the two checkpoint endpoints and the player's position to form sides a and b, then returns the triangle altitude via `triangle_altitude(a, b)`. Args: emu: Emulator instance. device: Torch device. Returns: Scalar torch.Tensor altitude value. """ next_checkpoint = read_next_checkpoint_position(emu, device=device) p1, p2 = next_checkpoint.chunk(2, dim=0) position = read_position(emu, device=device) a = torch.norm(p1 - position) b = torch.norm(p2 - position) return triangle_altitude(a, b)