socket/socket.js

import {
  global,
  comboWindow,
  CHANNEL_EVENTS,
  DEFAULT_TIMEOUT,
  DEFAULT_VSN,
  SOCKET_STATES,
  TRANSPORTS,
  WS_CLOSE_NORMAL,
  AUTH_TOKEN_PREFIX,
} from "./constants"

import { closure } from "./utils"

import Ajax from "./ajax"
import Channel from "./channel"
import LongPoll from "./longpoll"
import Serializer from "./serializer"
import Timer from "./timer"

/** Initializes the Socket *
 *
 * For IE8 support use an ES5-shim (https://github.com/es-shims/es5-shim)
 *
 * @param {string} endPoint - The string WebSocket endpoint, ie, `"ws://example.com/socket"`,
 *                                               `"wss://example.com"`
 *                                               `"/socket"` (inherited host & protocol)
 * @param {Object} [opts] - Optional configuration
 * @param {Function} [opts.transport] - The transport, for example WebSocket or LongPoll.
 *
 * Defaults to WebSocket with automatic LongPoll fallback if WebSocket is not defined.
 * To fallback to LongPoll when WebSocket attempts fail, use `longPollFallbackMs: 2500`.
 *
 * @param {number} [opts.longPollFallbackMs] - The millisecond time to attempt the primary transport
 * before falling back to the LongPoll transport. Disabled by default.
 *
 * @param {boolean} [opts.debug] - When true, enables debug logging. Default false.
 *
 * @param {Function} [opts.encode] - The function to encode outgoing messages.
 *
 * Defaults to JSON encoder.
 *
 * @param {Function} [opts.decode] - The function to decode incoming messages.
 *
 * Defaults to JSON:
 *
 * ```javascript
 * (payload, callback) => callback(JSON.parse(payload))
 * ```
 *
 * @param {number} [opts.timeout] - The default timeout in milliseconds to trigger push timeouts.
 *
 * Defaults `DEFAULT_TIMEOUT`
 * @param {number} [opts.heartbeatIntervalMs] - The millisec interval to send a heartbeat message
 * @param {Function} [opts.reconnectAfterMs] - The optional function that returns the
 * socket reconnect interval, in milliseconds.
 *
 * Defaults to stepped backoff of:
 *
 * ```javascript
 * function(tries){
 *   return [10, 50, 100, 150, 200, 250, 500, 1000, 2000][tries - 1] || 5000
 * }
 * ````
 *
 * @param {Function} [opts.rejoinAfterMs] - The optional function that returns the millisec
 * rejoin interval for individual channels.
 *
 * ```javascript
 * function(tries){
 *   return [1000, 2000, 5000][tries - 1] || 10000
 * }
 * ````
 *
 * @param {Function} [opts.logger] - The optional function for specialized logging, ie:
 *
 * ```javascript
 * function(kind, msg, data) {
 *   console.log(`${kind}: ${msg}`, data)
 * }
 * ```
 *
 * @param {number} [opts.longpollerTimeout] - The maximum timeout of a long poll AJAX request.
 *
 * Defaults to 20s (double the server long poll timer).
 *
 * @param {(Object|function)} [opts.params] - The optional params to pass when connecting
 * @param {string} [opts.authToken] - the optional authentication token to be exposed on the server
 * under the `:auth_token` connect_info key.
 * @param {string} [opts.binaryType] - The binary type to use for binary WebSocket frames.
 *
 * Defaults to "arraybuffer"
 *
 * @param {vsn} [opts.vsn] - The serializer's protocol version to send on connect.
 *
 * Defaults to DEFAULT_VSN.
 *
 * @param {Object} [opts.sessionStorage] - An optional Storage compatible object
 * Combo uses sessionStorage for longpoll fallback history. Overriding the store is
 * useful when Combo won't have access to `sessionStorage`. For example, This could
 * happen if a site loads a cross-domain channel in an iframe. Example usage:
 *
 *     class InMemoryStorage {
 *       constructor() { this.storage = {} }
 *       getItem(keyName) { return this.storage[keyName] || null }
 *       removeItem(keyName) { delete this.storage[keyName] }
 *       setItem(keyName, keyValue) { this.storage[keyName] = keyValue }
 *     }
 *
 */
