#include "win.h"

#ifdef USE_TOUCH

#include <cstdlib>
#include <exception>
#include <memory>
#include <optional>
#include <stdexcept>
#include <unordered_map>
#include <vector>
#include <Windows.h>
#include <hidsdi.h>
#include <hidpi.h>
#include <hidusage.h>

#include "win.h"

#pragma comment(lib,"hid")

// HID usages that are not already defined
#define HID_USAGE_DIGITIZER_CONTACT_ID 0x51
#define HID_USAGE_DIGITIZER_CONTACT_COUNT 0x54

namespace mswin {

  // Registers the specified window to receive touchpad HID events.
  bool EnableTouchpadInput(HWND hWnd)
  {
    RAWINPUTDEVICE dev;
    dev.usUsagePage = HID_USAGE_PAGE_DIGITIZER;
    dev.usUsage = HID_USAGE_DIGITIZER_TOUCH_PAD;
    dev.dwFlags = RIDEV_INPUTSINK;
    dev.hwndTarget = hWnd;
    return !!RegisterRawInputDevices(&dev, 1, sizeof(RAWINPUTDEVICE));
  }

  // C++ exception wrapping the Win32 GetLastError() status
  class win32_error : std::exception
  {
  public:
    win32_error(DWORD errorCode) : m_errorCode(errorCode) { }
    win32_error() : win32_error(GetLastError()) { }

    DWORD code() const { return m_errorCode; }

  private:
    DWORD m_errorCode;
  };

  // C++ exception wrapping the HIDP_STATUS_* codes
  class hid_error : std::exception
  {
  public:
    hid_error(NTSTATUS status) : m_errorCode(status) { }

    NTSTATUS code() const { return m_errorCode; }

  private:
    NTSTATUS m_errorCode;
  };

  // Wrapper for malloc with unique_ptr semantics, to allow
  // for variable-sized structures.
  struct free_deleter { void operator()(void *ptr) { free(ptr); } };
  template<typename T> using malloc_ptr = std::unique_ptr<T, free_deleter>;

  // Contact information parsed from the HID report descriptor.
  struct at_contact_info
  {
    USHORT link;
    RECT touchArea;
  };

  // The data for a touch event.
  struct at_contact
  {
    at_contact_info info;
    ULONG id;
    POINT point;
  };

  // Device information, such as touch area bounds and HID offsets.
  // This can be reused across HID events, so we only have to parse
  // this info once.
  struct at_device_info
  {
    malloc_ptr<_HIDP_PREPARSED_DATA> preparsedData; // HID internal data
    USHORT linkContactCount; // Link collection for number of contacts present
    std::vector<at_contact_info> contactInfo; // Link collection and touch area for each contact
    //std::optional<RECT> touchAreaOverride; // Override touch area for all points if set
  };

  // Caches per-device info for better performance
  static std::unordered_map<HANDLE, at_device_info> g_devices;

  // Whether absolute input mode is enabled
  static bool g_enabled;

  // Most recent device handle that sent raw input
  static HANDLE g_lastDevice;

  // Holds the current primary touch point ID
  static thread_local ULONG t_primaryContactID;

  // Allocates a malloc_ptr with the given size. The size must be
  // greater than or equal to sizeof(T).
  template<typename T> static malloc_ptr<T> make_malloc(size_t size)
  {
    T *ptr = (T *)malloc(size);
    if (ptr == nullptr) {
      throw std::bad_alloc();
    }
    return malloc_ptr<T>(ptr);
  }

  // Reads the raw input header for the given raw input handle.
  static RAWINPUTHEADER
    AT_GetRawInputHeader(HRAWINPUT hInput)
  {
    RAWINPUTHEADER hdr;
    UINT size = sizeof(hdr);
    if (GetRawInputData(hInput, RID_HEADER, &hdr, &size, sizeof(RAWINPUTHEADER)) == (UINT)-1) {
      throw win32_error();
    }
    return hdr;
  }

