/** */
/*global BigInt */
/*global BigInt64Array */

import * as wasmFeatureDetect from 'wasm-feature-detect';
import { BertTokenizer, env } from '@xenova/transformers';
import * as ort from 'onnxruntime-web/webgpu';

//Tokenizer settings
env.allowRemoteModels = false;
env.allowLocalModels = true;
env.localModelPath = process.env.PUBLIC_URL

var gpuAvailable = false;
if (!navigator.gpu) {
  const err = 'WebGPU is not supported by your browser. Using WASM instead and a smaller model checkpoint.';
  alert(err);

} else {
  console.log("WebGPU is supported!");
  gpuAvailable = true;
}

ort.env.debug = true;
//ort.env.wasm.wasmPaths = process.env.PUBLIC_URL;

//requires Cross-Origin-*-policy headers https://web.dev/coop-coep/
 wasmFeatureDetect.simd().then(simdSupported => {
    console.log("simd is supported? "+ simdSupported);
    if (simdSupported) {
      ort.env.wasm.numThreads = 4; 
      ort.env.wasm.simd = true;
    } else {
      ort.env.wasm.numThreads = 1; 
      ort.env.wasm.simd = false;
    }
});

var model = "https://huggingface.co/colbert-ir/colbertv2.0/resolve/main/model.onnx"
var executionProviders = ['webgpu', 'wasm']; 
if(!gpuAvailable) {
  executionProviders = ['wasm'];
  model = "https://huggingface.co/vespa-engine/col-minilm/resolve/main/onnx/model_quantized.onnx"
}


const options = {
  executionProviders: executionProviders, 
  graphOptimizationLevel: 'all'
};

var downLoadingModel = true;
console.log("downloading model " + model);
const session = ort.InferenceSession.create(model, options);
session.then(t => { 
  console.log("onnx model loaded");
  downLoadingModel = false;
});

const bert_tokenizer = BertTokenizer.from_pretrained('/tokenizer/', {local_files_only:true});

function isDownloading() {
  return downLoadingModel;
}

//colbert document-side filter tokens
const punctuations = [
  999n, 1000n, 1001n, 1002n, 1003n, 1004n, 1005n, 1006n, 1007n, 1008n,
  1009n, 1010n, 1011n, 1012n, 1013n, 1024n, 1025n, 1026n, 1027n, 1028n,
  1029n, 1030n, 1031n, 1032n, 1033n, 1034n, 1035n, 1036n, 1063n, 1064n,
  1065n, 1066n
];

const stopWordsMap = new Map([
  ['a', true],
  ['all', true],
  ['an', true],
  ['and', true],
  ['are', true],
  ['as', true],
  ['at', true],
  ['be', true],
  ['by', true],
  ['for', true],
  ['from', true],
  ['has', true],
  ['he', true],
  ['in', true],
  ['is', true],
  ['it', true],
  ['its', true],
  ['of', true],
  ['on', true],
  ['that', true],
  ['this', true],
  ['the', true],
  ['to', true],
  ['was', true],
  ['were', true],
  ['with', true]
]);

const max_colbert_query_length = 32;

function create_model_input(encoded, isQuery) {
  if(!isQuery) {
    encoded = encoded.filter(value => !punctuations.includes(value));
  }
  var length = encoded.length;
  if (isQuery && length > max_colbert_query_length) {
    length = max_colbert_query_length - 3;
  }
  var input_ids = new Array(length+3);
  var attention_mask = new Array(length+3);
  
  input_ids[0]= BigInt(101);
  input_ids[1] = isQuery? BigInt(1): BigInt(2)
  attention_mask[0]= BigInt(1);
  attention_mask[1]= BigInt(1);
  
  var i = 0;
  for(; i < length; i++) { 
    input_ids[i+2] = BigInt(encoded[i]);
    attention_mask[i+2] = BigInt(1);
  }
  input_ids[i+2] = BigInt(102);
  attention_mask[i+2] = BigInt(1);
  
  if(isQuery) {
    for(i = i+3; i < max_colbert_query_length; i++) {
      input_ids[i] = BigInt(103);
      attention_mask[i] = BigInt(0);
    }
  }
  const sequence_length = input_ids.length;
  var input_ids_tensor = new ort.Tensor('int64', BigInt64Array.from(input_ids), [1,sequence_length]);
  var attention_mask_tensor = new ort.Tensor('int64', BigInt64Array.from(attention_mask), [1,sequence_length]);
  return {
    model_input: {
      input_ids: input_ids_tensor,
      attention_mask: attention_mask_tensor
    },
    input:input_ids
  }
}

async function encode(text, isQuery) {
  var encoded_ids = await bert_tokenizer.then(t => {
    return t(text, {add_special_tokens:false});
  });
  encoded_ids = await encoded_ids;
  encoded_ids = Array.from(encoded_ids.input_ids.data)
  const model_input = create_model_input(encoded_ids, isQuery);
  const output =  await session.then(s => { return s.run(model_input.model_input,['contextual'])});
  return {
    model_input:model_input.input,
    embeddings:output['contextual']
  }
}

