--
-- Copyright (c) 2021-2023 Zeping Lee
-- Released under the MIT license.
-- Repository: https://github.com/zepinglee/citeproc-lua
--

local core = {}

local citeproc = require("citeproc")
local bibtex2csl  -- = require("citeproc-bibtex-parser")  -- load on demand
local util = citeproc.util
require("lualibs")
local latex_parser = require("citeproc-latex-parser")


core.locale_file_format = "csl-locales-%s.xml"
core.uncited_ids = {}
core.uncite_all_items = false

core.item_list = {}
core.item_dict = {}

function core.read_file(file_name, ftype, file_info)
  if file_info then
    file_info = util.capitalize(file_info)
  else
    file_info = "File"
  end
  local path = kpse.find_file(file_name, ftype)
  if not path then
    if ftype and not util.endswith(file_name, ftype) then
      file_name = file_name .. ftype
    end
    util.error(string.format('%s "%s" not found', file_info, file_name))
    return nil
  end
  local file = io.open(path, "r")
  if not file then
    util.error(string.format('Cannot open %s "%s"', file_info, path))
    return nil
  end
  local contents = file:read("*a")
  file:close()
  return contents
end


local function read_data_file(data_file)
  local file_name = data_file
  local extension = nil
  local contents = nil

  if util.endswith(data_file, ".json") then
    extension = ".json"
    contents = core.read_file(data_file, nil, "database file")
  elseif util.endswith(data_file, ".bib") then
    extension = ".bib"
    contents = core.read_file(data_file, "bib", "database file")
  else
    local path = kpse.find_file(data_file .. ".json")
    if path then
      file_name = data_file .. ".json"
      extension = ".json"
      contents = core.read_file(data_file .. ".json", nil, "database file")
    else
      path = kpse.find_file(data_file, "bib")
      if path then
        file_name = data_file .. ".bib"
        extension = ".bib"
        contents = core.read_file(data_file, "bib", "database file")
      else
        util.error(string.format('Cannot find database file "%s"', data_file .. ".json"))
      end
    end
  end

  local csl_items = nil

  if extension == ".json" then
    local status, res = pcall(utilities.json.tolua, contents)
    if status and res then
      csl_items = res
      for _, item in ipairs(csl_items) do
        -- Jounal abbreviations
        if item.type == "article-journal" or item.type == "article-magazine"
            or item.type == "article-newspaper" then
          util.check_journal_abbreviations(item)
        end
      end
    else
      util.error(string.format('JSON decoding error in file "%s"', data_file))
      csl_items = {}
    end
  elseif extension == ".bib" then
    bibtex2csl = bibtex2csl or require("citeproc-bibtex2csl")
    csl_items = bibtex2csl.parse_bibtex_to_csl(contents, true, true, true, true)
  end

  return file_name, csl_items
end


local function read_data_files(data_files)
  local item_list = {}
  local item_dict = {}
  for _, data_file in ipairs(data_files) do
    local file_name, csl_items = read_data_file(data_file)

    -- TODO: parse item_dict entries on demand
    for _, item in ipairs(csl_items) do
      local id = item.id
      if item_dict[id] then
        util.warning(string.format('Duplicate entry key "%s" in "%s".', id, file_name))
      else
        item_dict[id] = item
        table.insert(item_list, item)
      end
    end
  end
  return item_list, item_dict
end



function core.make_citeproc_sys(data_files)
  core.item_list, core.item_dict = read_data_files(data_files)
  local citeproc_sys = {
    retrieveLocale = function (lang)
      local locale_file_format = core.locale_file_format or "locales-%s.xml"
      local filename = string.format(locale_file_format, lang)
      return core.read_file(filename)
    end,
    retrieveItem = function (id)
      local res = core.item_dict[id]
      return res
    end
  }

  return citeproc_sys
end

function core.init(style_name, data_files, lang)
  if style_name == "" or #data_files == 0 then
    return nil
  end
  local style = core.read_file(style_name .. ".csl", nil, "style")
  if not style then
    util.error(string.format('Failed to load style "%s.csl"', style_name))
    return nil
  end

  local force_lang = nil
  if lang and lang ~= "" then
    force_lang = true
  else
    lang = nil
  end

  local citeproc_sys = core.make_citeproc_sys(data_files)
  local engine = citeproc.new(citeproc_sys, style, lang, force_lang)
  return engine
end

local function parse_latex_seq(s)
  local t = {}
  for item in string.gmatch(s, "(%b{})") do
    item = string.sub(item, 2, -2)
    table.insert(t, item)
  end
  return t
end

local function parse_latex_prop(s)
  local t = {}
  for key, value in string.gmatch(s, "([%w%-]+)%s*=%s*(%b{})") do
    value = string.sub(value, 2, -2)
    if value == "true" then
      value = true
    elseif value == "false" then
      value = false
    end
    t[key] = value
  end
  return t
end