  // Reads the raw input data for the given raw input handle.
  static malloc_ptr<RAWINPUT>
    AT_GetRawInput(HRAWINPUT hInput, RAWINPUTHEADER hdr)
  {
    malloc_ptr<RAWINPUT> input = make_malloc<RAWINPUT>(hdr.dwSize);
    UINT size = hdr.dwSize;
    if (GetRawInputData(hInput, RID_INPUT, input.get(), &size, sizeof(RAWINPUTHEADER)) == (UINT)-1) {
      throw win32_error();
    }
    return input;
  }

  // Gets a list of raw input devices attached to the system.
  static std::vector<RAWINPUTDEVICELIST>
    AT_GetRawInputDeviceList()
  {
    std::vector<RAWINPUTDEVICELIST> devices(64);
    while (true) {
      UINT numDevices = (UINT)devices.size();
      UINT ret = GetRawInputDeviceList(&devices[0], &numDevices, sizeof(RAWINPUTDEVICELIST));
      if (ret != (UINT)-1) {
        devices.resize(ret);
        return devices;
      }
      else if (GetLastError() == ERROR_INSUFFICIENT_BUFFER) {
        devices.resize(numDevices);
      }
      else {
        throw win32_error();
      }
    }
  }

  // Gets info about a raw input device.
  static RID_DEVICE_INFO
    AT_GetRawInputDeviceInfo(HANDLE hDevice)
  {
    RID_DEVICE_INFO info;
    info.cbSize = sizeof(RID_DEVICE_INFO);
    UINT size = sizeof(RID_DEVICE_INFO);
    if (GetRawInputDeviceInfoW(hDevice, RIDI_DEVICEINFO, &info, &size) == (UINT)-1) {
      throw win32_error();
    }
    return info;
  }

  // Reads the preparsed HID report descriptor for the device
  // that generated the given raw input.
  static malloc_ptr<_HIDP_PREPARSED_DATA>
    AT_GetHidPreparsedData(HANDLE hDevice)
  {
    UINT size = 0;
    if (GetRawInputDeviceInfoW(hDevice, RIDI_PREPARSEDDATA, nullptr, &size) == (UINT)-1) {
      throw win32_error();
    }
    malloc_ptr<_HIDP_PREPARSED_DATA> preparsedData = make_malloc<_HIDP_PREPARSED_DATA>(size);
    if (GetRawInputDeviceInfoW(hDevice, RIDI_PREPARSEDDATA, preparsedData.get(), &size) == (UINT)-1) {
      throw win32_error();
    }
    return preparsedData;
  }

  // Returns all input button caps for the given preparsed
  // HID report descriptor.
  static std::vector<HIDP_BUTTON_CAPS>
    AT_GetHidInputButtonCaps(PHIDP_PREPARSED_DATA preparsedData)
  {
    NTSTATUS status;
    HIDP_CAPS caps;
    status = HidP_GetCaps(preparsedData, &caps);
    if (status != HIDP_STATUS_SUCCESS) {
      throw hid_error(status);
    }
    USHORT numCaps = caps.NumberInputButtonCaps;
    std::vector<HIDP_BUTTON_CAPS> buttonCaps(numCaps);
    status = HidP_GetButtonCaps(HidP_Input, &buttonCaps[0], &numCaps, preparsedData);
    if (status != HIDP_STATUS_SUCCESS) {
      throw hid_error(status);
    }
    buttonCaps.resize(numCaps);
    return buttonCaps;
  }