export default class Socket {
  constructor(endPoint, opts = {}) {
    this.stateChangeCallbacks = { open: [], close: [], error: [], message: [] }
    this.channels = []
    this.sendBuffer = []
    this.ref = 0
    this.fallbackRef = null
    this.timeout = opts.timeout || DEFAULT_TIMEOUT
    this.transport = opts.transport || global.WebSocket || LongPoll
    this.primaryPassedHealthCheck = false
    this.longPollFallbackMs = opts.longPollFallbackMs
    this.fallbackTimer = null
    this.sessionStore = opts.sessionStorage || (global && global.sessionStorage)
    this.establishedConnections = 0
    this.defaultEncoder = Serializer.encode.bind(Serializer)
    this.defaultDecoder = Serializer.decode.bind(Serializer)
    // Set to true to avoid the visibility change logic from connecting if the
    // socket was never connected in the first place.
    this.closeWasClean = true
    this.disconnecting = false
    this.binaryType = opts.binaryType || "arraybuffer"
    this.connectClock = 1
    this.pageHidden = false
    if (this.transport !== LongPoll) {
      this.encode = opts.encode || this.defaultEncoder
      this.decode = opts.decode || this.defaultDecoder
    } else {
      this.encode = this.defaultEncoder
      this.decode = this.defaultDecoder
    }
    let awaitingConnectionOnPageShow = null
    if (comboWindow && comboWindow.addEventListener) {
      comboWindow.addEventListener("pagehide", (_e) => {
        if (this.conn) {
          this.disconnect()
          awaitingConnectionOnPageShow = this.connectClock
        }
      })
      comboWindow.addEventListener("pageshow", (_e) => {
        if (awaitingConnectionOnPageShow === this.connectClock) {
          awaitingConnectionOnPageShow = null
          this.connect()
        }
      })
    }
    comboWindow.addEventListener("visibilitychange", () => {
      if (document.visibilityState === "hidden") {
        this.pageHidden = true
      } else {
        this.pageHidden = false
        // reconnect immediately if the connection was not closed gracefully
        if (!this.isConnected() && !this.closeWasClean) {
          this.teardown(() => this.connect())
        }
      }
    })
    this.heartbeatIntervalMs = opts.heartbeatIntervalMs || 30000
    this.rejoinAfterMs = (tries) => {
      if (opts.rejoinAfterMs) {
        return opts.rejoinAfterMs(tries)
      } else {
        return [1000, 2000, 5000][tries - 1] || 10000
      }
    }
    this.reconnectAfterMs = (tries) => {
      if (opts.reconnectAfterMs) {
        return opts.reconnectAfterMs(tries)
      } else {
        return [10, 50, 100, 150, 200, 250, 500, 1000, 2000][tries - 1] || 5000
      }
    }
    this.logger = opts.logger || null
    if (!this.logger && opts.debug) {
      this.logger = (kind, msg, data) => {
        console.log(`${kind}: ${msg}`, data)
      }
    }
    this.longpollerTimeout = opts.longpollerTimeout || 20000
    this.params = closure(opts.params || {})
    this.endPoint = `${endPoint}/${TRANSPORTS.websocket}`
    this.vsn = opts.vsn || DEFAULT_VSN
    this.heartbeatTimeoutTimer = null
    this.heartbeatTimer = null
    this.pendingHeartbeatRef = null
    this.reconnectTimer = new Timer(() => {
      if (this.pageHidden) {
        this.log("Not reconnecting as page is hidden!")
        this.teardown()
        return
      }
      this.teardown(() => this.connect())
    }, this.reconnectAfterMs)
    this.authToken = opts.authToken
  }

  /**
   * Returns the LongPoll transport reference
   */
  getLongPollTransport() {
    return LongPoll
  }

