feat: extract linalg solver + add least-squares helper

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
sjat 2026-06-11 09:09:05 +02:00
parent 04fc642137
commit 72d32db516
2 changed files with 62 additions and 0 deletions

View file

@ -0,0 +1,24 @@
import { describe, it, expect } from 'vitest';
import { solveLinear, leastSquares } from './linalg';
describe('linalg', () => {
it('solveLinear solves a 2x2 system', () => {
// 2x + y = 5 ; x + 3y = 10 → x = 1, y = 3
const x = solveLinear([[2, 1], [1, 3]], [5, 10]);
expect(x[0]!).toBeCloseTo(1, 9);
expect(x[1]!).toBeCloseTo(3, 9);
});
it('solveLinear throws on a singular system', () => {
expect(() => solveLinear([[1, 2], [2, 4]], [3, 6])).toThrow();
});
it('leastSquares recovers exact coefficients of an over-determined linear fit', () => {
// model: t = 2*a + 3*b ; rows are [a, b]
const rows = [[1, 0], [0, 1], [1, 1], [2, 1]];
const targets = rows.map((r) => 2 * r[0]! + 3 * r[1]!);
const c = leastSquares(rows, targets);
expect(c[0]!).toBeCloseTo(2, 9);
expect(c[1]!).toBeCloseTo(3, 9);
});
});

38
src/geometry/linalg.ts Normal file
View file

@ -0,0 +1,38 @@
/** Solve a square linear system A x = b by Gaussian elimination with partial pivoting. */
export function solveLinear(A: number[][], b: number[]): number[] {
const n = b.length;
const M = A.map((row, i) => [...row, b[i]!]);
for (let col = 0; col < n; col++) {
let pivot = col;
for (let r = col + 1; r < n; r++) {
if (Math.abs(M[r]![col]!) > Math.abs(M[pivot]![col]!)) pivot = r;
}
if (Math.abs(M[pivot]![col]!) < 1e-12) throw new Error('Singular system');
[M[col], M[pivot]] = [M[pivot]!, M[col]!];
const pivRow = M[col]!;
for (let r = 0; r < n; r++) {
if (r === col) continue;
const factor = M[r]![col]! / pivRow[col]!;
for (let k = col; k <= n; k++) M[r]![k]! -= factor * pivRow[k]!;
}
}
return M.map((row, i) => row[n]! / row[i]!);
}
/**
* Least-squares solve for c minimizing rows·c targets via the normal equations
* (rowsᵀ rows) c = rowsᵀ targets. `rows` is M×N (M N), `targets` is length M.
*/
export function leastSquares(rows: number[][], targets: number[]): number[] {
const n = rows[0]!.length;
const ata: number[][] = Array.from({ length: n }, () => new Array(n).fill(0));
const atb: number[] = new Array(n).fill(0);
for (let r = 0; r < rows.length; r++) {
const row = rows[r]!;
for (let i = 0; i < n; i++) {
atb[i]! += row[i]! * targets[r]!;
for (let j = 0; j < n; j++) ata[i]![j]! += row[i]! * row[j]!;
}
}
return solveLinear(ata, atb);
}