  // Returns all input value caps for the given preparsed
  // HID report descriptor.
  static std::vector<HIDP_VALUE_CAPS>
    AT_GetHidInputValueCaps(PHIDP_PREPARSED_DATA preparsedData)
  {
    NTSTATUS status;
    HIDP_CAPS caps;
    status = HidP_GetCaps(preparsedData, &caps);
    if (status != HIDP_STATUS_SUCCESS) {
      throw hid_error(status);
    }
    USHORT numCaps = caps.NumberInputValueCaps;
    std::vector<HIDP_VALUE_CAPS> valueCaps(numCaps);
    status = HidP_GetValueCaps(HidP_Input, &valueCaps[0], &numCaps, preparsedData);
    if (status != HIDP_STATUS_SUCCESS) {
      throw hid_error(status);
    }
    valueCaps.resize(numCaps);
    return valueCaps;
  }

  // Reads the pressed status of a single HID report button.
  static bool
    AT_GetHidUsageButton(
      HIDP_REPORT_TYPE reportType,
      USAGE usagePage,
      USHORT linkCollection,
      USAGE usage,
      PHIDP_PREPARSED_DATA preparsedData,
      PBYTE report,
      ULONG reportLen)
  {
    ULONG numUsages = HidP_MaxUsageListLength(
      reportType,
      usagePage,
      preparsedData);
    std::vector<USAGE> usages(numUsages);
    NTSTATUS status = HidP_GetUsages(
      reportType,
      usagePage,
      linkCollection,
      &usages[0],
      &numUsages,
      preparsedData,
      (PCHAR)report,
      reportLen);
    if (status != HIDP_STATUS_SUCCESS) {
      throw hid_error(status);
    }
    usages.resize(numUsages);
    return std::find(usages.begin(), usages.end(), usage) != usages.end();
  }

  // Reads a single HID report value in logical units.
  static ULONG
    AT_GetHidUsageLogicalValue(
      HIDP_REPORT_TYPE reportType,
      USAGE usagePage,
      USHORT linkCollection,
      USAGE usage,
      PHIDP_PREPARSED_DATA preparsedData,
      PBYTE report,
      ULONG reportLen)
  {
    ULONG value;
    NTSTATUS status = HidP_GetUsageValue(
      reportType,
      usagePage,
      linkCollection,
      usage,
      &value,
      preparsedData,
      (PCHAR)report,
      reportLen);
    if (status != HIDP_STATUS_SUCCESS) {
      throw hid_error(status);
    }
    return value;
  }

  // Reads a single HID report value in physical units.
  static LONG
    AT_GetHidUsagePhysicalValue(
      HIDP_REPORT_TYPE reportType,
      USAGE usagePage,
      USHORT linkCollection,
      USAGE usage,
      PHIDP_PREPARSED_DATA preparsedData,
      PBYTE report,
      ULONG reportLen)
  {
    LONG value;
    NTSTATUS status = HidP_GetScaledUsageValue(
      reportType,
      usagePage,
      linkCollection,
      usage,
      &value,
      preparsedData,
      (PCHAR)report,
      reportLen);
    if (status != HIDP_STATUS_SUCCESS) {
      throw hid_error(status);
    }
    return value;
  }