function core.make_citation(citation_info)
  -- `citation_info`: "citationID={ITEM-1@2},citationItems={{id={ITEM-1},label={page},locator={6}}},properties={noteIndex={3}}"
  local citation = parse_latex_prop(citation_info)
  -- assert(citation.citationID)
  -- assert(citation.citationItems)
  -- assert(citation.properties)

  citation.citationItems = parse_latex_seq(citation.citationItems)

  for i, item in ipairs(citation.citationItems) do
    local citation_item = parse_latex_prop(item)
    if citation_item.prefix then
      -- util.debug(citation_item.prefix)
      citation_item.prefix = latex_parser.latex_to_pseudo_html(citation_item.prefix, true, false)
      -- util.debug(citation_item.prefix)
    end
    if citation_item.suffix then
      citation_item.suffix = latex_parser.latex_to_pseudo_html(citation_item.suffix, true, false)
    end
    citation.citationItems[i] = citation_item
  end

  citation.properties = parse_latex_prop(citation.properties)
  local note_index = citation.properties.noteIndex
  if not note_index or note_index == "" then
    citation.properties.noteIndex = 0
  elseif type(note_index) == "string" and string.match(note_index, "^%d+$") then
    citation.properties.noteIndex = tonumber(note_index)
  else
    util.error(string.format('Invalid note index "%s".', note_index))
  end

  return citation
end


function core.process_citations(engine, citations)
  local citations_pre = {}

  local citation_strings = {}

  core.update_cited_and_uncited_ids(engine, citations)

  for _, citation in ipairs(citations) do
    if citation.citationID ~= "@nocite" then
      -- local res = engine:processCitationCluster(citation, citations_pre, {})
      -- for _, tuple in ipairs(res[2]) do
      --   local citation_str = tuple[2]
      --   local citation_id = tuple[3]
      --   citation_strings[citation_id] = citation_str
      --   util.debug(citation_str)
      -- end

      local citation_str = engine:process_citation(citation)
      citation_strings[citation.citationID] = citation_str

      table.insert(citations_pre, {citation.citationID, citation.properties.noteIndex})
    end
  end

  return citation_strings
end

function core.update_cited_and_uncited_ids(engine, citations)
  local id_list = {}
  local id_map = {}  -- Boolean map for checking if id in list
  local uncited_id_list = {}
  local uncited_id_map = {}

  for _, citation in ipairs(citations) do
    if citation.citationID == "@nocite" then
      for _, cite_item in ipairs(citation.citationItems) do
        if cite_item.id == "*" then
          if not core.uncite_all_items then
            for _, item in ipairs(core.item_list) do
              if not uncited_id_map[item.id] then
                table.insert(uncited_id_list, item.id)
                uncited_id_map[item.id] = true
              end
            end
            core.uncite_all_items = true
          end
        elseif not uncited_id_map[cite_item.id] then
          table.insert(uncited_id_list, cite_item.id)
          uncited_id_map[cite_item.id] = true
        end
      end

    else  -- Real citation
      for _, cite_item in ipairs(citation.citationItems) do
        if not id_map[cite_item.id] then
          table.insert(id_list, cite_item.id)
          id_map[cite_item.id] = true
        end
      end

    end
  end

  engine:updateItems(id_list)
  engine:updateUncitedItems(uncited_id_list)

end

function core.parser_filter(filter_str)
  -- util.debug(filter_str)
  local filter = latex_parser.parse_prop(filter_str)
  for filter_type, conditions in pairs(filter) do
    conditions = latex_parser.parse_seq(conditions)
    filter[filter_type] = conditions
    for i, condition in ipairs(conditions) do
      conditions[i] = latex_parser.parse_prop(condition)
    end
  end
  return filter
end

function core.make_bibliography(engine, filter_str)
  local filter
  if filter_str then
    filter = core.parser_filter(filter_str)
  end
  local result = engine:makeBibliography(filter)

  local params = result[1]
  local bib_items = result[2]

  local res = ""

  local bib_options = {}
  bib_options["class"] = engine:get_style_class()
  local bib_option_list = {"class"}

  local bib_option_map = {
    ["entry-spacing"] = "entryspacing",
    ["line-spacing"] = "linespacing",
    ["hanging-indent"] = "hangingindent",
  }
  local bib_option_order = {
    "class",
    "hanging-indent",
    "line-spacing",
    "entry-spacing",
  }

  for option, param in pairs(bib_option_map) do
    if params[param] then
      bib_options[option] = params[param]
    end
  end

  local bib_options_str = "\\cslsetup{\n"
  for _, option in ipairs(bib_option_order) do
    local value = bib_options[option]
    if value then
      bib_options_str = bib_options_str .. string.format("  %s = %s,\n", option, tostring(value))
    end
  end
  bib_options_str = bib_options_str .. "}\n"
  res = res .. bib_options_str .. "\n"

  -- util.debug(params.bibstart)
  if params.bibstart then
    res = res .. params.bibstart
  end

  for _, bib_item in ipairs(bib_items) do
    res = res .. "\n" .. bib_item
  end

  if params.bibend then
    res = res .. "\n" .. params.bibend
  end
  return res
end


function core.set_categories(engine, categories_str)
  -- util.debug(categories_str)
  local category_dict = latex_parser.parse_prop(categories_str)
  for category, keys in pairs(category_dict) do
    category_dict[category] = latex_parser.parse_seq(keys)
  end
  for category, keys in pairs(category_dict) do
    for _, key in ipairs(keys) do
      local item = engine.registry.registry[key]
      if item then
        if not item.categories then
          item.categories = {}
        end
        if not util.in_list(category, item.categories) then
          table.insert(item.categories, category)
        end
      else
        util.error(string.format("Invalid citation key '%s'.", key))
      end
    end
  end
  -- for id, item in pairs(csl.engine.registry.registry) do
  --   util.debug(id)
  --   util.debug(item.categories)
  -- end
end


return core
