import {
global,
comboWindow,
CHANNEL_EVENTS,
DEFAULT_TIMEOUT,
DEFAULT_VSN,
SOCKET_STATES,
TRANSPORTS,
WS_CLOSE_NORMAL,
AUTH_TOKEN_PREFIX,
} from './constants'
import { closure } from './utils'
import Ajax from './ajax'
import Channel from './channel'
import LongPoll from './longpoll'
import Serializer from './serializer'
import Timer from './timer'
/** Initializes the Socket *
*
* For IE8 support use an ES5-shim (https://github.com/es-shims/es5-shim)
*
* @param {string} endPoint - The string WebSocket endpoint, ie, `"ws://example.com/socket"`,
* `"wss://example.com"`
* `"/socket"` (inherited host & protocol)
* @param {Object} [opts] - Optional configuration
* @param {Function} [opts.transport] - The transport, for example WebSocket or LongPoll.
*
* Defaults to WebSocket with automatic LongPoll fallback if WebSocket is not defined.
* To fallback to LongPoll when WebSocket attempts fail, use `longPollFallbackMs: 2500`.
*
* @param {number} [opts.longPollFallbackMs] - The millisecond time to attempt the primary transport
* before falling back to the LongPoll transport. Disabled by default.
*
* @param {boolean} [opts.debug] - When true, enables debug logging. Default false.
*
* @param {Function} [opts.encode] - The function to encode outgoing messages.
*
* Defaults to JSON encoder.
*
* @param {Function} [opts.decode] - The function to decode incoming messages.
*
* Defaults to JSON:
*
* ```javascript
* (payload, callback) => callback(JSON.parse(payload))
* ```
*
* @param {number} [opts.timeout] - The default timeout in milliseconds to trigger push timeouts.
*
* Defaults `DEFAULT_TIMEOUT`
* @param {number} [opts.heartbeatIntervalMs] - The millisec interval to send a heartbeat message
* @param {Function} [opts.reconnectAfterMs] - The optional function that returns the
* socket reconnect interval, in milliseconds.
*
* Defaults to stepped backoff of:
*
* ```javascript
* function(tries){
* return [10, 50, 100, 150, 200, 250, 500, 1000, 2000][tries - 1] || 5000
* }
* ````
*
* @param {Function} [opts.rejoinAfterMs] - The optional function that returns the millisec
* rejoin interval for individual channels.
*
* ```javascript
* function(tries){
* return [1000, 2000, 5000][tries - 1] || 10000
* }
* ````
*
* @param {Function} [opts.logger] - The optional function for specialized logging, ie:
*
* ```javascript
* function(kind, msg, data) {
* console.log(`${kind}: ${msg}`, data)
* }
* ```
*
* @param {number} [opts.longpollerTimeout] - The maximum timeout of a long poll AJAX request.
*
* Defaults to 20s (double the server long poll timer).
*
* @param {(Object|function)} [opts.params] - The optional params to pass when connecting
* @param {string} [opts.authToken] - the optional authentication token to be exposed on the server
* under the `:auth_token` connect_info key.
* @param {string} [opts.binaryType] - The binary type to use for binary WebSocket frames.
*
* Defaults to "arraybuffer"
*
* @param {vsn} [opts.vsn] - The serializer's protocol version to send on connect.
*
* Defaults to DEFAULT_VSN.
*
* @param {Object} [opts.sessionStorage] - An optional Storage compatible object
* Combo uses sessionStorage for longpoll fallback history. Overriding the store is
* useful when Combo won't have access to `sessionStorage`. For example, This could
* happen if a site loads a cross-domain channel in an iframe. Example usage:
*
* class InMemoryStorage {
* constructor() { this.storage = {} }
* getItem(keyName) { return this.storage[keyName] || null }
* removeItem(keyName) { delete this.storage[keyName] }
* setItem(keyName, keyValue) { this.storage[keyName] = keyValue }
* }
*
*/
export default class Socket {
constructor(endPoint, opts = {}) {
this.stateChangeCallbacks = { open: [], close: [], error: [], message: [] }
this.channels = []
this.sendBuffer = []
this.ref = 0
this.timeout = opts.timeout || DEFAULT_TIMEOUT
this.transport = opts.transport || global.WebSocket || LongPoll
this.primaryPassedHealthCheck = false
this.longPollFallbackMs = opts.longPollFallbackMs
this.fallbackTimer = null
this.sessionStore = opts.sessionStorage || (global && global.sessionStorage)
this.establishedConnections = 0
this.defaultEncoder = Serializer.encode.bind(Serializer)
this.defaultDecoder = Serializer.decode.bind(Serializer)
this.closeWasClean = false
this.disconnecting = false
this.binaryType = opts.binaryType || 'arraybuffer'
this.connectClock = 1
if (this.transport !== LongPoll) {
this.encode = opts.encode || this.defaultEncoder
this.decode = opts.decode || this.defaultDecoder
} else {
this.encode = this.defaultEncoder
this.decode = this.defaultDecoder
}
let awaitingConnectionOnPageShow = null
if (comboWindow && comboWindow.addEventListener) {
comboWindow.addEventListener('pagehide', (_e) => {
if (this.conn) {
this.disconnect()
awaitingConnectionOnPageShow = this.connectClock
}
})
comboWindow.addEventListener('pageshow', (_e) => {
if (awaitingConnectionOnPageShow === this.connectClock) {
awaitingConnectionOnPageShow = null
this.connect()
}
})
}
this.heartbeatIntervalMs = opts.heartbeatIntervalMs || 30000
this.rejoinAfterMs = (tries) => {
if (opts.rejoinAfterMs) {
return opts.rejoinAfterMs(tries)
} else {
return [1000, 2000, 5000][tries - 1] || 10000
}
}
this.reconnectAfterMs = (tries) => {
if (opts.reconnectAfterMs) {
return opts.reconnectAfterMs(tries)
} else {
return [10, 50, 100, 150, 200, 250, 500, 1000, 2000][tries - 1] || 5000
}
}
this.logger = opts.logger || null
if (!this.logger && opts.debug) {
this.logger = (kind, msg, data) => {
console.log(`${kind}: ${msg}`, data)
}
}
this.longpollerTimeout = opts.longpollerTimeout || 20000
this.params = closure(opts.params || {})
this.endPoint = `${endPoint}/${TRANSPORTS.websocket}`
this.vsn = opts.vsn || DEFAULT_VSN
this.heartbeatTimeoutTimer = null
this.heartbeatTimer = null
this.pendingHeartbeatRef = null
this.reconnectTimer = new Timer(() => {
this.teardown(() => this.connect())
}, this.reconnectAfterMs)
this.authToken = opts.authToken
}
/**
* Returns the LongPoll transport reference
*/
getLongPollTransport() {
return LongPoll
}
/**
* Disconnects and replaces the active transport
*
* @param {Function} newTransport - The new transport class to instantiate
*
*/
replaceTransport(newTransport) {
this.connectClock++
this.closeWasClean = true
clearTimeout(this.fallbackTimer)
this.reconnectTimer.reset()
if (this.conn) {
this.conn.close()
this.conn = null
}
this.transport = newTransport
}
/**
* Returns the socket protocol
*
* @returns {string}
*/
protocol() {
return location.protocol.match(/^https/) ? 'wss' : 'ws'
}
/**
* The fully qualified socket url
*
* @returns {string}
*/
endPointURL() {
let uri = Ajax.appendParams(Ajax.appendParams(this.endPoint, this.params()), {
vsn: this.vsn,
})
if (uri.charAt(0) !== '/') {
return uri
}
if (uri.charAt(1) === '/') {
return `${this.protocol()}:${uri}`
}
return `${this.protocol()}://${location.host}${uri}`
}
/**
* Disconnects the socket
*
* See https://developer.mozilla.org/en-US/docs/Web/API/CloseEvent#Status_codes for valid status codes.
*
* @param {Function} callback - Optional callback which is called after socket is disconnected.
* @param {integer} code - A status code for disconnection (Optional).
* @param {string} reason - A textual description of the reason to disconnect. (Optional)
*/
disconnect(callback, code, reason) {
this.connectClock++
this.disconnecting = true
this.closeWasClean = true
clearTimeout(this.fallbackTimer)
this.reconnectTimer.reset()
this.teardown(
() => {
this.disconnecting = false
callback && callback()
},
code,
reason,
)
}
/**
*
* @param {Object} params - The params to send when connecting, for example `{user_id: userToken}`
*
* Passing params to connect is deprecated; pass them in the Socket constructor instead:
* `new Socket("/socket", {params: {user_id: userToken}})`.
*/
connect(params) {
if (params) {
console
&& console.log(
'passing params to connect is deprecated. Instead pass :params to the Socket constructor',
)
this.params = closure(params)
}
if (this.conn && !this.disconnecting) {
return
}
if (this.longPollFallbackMs && this.transport !== LongPoll) {
this.connectWithFallback(LongPoll, this.longPollFallbackMs)
} else {
this.transportConnect()
}
}
/**
* Logs the message. Override `this.logger` for specialized logging. noops by default
* @param {string} kind
* @param {string} msg
* @param {Object} data
*/
log(kind, msg, data) {
this.logger && this.logger(kind, msg, data)
}
/**
* Returns true if a logger has been set on this socket.
*/
hasLogger() {
return this.logger !== null
}
/**
* Registers callbacks for connection open events
*
* @example socket.onOpen(function(){ console.info("the socket was opened") })
*
* @param {Function} callback
*/
onOpen(callback) {
let ref = this.makeRef()
this.stateChangeCallbacks.open.push([ref, callback])
return ref
}
/**
* Registers callbacks for connection close events
* @param {Function} callback
*/
onClose(callback) {
let ref = this.makeRef()
this.stateChangeCallbacks.close.push([ref, callback])
return ref
}
/**
* Registers callbacks for connection error events
*
* @example socket.onError(function(error){ alert("An error occurred") })
*
* @param {Function} callback
*/
onError(callback) {
let ref = this.makeRef()
this.stateChangeCallbacks.error.push([ref, callback])
return ref
}
/**
* Registers callbacks for connection message events
* @param {Function} callback
*/
onMessage(callback) {
let ref = this.makeRef()
this.stateChangeCallbacks.message.push([ref, callback])
return ref
}
/**
* Pings the server and invokes the callback with the RTT in milliseconds
* @param {Function} callback
*
* Returns true if the ping was pushed or false if unable to be pushed.
*/
ping(callback) {
if (!this.isConnected()) {
return false
}
let ref = this.makeRef()
let startTime = Date.now()
this.push({ topic: 'combo', event: 'heartbeat', payload: {}, ref: ref })
let onMsgRef = this.onMessage((msg) => {
if (msg.ref === ref) {
this.off([onMsgRef])
callback(Date.now() - startTime)
}
})
return true
}
/**
* @private
*/
transportConnect() {
this.connectClock++
this.closeWasClean = false
let protocols = undefined
// Sec-WebSocket-Protocol based token
// (longpoll uses Authorization header instead)
if (this.authToken) {
protocols = ['combo', `${AUTH_TOKEN_PREFIX}${btoa(this.authToken).replace(/=/g, '')}`]
}
this.conn = new this.transport(this.endPointURL(), protocols)
this.conn.binaryType = this.binaryType
this.conn.timeout = this.longpollerTimeout
this.conn.onopen = () => this.onConnOpen()
this.conn.onerror = error => this.onConnError(error)
this.conn.onmessage = event => this.onConnMessage(event)
this.conn.onclose = event => this.onConnClose(event)
}
getSession(key) {
return this.sessionStore && this.sessionStore.getItem(key)
}
storeSession(key, val) {
this.sessionStore && this.sessionStore.setItem(key, val)
}
connectWithFallback(fallbackTransport, fallbackThreshold = 2500) {
clearTimeout(this.fallbackTimer)
let established = false
let primaryTransport = true
let openRef, errorRef
let fallback = (reason) => {
this.log('transport', `falling back to ${fallbackTransport.name}...`, reason)
this.off([openRef, errorRef])
primaryTransport = false
this.replaceTransport(fallbackTransport)
this.transportConnect()
}
if (this.getSession(`combo:fallback:${fallbackTransport.name}`)) {
return fallback('memorized')
}
this.fallbackTimer = setTimeout(fallback, fallbackThreshold)
errorRef = this.onError((reason) => {
this.log('transport', 'error', reason)
if (primaryTransport && !established) {
clearTimeout(this.fallbackTimer)
fallback(reason)
}
})
this.onOpen(() => {
established = true
if (!primaryTransport) {
// only memorize LP if we never connected to primary
if (!this.primaryPassedHealthCheck) {
this.storeSession(`combo:fallback:${fallbackTransport.name}`, 'true')
}
return this.log('transport', `established ${fallbackTransport.name} fallback`)
}
// if we've established primary, give the fallback a new period to attempt ping
clearTimeout(this.fallbackTimer)
this.fallbackTimer = setTimeout(fallback, fallbackThreshold)
this.ping((rtt) => {
this.log('transport', 'connected to primary after', rtt)
this.primaryPassedHealthCheck = true
clearTimeout(this.fallbackTimer)
})
})
this.transportConnect()
}
clearHeartbeats() {
clearTimeout(this.heartbeatTimer)
clearTimeout(this.heartbeatTimeoutTimer)
}
onConnOpen() {
if (this.hasLogger())
this.log('transport', `${this.transport.name} connected to ${this.endPointURL()}`)
this.closeWasClean = false
this.disconnecting = false
this.establishedConnections++
this.flushSendBuffer()
this.reconnectTimer.reset()
this.resetHeartbeat()
this.stateChangeCallbacks.open.forEach(([, callback]) => callback())
}
/**
* @private
*/
heartbeatTimeout() {
if (this.pendingHeartbeatRef) {
this.pendingHeartbeatRef = null
if (this.hasLogger()) {
this.log('transport', 'heartbeat timeout. Attempting to re-establish connection')
}
this.triggerChanError()
this.closeWasClean = false
this.teardown(
() => this.reconnectTimer.scheduleTimeout(),
WS_CLOSE_NORMAL,
'heartbeat timeout',
)
}
}
resetHeartbeat() {
if (this.conn && this.conn.skipHeartbeat) {
return
}
this.pendingHeartbeatRef = null
this.clearHeartbeats()
this.heartbeatTimer = setTimeout(() => this.sendHeartbeat(), this.heartbeatIntervalMs)
}
teardown(callback, code, reason) {
if (!this.conn) {
return callback && callback()
}
let connectClock = this.connectClock
this.waitForBufferDone(() => {
if (connectClock !== this.connectClock) {
return
}
if (this.conn) {
if (code) {
this.conn.close(code, reason || '')
} else {
this.conn.close()
}
}
this.waitForSocketClosed(() => {
if (connectClock !== this.connectClock) {
return
}
if (this.conn) {
this.conn.onopen = function () {} // noop
this.conn.onerror = function () {} // noop
this.conn.onmessage = function () {} // noop
this.conn.onclose = function () {} // noop
this.conn = null
}
callback && callback()
})
})
}
waitForBufferDone(callback, tries = 1) {
if (tries === 5 || !this.conn || !this.conn.bufferedAmount) {
callback()
return
}
setTimeout(() => {
this.waitForBufferDone(callback, tries + 1)
}, 150 * tries)
}
waitForSocketClosed(callback, tries = 1) {
if (tries === 5 || !this.conn || this.conn.readyState === SOCKET_STATES.closed) {
callback()
return
}
setTimeout(() => {
this.waitForSocketClosed(callback, tries + 1)
}, 150 * tries)
}
onConnClose(event) {
let closeCode = event && event.code
if (this.hasLogger()) this.log('transport', 'close', event)
this.triggerChanError()
this.clearHeartbeats()
if (!this.closeWasClean && closeCode !== 1000) {
this.reconnectTimer.scheduleTimeout()
}
this.stateChangeCallbacks.close.forEach(([, callback]) => callback(event))
}
/**
* @private
*/
onConnError(error) {
if (this.hasLogger()) this.log('transport', error)
let transportBefore = this.transport
let establishedBefore = this.establishedConnections
this.stateChangeCallbacks.error.forEach(([, callback]) => {
callback(error, transportBefore, establishedBefore)
})
if (transportBefore === this.transport || establishedBefore > 0) {
this.triggerChanError()
}
}
/**
* @private
*/
triggerChanError() {
this.channels.forEach((channel) => {
if (!(channel.isErrored() || channel.isLeaving() || channel.isClosed())) {
channel.trigger(CHANNEL_EVENTS.error)
}
})
}
/**
* @returns {string}
*/
connectionState() {
switch (this.conn && this.conn.readyState) {
case SOCKET_STATES.connecting:
return 'connecting'
case SOCKET_STATES.open:
return 'open'
case SOCKET_STATES.closing:
return 'closing'
default:
return 'closed'
}
}
/**
* @returns {boolean}
*/
isConnected() {
return this.connectionState() === 'open'
}
/**
* @private
*
* @param {Channel}
*/
remove(channel) {
this.off(channel.stateChangeRefs)
this.channels = this.channels.filter(c => c !== channel)
}
/**
* Removes `onOpen`, `onClose`, `onError,` and `onMessage` registrations.
*
* @param {refs} - list of refs returned by calls to
* `onOpen`, `onClose`, `onError,` and `onMessage`
*/
off(refs) {
for (let key in this.stateChangeCallbacks) {
this.stateChangeCallbacks[key] = this.stateChangeCallbacks[key].filter(([ref]) => {
return refs.indexOf(ref) === -1
})
}
}
/**
* Initiates a new channel for the given topic
*
* @param {string} topic
* @param {Object} chanParams - Parameters for the channel
* @returns {Channel}
*/
channel(topic, chanParams = {}) {
let chan = new Channel(topic, chanParams, this)
this.channels.push(chan)
return chan
}
/**
* @param {Object} data
*/
push(data) {
if (this.hasLogger()) {
let { topic, event, payload, ref, join_ref } = data
this.log('push', `${topic} ${event} (${join_ref}, ${ref})`, payload)
}
if (this.isConnected()) {
this.encode(data, result => this.conn.send(result))
} else {
this.sendBuffer.push(() => this.encode(data, result => this.conn.send(result)))
}
}
/**
* Return the next message ref, accounting for overflows
* @returns {string}
*/
makeRef() {
let newRef = this.ref + 1
if (newRef === this.ref) {
this.ref = 0
} else {
this.ref = newRef
}
return this.ref.toString()
}
sendHeartbeat() {
if (this.pendingHeartbeatRef && !this.isConnected()) {
return
}
this.pendingHeartbeatRef = this.makeRef()
this.push({
topic: 'combo',
event: 'heartbeat',
payload: {},
ref: this.pendingHeartbeatRef,
})
this.heartbeatTimeoutTimer = setTimeout(
() => this.heartbeatTimeout(),
this.heartbeatIntervalMs,
)
}
flushSendBuffer() {
if (this.isConnected() && this.sendBuffer.length > 0) {
this.sendBuffer.forEach(callback => callback())
this.sendBuffer = []
}
}
onConnMessage(rawMessage) {
this.decode(rawMessage.data, (msg) => {
let { topic, event, payload, ref, join_ref } = msg
if (ref && ref === this.pendingHeartbeatRef) {
this.clearHeartbeats()
this.pendingHeartbeatRef = null
this.heartbeatTimer = setTimeout(() => this.sendHeartbeat(), this.heartbeatIntervalMs)
}
if (this.hasLogger())
this.log(
'receive',
`${payload.status || ''} ${topic} ${event} ${(ref && '(' + ref + ')') || ''}`,
payload,
)
for (let i = 0; i < this.channels.length; i++) {
const channel = this.channels[i]
if (!channel.isMember(topic, event, payload, join_ref)) {
continue
}
channel.trigger(event, payload, ref, join_ref)
}
for (let i = 0; i < this.stateChangeCallbacks.message.length; i++) {
let [, callback] = this.stateChangeCallbacks.message[i]
callback(msg)
}
})
}
leaveOpenTopic(topic) {
let dupChannel = this.channels.find(
c => c.topic === topic && (c.isJoined() || c.isJoining()),
)
if (dupChannel) {
if (this.hasLogger()) this.log('transport', `leaving duplicate topic "${topic}"`)
dupChannel.leave()
}
}
}