  // Gets the device info associated with the given raw input. Uses the
  // cached info if available; otherwise parses the HID report descriptor
  // and stores it into the cache.
  static at_device_info &
    AT_GetDeviceInfo(HANDLE hDevice)
  {
    if (g_devices.count(hDevice)) {
      return g_devices.at(hDevice);
    }

    at_device_info dev;
    uint_v         linkContactCount;
    dev.preparsedData = AT_GetHidPreparsedData(hDevice);

    // Struct to hold our parser state
    struct at_contact_info_tmp
    {
      bool hasContactID = false;
      bool hasTip = false;
      bool hasX = false;
      bool hasY = false;
      RECT touchArea;
    };
    std::unordered_map<USHORT, at_contact_info_tmp> contacts;

    // Get the touch area for all the contacts. Also make sure that each one
    // is actually a contact, as specified by:
    // https://docs.microsoft.com/en-us/windows-hardware/design/component-guidelines/windows-precision-touchpad-required-hid-top-level-collections
    for (const HIDP_VALUE_CAPS &cap : AT_GetHidInputValueCaps(dev.preparsedData.get())) {
      if (cap.IsRange || !cap.IsAbsolute) {
        continue;
      }

      if (cap.UsagePage == HID_USAGE_PAGE_GENERIC) {
        if (cap.NotRange.Usage == HID_USAGE_GENERIC_X) {
          contacts[cap.LinkCollection].touchArea.left = cap.PhysicalMin;
          contacts[cap.LinkCollection].touchArea.right = cap.PhysicalMax;
          contacts[cap.LinkCollection].hasX = true;
        }
        else if (cap.NotRange.Usage == HID_USAGE_GENERIC_Y) {
          contacts[cap.LinkCollection].touchArea.top = cap.PhysicalMin;
          contacts[cap.LinkCollection].touchArea.bottom = cap.PhysicalMax;
          contacts[cap.LinkCollection].hasY = true;
        }
      }
      else if (cap.UsagePage == HID_USAGE_PAGE_DIGITIZER) {
        if (cap.NotRange.Usage == HID_USAGE_DIGITIZER_CONTACT_COUNT) {
          linkContactCount = cap.LinkCollection;
        }
        else if (cap.NotRange.Usage == HID_USAGE_DIGITIZER_CONTACT_ID) {
          contacts[cap.LinkCollection].hasContactID = true;
        }
      }
    }

    for (const HIDP_BUTTON_CAPS &cap : AT_GetHidInputButtonCaps(dev.preparsedData.get())) {
      if (cap.UsagePage == HID_USAGE_PAGE_DIGITIZER) {
        if (cap.NotRange.Usage == HID_USAGE_DIGITIZER_TIP_SWITCH) {
          contacts[cap.LinkCollection].hasTip = true;
        }
      }
    }

    if (linkContactCount.is_undefined()) {
      throw tool::error("No contact count usage found");
    }
    dev.linkContactCount = (USHORT)linkContactCount.val();

    for (const auto &kvp : contacts) {
      USHORT link = kvp.first;
      const at_contact_info_tmp &info = kvp.second;
      if (info.hasContactID && info.hasTip && info.hasX && info.hasY) {
        /*dbg_printf("Contact for device %p: link=%d, touchArea={%d,%d,%d,%d}\n",
          hDevice,
          link,
          info.touchArea.left,
          info.touchArea.top,
          info.touchArea.right,
          info.touchArea.bottom);*/
        dev.contactInfo.push_back({ link, info.touchArea });
      }
    }

    return g_devices[hDevice] = std::move(dev);
  }

  // Reads all touch contact points from a raw input event.
  static std::vector<at_contact>
    AT_GetContacts(at_device_info &dev, RAWINPUT *input)
  {
    std::vector<at_contact> contacts;

    DWORD sizeHid = input->data.hid.dwSizeHid;
    DWORD count = input->data.hid.dwCount;
    BYTE *rawData = input->data.hid.bRawData;
    if (count == 0) {
      //dbg_printf("Raw input contained no HID events\n");
      return contacts;
    }

    ULONG numContacts = AT_GetHidUsageLogicalValue(
      HidP_Input,
      HID_USAGE_PAGE_DIGITIZER,
      dev.linkContactCount,
      HID_USAGE_DIGITIZER_CONTACT_COUNT,
      dev.preparsedData.get(),
      rawData,
      sizeHid);

    if (numContacts > dev.contactInfo.size()) {
      //dbg_printf("Device reported more contacts (%u) than we have links (%zu)\n", numContacts, dev.contactInfo.size());
      numContacts = (ULONG)dev.contactInfo.size();
    }

    // It's a little ambiguous as to whether contact count includes
    // released contacts. I interpreted the specs as a yes, but this
    // may require additional testing.
    for (ULONG i = 0; i < numContacts; ++i) {
      at_contact_info &info = dev.contactInfo[i];
      bool tip = AT_GetHidUsageButton(
        HidP_Input,
        HID_USAGE_PAGE_DIGITIZER,
        info.link,
        HID_USAGE_DIGITIZER_TIP_SWITCH,
        dev.preparsedData.get(),
        rawData,
        sizeHid);

      if (!tip) {
        //dbg_printf("Contact has tip = 0, ignoring\n");
        continue;
      }

      ULONG id = AT_GetHidUsageLogicalValue(
        HidP_Input,
        HID_USAGE_PAGE_DIGITIZER,
        info.link,
        HID_USAGE_DIGITIZER_CONTACT_ID,
        dev.preparsedData.get(),
        rawData,
        sizeHid);

      LONG x = AT_GetHidUsagePhysicalValue(
        HidP_Input,
        HID_USAGE_PAGE_GENERIC,
        info.link,
        HID_USAGE_GENERIC_X,
        dev.preparsedData.get(),
        rawData,
        sizeHid);

      LONG y = AT_GetHidUsagePhysicalValue(
        HidP_Input,
        HID_USAGE_PAGE_GENERIC,
        info.link,
        HID_USAGE_GENERIC_Y,
        dev.preparsedData.get(),
        rawData,
        sizeHid);

      contacts.push_back({ info, id, { x, y } });
    }

    return contacts;
  }

