main.go 4.5 KB
Newer Older
1
2
3
package main

import (
4
	"context"
5
	"flag"
6
	"fmt"
7
8
9
10
	"io"
	"log"
	"net/http"
	"os"
11
12
	"os/signal"
	"syscall"
13
14
15
16
17
18
19
20
21
22
23
24
25
	"time"

	"git.torproject.org/pluggable-transports/snowflake.git/common/safelog"
	"github.com/gorilla/mux"
)

type Route struct {
	Name        string
	Method      string
	Pattern     string
	HandlerFunc http.HandlerFunc
}

26
27
var torCtx *TorContext

28
29
30
31
type Routes []Route

var routes = Routes{
	Route{
32
33
34
35
		"BridgeState",
		"GET",
		"/bridge-state",
		BridgeState,
36
	},
37
38
39
40
41
42
	Route{
		"BridgeStateWeb",
		"GET",
		"/result",
		BridgeStateWeb,
	},
43
44
}

45
46
47
// tmpDataDir contains the path to Tor's data directory.
var tmpDataDir string

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
// Logger logs when we receive requests, and the execution time of handling
// these requests.  We don't log client IP addresses or the given obfs4
// parameters.
func Logger(inner http.Handler, name string) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		start := time.Now()

		inner.ServeHTTP(w, r)

		log.Printf(
			"%s\t%s\t%s\t%s",
			r.Method,
			r.RequestURI,
			name,
			time.Since(start),
		)
	})
}

// NewRouter creates and returns a new request router.
func NewRouter() *mux.Router {

	router := mux.NewRouter().StrictSlash(true)
	for _, route := range routes {
		var handler http.Handler

		handler = route.HandlerFunc
		handler = Logger(handler, route.Name)

		router.
			Methods(route.Method).
			Path(route.Pattern).
			Name(route.Name).
			Handler(handler)
	}

	return router
}

87
88
func printPrettyCache() {
	var shortError string
89
90
	var numFunctional int

91
92
93
94
95
96
	for bridgeLine, cacheEntry := range cache {
		shortError = cacheEntry.Error
		maxChars := 50
		if len(cacheEntry.Error) > maxChars {
			shortError = cacheEntry.Error[:maxChars]
		}
97
98
99
		if cacheEntry.Error == "" {
			numFunctional++
		}
100
101
		fmt.Printf("%-22s %-50s %s\n", bridgeLine, shortError, cacheEntry.Time)
	}
102
103
104
105
	if len(cache) > 0 {
		log.Printf("Found %d (%.2f%%) out of %d functional.\n", numFunctional,
			float64(numFunctional)/float64(len(cache))*100.0, len(cache))
	}
106
107
}

108
109
func main() {

110
	var err error
111
	var addr string
112
	var web, printCache bool
113
	var certFilename, keyFilename string
114
	var cacheFile string
115
	var templatesDir string
116
	var numSecs int
117

Philipp Winter's avatar
Philipp Winter committed
118
	flag.StringVar(&addr, "addr", ":5000", "Address to listen on.")
119
	flag.BoolVar(&web, "web", false, "Enable the web interface (in addition to the JSON API).")
120
	flag.BoolVar(&printCache, "print-cache", false, "Print the given cache file and exit.")
121
122
	flag.StringVar(&certFilename, "cert", "", "TLS certificate file.")
	flag.StringVar(&keyFilename, "key", "", "TLS private key file.")
123
	flag.StringVar(&cacheFile, "cache", "bridgestrap-cache.bin", "Cache file that contains test results.")
124
	flag.StringVar(&templatesDir, "templates", "templates", "Path to directory that contains our web templates.")
125
	flag.IntVar(&numSecs, "seconds", 0, "Number of seconds after two subsequent requests are handled.")
126
127
128
129
	flag.Parse()

	var logOutput io.Writer = os.Stderr
	// Send the log output through our scrubber first.
130
131
132
	if !printCache {
		log.SetOutput(&safelog.LogScrubber{Output: logOutput})
	}
133
134
	log.SetFlags(log.LstdFlags | log.LUTC)

135
136
	if web {
		log.Println("Enabling web interface.")
137
		LoadHtmlTemplates(templatesDir)
138
139
140
141
142
143
144
145
146
		routes = append(routes,
			Route{
				"Index",
				"GET",
				"/",
				Index,
			})
	}

147
	if err = cache.ReadFromDisk(cacheFile); err != nil {
Philipp Winter's avatar
Philipp Winter committed
148
		log.Printf("Could not read cache: %s", err)
149
	}
150
151
152
153
	if printCache {
		printPrettyCache()
		return
	}
154

155
156
157
158
159
160
	torCtx = &TorContext{}
	if err = torCtx.Start(); err != nil {
		log.Printf("Failed to start Tor process: %s", err)
		return
	}

161
162
163
164
	var srv http.Server
	srv.Addr = addr
	srv.Handler = NewRouter()
	log.Printf("Starting service on port %s.", addr)
165
	go func() {
166
167
168
169
		if certFilename != "" && keyFilename != "" {
			srv.ListenAndServeTLS(certFilename, keyFilename)
		} else {
			srv.ListenAndServe()
170
171
172
		}
	}()

173
174
175
	signalChan := make(chan os.Signal, 1)
	signal.Notify(signalChan, syscall.SIGINT)
	signal.Notify(signalChan, syscall.SIGTERM)
176

177
178
179
	log.Printf("Waiting for signal to shut down.")
	<-signalChan
	log.Printf("Received signal to shut down.")
180
181
182
183
184

	if err := torCtx.Stop(); err != nil {
		log.Printf("Failed to clean up after Tor: %s", err)
	}

185
186
	// Give our Web server a maximum of a minute to finish handling open
	// connections and shut down gracefully.
187
188
189
190
191
192
193
194
195
	t := time.Now().Add(time.Minute)
	ctx, cancel := context.WithDeadline(context.Background(), t)
	defer cancel()
	if err := srv.Shutdown(ctx); err != nil {
		log.Printf("Failed to shut down Web server: %s", err)
	}

	if err := cache.WriteToDisk(cacheFile); err != nil {
		log.Printf("Failed to write cache to disk: %s", err)
196
	}
197
}