server.go 10.6 KB
Newer Older
1
// Snowflake-specific websocket server plugin. It reports the transport name as
2
// "snowflake".
3
4
5
package main

import (
6
	"crypto/tls"
7
8
9
	"flag"
	"fmt"
	"io"
10
	"io/ioutil"
11
12
13
14
15
	"log"
	"net"
	"net/http"
	"os"
	"os/signal"
16
	"path/filepath"
17
	"strings"
Arlo Breault's avatar
Arlo Breault committed
18
	"sync"
19
20
21
	"syscall"
	"time"

22
	pt "git.torproject.org/pluggable-transports/goptlib.git"
23
	"git.torproject.org/pluggable-transports/snowflake.git/common/safelog"
24
	"git.torproject.org/pluggable-transports/snowflake.git/common/websocketconn"
25
	"github.com/gorilla/websocket"
26
	"golang.org/x/crypto/acme/autocert"
27
	"golang.org/x/net/http2"
28
29
)

30
const ptMethodName = "snowflake"
31
32
const requestTimeout = 10 * time.Second

33
34
35
36
// How long to wait for ListenAndServe or ListenAndServeTLS to return an error
// before deciding that it's not going to return.
const listenAndServeErrorTimeout = 100 * time.Millisecond

37
38
39
var ptInfo pt.ServerInfo

func usage() {
David Fifield's avatar
David Fifield committed
40
41
42
43
44
	fmt.Fprintf(os.Stderr, `Usage: %s [OPTIONS]

WebSocket server pluggable transport for Snowflake. Works only as a managed
proxy. Uses TLS with ACME (Let's Encrypt) by default. Set the certificate
hostnames with the --acme-hostnames option. Use ServerTransportListenAddr in
45
46
torrc to choose the listening port. When using TLS, this program will open an
additional HTTP listener on port 80 to work with ACME.
David Fifield's avatar
David Fifield committed
47
48

`, os.Args[0])
49
	flag.PrintDefaults()
50
51
}

Arlo Breault's avatar
Arlo Breault committed
52
// Copy from WebSocket to socket and vice versa.
53
func proxy(local *net.TCPConn, conn *websocketconn.Conn) {
Arlo Breault's avatar
Arlo Breault committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
	var wg sync.WaitGroup
	wg.Add(2)

	go func() {
		if _, err := io.Copy(conn, local); err != nil {
			log.Printf("error copying ORPort to WebSocket %v", err)
		}
		if err := local.CloseRead(); err != nil {
			log.Printf("error closing read after copying ORPort to WebSocket %v", err)
		}
		conn.Close()
		wg.Done()
	}()
	go func() {
		if _, err := io.Copy(local, conn); err != nil {
69
			log.Printf("error copying WebSocket to ORPort %v", err)
Arlo Breault's avatar
Arlo Breault committed
70
71
72
73
74
75
76
77
78
79
80
		}
		if err := local.CloseWrite(); err != nil {
			log.Printf("error closing write after copying WebSocket to ORPort %v", err)
		}
		conn.Close()
		wg.Done()
	}()

	wg.Wait()
}

81
82
// Return an address string suitable to pass into pt.DialOr.
func clientAddr(clientIPParam string) string {
83
84
85
	if clientIPParam == "" {
		return ""
	}
86
87
88
	// Check if client addr is a valid IP
	clientIP := net.ParseIP(clientIPParam)
	if clientIP == nil {
89
90
91
92
93
		return ""
	}
	// Check if client addr is 0.0.0.0 or [::]. Some proxies erroneously
	// report an address of 0.0.0.0: https://bugs.torproject.org/33157.
	if clientIP.IsUnspecified() {
94
95
		return ""
	}
96
97
	// Add a dummy port number. USERADDR requires a port number.
	return (&net.TCPAddr{IP: clientIP, Port: 1, Zone: ""}).String()
98
99
}

100
101
102
var upgrader = websocket.Upgrader{
	CheckOrigin: func(r *http.Request) bool { return true },
}
103
104
105
106
107
108
109
110

type HTTPHandler struct{}

func (handler *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	ws, err := upgrader.Upgrade(w, r, nil)
	if err != nil {
		log.Println(err)
		return
111
	}
112

113
	conn := websocketconn.New(ws)
Arlo Breault's avatar
Arlo Breault committed
114
	defer conn.Close()
115

116
	// Pass the address of client as the remote address of incoming connection
117
	clientIPParam := r.URL.Query().Get("client_ip")
118
	addr := clientAddr(clientIPParam)
David Fifield's avatar
David Fifield committed
119
	statsChannel <- addr != ""
120
	or, err := pt.DialOr(&ptInfo, addr, ptMethodName)
121
	if err != nil {
122
		log.Printf("failed to connect to ORPort: %s", err)
123
124
125
126
		return
	}
	defer or.Close()

127
	proxy(or, conn)
128
129
}