  /**
   * Disconnects and replaces the active transport
   *
   * @param {Function} newTransport - The new transport class to instantiate
   *
   */
  replaceTransport(newTransport) {
    this.connectClock++
    this.closeWasClean = true
    clearTimeout(this.fallbackTimer)
    this.reconnectTimer.reset()
    if (this.conn) {
      this.conn.close()
      this.conn = null
    }
    this.transport = newTransport
  }

  /**
   * Returns the socket protocol
   *
   * @returns {string}
   */
  protocol() {
    return location.protocol.match(/^https/) ? "wss" : "ws"
  }

  /**
   * The fully qualified socket url
   *
   * @returns {string}
   */
  endPointURL() {
    let uri = Ajax.appendParams(Ajax.appendParams(this.endPoint, this.params()), {
      vsn: this.vsn,
    })
    if (uri.charAt(0) !== "/") {
      return uri
    }
    if (uri.charAt(1) === "/") {
      return `${this.protocol()}:${uri}`
    }

    return `${this.protocol()}://${location.host}${uri}`
  }

  /**
   * Disconnects the socket
   *
   * See https://developer.mozilla.org/en-US/docs/Web/API/CloseEvent#Status_codes for valid status codes.
   *
   * @param {Function} callback - Optional callback which is called after socket is disconnected.
   * @param {integer} code - A status code for disconnection (Optional).
   * @param {string} reason - A textual description of the reason to disconnect. (Optional)
   */
  disconnect(callback, code, reason) {
    this.connectClock++
    this.disconnecting = true
    this.closeWasClean = true
    clearTimeout(this.fallbackTimer)
    this.reconnectTimer.reset()
    this.teardown(
      () => {
        this.disconnecting = false
        callback && callback()
      },
      code,
      reason,
    )
  }

  /**
   *
   * @param {Object} params - The params to send when connecting, for example `{user_id: userToken}`
   *
   * Passing params to connect is deprecated; pass them in the Socket constructor instead:
   * `new Socket("/socket", {params: {user_id: userToken}})`.
   */
  connect(params) {
    if (params) {
      console &&
        console.log(
          "passing params to connect is deprecated. Instead pass :params to the Socket constructor",
        )
      this.params = closure(params)
    }
    if (this.conn && !this.disconnecting) {
      return
    }
    if (this.longPollFallbackMs && this.transport !== LongPoll) {
      this.connectWithFallback(LongPoll, this.longPollFallbackMs)
    } else {
      this.transportConnect()
    }
  }

  /**
   * Logs the message. Override `this.logger` for specialized logging. noops by default
   * @param {string} kind
   * @param {string} msg
   * @param {Object} data
   */
  log(kind, msg, data) {
    this.logger && this.logger(kind, msg, data)
  }

  /**
   * Returns true if a logger has been set on this socket.
   */
  hasLogger() {
    return this.logger !== null
  }

  /**
   * Registers callbacks for connection open events
   *
   * @example socket.onOpen(function(){ console.info("the socket was opened") })
   *
   * @param {Function} callback
   */
  onOpen(callback) {
    let ref = this.makeRef()
    this.stateChangeCallbacks.open.push([ref, callback])
    return ref
  }

  /**
   * Registers callbacks for connection close events
   * @param {Function} callback
   */
  onClose(callback) {
    let ref = this.makeRef()
    this.stateChangeCallbacks.close.push([ref, callback])
    return ref
  }

  /**
   * Registers callbacks for connection error events
   *
   * @example socket.onError(function(error){ alert("An error occurred") })
   *
   * @param {Function} callback
   */
  onError(callback) {
    let ref = this.makeRef()
    this.stateChangeCallbacks.error.push([ref, callback])
    return ref
  }

  /**
   * Registers callbacks for connection message events
   * @param {Function} callback
   */
  onMessage(callback) {
    let ref = this.makeRef()
    this.stateChangeCallbacks.message.push([ref, callback])
    return ref
  }

