#ifndef FASTCORESET_HPP
#define	FASTCORESET_HPP

#include <vector>
#include <unordered_map>
#include <iterator>

#include "Metric.hpp"
#include "Sampling.hpp"
#include "Weiszfeld.hpp"
#include "CenterOfGravity.hpp"
#include "KumarMedian.hpp"
#include "LloydMedian.hpp"
#include "PKMedian.hpp"
#include "ProbabilisticPoint.hpp"

/**
 * @brief Fast implementation of PROBI
 * 
 * PROBI is a clustering algorithm for the probabilistic Euclidean k-median problem.
 */
class FastCoreset
{
private:
    Weiszfeld weiszfeld;
    CenterOfGravity centerOfGravity;
    KumarMedian kumar;
    LloydMedian lloyd;
    PKMedian pkmedian;
    Metric<Point>* metric;
    int k;
    int weiszfeldMedianIterations;
    int kumarMedianIterations;
    int maxLloydClusteringIterations;
    int allSamplesSize;
public:
    FastCoreset(std::function<Metric<Point>*() > createMetric, std::function<Norm<Point>*() > createNorm);
    virtual ~FastCoreset();

    /**
     * @brief Sets number of centers
     * @param k Number of centers
     */
    void setK(int k);
    
    /**
     * @brief Gets number of centers
     * @return Number of centers
     */
    int getK() const;
    
    /**
     * @brief Sets number of Weiszfeld iterations (approximation of 1-median)
     * @param weiszfeldMedianIterations Maximum number of iterations
     */
    void setWeiszfeldMedianIterations(int weiszfeldMedianIterations);
    
    /**
     * @brief Gets number of Weiszfeld iterations (approximation of 1-median)
     * @return Maximum number of iterations
     */
    int getWeiszfeldMedianIterations() const;
    
    /**
     * @brief Sets number of "probabilistic Lloyd" iterations
     * @param maxLloydClusteringIterations Maximum number of iterations
     */
    void setMaxLloydClusteringIterations(int maxLloydClusteringIterations);
    
    /**
     * @brief Gets number of "probabilistic Lloyd" iterations
     * @return Maximum number of iterations
     */
    int getMaxLloydClusteringIterations() const;
    
    /**
     * @brief Sets number of iterations in Kumar's k-median algorithm (fallback)
     * @param kumarMedianIterations Maximum number of iterations
     */
    void setKumarMedianIterations(int kumarMedianIterations);
    
    /**
     * @brief Gets number of iterations in Kumar's k-median algorithm (fallback)
     * @return Maximum number of iterations
     */
    int getKumarMedianIterations() const;
    
    /**
     * @brief Sets the ring sample size
     * @param allSamplesSize Ring sample size
     */
    void setAllSamplesSize(int allSamplesSize);
    
    /**
     * @brief Gets the ring sample size
     * @return Ring sample size
     */
    int getAllSamplesSize() const;

    /**
     * Computes a k-median coreset
     * @param begin Input point set: begin
     * @param end Input point set: end
     * @param output Output iterator
     * @param n Size of input (optional)
     * @return k-median coreset
     */
    template<typename Iterator1, typename Iterator2>
    void computeCoreset(Iterator1 inputBegin, Iterator1 inputEnd, Iterator2 output, size_t n = 0);

private:
    template<typename RandomAccessIterator>
    std::unique_ptr<std::vector<std::vector<std::vector<std::vector<std::pair<size_t, size_t >> >> >> partition(RandomAccessIterator begin, RandomAccessIterator end, std::vector<WeightedPoint>& medians, std::vector<Point>& centers, double radius, size_t n = 0);

    template<typename RandomAccessIterator>
    std::unique_ptr<std::vector<std::pair<size_t, double >> > sampleCoreset(RandomAccessIterator begin, RandomAccessIterator end, std::vector<std::vector<std::vector<std::vector<std::pair<size_t, size_t >> >> > const & partitions, size_t sampleSize);

    template<typename RandomAccessIterator>
    ProbabilisticPoint const & input(RandomAccessIterator begin, size_t index)
    {
        return *(begin + index);
    }
};

template<typename RandomAccessIterator1, typename Iterator2>
void FastCoreset::computeCoreset(RandomAccessIterator1 inputBegin, RandomAccessIterator1 inputEnd, Iterator2 output, size_t n)
{
    if (n == 0)
        for (RandomAccessIterator1 it = inputBegin; it != inputEnd; ++it)
            ++n;
    double W = 0;

    // 1. Compute 1-median of each (probabilistic) input point
    std::vector<WeightedPoint> medians;
    medians.reserve(n);
    for (RandomAccessIterator1 it = inputBegin; it != inputEnd; ++it)
    {
        Point p;
#if KMEANS
        p = centerOfGravity.cog(it->cbegin(), it->cend());
#else
        try
        {
            p = weiszfeld.approximateOneMedian(it->cbegin(), it->cend(), weiszfeldMedianIterations);
        }
        catch (Weiszfeld::IterationFailed err)
        {
            p = kumar.approximateOneMedianRounds(it->cbegin(), it->cend(), 0.9999999999, kumarMedianIterations);
        }
#endif
        medians.push_back(p);
        medians[medians.size() - 1].setWeight(it->getWeight());
    }

    // 2. Cluster 1-medians
    std::vector<Point> centers;
    centers.reserve(k);
    lloyd.computeCenterSet(medians.begin(), medians.end(), std::back_inserter(centers), k, maxLloydClusteringIterations, n);
    double cost = pkmedian.weightedCost(inputBegin, inputEnd, centers.begin(), centers.end());

    // 3. Partitioning and sampling
    double radius = cost / W; //TODO value of radius
    std::unique_ptr < std::vector < std::vector < std::vector < std::vector < std::pair<size_t, size_t >> >> >> partitions =
            partition(inputBegin, inputEnd, medians, centers, radius, n);
    size_t usedRings = 0;
    for (size_t l = 0; l < partitions->size(); ++l)
    {
        std::vector < std::vector < std::vector < std::pair<size_t, size_t >> >> const & L = (*partitions)[l];
        for (size_t h = 0; h < L.size(); ++h)
        {
            std::vector < std::vector < std::pair<size_t, size_t >> > const & H = L[h];
            for (size_t a = 0; a < H.size(); ++a)
            {
                std::vector < std::pair<size_t, size_t >> const & P = H[a];
                if (P.size() > 0)
                    ++usedRings;
            }
        }
    }
    int sampleSize = allSamplesSize / usedRings;
    if(sampleSize == 0)
        sampleSize = 1;
    std::unique_ptr < std::vector < std::pair<size_t, double >> > samples =
            sampleCoreset(inputBegin, inputEnd, *partitions, sampleSize);

    // Write coreset to output iterator
    for (size_t i = 0; i < samples->size(); ++i)
    {
        ProbabilisticPoint pp(*(inputBegin + i));
        pp.setWeight((*samples)[i].second);
        *output = pp;
        ++output;
    }
}

template<typename RandomAccessIterator>
std::unique_ptr<std::vector<std::vector<std::vector<std::vector<std::pair<size_t, size_t >> >> >> FastCoreset::partition(RandomAccessIterator begin, RandomAccessIterator end, std::vector<WeightedPoint>& medians, std::vector<Point>& centers, double radius, size_t n)
{
    if (n == 0)
        for (RandomAccessIterator it = begin; it != end; ++it)
            ++n;

    std::unique_ptr < std::vector < std::vector < std::vector < std::vector < std::pair<size_t, size_t >> >> >> partitions(new std::vector < std::vector < std::vector < std::vector < std::pair<size_t, size_t >> >> >());

    for (size_t i = 0; i < medians.size(); ++i)
    {
        // Determine l
        size_t l = 0;
        double dist = 0;
        {
            for (size_t j = 0; j < centers.size(); ++j)
            {
                double tmpDist = metric->distance(medians[i], centers[j]);
                if (tmpDist < dist || j == 0)
                {
                    dist = tmpDist;
                    l = j;
                }
            }
        }

        // Determine h
        size_t h = 0;
        {
            if (dist > radius)
                h = std::ceil(std::abs(std::log2(dist / radius)));
        }

        // Determine a
        size_t a = 0;
        {
            double width = 0;
            ProbabilisticPoint const & v(input(begin, i));
            for (auto it = v.cbegin(); it != v.cend(); ++it)
                width += it->getWeight() * metric->distance(*it, medians[i]);
            width /= v.getRealizationProbability();
            if (width > radius)
                a = std::ceil(std::abs(std::log2(width / radius)));
        }

        // Allocate space and push index to ring (l,h,a)
        if ((*partitions).size() < l + 1)
            (*partitions).resize(l + 1);
        if ((*partitions)[l].size() < h + 1)
            (*partitions)[l].resize(h + 1);
        if ((*partitions)[l][h].size() < a + 1)
            (*partitions)[l][h].resize(a + 1);
        (*partitions)[l][h][a].push_back(std::pair<size_t, size_t>(i, l));
    }

    return partitions;
}

template<typename RandomAccessIterator>
std::unique_ptr<std::vector<std::pair<size_t, double >> > FastCoreset::sampleCoreset(RandomAccessIterator begin, RandomAccessIterator end, std::vector<std::vector<std::vector<std::vector<std::pair<size_t, size_t >> >> > const & partitions, size_t sampleSize)
{
    std::unordered_map<size_t, size_t> sampledPoints;
    std::unique_ptr < std::vector < std::pair<size_t, double >> > coreset(new std::vector < std::pair<size_t, double >>);
    for (size_t l = 0; l < partitions.size(); ++l)
    {
        std::vector < std::vector < std::vector < std::pair<size_t, size_t >> >> const & L = partitions[l];
        for (size_t h = 0; h < L.size(); ++h)
        {
            std::vector < std::vector < std::pair<size_t, size_t >> > const & H = L[h];
            for (size_t a = 0; a < H.size(); ++a)
            {
                std::vector < std::pair<size_t, size_t >> const & P = H[a];

                // Sampling probabilities
                double sum = 0;
                for (size_t p = 0; p < P.size(); ++p)
                    sum += input(begin, P[p].first).getWeight() * input(begin, P[p].first).getRealizationProbability();
                std::vector<double> prob(P.size());
                for (size_t p = 0; p < P.size(); ++p)
                    prob[p] = input(begin, P[p].first).getWeight() * input(begin, P[p].first).getRealizationProbability();

                // Do sampling
                std::unique_ptr < std::vector < std::pair<size_t, size_t >> > sample(Sampling::sampleWithReplacement < std::vector < std::pair<size_t, size_t >> ::const_iterator, std::pair<size_t, size_t >> (P.begin(), P.end(), prob, sampleSize));

                // Return point and weight
                for (size_t s = 0; s < sample->size(); ++s)
                {
                    size_t index = (*sample)[s].first;
                    double weight = sum / (input(begin, (*sample)[s].first).getRealizationProbability() * sampleSize);
                    // Check if point was sampled before
                    if (sampledPoints.count(index) > 0)
                    {
                        (*coreset)[sampledPoints[index]].second += weight;
                    }
                    else
                    {
                        sampledPoints.insert(std::pair<size_t, size_t>(index, coreset->size()));
                        coreset->push_back(std::pair<size_t, double>(index, weight));
                    }
                }
            }
        }
    }
    return coreset;
}

#endif	/* FASTCORESET_HPP */

