К оглавлению

Метод наименьших квадратов

С.И.Хашин


LeastSq.h - заголовочный файл
LeastSq.cpp - реализация.

(Подключается файл dMatrix.h)

В модуле LeastSq(.h,.cpp) решается методом наименьших квадратов следующая задача. В k-мерном пространстве с координатами

        X=(x[0],…x[k-1])

найти линейную функцию

        L(X)=b[0]*x[0] + ...+ b[k-1]*x[k-1]

которая в данных точках максимально близка к данным значеним. Более точно, пусть даны точки X[1], …, X[n]:

    
    X[1] = (x[1,0], …, x[1,k-1])
         …
    X[n] = (x[n,0], …, x[n,k-1])
и ожидаемые значения в них F[1], …, F[n]. Будем подбирать функцию L(X) так, чтобы квадратичное отклонение
    S = Σ (L(X[i]) – F[i])2

было минимальным.

При заданных векторах X[i] и значениях F[i] величина S будет квадратичной функцией от b[i]. Приравняв частные производные от S по всем b[i] нулю, получим систему из k линейных уравнений от k переменных b[0],…,b[k-1]. Эта система будет иметь единственное решение тогда и только тогда, когда исходные вектора X[1],…,X[n] порождают все пространство Rk.

Класс хранит следующие данные:

    dmatr *a;
    double *sol;    // решение
    int cnt;        // количество узлов для поиска формулы
    double sum2;    // сумма квадратов f[i]
Размерность пространства k отдельно не хранится, её можно получить как a.getSize(). Добавляемые точки хранить не надо, только копим уравнения в матрице a.
Основные методы:
void Clear();                       // обнулить
void setSize(int sz);               // изменить размерность пространства
void add( double*fi, double fx);    // добавить точку размерности getSize()
void add1(double f0, double fx);    // добавить точку dim=1
void add2(double f0, double f1, double fx);             // добавить точку dim=2
void add3(double f0, double f1, double f2, double fx);  // добавить точку dim=3
void add4(double f0, double f1, double f2, double f3, double fx);   // добавить точку dim=4
void add5(double f0, …, double f4, double fx);          // добавить точку dim=5

В любой момент можно получить оптимальное текущие коэффициенты искомой функции с помощью метода

double solve(); // найти решение в вектор sol. Возвращает ср.кв.отклонение

То же самое среднеквадратичнтое отклонение можно получить (после вызова solve) и с помощью метода MSE (Mean Square Error), без аргументов. Если же мы хотим получить MSE (Среднеквадратичное отклонение) для некоторой другой функции L, т.е. для некоторого другого набора коэффициентов b[0],…,b[k-1], можно воспользоваться функцией MSE(double *fi), передав ей этот самый массив коэффициентов.

Примеры

Пример 1. Пусть на плоскости даны N точек:

    (x[0],y[0]), ..., (x[N-1],y[N-1]).
Нахождение прямой вида y = k*x+b как можно ближе проходящей к этим точкам:
    const int N = 4;
    double x[N] = { 1,3,7,11 };
    double y[N] = { 5,6,9,12 };
    LeastSq m(2);
    for (int i = 0; i < N; i++)
        m.add2(1, x[i], y[i]);
    m.solve();
    double k = m.sol[1];
    double b = m.sol[0];
    printf(" y = %10.5f *x + %10.5f, err=%10.5f\n", k, b, m.MSE());

Получим ответ:

 y = 0.71186 *x + 4.08475, err = 0.15945

Пример 2. Найти кубическую параболу f(x) = f0 + x*(f1 + x*(f2 + x*f3)) аппроксимирующую функцию 1/(1+x*x) на отрезке [0,1]:

    LeastSq m(4);
    for (double x = 0; x<=1.0000001; x+=0.01)
        m.add4(1, x, x*x, x*x*x, 1/(1+x*x));
    m.solve();
    printf(" y = %8.5f + x*( %8.5f + x*( %8.5f + x*%8.5f)), err=%8.5f\n", 
        m.sol[0],m.sol[1], m.sol[2], m.sol[3], m.MSE()); 

Получим ответ:

    y = 1.00219 + x*(-0.02794 + x*(-1.02316 + x* 0.55283)), err = 0.00128

Пример 3. Аппроксимировать функцию двух переменных

    f(x,y) = 1/(1 + x*x + y*y)

квадратичной функцией

    g(x,y) = g0 + g1*x + g2*y + g3*x*x + g4*x*y + g5*y*y

на единичном квадрате.

    LeastSq m(6);
    double r[6];
    for (double x = 0; x <= 1.01; x += 0.1)
        for (double y = 0; y <= 1.01; y += 0.1) {
            r[0] = 1;
            r[1] = x;
            r[2] = y;
            r[3] = x*x;
            r[4] = x*y;
            r[5] = y*y;
            m.add(r, 1/(1 + x*x + y*y));
        }
    m.solve();
    printf(" g(x,y) = %8.5f + %8.5f*x + %8.5f*y + %8.5f*x*x + %8.5f*x*y + %8.5f*y*y, err=%8.5f\n",
        m.sol[0], m.sol[1], m.sol[2], m.sol[3], m.sol[4], m.sol[5], m.MSE());

Получим ответ:

 g(x,y) =  1.05502 + -0.40254*x + -0.40254*y + -0.16860*x*x +  0.39670*x*y + -0.16860*y*y, err = 0.01314

К оглавлению


free counters