  /**
   * Pings the server and invokes the callback with the RTT in milliseconds
   * @param {Function} callback
   *
   * Returns true if the ping was pushed or false if unable to be pushed.
   */
  ping(callback) {
    if (!this.isConnected()) {
      return false
    }
    let ref = this.makeRef()
    let startTime = Date.now()
    this.push({ topic: "combo", event: "heartbeat", payload: {}, ref: ref })
    let onMsgRef = this.onMessage((msg) => {
      if (msg.ref === ref) {
        this.off([onMsgRef])
        callback(Date.now() - startTime)
      }
    })
    return true
  }

  /**
   * @private
   *
   * @param {Function}
   */
  transportName(transport) {
    // JavavScript minification renames symbols to reduce code size, including
    // the transport class names.
    // But, we rely on constant transport class names. To make them constant
    // even after JavaScript minification, we create this function.
    //
    //   * LongPoll - handled.
    //   * WebSocket - native to browsers and does not need special handling.
    //
    switch (transport) {
      case LongPoll:
        return "LongPoll"
      default:
        return transport.name
    }
  }

  /**
   * @private
   */
  transportConnect() {
    this.connectClock++
    this.closeWasClean = false
    let protocols = undefined
    // Sec-WebSocket-Protocol based token
    // (longpoll uses Authorization header instead)
    if (this.authToken) {
      protocols = ["combo", `${AUTH_TOKEN_PREFIX}${btoa(this.authToken).replace(/=/g, "")}`]
    }
    this.conn = new this.transport(this.endPointURL(), protocols)
    this.conn.binaryType = this.binaryType
    this.conn.timeout = this.longpollerTimeout
    this.conn.onopen = () => this.onConnOpen()
    this.conn.onerror = (error) => this.onConnError(error)
    this.conn.onmessage = (event) => this.onConnMessage(event)
    this.conn.onclose = (event) => this.onConnClose(event)
  }

  getSession(key) {
    return this.sessionStore && this.sessionStore.getItem(key)
  }

  storeSession(key, val) {
    this.sessionStore && this.sessionStore.setItem(key, val)
  }

  connectWithFallback(fallbackTransport, fallbackThreshold = 2500) {
    clearTimeout(this.fallbackTimer)
    let established = false
    let primaryTransport = true
    let openRef, errorRef
    let fallbackTransportName = this.transportName(fallbackTransport)
    let fallback = (reason) => {
      this.log("transport", `falling back to ${fallbackTransportName}...`, reason)
      this.off([openRef, errorRef])
      primaryTransport = false
      this.replaceTransport(fallbackTransport)
      this.transportConnect()
    }

    if (this.getSession(`combo:fallback:${fallbackTransportName}`)) {
      return fallback("memorized")
    }

    this.fallbackTimer = setTimeout(fallback, fallbackThreshold)

    errorRef = this.onError((reason) => {
      this.log("transport", "error", reason)
      if (primaryTransport && !established) {
        clearTimeout(this.fallbackTimer)
        fallback(reason)
      }
    })
    if (this.fallbackRef) {
      this.off([this.fallbackRef])
    }
    this.fallbackRef = this.onOpen(() => {
      established = true
      if (!primaryTransport) {
        let fallbackTransportName = this.transportName(fallbackTransport)
        // only memorize LP if we never connected to primary
        if (!this.primaryPassedHealthCheck) {
          this.storeSession(`combo:fallback:${fallbackTransportName}`, "true")
        }
        return this.log("transport", `established ${fallbackTransportName} fallback`)
      }
      // if we've established primary, give the fallback a new period to attempt ping
      clearTimeout(this.fallbackTimer)
      this.fallbackTimer = setTimeout(fallback, fallbackThreshold)
      this.ping((rtt) => {
        this.log("transport", "connected to primary after", rtt)
        this.primaryPassedHealthCheck = true
        clearTimeout(this.fallbackTimer)
      })
    })
    this.transportConnect()
  }

  clearHeartbeats() {
    clearTimeout(this.heartbeatTimer)
    clearTimeout(this.heartbeatTimeoutTimer)
  }