  // Returns the primary contact for a given list of contacts. This is
  // necessary since we are mapping potentially many touches to a single
  // mouse position. Currently this just stores a global contact ID and
  // uses that as the primary contact.
  static at_contact
    AT_GetPrimaryContact(const std::vector<at_contact> &contacts)
  {
    for (const at_contact &contact : contacts) {
      if (contact.id == t_primaryContactID) {
        return contact;
      }
    }
    t_primaryContactID = contacts[0].id;
    return contacts[0];
  }

  // Handles a WM_INPUT event. Returns true if the event is handled 
  // entirely and should not be delivered to the real WndProc.
  // Returns false if the real WndProc should be called.

  uint get_alts(LPARAM lParam = 0);

  bool HandleTouchInput(WPARAM wParam, LPARAM lParam, window* pw)
  {
    HRAWINPUT hInput = (HRAWINPUT)lParam;
    try {
      RAWINPUTHEADER hdr = AT_GetRawInputHeader(hInput);
      if (hdr.dwType != RIM_TYPEHID) {
        //dbg_printf("Got raw input for device %p with event type != HID: %u\n", hdr.hDevice, hdr.dwType);

        // Suppress mouse input events to prevent it from getting
        // mixed in with our absolute movement events. Unfortunately
        // this has the side effect of disabling all non-touchpad
        // input. One solution might be to determine the device that
        // sent the event and check if it's also a touchpad, and only
        // filter out events from such devices.
        if (hdr.dwType == RIM_TYPEMOUSE) {
          return true;
        }
        return false;
      }

      g_lastDevice = hdr.hDevice;
      //dbg_printf("Got HID raw input event for device %p\n", hdr.hDevice);

      at_device_info &dev = AT_GetDeviceInfo(hdr.hDevice);
      malloc_ptr<RAWINPUT> input = AT_GetRawInput(hInput, hdr);
      std::vector<at_contact> contacts = AT_GetContacts(dev, input.get());

      if (contacts.empty()) {
        //dbg_printf("Found no contacts in input event\n");
        pw->handle_touch(slice<window::touch_point>(), get_alts());
        return false;
      }

      array<window::touch_point> touch_points;
      touch_points.reserve(contacts.size());

      for (auto& c : contacts) {
        window::touch_point tp;
        tp.id = c.id;
        tp.pos.x = c.point.x;
        tp.pos.y = c.point.y;
        touch_points.push(tp);
        //dbg_printf("CONTACTS: %d\n", contacts.size());
      }
      //dbg_printf("\n");
      pw->handle_touch(touch_points(), get_alts());
    }
    catch (hid_error&) {
      //debug_printf() ???
      return false;
    }

    return false;
  }

}

#endif