async function maxSimWithTopIndices(queryVectors, documentVectors) {
  var querySum = 0;
  var topIndices = [];
  var n = 5;

  for (var i = 0; i < queryVectors.dims[1]; i++) {
    var max = -Infinity;
    var topDotProducts = []; // Array to store top n dot products
    var topDotProductIndices = []; // Array to store corresponding indices

    for (var j = 0; j < documentVectors.dims[1]; j++) {
      const startIdxQuery = i * queryVectors.dims[2];
      const startIdxDocument = j * documentVectors.dims[2];

      // Extract vectors from data arrays
      const vectorQuery = queryVectors.data.subarray(startIdxQuery, startIdxQuery + queryVectors.dims[2]);
      const vectorDocument = documentVectors.data.subarray(startIdxDocument, startIdxDocument + documentVectors.dims[2]);

      // Calculate dot product
      const dotProduct = vectorQuery.reduce((acc, value, index) => acc + value * vectorDocument[index], 0);

      // Update top n dot products and their indices
      if (dotProduct > max) {
        max = dotProduct;
      }
      // Maintain a sorted list of top 4 dot products and their indices
      const insertionIndex = topDotProducts.findIndex((value) => dotProduct > value);
      if (insertionIndex !== -1 || topDotProducts.length < n) {
        topDotProducts.splice(insertionIndex !== -1 ? insertionIndex : topDotProducts.length, 0, dotProduct);
        topDotProductIndices.splice(insertionIndex !== -1 ? insertionIndex : topDotProductIndices.length, 0, j);
        topDotProducts = topDotProducts.slice(0, n);
        topDotProductIndices = topDotProductIndices.slice(0, n);
      }
    }
    // Update querySum with the max dot product for each vector in dim 1
    querySum += max;

    // Store top n dot products and their indices for the current query dimension
    topIndices.push({
      indices: topDotProductIndices,
      dotProducts: topDotProducts,
    });
  }

  return { querySum, topIndices };
}



async function explain(documentVectors, result, documentText) {
  var decoder = await bert_tokenizer;
  var topIndices = result.topIndices;
  const tokenScoresMap = new Map();
  
  for (var i = 0; i < topIndices.length; i++) {
    var best_scoring_doc_tokens = topIndices[i].indices;
    var scores = topIndices[i].dotProducts;

    // Consider only the top scoring document tokens
    for (var j = 0; j < 4; j++) {
      var doc_position = best_scoring_doc_tokens[j];
      var score = scores[j];
      var doc_token_id = documentVectors.model_input[doc_position];

      if (doc_token_id === 102 || doc_token_id === 101 || doc_token_id === 2) {
        continue;
      }

      var doc_word = decoder.decode_single([doc_token_id], { skip_special_tokens: false });

      if (!tokenScoresMap.has(doc_word)) {
        tokenScoresMap.set(doc_word, [{ score, doc_position }]);
      } else {
        // If the token is already in the map, find the index of the existing doc_position
        // eslint-disable-next-line 
        var existingIndex = tokenScoresMap.get(doc_word).findIndex(entry => entry.doc_position === doc_position);

        if (existingIndex !== -1) {
          // If doc_position already exists, update the score
          tokenScoresMap.get(doc_word)[existingIndex].score += score;
        } else {
          // If doc_position doesn't exist, add a new entry
          tokenScoresMap.get(doc_word).push({ score, doc_position });
        }
      }
    }
  }

  // Sort the values array by doc_position
  tokenScoresMap.forEach((value, key) => {
    tokenScoresMap.set(key, value.sort((a, b) => a.doc_position - b.doc_position));
  });
  var arrayResult = [];
  var occurrences = new Map(); 
  //const words = documentText.split(/\s+/);
  const words = documentText.split(/\s+|(?<!\w)(?=\w)|(?<=\w)(?!\w)/);
  // eslint-disable-next-line 
  for(var i = 0; i < words.length; i++) {
    var word = words[i];
    var sub_words = decoder._encode_text(words[i]);
   
    var sub_word_scores = [];
    // eslint-disable-next-line 
    for(var j = 0; j < sub_words.length; j++) {
      var sub_word = sub_words[j];
      if (stopWordsMap.has(sub_word)) {
        continue;
      }
      // eslint-disable-next-line 
      var scores = tokenScoresMap.get(sub_word);
      if(scores) {
        if(occurrences.has(sub_word)) {
          occurrences.set(sub_word, occurrences.get(sub_word) + 1);
        } else {
          occurrences.set(sub_word,0);
        }
        var index = occurrences.get(sub_word);
        if (index < scores.length) {
          var final_score = scores[index].score;
          sub_word_scores.push(final_score);
        }
      }
    }
    if (sub_word_scores.length > 0) {
      var max = Math.max(sub_word_scores)// Use the max sub word score 
      arrayResult.push({word, max});
    } else {
      arrayResult.push({word,max:0})
    }
  }
  return arrayResult;
}


async function lm_inference(queryText, documentText) {
  try { 
    if (navigator.gpu) {
      console.log("WebGPU is supported!");
    }
    var start = performance.now();
    var queryVectors = await encode(queryText,true);
    console.log("Query encoding duration " + (performance.now() - start) + " milliseconds.");
    
    start = performance.now();
    var docVectors = await encode(documentText, false);
    console.log("Document encoding duration " + (performance.now() - start) + " milliseconds.");

    start = performance.now();
    var result = await maxSimWithTopIndices(queryVectors.embeddings, docVectors.embeddings);
    var highlights = await explain(docVectors, result, documentText);
    console.log("MaxSims + explain took " + (performance.now() - start) + " milliseconds.");
    return [result.querySum,0.0,highlights];
  } catch (e) {
      console.log(e);
      return [0.0,0.0];
  }
}    

export let inference = lm_inference 
export let modelDownloadInProgress = isDownloading