  onConnOpen() {
    if (this.hasLogger()) {
      this.log(
        "transport",
        `${this.transportName(this.transport)} connected to ${this.endPointURL()}`,
      )
    }

    this.closeWasClean = false
    this.disconnecting = false
    this.establishedConnections++
    this.flushSendBuffer()
    this.reconnectTimer.reset()
    this.resetHeartbeat()
    this.stateChangeCallbacks.open.forEach(([, callback]) => callback())
  }

  /**
   * @private
   */

  heartbeatTimeout() {
    if (this.pendingHeartbeatRef) {
      this.pendingHeartbeatRef = null
      if (this.hasLogger()) {
        this.log("transport", "heartbeat timeout. Attempting to re-establish connection")
      }
      this.triggerChanError()
      this.closeWasClean = false
      this.teardown(
        () => this.reconnectTimer.scheduleTimeout(),
        WS_CLOSE_NORMAL,
        "heartbeat timeout",
      )
    }
  }

  resetHeartbeat() {
    if (this.conn && this.conn.skipHeartbeat) {
      return
    }
    this.pendingHeartbeatRef = null
    this.clearHeartbeats()
    this.heartbeatTimer = setTimeout(() => this.sendHeartbeat(), this.heartbeatIntervalMs)
  }

  teardown(callback, code, reason) {
    if (!this.conn) {
      return callback && callback()
    }

    // If someone calls connect before we finish tearing down,
    // we create a new connection, but we still want to finish tearing down the old one.
    const connToClose = this.conn

    this.waitForBufferDone(connToClose, () => {
      if (code) {
        connToClose.close(code, reason || "")
      } else {
        connToClose.close()
      }

      this.waitForSocketClosed(connToClose, () => {
        if (this.conn === connToClose) {
          this.conn.onopen = function () {} // noop
          this.conn.onerror = function () {} // noop
          this.conn.onmessage = function () {} // noop
          this.conn.onclose = function () {} // noop
          this.conn = null
        }

        callback && callback()
      })
    })
  }

  waitForBufferDone(conn, callback, tries = 1) {
    if (tries === 5 || !conn.bufferedAmount) {
      callback()
      return
    }

    setTimeout(() => {
      this.waitForBufferDone(conn, callback, tries + 1)
    }, 150 * tries)
  }

  waitForSocketClosed(conn, callback, tries = 1) {
    if (tries === 5 || conn.readyState === SOCKET_STATES.closed) {
      callback()
      return
    }

    setTimeout(() => {
      this.waitForSocketClosed(conn, callback, tries + 1)
    }, 150 * tries)
  }

  onConnClose(event) {
    if (this.conn) this.conn.onclose = () => {} // noop to prevent recursive calls in teardown
    let closeCode = event && event.code
    if (this.hasLogger()) this.log("transport", "close", event)
    this.triggerChanError()
    this.clearHeartbeats()
    if (!this.closeWasClean && closeCode !== 1000) {
      this.reconnectTimer.scheduleTimeout()
    }
    this.stateChangeCallbacks.close.forEach(([, callback]) => callback(event))
  }

  /**
   * @private
   */
  onConnError(error) {
    if (this.hasLogger()) this.log("transport", "error", error)
    let transportBefore = this.transport
    let establishedBefore = this.establishedConnections
    this.stateChangeCallbacks.error.forEach(([, callback]) => {
      callback(error, transportBefore, establishedBefore)
    })
    if (transportBefore === this.transport || establishedBefore > 0) {
      this.triggerChanError()
    }
  }

  /**
   * @private
   */
  triggerChanError() {
    this.channels.forEach((channel) => {
      if (!(channel.isErrored() || channel.isLeaving() || channel.isClosed())) {
        channel.trigger(CHANNEL_EVENTS.error)
      }
    })
  }

  /**
   * @returns {string}
   */
  connectionState() {
    switch (this.conn && this.conn.readyState) {
      case SOCKET_STATES.connecting:
        return "connecting"
      case SOCKET_STATES.open:
        return "open"
      case SOCKET_STATES.closing:
        return "closing"
      default:
        return "closed"
    }
  }

