
#include "tool.h"
#include "websockets.h"

#ifdef WEBSOCKETS_SUPPORT

namespace tool { namespace async {


    // http://tools.ietf.org/html/rfc6455#section-5.2  Base Framing Protocol
    //
    //  0                   1                   2                   3
    //  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
    // +-+-+-+-+-------+-+-------------+-------------------------------+
    // |F|R|R|R| opcode|M| Payload len |    Extended payload length    |
    // |I|S|S|S|  (4)  |A|     (7)     |             (16/64)           |
    // |N|V|V|V|       |S|             |   (if payload len==126/127)   |
    // | |1|2|3|       |K|             |                               |
    // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
    // |     Extended payload length continued, if payload len == 127  |
    // + - - - - - - - - - - - - - - - +-------------------------------+
    // |                               |Masking-key, if MASK set to 1  |
    // +-------------------------------+-------------------------------+
    // | Masking-key (continued)       |          Payload Data         |
    // +-------------------------------- - - - - - - - - - - - - - - - +
    // :                     Payload Data continued ...                :
    // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
    // |                     Payload Data continued ...                |
    // +---------------------------------------------------------------+

    void websocket_connection::handle_read(bytes data_in) {

      if (!is_receiving)
       handle_initial_read(data_in);
      else
       rxbuf.push(data_in);

      while (true) {
        wsheader_t ws = { 0 };
        if (rxbuf.size() < 2) {
          return; // need at least 2 
        }
        const byte *data = (byte *)&rxbuf[0]; // peek, but don't consume
        ws.fin = (data[0] & 0x80) == 0x80;
        ws.opcode = (opcode_t)(data[0] & 0x0f);

        ws.mask = (data[1] & 0x80) == 0x80;
        ws.N0 = (data[1] & 0x7f);

        ws.header_size = 2;
        if (ws.N0 == 127)
          ws.header_size += 8;
        else if (ws.N0 == 126)
          ws.header_size += 2;

        if (ws.mask)
          ws.header_size += 4;

        if (rxbuf.length() < ws.header_size) {
          return; /* Need: ws.header_size - rxbuf.size() */
        }
        int i = 0;
        if (ws.N0 < 126) {
          ws.N = ws.N0;
          i = 2;
        }
        else if (ws.N0 == 126) {
          ws.N = 0;
          ws.N |= ((uint64)data[2]) << 8;
          ws.N |= ((uint64)data[3]) << 0;
          i = 4;
        }
        else if (ws.N0 == 127) {
          ws.N = 0;
          ws.N |= ((uint64)data[2]) << 56;
          ws.N |= ((uint64)data[3]) << 48;
          ws.N |= ((uint64)data[4]) << 40;
          ws.N |= ((uint64)data[5]) << 32;
          ws.N |= ((uint64)data[6]) << 24;
          ws.N |= ((uint64)data[7]) << 16;
          ws.N |= ((uint64)data[8]) << 8;
          ws.N |= ((uint64)data[9]) << 0;
          i = 10;
        }
        if (ws.mask) {
          ws.masking_key[0] = ((byte)data[i + 0]);
          ws.masking_key[1] = ((byte)data[i + 1]);
          ws.masking_key[2] = ((byte)data[i + 2]);
          ws.masking_key[3] = ((byte)data[i + 3]);
        }
        else {
          ws.masking_key[0] = 0;
          ws.masking_key[1] = 0;
          ws.masking_key[2] = 0;
          ws.masking_key[3] = 0;
        }

        size_t len = rxbuf.length();

        if (len < ws.header_size + ws.N) {
          return; /* Need: ws.header_size+ws.N - rxbuf.size() */
        }

        auto opcode = ws.opcode;

        if (ws.mask) {
          for (uint i = 0; i != ws.N; ++i) {
            rxbuf[i + ws.header_size] ^= ws.masking_key[i & 0x3];
          }
        }
        bytes frame_data =
          bytes(rxbuf(index_t(ws.header_size), index_t(ws.header_size + ws.N)));
        rxdata.push(frame_data);
        rxbuf.remove(0, index_t(ws.header_size + ws.N));

        if (opcode == CONTINUATION) {
          if (ws.fin)
            opcode = rxopcode; // and fall through
          else
            continue;
        }
        else {
          rxopcode = opcode;
          if (!ws.fin)
            continue;
        }

        // we've got a whole message, now do something with it:

        switch (opcode) {
        case TEXT_FRAME: {
          array<wchar> text;
          if (u8::to_utf16(rxdata(), text, true))
            on_text(text());
          else {
            string msg = string::format("invalid utf-8 sequence", opcode, opcode);
            on_error(msg());
            close();
          }
        } break;
        case BINARY_FRAME:
          on_data(rxdata());
          break;
        case PING:
          send_message(rxdata(), PONG); // "still alive" response
          break;
        case PONG:
          len = len;
          break;
        case CLOSE:
          close();
          // on_close();
          break;
        default: {
          string msg = string::format(
            "unexpected websocket message, opcode %d (0x%x)", opcode, opcode);
          on_error(msg());
          close();
        } break;
        }
        rxdata.clear();
      }
    }

