feat: extract linalg solver + add least-squares helper
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
04fc642137
commit
72d32db516
2 changed files with 62 additions and 0 deletions
24
src/geometry/linalg.test.ts
Normal file
24
src/geometry/linalg.test.ts
Normal file
|
|
@ -0,0 +1,24 @@
|
||||||
|
import { describe, it, expect } from 'vitest';
|
||||||
|
import { solveLinear, leastSquares } from './linalg';
|
||||||
|
|
||||||
|
describe('linalg', () => {
|
||||||
|
it('solveLinear solves a 2x2 system', () => {
|
||||||
|
// 2x + y = 5 ; x + 3y = 10 → x = 1, y = 3
|
||||||
|
const x = solveLinear([[2, 1], [1, 3]], [5, 10]);
|
||||||
|
expect(x[0]!).toBeCloseTo(1, 9);
|
||||||
|
expect(x[1]!).toBeCloseTo(3, 9);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('solveLinear throws on a singular system', () => {
|
||||||
|
expect(() => solveLinear([[1, 2], [2, 4]], [3, 6])).toThrow();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('leastSquares recovers exact coefficients of an over-determined linear fit', () => {
|
||||||
|
// model: t = 2*a + 3*b ; rows are [a, b]
|
||||||
|
const rows = [[1, 0], [0, 1], [1, 1], [2, 1]];
|
||||||
|
const targets = rows.map((r) => 2 * r[0]! + 3 * r[1]!);
|
||||||
|
const c = leastSquares(rows, targets);
|
||||||
|
expect(c[0]!).toBeCloseTo(2, 9);
|
||||||
|
expect(c[1]!).toBeCloseTo(3, 9);
|
||||||
|
});
|
||||||
|
});
|
||||||
38
src/geometry/linalg.ts
Normal file
38
src/geometry/linalg.ts
Normal file
|
|
@ -0,0 +1,38 @@
|
||||||
|
/** Solve a square linear system A x = b by Gaussian elimination with partial pivoting. */
|
||||||
|
export function solveLinear(A: number[][], b: number[]): number[] {
|
||||||
|
const n = b.length;
|
||||||
|
const M = A.map((row, i) => [...row, b[i]!]);
|
||||||
|
for (let col = 0; col < n; col++) {
|
||||||
|
let pivot = col;
|
||||||
|
for (let r = col + 1; r < n; r++) {
|
||||||
|
if (Math.abs(M[r]![col]!) > Math.abs(M[pivot]![col]!)) pivot = r;
|
||||||
|
}
|
||||||
|
if (Math.abs(M[pivot]![col]!) < 1e-12) throw new Error('Singular system');
|
||||||
|
[M[col], M[pivot]] = [M[pivot]!, M[col]!];
|
||||||
|
const pivRow = M[col]!;
|
||||||
|
for (let r = 0; r < n; r++) {
|
||||||
|
if (r === col) continue;
|
||||||
|
const factor = M[r]![col]! / pivRow[col]!;
|
||||||
|
for (let k = col; k <= n; k++) M[r]![k]! -= factor * pivRow[k]!;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return M.map((row, i) => row[n]! / row[i]!);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Least-squares solve for c minimizing ‖rows·c − targets‖ via the normal equations
|
||||||
|
* (rowsᵀ rows) c = rowsᵀ targets. `rows` is M×N (M ≥ N), `targets` is length M.
|
||||||
|
*/
|
||||||
|
export function leastSquares(rows: number[][], targets: number[]): number[] {
|
||||||
|
const n = rows[0]!.length;
|
||||||
|
const ata: number[][] = Array.from({ length: n }, () => new Array(n).fill(0));
|
||||||
|
const atb: number[] = new Array(n).fill(0);
|
||||||
|
for (let r = 0; r < rows.length; r++) {
|
||||||
|
const row = rows[r]!;
|
||||||
|
for (let i = 0; i < n; i++) {
|
||||||
|
atb[i]! += row[i]! * targets[r]!;
|
||||||
|
for (let j = 0; j < n; j++) ata[i]![j]! += row[i]! * row[j]!;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return solveLinear(ata, atb);
|
||||||
|
}
|
||||||
Loading…
Add table
Reference in a new issue