  /**
   * @returns {boolean}
   */
  isConnected() {
    return this.connectionState() === "open"
  }

  /**
   * @private
   *
   * @param {Channel}
   */
  remove(channel) {
    this.off(channel.stateChangeRefs)
    this.channels = this.channels.filter((c) => c !== channel)
  }

  /**
   * Removes `onOpen`, `onClose`, `onError,` and `onMessage` registrations.
   *
   * @param {refs} - list of refs returned by calls to
   *                 `onOpen`, `onClose`, `onError,` and `onMessage`
   */
  off(refs) {
    for (let key in this.stateChangeCallbacks) {
      this.stateChangeCallbacks[key] = this.stateChangeCallbacks[key].filter(([ref]) => {
        return refs.indexOf(ref) === -1
      })
    }
  }

  /**
   * Initiates a new channel for the given topic
   *
   * @param {string} topic
   * @param {Object} chanParams - Parameters for the channel
   * @returns {Channel}
   */
  channel(topic, chanParams = {}) {
    let chan = new Channel(topic, chanParams, this)
    this.channels.push(chan)
    return chan
  }

  /**
   * @param {Object} data
   */
  push(data) {
    if (this.hasLogger()) {
      let { topic, event, payload, ref, join_ref } = data
      this.log("push", `${topic} ${event} (${join_ref}, ${ref})`, payload)
    }

    if (this.isConnected()) {
      this.encode(data, (result) => this.conn.send(result))
    } else {
      this.sendBuffer.push(() => this.encode(data, (result) => this.conn.send(result)))
    }
  }

  /**
   * Return the next message ref, accounting for overflows
   * @returns {string}
   */
  makeRef() {
    let newRef = this.ref + 1
    if (newRef === this.ref) {
      this.ref = 0
    } else {
      this.ref = newRef
    }

    return this.ref.toString()
  }

  sendHeartbeat() {
    if (this.pendingHeartbeatRef && !this.isConnected()) {
      return
    }
    this.pendingHeartbeatRef = this.makeRef()
    this.push({
      topic: "combo",
      event: "heartbeat",
      payload: {},
      ref: this.pendingHeartbeatRef,
    })
    this.heartbeatTimeoutTimer = setTimeout(
      () => this.heartbeatTimeout(),
      this.heartbeatIntervalMs,
    )
  }

  flushSendBuffer() {
    if (this.isConnected() && this.sendBuffer.length > 0) {
      this.sendBuffer.forEach((callback) => callback())
      this.sendBuffer = []
    }
  }

  onConnMessage(rawMessage) {
    this.decode(rawMessage.data, (msg) => {
      let { topic, event, payload, ref, join_ref } = msg
      if (ref && ref === this.pendingHeartbeatRef) {
        this.clearHeartbeats()
        this.pendingHeartbeatRef = null
        this.heartbeatTimer = setTimeout(() => this.sendHeartbeat(), this.heartbeatIntervalMs)
      }

      if (this.hasLogger())
        this.log(
          "receive",
          `${payload.status || ""} ${topic} ${event} ${(ref && "(" + ref + ")") || ""}`,
          payload,
        )

      for (let i = 0; i < this.channels.length; i++) {
        const channel = this.channels[i]
        if (!channel.isMember(topic, event, payload, join_ref)) {
          continue
        }
        channel.trigger(event, payload, ref, join_ref)
      }

      for (let i = 0; i < this.stateChangeCallbacks.message.length; i++) {
        let [, callback] = this.stateChangeCallbacks.message[i]
        callback(msg)
      }
    })
  }

  leaveOpenTopic(topic) {
    let dupChannel = this.channels.find(
      (c) => c.topic === topic && (c.isJoined() || c.isJoining()),
    )
    if (dupChannel) {
      if (this.hasLogger()) this.log("transport", `leaving duplicate topic "${topic}"`)
      dupChannel.leave()
    }
  }
}