    bool  websocket_connection::send_message(bytes message, opcode_t code)
    {
      if (!is_live())
        return false;

      // consider acquiring a lock on txbuf if threads involved ...

      static uint rseed = (uint)time(0);
      uint        mask = rand(rseed);

      const byte *masking_key = (const byte *)&mask;

      array<byte> header;
      uint64      message_size = message.size();

      uint header_size = 2; // mandatory header bits

      if (message_size >= 65536)
        header_size += 8;
      else if (message_size >= 126)
        header_size += 2;

      if (use_mask)
        header_size += 4;

      header.push(byte(0), header_size);
      header[0] = 0x80 | byte(code);

      if (false) {
      }
      else if (message_size < 126) {
        header[1] = (message_size & 0xff) | (use_mask ? 0x80 : 0);
        if (use_mask) {
          header[2] = masking_key[0];
          header[3] = masking_key[1];
          header[4] = masking_key[2];
          header[5] = masking_key[3];
        }
      }
      else if (message_size < 65536) {
        header[1] = 126 | (use_mask ? 0x80 : 0);
        header[2] = (message_size >> 8) & 0xff;
        header[3] = (message_size >> 0) & 0xff;
        if (use_mask) {
          header[4] = masking_key[0];
          header[5] = masking_key[1];
          header[6] = masking_key[2];
          header[7] = masking_key[3];
        }
      }
      else { // TODO: coverage testing here
        header[1] = 127 | (use_mask ? 0x80 : 0);
        header[2] = (message_size >> 56) & 0xff;
        header[3] = (message_size >> 48) & 0xff;
        header[4] = (message_size >> 40) & 0xff;
        header[5] = (message_size >> 32) & 0xff;
        header[6] = (message_size >> 24) & 0xff;
        header[7] = (message_size >> 16) & 0xff;
        header[8] = (message_size >> 8) & 0xff;
        header[9] = (message_size >> 0) & 0xff;
        if (use_mask) {
          header[10] = masking_key[0];
          header[11] = masking_key[1];
          header[12] = masking_key[2];
          header[13] = masking_key[3];
        }
      }
      // N.B. - txbuf will keep growing until it can be transmitted over the socket:
      txbuf.push(header());
      txbuf.push(message);

      if (use_mask)
        for (int i = 0; i < message.size(); ++i) {
          *(txbuf.end() - message.size() + i) ^= masking_key[i & 0x3];
        }
    
      handle_write();
      return true;
    }
      
    // this will be called on sent complete event so 
    // if there is a data in txbuf it will be send until txbuf is empty
    void websocket_connection::handle_write(void)
    {
      if (txbuf.is_empty())
        is_sending = false;
      else {
        is_sending = true;
        super::send(txbuf());
        txbuf.clear();
      }
    }

    void websocket_connection::handle_initial_read(bytes data)
    {
      //chars text = chars((const char*)data.start, data.length);

      auto chopline = [](bytes &text) -> bytes {
        int    d = text.index_of(CHAR_BYTES("\r\n"));
        bytes head;
        if (d < 0) {
          head = text;
          text.start += text.length;
          text.length = 0;
        }
        else {
          head = bytes(text.start, d);
          text.prune(d + 2);
        }
        return head;
      };

      string first_line = chopline(data);

      if (data.length == 0) {
        on_error(CHARS("data reading error"));
        return;
      }

      int status;

      if (sscanf(first_line, "HTTP/1.1 %d", &status) != 1 || status != 101) {
        on_error(string::format("got bad status connecting to %s: %s",
          url.src.c_str(), first_line.c_str()));
        return;
      }
      // TODO: verify response headers,
      while (true) {
        bytes header = chopline(data);
        if (header.length == 0)
          break;
      }

      // store the rest for further reading
      rxbuf.push(data);
      is_receiving = true;
      on_connected();
    }

    void websocket_connection::handle_connect(void)
    {
      txbuf.push(string::format("GET %s HTTP/1.1\r\n", url.compose_object().c_str()).chars_as_bytes());
      if (!url.port)
        txbuf.push(string::format("Host: %s\r\n", url.hostname.c_str()).chars_as_bytes());
      else
        txbuf.push(string::format("Host: %s:%d\r\n", url.hostname.c_str(), url.port).chars_as_bytes());

      txbuf.push(CHAR_BYTES("Upgrade: websocket\r\n"));
      txbuf.push(CHAR_BYTES("Connection: Upgrade\r\n"));
      txbuf.push(CHAR_BYTES("Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n"));
      txbuf.push(CHAR_BYTES("Sec-WebSocket-Version: 13\r\n"));
      txbuf.push(CHAR_BYTES("\r\n"));

      handle_write();
    }

    void websocket_connection::handle_close(void) {
      super::handle_close();
      on_closed();
    }


} // namespace async
} // namespace tool

#endif
