#ifndef PKMEDIAN_H
#define PKMEDIAN_H

#include <functional>

#include "Helper.hpp"
#include "Metric.hpp"
#include "Point.hpp"
#include "WeightedPoint.hpp"
#include "ProbabilisticPoint.hpp"

/**
 * @brief Probabilistic k-median evaluator
 */
class PKMedian
{
public:
    PKMedian(std::function<Metric<Point>*() > createMetric);
    ~PKMedian();

    /**
     * @brief Probabilistic k-median
     */
    template<typename ForwardIteratorCenter>
    double weightedCost(ProbabilisticPoint const & pp, ForwardIteratorCenter beginC, ForwardIteratorCenter endC);

    /**
     * @brief Probabilistic k-median
     */
    template<typename ForwardIteratorPoint, typename ForwardIteratorCenter>
    double weightedCost(ForwardIteratorPoint beginP, ForwardIteratorPoint endP, ForwardIteratorCenter beginC, ForwardIteratorCenter endC);

private:
    Metric<Point>* dist;
};

template<typename ForwardIteratorCenter>
double PKMedian::weightedCost(ProbabilisticPoint const & pp, ForwardIteratorCenter beginC, ForwardIteratorCenter endC)
{
    double minCost = 0;
    for (auto itC = beginC; itC != endC; ++itC)
    {
        double tmpCost = 0;
        for (auto it = pp.cbegin(); it != pp.cend(); ++it)
            tmpCost += toPointer(*it)->getWeight() * dist->distance(toPointer(*it), toPointer(*itC));
        tmpCost *= pp.getWeight();
        if (tmpCost < minCost || itC == beginC)
            minCost = tmpCost;
    }
    return minCost;
}

template<typename ForwardIteratorPoint, typename ForwardIteratorCenter>
double PKMedian::weightedCost(ForwardIteratorPoint beginP, ForwardIteratorPoint endP, ForwardIteratorCenter beginC, ForwardIteratorCenter endC)
{
    double sum = 0;
    for (auto itP = beginP; itP != endP; ++itP)
    {
        double minCost = 0;
        for (auto itC = beginC; itC != endC; ++itC)
        {
            double tmpCost = 0;
            for (auto it = itP->cbegin(); it != itP->cend(); ++it)
                tmpCost += toPointer(*it)->getWeight() * dist->distance(toPointer(*it), toPointer(*itC));
            tmpCost *= toPointer(*itP)->getWeight();
            if (tmpCost < minCost || itC == beginC)
                minCost = tmpCost;
        }
        sum += minCost;
    }
    return sum;
}


#endif