130
131
func initServer(addr *net.TCPAddr,
	getCertificate func(*tls.ClientHelloInfo) (*tls.Certificate, error),
132
	listenAndServe func(*http.Server, chan<- error)) (*http.Server, error) {
133
134
135
136
137
	// We're not capable of listening on port 0 (i.e., an ephemeral port
	// unknown in advance). The reason is that while the net/http package
	// exposes ListenAndServe and ListenAndServeTLS, those functions never
	// return, so there's no opportunity to find out what the port number
	// is, in between the Listen and Serve steps.
138
	// https://groups.google.com/d/msg/Golang-nuts/3F1VRCCENp8/3hcayZiwYM8J
139
140
141
	if addr.Port == 0 {
		return nil, fmt.Errorf("cannot listen on port %d; configure a port using ServerTransportListenAddr", addr.Port)
	}
142

143
	var handler HTTPHandler
144
145
	server := &http.Server{
		Addr:        addr.String(),
146
		Handler:     &handler,
147
148
149
150
151
152
153
154
155
156
157
		ReadTimeout: requestTimeout,
	}
	// We need to override server.TLSConfig.GetCertificate--but first
	// server.TLSConfig needs to be non-nil. If we just create our own new
	// &tls.Config, it will lack the default settings that the net/http
	// package sets up for things like HTTP/2. Therefore we first call
	// http2.ConfigureServer for its side effect of initializing
	// server.TLSConfig properly. An alternative would be to make a dummy
	// net.Listener, call Serve on it, and let it return.
	// https://github.com/golang/go/issues/16588#issuecomment-237386446
	err := http2.ConfigureServer(server, nil)
158
	if err != nil {
159
		return server, err
160
	}
161
	server.TLSConfig.GetCertificate = getCertificate
162

163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
	// Another unfortunate effect of the inseparable net/http ListenAndServe
	// is that we can't check for Listen errors like "permission denied" and
	// "address already in use" without potentially entering the infinite
	// loop of Serve. The hack we apply here is to wait a short time,
	// listenAndServeErrorTimeout, to see if an error is returned (because
	// it's better if the error message goes to the tor log through
	// SMETHOD-ERROR than if it only goes to the snowflake log).
	errChan := make(chan error)
	go listenAndServe(server, errChan)
	select {
	case err = <-errChan:
		break
	case <-time.After(listenAndServeErrorTimeout):
		break
	}
178

179
	return server, err
180
181
}

182
func startServer(addr *net.TCPAddr) (*http.Server, error) {
183
	return initServer(addr, nil, func(server *http.Server, errChan chan<- error) {
184
185
186
187
188
		log.Printf("listening with plain HTTP on %s", addr)
		err := server.ListenAndServe()
		if err != nil {
			log.Printf("error in ListenAndServe: %s", err)
		}
189
		errChan <- err
190
	})
191
192
}

193
func startServerTLS(addr *net.TCPAddr, getCertificate func(*tls.ClientHelloInfo) (*tls.Certificate, error)) (*http.Server, error) {
194
	return initServer(addr, getCertificate, func(server *http.Server, errChan chan<- error) {
195
196
		log.Printf("listening with HTTPS on %s", addr)
		err := server.ListenAndServeTLS("", "")
197
		if err != nil {
198
			log.Printf("error in ListenAndServeTLS: %s", err)
199
		}
200
		errChan <- err
201
	})
202
203
}

204
205
206
207
208
209
210
211
func getCertificateCacheDir() (string, error) {
	stateDir, err := pt.MakeStateDir()
	if err != nil {
		return "", err
	}
	return filepath.Join(stateDir, "snowflake-certificate-cache"), nil
}

