from __future__ import annotations
import torch
import torch.nn.functional as F
import math
#from utils.emulator import SCREEN_WIDTH, SCREEN_HEIGHT
SCREEN_WIDTH, SCREEN_HEIGHT = 256, 192
[docs]
def get_mps_device() -> torch.device:
# Check that MPS is available
if not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
print(
"MPS not available because the current PyTorch install was not "
"built with MPS enabled."
)
else:
print(
"MPS not available because the current MacOS version is not 12.3+ "
"and/or you do not have an MPS-enabled device on this machine."
)
exit()
device = torch.device("mps")
return device
def _determinant_2d(x0, y0, x1, y1):
return x0 * y1 - y0 * x1
[docs]
def cross_product_2d(p0: torch.Tensor, p1: torch.Tensor) -> torch.Tensor:
return _determinant_2d(p0[..., 0], p0[..., 1], p1[..., 0], p1[..., 1])
[docs]
def triangle_raycast(
ray_origin: torch.Tensor,
ray_dir: torch.Tensor,
v1: torch.Tensor,
v2: torch.Tensor,
v3: torch.Tensor,
epsilon=1e-8
) -> torch.Tensor | None:
# Test if the ray intersects the triangle
edge1 = v2 - v1
edge2 = v3 - v1
ray_cross_e2 = torch.cross(ray_dir, edge2)
det = torch.dot(edge1, ray_cross_e2)
if det > -epsilon and det < epsilon:
return None
inv_det = 1.0 / det
s = ray_origin - v1
u = inv_det * torch.dot(s, ray_cross_e2)
if (u < 0 and torch.abs(u) > epsilon) or (u > 1 and torch.abs(u - 1) > epsilon):
return None
s_cross_e1 = torch.dot(s, edge1)
v = inv_det * torch.dot(ray_dir, s_cross_e1)
if (v < 0 and torch.abs(v) > epsilon) or (u + v > 1 and torch.abs(u + v - 1) > epsilon):
return None
# Compute where the ray intersects the triangle
t = inv_det * torch.dot(edge2, s_cross_e1)
if t > epsilon:
return ray_origin + ray_dir * t
else:
return None
import torch
[docs]
def intersect_ray_line_2d(O, D, P1, P2, eps=1e-8):
"""
Find intersection of a ray (O + tD, t>=0)
with a line segment between P1 and P2.
All inputs are torch tensors of shape (..., 2)
and can be batched.
"""
# Direction of the segment
v = P2 - P1
# 2D cross product helper (scalar)
def cross(a, b):
return a[..., 0] * b[..., 1] - a[..., 1] * b[..., 0]
# Compute determinant
denom = cross(D, v) # parallel if denom == 0
# Compute relative position
w = P1 - O
t = cross(w, v) / (denom + eps)
u = cross(w, D) / (denom + eps)
# Compute intersection point
intersection = O + t.unsqueeze(-1) * D
# Valid if:
valid = (denom.abs() > eps) & (t >= 0) & (u >= 0) & (u <= 1)
return intersection, valid
[docs]
def triangle_raycast_batch(
ray_origin: torch.Tensor, # (N, 3)
ray_dir: torch.Tensor, # (N, 3)
v1: torch.Tensor, # (M, 3)
v2: torch.Tensor, # (M, 3)
v3: torch.Tensor, # (M, 3)
epsilon=1e-8
) -> torch.Tensor:
"""
Batched ray-triangle intersection using Möller–Trumbore.
Args:
ray_origin: (N, 3) origins
ray_dir: (N, 3) directions (normalized)
v1, v2, v3: (M, 3) triangle vertices
Returns:
intersections: (N, M, 3) intersection points (NaN if no hit)
"""
# Expand to (N, M, 3)
ro = ray_origin[:, None, :] # (N,1,3)
rd = ray_dir[:, None, :] # (N,1,3)
v1 = v1[None, :, :] # (1,M,3)
v2 = v2[None, :, :]
v3 = v3[None, :, :]
edge1 = v2 - v1
edge2 = v3 - v1
h = torch.cross(rd, edge2, dim=-1) # (N,M,3)
a = torch.sum(edge1 * h, dim=-1) # (N,M)
mask_parallel = (a.abs() < epsilon)
f = 1.0 / (a + epsilon * mask_parallel.sign()) # safe div
s = ro - v1
u = f * torch.sum(s * h, dim=-1)
mask_u = (u < 0) | (u > 1)
q = torch.cross(s, edge1, dim=-1)
v = f * torch.sum(rd * q, dim=-1)
mask_v = (v < 0) | (u + v > 1)
t = f * torch.sum(edge2 * q, dim=-1)
mask_t = (t <= epsilon)
# Valid hits
valid = ~(mask_parallel | mask_u | mask_v | mask_t)
# Intersection points
pts = ro + rd * t[..., None] # (N,M,3)
# Mask invalid with NaN
pts = pts[None, valid]
return pts
[docs]
def pairwise_distances(pts):
norms = (pts**2).sum(dim=-1)
G = pts @ pts.T
D2 = norms[:, None] + norms[None, :] - 2 * G
D2 = torch.clamp(D2, min=0.0)
D = torch.sqrt(D2)
return D
[docs]
def pairwise_distances_cross(A, B):
# Compute squared norms
A_norms = (A**2).sum(dim=1).unsqueeze(1) # n x 1
B_norms = (B**2).sum(dim=1).unsqueeze(0) # 1 x m
# Compute pairwise squared distances
D_squared = A_norms + B_norms - 2 * (A @ B.T)
# Euclidean distances
D = torch.sqrt(D_squared.clamp(min=0))
return D
[docs]
def compute_orthonormal_basis(forward_vector_3d, reference_vector_3d=None, device=None):
if reference_vector_3d is None:
reference_vector_3d = torch.tensor([0., 1., 0.],
dtype=forward_vector_3d.dtype,
device=forward_vector_3d.device)
right_vector_3d = torch.cross(forward_vector_3d, reference_vector_3d, dim=0)
right_vector_3d /= right_vector_3d.norm()
up_vector_3d = torch.cross(right_vector_3d, forward_vector_3d, dim=0)
up_vector_3d /= up_vector_3d.norm()
basis = torch.stack([
right_vector_3d,
up_vector_3d,
forward_vector_3d,
], dim=0)
return basis
[docs]
def compute_model_view(camera_pos, camera_target_pos, device=None):
forward = camera_target_pos - camera_pos
forward /= torch.norm(forward, dim=-1)
rot = compute_orthonormal_basis(forward, device=device)
pos_proj = rot @ camera_pos.unsqueeze(-2).transpose(-1, -2)
model_view = torch.eye(4, dtype=rot.dtype, device=device)
model_view[:3,:3] = rot
model_view[:3,3] = -pos_proj.squeeze(-1)
return model_view
[docs]
def project_to_screen(world_points, model_view, fov, aspect, far, near, z_scale, device=None):
N = world_points.shape[0]
# Homogenize points
ones = torch.ones((N, 1), device=device)
world_points = torch.cat([world_points, ones], dim=-1)
cam_space = (model_view @ world_points.T).T
# Perspective projection
f = torch.tan(torch.tensor(fov, device=device) / 2)
if cam_space.shape[0] == 0:
return torch.empty((0, 2), device=device)
fov_h = math.tan(fov)
fov_w = math.tan(fov) * aspect
projection_matrix = torch.zeros((4, 4), device=device)
projection_matrix[0, 0] = 1 / fov_w
projection_matrix[1, 1] = 1 / fov_h
projection_matrix[2, 2] = (far + near) / (near - far)
projection_matrix[2, 3] = -(2 * far * near) / (near - far)
projection_matrix[3, 2] = 1
clip_space = (projection_matrix @ cam_space.T).T
ndc = clip_space[:, :3] / clip_space[:, 3, None]
screen_x = (ndc[:, 0] + 1) / 2 * SCREEN_WIDTH
screen_y = (1 - ndc[:, 1]) / 2 * SCREEN_HEIGHT
screen_depth = clip_space[:, 2]
screen_depth_norm = -far / (-far + z_scale * clip_space[:, 2])
return torch.stack([screen_x, screen_y, screen_depth, screen_depth_norm], dim=-1)
[docs]
def sample_cone(x, theta, k):
device = x.device
dtype = x.dtype
phi = torch.rand(k, device=device) * torch.pi * 2
cos_alpha = torch.rand(k, device=device) * ((1 - math.cos(theta)) + math.cos(theta))
alpha = torch.arccos(cos_alpha)
sin_alpha = torch.sin(alpha)
cos_phi = torch.cos(phi)
sin_phi = torch.sin(phi)
local = torch.stack([
sin_alpha * cos_phi,
sin_alpha * sin_phi,
cos_alpha
], dim=-1)
if torch.abs(x[2]) < 0.99:
ref = torch.tensor([0.,0.,1.], device=device, dtype=dtype)
else:
ref = torch.tensor([0.,1.,0.], device=device, dtype=dtype)
u = torch.nn.functional.normalize(torch.cross(ref, x, dim=0), dim=0)
v = torch.cross(x, u)
# Basis matrix
basis = torch.stack([u, v, x], dim=1) # (3,3)
# Step 3: rotate local vectors into world space
world = local @ basis.T
return world
[docs]
def clipped_mean(points: torch.Tensor, std_thresh: float = 2.0):
"""
Compute mean of 3D points, keeping only those within `std_thresh` stds
of the mean.
points: (N,3) tensor
std_thresh: threshold in standard deviations
"""
# First-pass mean and std (over points, separately for x/y/z)
mean = points.mean(dim=0, keepdim=True) # (1,3)
std = points.std(dim=0, unbiased=False, keepdim=True) # (1,3)
# Compute z-scores (absolute deviations scaled by std)
z = (points - mean).abs() / (std + 1e-8) # (N,3)
# A point is valid if *all* coords are within the threshold
mask = (z <= std_thresh).all(dim=1) # (N,)
# Filter points
valid_points = points[mask]
if valid_points.numel() == 0:
# fallback: return global mean if all filtered out
return mean.squeeze(0), mask
return valid_points.mean(dim=0), mask
[docs]
def interpolate(x0, x1, alpha):
return (1 - alpha) * x0 + (alpha * x1)
[docs]
def smooth_mean(x0: torch.Tensor, sx1: torch.Tensor, alpha: float, std_threshold: float):
x1, _ = clipped_mean(sx1, std_threshold)
return interpolate(x0, x1, alpha)
[docs]
def project(a: torch.Tensor, b: torch.Tensor):
x = a.dot(b)
y = b.dot(b)
return x / y * b
[docs]
def triangle_altitude(a, b):
assert isinstance(a, torch.Tensor) == isinstance(b, torch.Tensor)
if isinstance(a, torch.Tensor):
return a * b / torch.sqrt(a**2 + b**2)
return a * b / math.sqrt(a**2 + b**2)