From 12cd3f06dbe777033cc6ce98f3131219bf36315c Mon Sep 17 00:00:00 2001
From: WofWca <wofwca@protonmail.com>
Date: Wed, 15 Jan 2025 00:04:45 +0400
Subject: [PATCH] WIP: change API to expose smux.OpenStream()

---
 client/lib/snowflake.go | 32 ++++++++++++++------------
 client/snowflake.go     |  8 ++++++-
 server/lib/snowflake.go | 50 +++++++++++++++++------------------------
 server/server.go        | 15 ++-----------
 4 files changed, 48 insertions(+), 57 deletions(-)

diff --git a/client/lib/snowflake.go b/client/lib/snowflake.go
index f1a3bada..6281df1a 100644
--- a/client/lib/snowflake.go
+++ b/client/lib/snowflake.go
@@ -172,7 +172,7 @@ func NewSnowflakeClient(config ClientConfig) (*Transport, error) {
 // Dial starts the collection of snowflakes and returns a SnowflakeConn that is a
 // wrapper around a smux.Stream that will reliably deliver data to a Snowflake
 // server through one or more snowflake proxies.
-func (t *Transport) Dial() (net.Conn, error) {
+func (t *Transport) Dial() (*SnowflakeConn, error) {
 	// Cleanup functions to run before returning, in case of an error.
 	var cleanup []func()
 	defer func() {
@@ -207,17 +207,17 @@ func (t *Transport) Dial() (net.Conn, error) {
 	})
 
 	// On the smux session we overlay a stream.
-	stream, err := sess.OpenStream()
-	if err != nil {
-		return nil, err
-	}
-	// Begin exchanging data.
-	log.Printf("---- SnowflakeConn: begin stream %v ---", stream.ID())
-	cleanup = append(cleanup, func() { stream.Close() })
+	// stream, err := sess.OpenStream()
+	// if err != nil {
+	// 	return nil, err
+	// }
+	// // Begin exchanging data.
+	// log.Printf("---- SnowflakeConn: begin stream %v ---", stream.ID())
+	// cleanup = append(cleanup, func() { stream.Close() })
 
 	// All good, clear the cleanup list.
 	cleanup = nil
-	return &SnowflakeConn{Stream: stream, sess: sess, pconn: pconn, snowflakes: snowflakes}, nil
+	return &SnowflakeConn{Sess: sess, pconn: pconn, snowflakes: snowflakes}, nil
 }
 
 func (t *Transport) AddSnowflakeEventListener(receiver event.SnowflakeEventReceiver) {
@@ -235,26 +235,30 @@ func (t *Transport) SetRendezvousMethod(r RendezvousMethod) {
 
 // SnowflakeConn is a reliable connection to a snowflake server that implements net.Conn.
 type SnowflakeConn struct {
-	*smux.Stream
-	sess       *smux.Session
+	// *smux.Stream
+	Sess       *smux.Session
 	pconn      net.PacketConn
 	snowflakes *Peers
 }
 
+func (conn *SnowflakeConn) OpenStream() (*smux.Stream, error) {
+	return conn.Sess.OpenStream()
+}
+
 // Close closes the connection.
 //
 // The collection of snowflake proxies for this connection is stopped.
 func (conn *SnowflakeConn) Close() error {
 	var err error
-	log.Printf("---- SnowflakeConn: closed stream %v ---", conn.ID())
-	err = conn.Stream.Close()
+	// log.Printf("---- SnowflakeConn: closed stream %v ---", conn.ID())
+	// err = conn.Stream.Close()
 	log.Printf("---- SnowflakeConn: end collecting snowflakes ---")
 	conn.snowflakes.End()
 	if inerr := conn.pconn.Close(); err == nil {
 		err = inerr
 	}
 	log.Printf("---- SnowflakeConn: discarding finished session ---")
-	if inerr := conn.sess.Close(); err == nil {
+	if inerr := conn.Sess.Close(); err == nil {
 		err = inerr
 	}
 	return err
diff --git a/client/snowflake.go b/client/snowflake.go
index 648481fa..1130bbf0 100644
--- a/client/snowflake.go
+++ b/client/snowflake.go
@@ -144,9 +144,15 @@ func socksAcceptLoop(ln *pt.SocksListener, config sf.ClientConfig, shutdown chan
 					log.Printf("dial error: %s", err)
 					return
 				}
+				stream, err := sconn.OpenStream()
+				if err != nil {
+					log.Printf("sconn.OpenStream: %s", err)
+					return
+				}
+
 				defer sconn.Close()
 				// copy between the created Snowflake conn and the SOCKS conn
-				copyLoop(conn, sconn)
+				copyLoop(conn, stream)
 			}()
 			select {
 			case <-shutdown:
diff --git a/server/lib/snowflake.go b/server/lib/snowflake.go
index bcf9dd68..1e2fefd6 100644
--- a/server/lib/snowflake.go
+++ b/server/lib/snowflake.go
@@ -74,7 +74,7 @@ func NewSnowflakeServer(getCertificate func(*tls.ClientHelloInfo) (*tls.Certific
 func (t *Transport) Listen(addr net.Addr, numKCPInstances int) (*SnowflakeListener, error) {
 	listener := &SnowflakeListener{
 		addr:   addr,
-		queue:  make(chan net.Conn, 65534),
+		queue:  make(chan *SnowflakeClientConn, 65534),
 		closed: make(chan struct{}),
 		ln:     make([]*kcp.Listener, 0, numKCPInstances),
 	}
@@ -167,7 +167,7 @@ func (t *Transport) Listen(addr net.Addr, numKCPInstances int) (*SnowflakeListen
 
 type SnowflakeListener struct {
 	addr      net.Addr
-	queue     chan net.Conn
+	queue     chan *SnowflakeClientConn
 	server    *http.Server
 	ln        []*kcp.Listener
 	closed    chan struct{}
@@ -177,7 +177,7 @@ type SnowflakeListener struct {
 // Accept allows the caller to accept incoming Snowflake connections.
 // We accept connections from a queue to accommodate both incoming
 // smux Streams and legacy non-turbotunnel connections.
-func (l *SnowflakeListener) Accept() (net.Conn, error) {
+func (l *SnowflakeListener) Accept() (*SnowflakeClientConn, error) {
 	select {
 	case <-l.closed:
 		// channel has been closed, no longer accepting connections
@@ -228,16 +228,9 @@ func (l *SnowflakeListener) acceptStreams(conn *kcp.UDPSession) error {
 		return err
 	}
 
-	for {
-		stream, err := sess.AcceptStream()
-		if err != nil {
-			if err, ok := err.(net.Error); ok && err.Temporary() {
-				continue
-			}
-			return err
-		}
-		l.queueConn(&SnowflakeClientConn{stream: stream, address: addr})
-	}
+	l.queueConn(&SnowflakeClientConn{Sess: sess, address: addr})
+	<-sess.CloseChan()
+	return nil
 }
 
 // acceptSessions listens for incoming KCP connections and passes them to
@@ -275,7 +268,7 @@ func (l *SnowflakeListener) acceptSessions(ln *kcp.Listener) error {
 	}
 }
 
-func (l *SnowflakeListener) queueConn(conn net.Conn) error {
+func (l *SnowflakeListener) queueConn(conn *SnowflakeClientConn) error {
 	select {
 	case <-l.closed:
 		return fmt.Errorf("accepted connection on closed listener")
@@ -289,23 +282,22 @@ func (l *SnowflakeListener) queueConn(conn net.Conn) error {
 // RemoteAddr method is overridden to refer to a real IP address, looked up from
 // the client address map, rather than an abstract client ID.
 type SnowflakeClientConn struct {
-	stream  *smux.Stream
+	Sess    *smux.Session
 	address net.Addr
 }
 
 // Forward net.Conn methods, other than RemoteAddr, to the inner stream.
-func (conn *SnowflakeClientConn) Read(b []byte) (int, error)    { return conn.stream.Read(b) }
-func (conn *SnowflakeClientConn) Write(b []byte) (int, error)   { return conn.stream.Write(b) }
-func (conn *SnowflakeClientConn) Close() error                  { return conn.stream.Close() }
-func (conn *SnowflakeClientConn) LocalAddr() net.Addr           { return conn.stream.LocalAddr() }
-func (conn *SnowflakeClientConn) SetDeadline(t time.Time) error { return conn.stream.SetDeadline(t) }
-func (conn *SnowflakeClientConn) SetReadDeadline(t time.Time) error {
-	return conn.stream.SetReadDeadline(t)
-}
+func (conn *SnowflakeClientConn) Close() error                  { return conn.Sess.Close() }
+func (conn *SnowflakeClientConn) LocalAddr() net.Addr           { return conn.Sess.LocalAddr() }
+func (conn *SnowflakeClientConn) SetDeadline(t time.Time) error { return conn.Sess.SetDeadline(t) }
 
-func (conn *SnowflakeClientConn) SetWriteDeadline(t time.Time) error {
-	return conn.stream.SetWriteDeadline(t)
-}
+// func (conn *SnowflakeClientConn) SetReadDeadline(t time.Time) error {
+// 	return conn.sess.SetReadDeadline(t)
+// }
+
+// func (conn *SnowflakeClientConn) SetWriteDeadline(t time.Time) error {
+// 	return conn.sess.SetWriteDeadline(t)
+// }
 
 // RemoteAddr returns the mapped client address of the Snowflake connection.
 func (conn *SnowflakeClientConn) RemoteAddr() net.Addr {
@@ -314,6 +306,6 @@ func (conn *SnowflakeClientConn) RemoteAddr() net.Addr {
 
 // WriteTo implements the io.WriterTo interface by passing the call to the
 // underlying smux.Stream.
-func (conn *SnowflakeClientConn) WriteTo(w io.Writer) (int64, error) {
-	return conn.stream.WriteTo(w)
-}
+// func (conn *SnowflakeClientConn) WriteTo(w io.Writer) (int64, error) {
+// 	return conn.stream.WriteTo(w)
+// }
diff --git a/server/server.go b/server/server.go
index 3bd624f6..d04ed214 100644
--- a/server/server.go
+++ b/server/server.go
@@ -271,17 +271,6 @@ func main() {
 
 		// Are we requested to use source addresses from a particular
 		// range when dialing the ORPort for this transport?
-		var orPortSrcAddr *net.IPNet
-		if orPortSrcAddrCIDR, ok := bindaddr.Options.Get("orport-srcaddr"); ok {
-			ipnet, err := parseIPCIDR(orPortSrcAddrCIDR)
-			if err != nil {
-				err = fmt.Errorf("parsing srcaddr: %w", err)
-				log.Println(err)
-				pt.SmethodError(bindaddr.MethodName, err.Error())
-				continue
-			}
-			orPortSrcAddr = ipnet
-		}
 
 		numKCPInstances := 1
 		// Are we requested to run a certain number of KCP state
@@ -307,9 +296,9 @@ func main() {
 			continue
 		}
 		defer ln.Close()
-		go acceptLoop(ln, orPortSrcAddr)
+		// go acceptLoop(ln, orPortSrcAddr)
 		pt.SmethodArgs(bindaddr.MethodName, bindaddr.Addr, args)
-		listeners = append(listeners, ln)
+		// listeners = append(listeners, ln)
 	}
 	pt.SmethodsDone()
 
-- 
GitLab