212
func main() {
David Fifield's avatar
David Fifield committed
213
	var acmeEmail string
214
	var acmeHostnamesCommas string
215
	var disableTLS bool
216
	var logFilename string
Arlo Breault's avatar
Arlo Breault committed
217
	var unsafeLogging bool
218
219

	flag.Usage = usage
David Fifield's avatar
David Fifield committed
220
	flag.StringVar(&acmeEmail, "acme-email", "", "optional contact email for Let's Encrypt notifications")
221
	flag.StringVar(&acmeHostnamesCommas, "acme-hostnames", "", "comma-separated hostnames for TLS certificate")
222
	flag.BoolVar(&disableTLS, "disable-tls", false, "don't use HTTPS")
223
	flag.StringVar(&logFilename, "log", "", "log file to write to")
Arlo Breault's avatar
Arlo Breault committed
224
	flag.BoolVar(&unsafeLogging, "unsafe-logging", false, "prevent logs from being scrubbed")
225
226
	flag.Parse()

227
	log.SetFlags(log.LstdFlags | log.LUTC)
228
229

	var logOutput io.Writer = os.Stderr
230
231
232
	if logFilename != "" {
		f, err := os.OpenFile(logFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
		if err != nil {
233
			log.Fatalf("can't open log file: %s", err)
234
		}
David Fifield's avatar
David Fifield committed
235
		defer f.Close()
236
237
		logOutput = f
	}
Arlo Breault's avatar
Arlo Breault committed
238
239
240
241
242
243
	if unsafeLogging {
		log.SetOutput(logOutput)
	} else {
		// We want to send the log output through our scrubber first
		log.SetOutput(&safelog.LogScrubber{Output: logOutput})
	}
244

245
246
	if !disableTLS && acmeHostnamesCommas == "" {
		log.Fatal("the --acme-hostnames option is required")
247
	}
248
	acmeHostnames := strings.Split(acmeHostnamesCommas, ",")
249

250
251
252
253
	log.Printf("starting")
	var err error
	ptInfo, err = pt.ServerSetup(nil)
	if err != nil {
254
		log.Fatalf("error in setup: %s", err)
255
256
	}

257
258
	go statsThread()

259
	var certManager *autocert.Manager
260
261
	if !disableTLS {
		log.Printf("ACME hostnames: %q", acmeHostnames)
262
263

		var cache autocert.Cache
264
265
		var cacheDir string
		cacheDir, err = getCertificateCacheDir()
266
267
268
269
270
271
272
		if err == nil {
			log.Printf("caching ACME certificates in directory %q", cacheDir)
			cache = autocert.DirCache(cacheDir)
		} else {
			log.Printf("disabling ACME certificate cache: %s", err)
		}

273
274
275
276
		certManager = &autocert.Manager{
			Prompt:     autocert.AcceptTOS,
			HostPolicy: autocert.HostWhitelist(acmeHostnames...),
			Email:      acmeEmail,
277
			Cache:      cache,
278
		}
279
280
	}

281
282
283
284
285
	// The ACME HTTP-01 responder only works when it is running on port 80.
	// We actually open the port in the loop below, so that any errors can
	// be reported in the SMETHOD-ERROR of some bindaddr.
	// https://github.com/ietf-wg-acme/acme/blob/master/draft-ietf-acme-acme.md#http-challenge
	needHTTP01Listener := !disableTLS
286

287
	servers := make([]*http.Server, 0)
288
	for _, bindaddr := range ptInfo.Bindaddrs {
289
		if bindaddr.MethodName != ptMethodName {
290
			pt.SmethodError(bindaddr.MethodName, "no such method")
291
			continue
292
		}
293

294
		if needHTTP01Listener {
295
			addr := *bindaddr.Addr
296
297
			addr.Port = 80
			log.Printf("Starting HTTP-01 ACME listener")
298
299
			var lnHTTP01 *net.TCPListener
			lnHTTP01, err = net.ListenTCP("tcp", &addr)
300
			if err != nil {
301
				log.Printf("error opening HTTP-01 ACME listener: %s", err)
302
				pt.SmethodError(bindaddr.MethodName, "HTTP-01 ACME listener: "+err.Error())
303
304
				continue
			}
305
306
307
308
			server := &http.Server{
				Addr:    addr.String(),
				Handler: certManager.HTTPHandler(nil),
			}
309
			go func() {
310
				log.Fatal(server.Serve(lnHTTP01))
311
			}()
312
			servers = append(servers, server)
313
			needHTTP01Listener = false
314
315
		}

316
		var server *http.Server
317
318
319
		args := pt.Args{}
		if disableTLS {
			args.Add("tls", "no")
320
			server, err = startServer(bindaddr.Addr)
321
322
		} else {
			args.Add("tls", "yes")
323
324
325
			for _, hostname := range acmeHostnames {
				args.Add("hostname", hostname)
			}
326
			server, err = startServerTLS(bindaddr.Addr, certManager.GetCertificate)
327
328
		}
		if err != nil {
329
			log.Printf("error opening listener: %s", err)
330
			pt.SmethodError(bindaddr.MethodName, err.Error())
331
			continue
332
		}
333
334
		pt.SmethodArgs(bindaddr.MethodName, bindaddr.Addr, args)
		servers = append(servers, server)
335
336
337
338
	}
	pt.SmethodsDone()

	sigChan := make(chan os.Signal, 1)
339
	signal.Notify(sigChan, syscall.SIGTERM)
340

341
342
343
344
	if os.Getenv("TOR_PT_EXIT_ON_STDIN_CLOSE") == "1" {
		// This environment variable means we should treat EOF on stdin
		// just like SIGTERM: https://bugs.torproject.org/15435.
		go func() {
345
346
347
			if _, err := io.Copy(ioutil.Discard, os.Stdin); err != nil {
				log.Printf("error copying os.Stdin to ioutil.Discard: %v", err)
			}
348
349
350
351
352
			log.Printf("synthesizing SIGTERM because of stdin close")
			sigChan <- syscall.SIGTERM
		}()
	}

David Fifield's avatar
David Fifield committed
353
	// Wait for a signal.
354
	sig := <-sigChan
355

David Fifield's avatar
David Fifield committed
356
	// Signal received, shut down.
357
	log.Printf("caught signal %q, exiting", sig)
358
359
	for _, server := range servers {
		server.Close()
360
361
	}
}