Commit 15963688 authored by David Fifield's avatar David Fifield
Browse files

Remove support for the base64 WebSocket subprotocol.

This was only needed for very very old Firefox before WebSockets were
properly standardized.
parent 3e782517
Loading
Loading
Loading
Loading
+9 −40
Original line number Diff line number Diff line
@@ -10,7 +10,6 @@ package main

import (
	"crypto/tls"
	"encoding/base64"
	"errors"
	"flag"
	"fmt"
@@ -31,8 +30,7 @@ import (
const ptMethodName = "snowflake"
const requestTimeout = 10 * time.Second

// "4/3+1" accounts for possible base64 encoding.
const maxMessageSize = 64*1024*4/3 + 1
const maxMessageSize = 64*1024

var ptInfo pt.ServerInfo

@@ -50,11 +48,9 @@ func usage() {
}

// An abstraction that makes an underlying WebSocket connection look like an
// io.ReadWriteCloser. It internally takes care of things like base64 encoding
// and decoding.
// io.ReadWriteCloser.
type webSocketConn struct {
	Ws         *websocket.WebSocket
	Base64     bool
	messageBuf []byte
}

@@ -70,26 +66,12 @@ func (conn *webSocketConn) Read(b []byte) (n int, err error) {
			err = io.EOF
			return
		}
		if conn.Base64 {
			if m.Opcode != 1 {
				err = errors.New(fmt.Sprintf("got non-text opcode %d with the base64 subprotocol", m.Opcode))
				return
			}
			conn.messageBuf = make([]byte, base64.StdEncoding.DecodedLen(len(m.Payload)))
			var num int
			num, err = base64.StdEncoding.Decode(conn.messageBuf, m.Payload)
			if err != nil {
				return
			}
			conn.messageBuf = conn.messageBuf[:num]
		} else {
		if m.Opcode != 2 {
				err = errors.New(fmt.Sprintf("got non-binary opcode %d with no subprotocol", m.Opcode))
			err = errors.New(fmt.Sprintf("got non-binary opcode %d", m.Opcode))
			return
		}
		conn.messageBuf = m.Payload
	}
	}

	n = copy(b, conn.messageBuf)
	conn.messageBuf = conn.messageBuf[n:]
@@ -98,20 +80,9 @@ func (conn *webSocketConn) Read(b []byte) (n int, err error) {
}

// Implements io.Writer.
func (conn *webSocketConn) Write(b []byte) (n int, err error) {
	if conn.Base64 {
		buf := make([]byte, base64.StdEncoding.EncodedLen(len(b)))
		base64.StdEncoding.Encode(buf, b)
		err = conn.Ws.WriteMessage(1, buf)
		if err != nil {
			return
		}
		n = len(b)
	} else {
		err = conn.Ws.WriteMessage(2, b)
		n = len(b)
	}
	return
func (conn *webSocketConn) Write(b []byte) (int, error) {
	err := conn.Ws.WriteMessage(2, b)
	return len(b), err
}

// Implements io.Closer.
@@ -125,7 +96,6 @@ func (conn *webSocketConn) Close() error {
func newWebSocketConn(ws *websocket.WebSocket) webSocketConn {
	var conn webSocketConn
	conn.Ws = ws
	conn.Base64 = (ws.Subprotocol == "base64")
	return conn
}

@@ -233,7 +203,6 @@ func startServer(ln net.Listener) (net.Listener, error) {
	go func() {
		defer ln.Close()
		var config websocket.Config
		config.Subprotocols = []string{"base64"}
		config.MaxMessageSize = maxMessageSize
		s := &http.Server{
			Handler:     config.Handler(webSocketHandler),