feat: extract linalg solver + add least-squares helper
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
04fc642137
commit
72d32db516
2 changed files with 62 additions and 0 deletions
24
src/geometry/linalg.test.ts
Normal file
24
src/geometry/linalg.test.ts
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
import { describe, it, expect } from 'vitest';
|
||||
import { solveLinear, leastSquares } from './linalg';
|
||||
|
||||
describe('linalg', () => {
|
||||
it('solveLinear solves a 2x2 system', () => {
|
||||
// 2x + y = 5 ; x + 3y = 10 → x = 1, y = 3
|
||||
const x = solveLinear([[2, 1], [1, 3]], [5, 10]);
|
||||
expect(x[0]!).toBeCloseTo(1, 9);
|
||||
expect(x[1]!).toBeCloseTo(3, 9);
|
||||
});
|
||||
|
||||
it('solveLinear throws on a singular system', () => {
|
||||
expect(() => solveLinear([[1, 2], [2, 4]], [3, 6])).toThrow();
|
||||
});
|
||||
|
||||
it('leastSquares recovers exact coefficients of an over-determined linear fit', () => {
|
||||
// model: t = 2*a + 3*b ; rows are [a, b]
|
||||
const rows = [[1, 0], [0, 1], [1, 1], [2, 1]];
|
||||
const targets = rows.map((r) => 2 * r[0]! + 3 * r[1]!);
|
||||
const c = leastSquares(rows, targets);
|
||||
expect(c[0]!).toBeCloseTo(2, 9);
|
||||
expect(c[1]!).toBeCloseTo(3, 9);
|
||||
});
|
||||
});
|
||||
38
src/geometry/linalg.ts
Normal file
38
src/geometry/linalg.ts
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
/** Solve a square linear system A x = b by Gaussian elimination with partial pivoting. */
|
||||
export function solveLinear(A: number[][], b: number[]): number[] {
|
||||
const n = b.length;
|
||||
const M = A.map((row, i) => [...row, b[i]!]);
|
||||
for (let col = 0; col < n; col++) {
|
||||
let pivot = col;
|
||||
for (let r = col + 1; r < n; r++) {
|
||||
if (Math.abs(M[r]![col]!) > Math.abs(M[pivot]![col]!)) pivot = r;
|
||||
}
|
||||
if (Math.abs(M[pivot]![col]!) < 1e-12) throw new Error('Singular system');
|
||||
[M[col], M[pivot]] = [M[pivot]!, M[col]!];
|
||||
const pivRow = M[col]!;
|
||||
for (let r = 0; r < n; r++) {
|
||||
if (r === col) continue;
|
||||
const factor = M[r]![col]! / pivRow[col]!;
|
||||
for (let k = col; k <= n; k++) M[r]![k]! -= factor * pivRow[k]!;
|
||||
}
|
||||
}
|
||||
return M.map((row, i) => row[n]! / row[i]!);
|
||||
}
|
||||
|
||||
/**
|
||||
* Least-squares solve for c minimizing ‖rows·c − targets‖ via the normal equations
|
||||
* (rowsᵀ rows) c = rowsᵀ targets. `rows` is M×N (M ≥ N), `targets` is length M.
|
||||
*/
|
||||
export function leastSquares(rows: number[][], targets: number[]): number[] {
|
||||
const n = rows[0]!.length;
|
||||
const ata: number[][] = Array.from({ length: n }, () => new Array(n).fill(0));
|
||||
const atb: number[] = new Array(n).fill(0);
|
||||
for (let r = 0; r < rows.length; r++) {
|
||||
const row = rows[r]!;
|
||||
for (let i = 0; i < n; i++) {
|
||||
atb[i]! += row[i]! * targets[r]!;
|
||||
for (let j = 0; j < n; j++) ata[i]![j]! += row[i]! * row[j]!;
|
||||
}
|
||||
}
|
||||
return solveLinear(ata, atb);
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue