#include "pqnnsearch.h"

#include "abstractpquantizer.h"
#include "pqcodeloader.h"
#include "evaluator.hpp"
#include "vstring.h"

#include <iostream>
#include <cstring>
#include <cassert>
#include <queue>
#include <cmath>

using namespace std;

const unsigned int PQnnSearch::topk0 = 128;

using PriorityQType =
            std::priority_queue<std::pair<float, unsigned>, std::vector<std::pair<float, unsigned>>>;

PQnnSearch::PQnnSearch(string vocabFn, string refPQFn)
{   
    this->pqVocab    = AbstractPQuantizer::loadPQVocab(vocabFn, this->pqDim, this->pqNum, this->nSeg0);
    this->refPQCodes = nullptr;
    this->ftDim      = this->pqDim*this->nSeg0;
    this->ADCTab     = new float[this->pqNum*this->nSeg0];

    memset(this->ADCTab, 0, this->pqNum*this->nSeg0*sizeof(float));

    std::cout << "Vector dim ....................... " << this->ftDim << endl;
    std::cout << "PQ dim ........................... " << this->pqDim << endl;
    std::cout << "PQ Size .......................... " << this->pqNum << endl;
    std::cout << "No. of PQ segments ............... " << this->nSeg0 << endl;

    refPQCodes = PQCodeLoader::loadPQCodes(refPQFn, this->imgNum, this->ftDim);

    if(this->imgNum == 0 || refPQCodes == nullptr)
    {
        cerr << "Error: no reference set has been loaded!\n";
        exit(0);
    }

    std::cout << "Referece size .................... " << imgNum << endl;
    //std::cout << endl;
}

void PQnnSearch::updateADCTab(const float *qry,  const unsigned int nSeg0)
{
    const float *p_pqVocab = nullptr, *ppq = nullptr;
    unsigned long ti = 0, locs = 0, i = 0, di = 0;
    float dst = 0, delta = 0;
    unsigned s = 0;

    for(s = 0; s < nSeg0; s++)
    {
        p_pqVocab = this->pqVocab + s*this->pqDim*this->pqNum;
        locs      = s*this->pqDim;
        for(i = 0; i < this->pqNum; i++)
        {
            ppq = p_pqVocab + i*this->pqDim;
            dst = 0;
            /** 
            *filling your codes here
            **/
            ti = s*this->pqNum + i;
            this->ADCTab[ti] = dst;
        }
    }

    return ;
}

std::vector<std::vector<size_t> > PQnnSearch::performADCQuery(string srcFn, string dstFn, const unsigned topk)
{
    unsigned int idxr = 0, s = 0, i = 0, k = 0, ri = 0, tmpIdx = 0;
    unsigned char *ptcodes = nullptr;
    size_t qDim = 0, qRow = 0;
    float *queries = nullptr;
    float dist = 0;
    std::vector<std::pair<unsigned, float> > knns; //for debug use
    std::vector<std::vector<size_t> > nbsAllQry;
    
    queries = AbstractPQuantizer::loadfvecs(srcFn, qRow, qDim);
    std::cout << "Query size ....................... " << qRow << "x" << qDim << std::endl;
    PriorityQType topRank;
    for(i = 0; i < qRow; i++)
    {
        this->updateADCTab(queries + i*qDim, this->nSeg0);

        for(ri = 0; ri < this->imgNum; ri++)
        {
            dist    = 0;
            k       = ri*this->nSeg0;
            ptcodes = &(refPQCodes[k]);

            /** 
            *filling your codes here
            **/

            topRank.emplace(pair<float, unsigned>(dist, ri));
            if(topRank.size() > topk)
            {
                topRank.pop();
            }
        }//for(ri)
        std::vector<size_t> nbs;
        int j = topRank.size() - 1;
        while(!topRank.empty())
        {
            pair<float, unsigned> itm = topRank.top();
            nbs.emplace_back(itm.second);
            topRank.pop();
            j--;
        }
        nbsAllQry.emplace_back(nbs);
        std::cout << "\r\r\r\r\t\t\t" << i;
    }//(for i)
    std::cout << std::endl;

    return nbsAllQry;
}

PQnnSearch::~PQnnSearch()
{
    if(this->pqVocab != nullptr)
    {
        delete [] this->pqVocab;
        this->pqVocab = nullptr;
    }

    if(this->ADCTab != nullptr)
    {
        delete [] this->ADCTab;
        this->ADCTab = nullptr;
    }

    if(this->refPQCodes != nullptr)
    {
        delete [] this->refPQCodes;
        this->refPQCodes = nullptr;
    }
}

void PQnnSearch::test()
{
    /*****modify the following paths to your appropriate paths */
    string vocabFn = "/home/wlzhao/datasets/bignn/sift1m/pq/vocab_pqk256m8.txt";
    string refpqFn = "/home/wlzhao/datasets/bignn/sift1m/pq/sift1m_pq1.txt";
    string queryFn = "/home/wlzhao/datasets/bignn/sift1m/sift1m_query.fvecs";
    string gtFn    = "/home/wlzhao/datasets/bignn/sift1m/pq/sift1m_gt.txt";
    string dstFn   = "/home/wlzhao/datasets/bignn/sift1m/pq/sift1m_rslt.txt";

    std::vector<std::vector<size_t> > nbsAllQry;
    Evaluator  eval(gtFn);
    PQnnSearch *pqsrch = new PQnnSearch(vocabFn, refpqFn);
    nbsAllQry = pqsrch->performADCQuery(queryFn, dstFn, 32);
    float recall = eval.getRecall(nbsAllQry, 32);

    std::cout << "Recall: " << recall << std::endl;
    nbsAllQry.clear();

    return;
}
