From 12cd3f06dbe777033cc6ce98f3131219bf36315c Mon Sep 17 00:00:00 2001 From: WofWca <wofwca@protonmail.com> Date: Wed, 15 Jan 2025 00:04:45 +0400 Subject: [PATCH] WIP: change API to expose smux.OpenStream() --- client/lib/snowflake.go | 32 ++++++++++++++------------ client/snowflake.go | 8 ++++++- server/lib/snowflake.go | 50 +++++++++++++++++------------------------ server/server.go | 15 ++----------- 4 files changed, 48 insertions(+), 57 deletions(-) diff --git a/client/lib/snowflake.go b/client/lib/snowflake.go index f1a3bada..6281df1a 100644 --- a/client/lib/snowflake.go +++ b/client/lib/snowflake.go @@ -172,7 +172,7 @@ func NewSnowflakeClient(config ClientConfig) (*Transport, error) { // Dial starts the collection of snowflakes and returns a SnowflakeConn that is a // wrapper around a smux.Stream that will reliably deliver data to a Snowflake // server through one or more snowflake proxies. -func (t *Transport) Dial() (net.Conn, error) { +func (t *Transport) Dial() (*SnowflakeConn, error) { // Cleanup functions to run before returning, in case of an error. var cleanup []func() defer func() { @@ -207,17 +207,17 @@ func (t *Transport) Dial() (net.Conn, error) { }) // On the smux session we overlay a stream. - stream, err := sess.OpenStream() - if err != nil { - return nil, err - } - // Begin exchanging data. - log.Printf("---- SnowflakeConn: begin stream %v ---", stream.ID()) - cleanup = append(cleanup, func() { stream.Close() }) + // stream, err := sess.OpenStream() + // if err != nil { + // return nil, err + // } + // // Begin exchanging data. + // log.Printf("---- SnowflakeConn: begin stream %v ---", stream.ID()) + // cleanup = append(cleanup, func() { stream.Close() }) // All good, clear the cleanup list. cleanup = nil - return &SnowflakeConn{Stream: stream, sess: sess, pconn: pconn, snowflakes: snowflakes}, nil + return &SnowflakeConn{Sess: sess, pconn: pconn, snowflakes: snowflakes}, nil } func (t *Transport) AddSnowflakeEventListener(receiver event.SnowflakeEventReceiver) { @@ -235,26 +235,30 @@ func (t *Transport) SetRendezvousMethod(r RendezvousMethod) { // SnowflakeConn is a reliable connection to a snowflake server that implements net.Conn. type SnowflakeConn struct { - *smux.Stream - sess *smux.Session + // *smux.Stream + Sess *smux.Session pconn net.PacketConn snowflakes *Peers } +func (conn *SnowflakeConn) OpenStream() (*smux.Stream, error) { + return conn.Sess.OpenStream() +} + // Close closes the connection. // // The collection of snowflake proxies for this connection is stopped. func (conn *SnowflakeConn) Close() error { var err error - log.Printf("---- SnowflakeConn: closed stream %v ---", conn.ID()) - err = conn.Stream.Close() + // log.Printf("---- SnowflakeConn: closed stream %v ---", conn.ID()) + // err = conn.Stream.Close() log.Printf("---- SnowflakeConn: end collecting snowflakes ---") conn.snowflakes.End() if inerr := conn.pconn.Close(); err == nil { err = inerr } log.Printf("---- SnowflakeConn: discarding finished session ---") - if inerr := conn.sess.Close(); err == nil { + if inerr := conn.Sess.Close(); err == nil { err = inerr } return err diff --git a/client/snowflake.go b/client/snowflake.go index 648481fa..1130bbf0 100644 --- a/client/snowflake.go +++ b/client/snowflake.go @@ -144,9 +144,15 @@ func socksAcceptLoop(ln *pt.SocksListener, config sf.ClientConfig, shutdown chan log.Printf("dial error: %s", err) return } + stream, err := sconn.OpenStream() + if err != nil { + log.Printf("sconn.OpenStream: %s", err) + return + } + defer sconn.Close() // copy between the created Snowflake conn and the SOCKS conn - copyLoop(conn, sconn) + copyLoop(conn, stream) }() select { case <-shutdown: diff --git a/server/lib/snowflake.go b/server/lib/snowflake.go index bcf9dd68..1e2fefd6 100644 --- a/server/lib/snowflake.go +++ b/server/lib/snowflake.go @@ -74,7 +74,7 @@ func NewSnowflakeServer(getCertificate func(*tls.ClientHelloInfo) (*tls.Certific func (t *Transport) Listen(addr net.Addr, numKCPInstances int) (*SnowflakeListener, error) { listener := &SnowflakeListener{ addr: addr, - queue: make(chan net.Conn, 65534), + queue: make(chan *SnowflakeClientConn, 65534), closed: make(chan struct{}), ln: make([]*kcp.Listener, 0, numKCPInstances), } @@ -167,7 +167,7 @@ func (t *Transport) Listen(addr net.Addr, numKCPInstances int) (*SnowflakeListen type SnowflakeListener struct { addr net.Addr - queue chan net.Conn + queue chan *SnowflakeClientConn server *http.Server ln []*kcp.Listener closed chan struct{} @@ -177,7 +177,7 @@ type SnowflakeListener struct { // Accept allows the caller to accept incoming Snowflake connections. // We accept connections from a queue to accommodate both incoming // smux Streams and legacy non-turbotunnel connections. -func (l *SnowflakeListener) Accept() (net.Conn, error) { +func (l *SnowflakeListener) Accept() (*SnowflakeClientConn, error) { select { case <-l.closed: // channel has been closed, no longer accepting connections @@ -228,16 +228,9 @@ func (l *SnowflakeListener) acceptStreams(conn *kcp.UDPSession) error { return err } - for { - stream, err := sess.AcceptStream() - if err != nil { - if err, ok := err.(net.Error); ok && err.Temporary() { - continue - } - return err - } - l.queueConn(&SnowflakeClientConn{stream: stream, address: addr}) - } + l.queueConn(&SnowflakeClientConn{Sess: sess, address: addr}) + <-sess.CloseChan() + return nil } // acceptSessions listens for incoming KCP connections and passes them to @@ -275,7 +268,7 @@ func (l *SnowflakeListener) acceptSessions(ln *kcp.Listener) error { } } -func (l *SnowflakeListener) queueConn(conn net.Conn) error { +func (l *SnowflakeListener) queueConn(conn *SnowflakeClientConn) error { select { case <-l.closed: return fmt.Errorf("accepted connection on closed listener") @@ -289,23 +282,22 @@ func (l *SnowflakeListener) queueConn(conn net.Conn) error { // RemoteAddr method is overridden to refer to a real IP address, looked up from // the client address map, rather than an abstract client ID. type SnowflakeClientConn struct { - stream *smux.Stream + Sess *smux.Session address net.Addr } // Forward net.Conn methods, other than RemoteAddr, to the inner stream. -func (conn *SnowflakeClientConn) Read(b []byte) (int, error) { return conn.stream.Read(b) } -func (conn *SnowflakeClientConn) Write(b []byte) (int, error) { return conn.stream.Write(b) } -func (conn *SnowflakeClientConn) Close() error { return conn.stream.Close() } -func (conn *SnowflakeClientConn) LocalAddr() net.Addr { return conn.stream.LocalAddr() } -func (conn *SnowflakeClientConn) SetDeadline(t time.Time) error { return conn.stream.SetDeadline(t) } -func (conn *SnowflakeClientConn) SetReadDeadline(t time.Time) error { - return conn.stream.SetReadDeadline(t) -} +func (conn *SnowflakeClientConn) Close() error { return conn.Sess.Close() } +func (conn *SnowflakeClientConn) LocalAddr() net.Addr { return conn.Sess.LocalAddr() } +func (conn *SnowflakeClientConn) SetDeadline(t time.Time) error { return conn.Sess.SetDeadline(t) } -func (conn *SnowflakeClientConn) SetWriteDeadline(t time.Time) error { - return conn.stream.SetWriteDeadline(t) -} +// func (conn *SnowflakeClientConn) SetReadDeadline(t time.Time) error { +// return conn.sess.SetReadDeadline(t) +// } + +// func (conn *SnowflakeClientConn) SetWriteDeadline(t time.Time) error { +// return conn.sess.SetWriteDeadline(t) +// } // RemoteAddr returns the mapped client address of the Snowflake connection. func (conn *SnowflakeClientConn) RemoteAddr() net.Addr { @@ -314,6 +306,6 @@ func (conn *SnowflakeClientConn) RemoteAddr() net.Addr { // WriteTo implements the io.WriterTo interface by passing the call to the // underlying smux.Stream. -func (conn *SnowflakeClientConn) WriteTo(w io.Writer) (int64, error) { - return conn.stream.WriteTo(w) -} +// func (conn *SnowflakeClientConn) WriteTo(w io.Writer) (int64, error) { +// return conn.stream.WriteTo(w) +// } diff --git a/server/server.go b/server/server.go index 3bd624f6..d04ed214 100644 --- a/server/server.go +++ b/server/server.go @@ -271,17 +271,6 @@ func main() { // Are we requested to use source addresses from a particular // range when dialing the ORPort for this transport? - var orPortSrcAddr *net.IPNet - if orPortSrcAddrCIDR, ok := bindaddr.Options.Get("orport-srcaddr"); ok { - ipnet, err := parseIPCIDR(orPortSrcAddrCIDR) - if err != nil { - err = fmt.Errorf("parsing srcaddr: %w", err) - log.Println(err) - pt.SmethodError(bindaddr.MethodName, err.Error()) - continue - } - orPortSrcAddr = ipnet - } numKCPInstances := 1 // Are we requested to run a certain number of KCP state @@ -307,9 +296,9 @@ func main() { continue } defer ln.Close() - go acceptLoop(ln, orPortSrcAddr) + // go acceptLoop(ln, orPortSrcAddr) pt.SmethodArgs(bindaddr.MethodName, bindaddr.Addr, args) - listeners = append(listeners, ln) + // listeners = append(listeners, ln) } pt.SmethodsDone() -- GitLab