Commit ac9d49b8 authored by Serene H's avatar Serene H
Browse files

ensure closing stale remotes from the client side

parent ea2e052a
...@@ -20,6 +20,7 @@ import ( ...@@ -20,6 +20,7 @@ import (
const ( const (
ReconnectTimeout = 10 ReconnectTimeout = 10
DefaultSnowflakeCapacity = 1 DefaultSnowflakeCapacity = 1
SnowflakeTimeout = 30
) )
// When a connection handler starts, +1 is written to this channel; when it // When a connection handler starts, +1 is written to this channel; when it
...@@ -81,7 +82,7 @@ func handler(socks SocksConnector, snowflakes SnowflakeCollector) error { ...@@ -81,7 +82,7 @@ func handler(socks SocksConnector, snowflakes SnowflakeCollector) error {
return errors.New("handler: Received invalid Snowflake") return errors.New("handler: Received invalid Snowflake")
} }
defer socks.Close() defer socks.Close()
defer snowflake.Reset() defer snowflake.Close()
log.Println("---- Handler: snowflake assigned ----") log.Println("---- Handler: snowflake assigned ----")
err := socks.Grant(&net.TCPAddr{IP: net.IPv4zero, Port: 0}) err := socks.Grant(&net.TCPAddr{IP: net.IPv4zero, Port: 0})
if err != nil { if err != nil {
......
...@@ -29,6 +29,7 @@ type WebRTCPeer struct { ...@@ -29,6 +29,7 @@ type WebRTCPeer struct {
errorChannel chan error errorChannel chan error
recvPipe *io.PipeReader recvPipe *io.PipeReader
writePipe *io.PipeWriter writePipe *io.PipeWriter
lastReceive time.Time
buffer bytes.Buffer buffer bytes.Buffer
reset chan struct{} reset chan struct{}
...@@ -37,6 +38,28 @@ type WebRTCPeer struct { ...@@ -37,6 +38,28 @@ type WebRTCPeer struct {
BytesLogger BytesLogger
} }
// Construct a WebRTC PeerConnection.
func NewWebRTCPeer(config *webrtc.Configuration,
broker *BrokerChannel) *WebRTCPeer {
connection := new(WebRTCPeer)
connection.id = "snowflake-" + uniuri.New()
connection.config = config
connection.broker = broker
connection.offerChannel = make(chan *webrtc.SessionDescription, 1)
connection.answerChannel = make(chan *webrtc.SessionDescription, 1)
// Error channel is mostly for reporting during the initial SDP offer
// creation & local description setting, which happens asynchronously.
connection.errorChannel = make(chan error, 1)
connection.reset = make(chan struct{}, 1)
// Override with something that's not NullLogger to have real logging.
connection.BytesLogger = &BytesNullLogger{}
// Pipes remain the same even when DataChannel gets switched.
connection.recvPipe, connection.writePipe = io.Pipe()
return connection
}
// Read bytes from local SOCKS. // Read bytes from local SOCKS.
// As part of |io.ReadWriter| // As part of |io.ReadWriter|
func (c *WebRTCPeer) Read(b []byte) (int, error) { func (c *WebRTCPeer) Read(b []byte) (int, error) {
...@@ -47,6 +70,7 @@ func (c *WebRTCPeer) Read(b []byte) (int, error) { ...@@ -47,6 +70,7 @@ func (c *WebRTCPeer) Read(b []byte) (int, error) {
// As part of |io.ReadWriter| // As part of |io.ReadWriter|
func (c *WebRTCPeer) Write(b []byte) (int, error) { func (c *WebRTCPeer) Write(b []byte) (int, error) {
c.BytesLogger.AddOutbound(len(b)) c.BytesLogger.AddOutbound(len(b))
// TODO: Buffering could be improved / separated out of WebRTCPeer.
if nil == c.transport { if nil == c.transport {
log.Printf("Buffered %d bytes --> WebRTC", len(b)) log.Printf("Buffered %d bytes --> WebRTC", len(b))
c.buffer.Write(b) c.buffer.Write(b)
...@@ -61,45 +85,42 @@ func (c *WebRTCPeer) Close() error { ...@@ -61,45 +85,42 @@ func (c *WebRTCPeer) Close() error {
if c.closed { // Skip if already closed. if c.closed { // Skip if already closed.
return nil return nil
} }
log.Printf("WebRTC: Closing")
c.cleanup()
// Mark for deletion. // Mark for deletion.
c.closed = true c.closed = true
c.cleanup()
c.Reset()
log.Printf("WebRTC: Closing")
return nil return nil
} }
// As part of |Resetter| // As part of |Resetter|
func (c *WebRTCPeer) Reset() { func (c *WebRTCPeer) Reset() {
c.Close() if nil == c.reset {
go func() { return
}
c.reset <- struct{}{} c.reset <- struct{}{}
log.Println("WebRTC resetting...")
}()
} }
// As part of |Resetter| // As part of |Resetter|
func (c *WebRTCPeer) WaitForReset() { <-c.reset } func (c *WebRTCPeer) WaitForReset() { <-c.reset }
// Construct a WebRTC PeerConnection. // Prevent long-lived broken remotes.
func NewWebRTCPeer(config *webrtc.Configuration, // Should also update the DataChannel in underlying go-webrtc's to make Closes
broker *BrokerChannel) *WebRTCPeer { // more immediate / responsive.
connection := new(WebRTCPeer) func (c *WebRTCPeer) checkForStaleness() {
connection.id = "snowflake-" + uniuri.New() c.lastReceive = time.Now()
connection.config = config for {
connection.broker = broker if c.closed {
connection.offerChannel = make(chan *webrtc.SessionDescription, 1) return
connection.answerChannel = make(chan *webrtc.SessionDescription, 1) }
// Error channel is mostly for reporting during the initial SDP offer if time.Since(c.lastReceive).Seconds() > SnowflakeTimeout {
// creation & local description setting, which happens asynchronously. log.Println("WebRTC: No messages received for", SnowflakeTimeout,
connection.errorChannel = make(chan error, 1) "seconds -- closing stale connection.")
connection.reset = make(chan struct{}, 1) c.Close()
return
// Override with something that's not NullLogger to have real logging. }
connection.BytesLogger = &BytesNullLogger{} <-time.After(time.Second)
}
// Pipes remain the same even when DataChannel gets switched.
connection.recvPipe, connection.writePipe = io.Pipe()
return connection
} }
// As part of |Connector| interface. // As part of |Connector| interface.
...@@ -119,6 +140,7 @@ func (c *WebRTCPeer) Connect() error { ...@@ -119,6 +140,7 @@ func (c *WebRTCPeer) Connect() error {
if err != nil { if err != nil {
return err return err
} }
go c.checkForStaleness()
return nil return nil
} }
...@@ -208,7 +230,7 @@ func (c *WebRTCPeer) establishDataChannel() error { ...@@ -208,7 +230,7 @@ func (c *WebRTCPeer) establishDataChannel() error {
// Disable the DataChannel as a write destination. // Disable the DataChannel as a write destination.
log.Println("WebRTC: DataChannel.OnClose [remotely]") log.Println("WebRTC: DataChannel.OnClose [remotely]")
c.transport = nil c.transport = nil
c.Reset() c.Close()
} }
dc.OnMessage = func(msg []byte) { dc.OnMessage = func(msg []byte) {
if len(msg) <= 0 { if len(msg) <= 0 {
...@@ -225,6 +247,7 @@ func (c *WebRTCPeer) establishDataChannel() error { ...@@ -225,6 +247,7 @@ func (c *WebRTCPeer) establishDataChannel() error {
log.Println("Error: short write") log.Println("Error: short write")
panic("short write") panic("short write")
} }
c.lastReceive = time.Now()
} }
log.Println("WebRTC: DataChannel created.") log.Println("WebRTC: DataChannel created.")
return nil return nil
...@@ -257,7 +280,7 @@ func (c *WebRTCPeer) exchangeSDP() error { ...@@ -257,7 +280,7 @@ func (c *WebRTCPeer) exchangeSDP() error {
} }
case err := <-c.errorChannel: case err := <-c.errorChannel:
log.Println("Failed to prepare offer", err) log.Println("Failed to prepare offer", err)
c.Reset() c.Close()
return err return err
} }
// Keep trying the same offer until a valid answer arrives. // Keep trying the same offer until a valid answer arrives.
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment