#! /usr/bin/lua
-- $NetBSD: fmt-list,v 1.6 2022/09/08 05:05:08 rillig Exp $

--[[

Align the lines of a file list so that all lines from the same directory
have the other fields at the same indentation.

Sort the lines and remove duplicate lines.

usage: ./fmt-list [-n] */*/{mi,ad.*,md.*}

]]

local function test(func)
  func()
end

local function assert_equals(got, expected)
  if got ~= expected then
    assert(false, ("got %q, expected %q"):format(got, expected))
  end
end


-- Calculate the width of the given string on the screen, assuming that
-- the tab width is 8 and that the string starts at a tabstop.
local function tabwidth(str)
  local width = 0
  for i = 1, #str do
    if str:sub(i, i) == "\t" then
      width = width // 8 * 8 + 8
    else
      width = width + 1
    end
  end
  return width
end

test(function()
  assert_equals(tabwidth(""), 0)
  assert_equals(tabwidth("1234"), 4)
  assert_equals(tabwidth("\t"), 8)
  assert_equals(tabwidth("1234567\t"), 8)
  assert_equals(tabwidth("\t1234\t"), 16)
  assert_equals(tabwidth("\t1234\t1"), 17)
end)


-- Calculate the tab characters that are necessary to set the width
-- of the string to the desired width.
local function tabs(str, width)
  local strwidth = tabwidth(str)
  local tabs = ("\t"):rep((width - strwidth + 7) // 8)
  if tabs == "" then
    error(("%q\t%d\t%d"):format(str, strwidth, width))
  end
  assert(tabs ~= "")
  return tabs
end

test(function()
  assert_equals(tabs("", 8), "\t")
  assert_equals(tabs("1234567", 8), "\t")
  assert_equals(tabs("", 64), "\t\t\t\t\t\t\t\t")
end)


-- Group the items by a key and then execute the action on each of the
-- groups.
local function foreach_group(items, get_key, action)
  local key
  local group = {}
  for _, item in ipairs(items) do
    local item_key = assert(get_key(item))
    if item_key ~= key then
      if #group > 0 then action(group, key) end
      key = item_key
      group = {}
    end
    table.insert(group, item)
  end
  if #group > 0 then action(group, key) end
end

test(function()
  local items = {
    {"prime", 2},
    {"prime", 3},
    {"not prime", 4},
    {"prime", 5},
    {"prime", 7}
  }
  local result = ""
  foreach_group(
    items,
    function(item) return item[1] end,
    function(group, key)
      result = result .. ("%d %s\n"):format(#group, key)
    end)
  assert_equals(result, "2 prime\n1 not prime\n2 prime\n")
end)


-- Parse a line from a file list and split it into its meaningful parts.
local function parse_entry(line)

  local category_align, prefix, fullname, flags_align, category, flags =
    line:match("^(([#%-]?)(%.%S*)%s+)((%S+)%s+)(%S+)$")
  if fullname == nil then
    category_align, prefix, fullname, category =
      line:match("^(([#%-]?)(%.%S*)%s+)(%S+)$")
  end
  if fullname == nil then
    prefix, fullname = line:match("^(%-)(%.%S*)$")
  end
  if fullname == nil then
    return
  end

  local dirname, basename = fullname:match("^(.+)/([^/]+)$")
  if dirname == nil then
    dirname, basename = "", fullname
  end

  local category_col, flags_col
  if category_align ~= nil then
    category_col = tabwidth(category_align)
  end
  if flags_align ~= nil then
    flags_col = tabwidth(flags_align)
  end

  return {
    prefix = prefix,
    fullname = fullname,
    dirname = dirname,
    basename = basename,
    category_col = category_col,
    category = category,
    flags_col = flags_col,
    flags = flags
  }
end

test(function()
  local entry = parse_entry("./dirname/filename\t\t\tcategory\tflags")
  assert_equals(entry.prefix, "")
  assert_equals(entry.fullname, "./dirname/filename")
  assert_equals(entry.dirname, "./dirname")
  assert_equals(entry.basename, "filename")
  assert_equals(entry.category_col, 40)
  assert_equals(entry.category, "category")
  assert_equals(entry.flags_col, 16)
  assert_equals(entry.flags, "flags")

  entry = parse_entry("#./dirname/filename\tcat\tflags")
  assert_equals(entry.prefix, "#")
  assert_equals(entry.fullname, "./dirname/filename")
  assert_equals(entry.dirname, "./dirname")
  assert_equals(entry.basename, "filename")
  assert_equals(entry.category_col, 24)
  assert_equals(entry.category, "cat")
  assert_equals(entry.flags_col, 8)
  assert_equals(entry.flags, "flags")
end)


-- Return the smaller of the given values, ignoring nil.
local function min(curr, value)
  if curr == nil or (value ~= nil and value < curr) then
    return value
  end
  return curr
end

test(function()
  assert_equals(min(nil, nil), nil)
  assert_equals(min(0, nil), 0)
  assert_equals(min(nil, 0), 0)
  assert_equals(min(0, 0), 0)
  assert_equals(min(1, -1), -1)
  assert_equals(min(-1, 1), -1)
end)


-- Return the larger of the given values, ignoring nil.
local function max(curr, value)
  if curr == nil or (value ~= nil and value > curr) then
    return value
  end
  return curr
end

test(function()
  assert_equals(max(nil, nil), nil)
  assert_equals(max(0, nil), 0)
  assert_equals(max(nil, 0), 0)
  assert_equals(max(0, 0), 0)
  assert_equals(max(1, -1), 1)
  assert_equals(max(-1, 1), 1)
end)


-- Calculate the column on which the field should be aligned.
local function column(entries, get_width_before, colname)

  local function nexttab(col)
    return col // 8 * 8 + 8
  end

  local currmin, currmax, required

  for _, entry in ipairs(entries) do
    local width = get_width_before(entry)
    if width ~= nil then
      required = max(required, width)

      local col = entry[colname]
      currmin = min(currmin, col)
      currmax = max(currmax, col)
    end
  end

  if currmin == currmax then
    return currmin, "aligned"
  end
  return nexttab(required), "unaligned"
end

test(function()

  local function width_before_category(entry)
    return tabwidth(entry.prefix .. entry.fullname)
  end

  local function width_before_flags(entry)
    return tabwidth(entry.category)
  end

  -- The entries are nicely aligned, therefore there is no need to change
  -- anything.
  local entries = {
    parse_entry("./file1\tcategory"),
    parse_entry("./file2\tcategory")
  }
  assert_equals(entries[2].category_col, 8)
  assert_equals(width_before_category(entries[2]), 7)
  assert_equals(column(entries, width_before_category, "category_col"), 8)

  -- The entries are currently not aligned, therefore they are aligned
  -- to the minimum required column.
  entries = {
    parse_entry("./file1\tcategory"),
    parse_entry("./directory/file2\tcategory"),
  }
  assert_equals(entries[2].category_col, 24)
  assert_equals(column(entries, width_before_category, "category_col"), 24)

  -- The entries are already aligned, therefore the current alignment is
  -- preserved, even though it is more than the minimum required alignment
  -- of 8.  There are probably reasons for the large indentation.
  entries = {
    parse_entry("./file1\t\t\tcategory"),
    parse_entry("./file2\t\t\tcategory")
  }
  assert_equals(column(entries, width_before_category, "category_col"), 24)

  -- The flags are already aligned, 4 tabs to the right of the category.
  -- There is no reason to change anything here.
  entries = {
    parse_entry("./file1\tcategory\t\t\tflags"),
    parse_entry("./file2\tcategory"),
    parse_entry("./file3\tcat\t\t\t\tflags")
  }
  assert_equals(column(entries, width_before_flags, "flags_col"), 32)

end)


-- Amend the entries by the tabs used for alignment.
local function add_tabs(entries)

  local function width_before_category(entry)
    return tabwidth(entry.prefix .. entry.fullname)
  end
  local function width_before_flags(entry)
    if entry.flags ~= nil then
      return tabwidth(entry.category)
    end
  end

  local category_col, category_aligned =
    column(entries, width_before_category, "category_col")
  local flags_col = column(entries, width_before_flags, "flags_col")

  -- To avoid horizontal jumps for the category column, the minimum column is
  -- set to 56.  This way, the third column is usually set to 72, which is
  -- still visible on an 80-column screen.
  if category_aligned == "unaligned" then
    category_col = max(category_col, 56)
  end

  for _, entry in ipairs(entries) do
    local prefix = entry.prefix
    local fullname = entry.fullname
    local category = entry.category
    local flags = entry.flags

    if category ~= nil then
      entry.category_tabs = tabs(prefix .. fullname, category_col)
      if flags ~= nil then
        entry.flags_tabs = tabs(category, flags_col)
      end
    end
  end
end

test(function()
  local entries = {
    parse_entry("./file1\t\t\t\tcategory\t\tflags"),
    parse_entry("./file2\t\t\t\tcategory\t\tflags"),
    parse_entry("./file3\t\t\tcategory\t\tflags")
  }
  add_tabs(entries)
  assert_equals(entries[1].category_tabs, "\t\t\t\t\t\t\t")
  assert_equals(entries[2].category_tabs, "\t\t\t\t\t\t\t")
  assert_equals(entries[3].category_tabs, "\t\t\t\t\t\t\t")
  assert_equals(entries[1].flags_tabs, "\t\t")
  assert_equals(entries[2].flags_tabs, "\t\t")
  assert_equals(entries[3].flags_tabs, "\t\t")
end)


-- Normalize the alignment of the fields of the entries.
local function normalize(entries)

  local function less(a, b)
    if a.fullname ~= b.fullname then
      -- To sort by directory first, comment out the following line.
      return a.fullname < b.fullname
    end
    if a.dirname ~= b.dirname then
      return a.dirname < b.dirname
    end
    if a.basename ~= b.basename then
      return a.basename < b.basename
    end
    if a.category ~= nil and b.category ~= nil and a.category ~= b.category then
      return a.category < b.category
    end
    return a.flags ~= nil and b.flags ~= nil and a.flags < b.flags
  end
  table.sort(entries, less)

  local function by_dirname(entry)
    return entry.dirname
  end
  foreach_group(entries, by_dirname, add_tabs)

end


-- Read a file list completely into memory.
local function read_list(fname)
  local head = {}
  local entries = {}
  local errors = {}

  local f = assert(io.open(fname, "r"))
  local lineno = 0
  for line in f:lines() do
    lineno = lineno + 1

    local entry = parse_entry(line)
    if entry ~= nil then
      table.insert(entries, entry)
    elseif line:match("^#") then
      table.insert(head, line)
    else
      local msg = ("%s:%d: unknown line format %q"):format(fname, lineno, line)
      table.insert(errors, msg)
    end
  end

  f:close()

  return head, entries, errors
end


-- Write the normalized list file back to disk.
--
-- Duplicate lines are skipped.  This allows to append arbitrary lines to
-- the end of the file and have them cleaned up automatically.
local function write_list(fname, head, entries)
  local f = assert(io.open(fname, "w"))

  for _, line in ipairs(head) do
    f:write(line, "\n")
  end

  local prev_line = ""
  for _, entry in ipairs(entries) do
    local line = entry.prefix .. entry.fullname
    if entry.category ~= nil then
      line = line .. entry.category_tabs .. entry.category
    end
    if entry.flags ~= nil then
      line = line .. entry.flags_tabs .. entry.flags
    end

    if line ~= prev_line then
      prev_line = line
      f:write(line, "\n")
    else
      --print(("%s: duplicate entry: %s"):format(fname, line))
    end
  end

  f:close()
end


-- Load a file list, normalize it and write it back to disk.
local function format_list(fname, write_back)
  local head, entries, errors = read_list(fname)
  if #errors > 0 then
    for _, err in ipairs(errors) do
      print(err)
    end
    return false
  end

  normalize(entries)

  if write_back then
    write_list(fname, head, entries)
  end
  return true
end


local function main(arg)
  local seen_error = false
  local write_back = true
  for _, fname in ipairs(arg) do
    if fname == "-n" then
      write_back = false
    else
      if not format_list(fname, write_back) then
        seen_error = true
      end
    end
  end
  return not seen_error
end

os.exit(main(arg))