Unverified Commit 75eb8e9e authored by 3andne's avatar 3andne Committed by GitHub
Browse files

feat: add an option to skip resumption on nil ext & update examples (#239)

* feat: add an option to skip resumption on nil ext
feat: update examples

* fix: clone unit test
parent df6e4c82
Loading
Loading
Loading
Loading
+15 −0
Original line number Diff line number Diff line
@@ -700,6 +700,19 @@ type Config struct {
	// This field is ignored when InsecureSkipVerify is true.
	InsecureServerNameToVerify string // [uTLS]

	// PreferSkipResumptionOnNilExtension controls the behavior when session resumption is enabled but the corresponding session extensions are nil.
	//
	// To successfully use session resumption, ensure that the following requirements are met:
	//  - SessionTicketsDisabled is set to false
	//  - ClientSessionCache is non-nil
	//  - For TLS 1.2, SessionTicketExtension is non-nil
	//  - For TLS 1.3, PreSharedKeyExtension is non-nil
	//
	// There may be cases where users enable session resumption (SessionTicketsDisabled: false && ClientSessionCache: non-nil), but they do not provide SessionTicketExtension or PreSharedKeyExtension in the ClientHelloSpec. This could be intentional or accidental.
	//
	// By default, utls throws an exception in such scenarios. Set this to true to skip the resumption and suppress the exception.
	PreferSkipResumptionOnNilExtension bool // [uTLS]

	// CipherSuites is a list of enabled TLS 1.0–1.2 cipher suites. The order of
	// the list is ignored. Note that TLS 1.3 ciphersuites are not configurable.
	//
@@ -906,6 +919,8 @@ func (c *Config) Clone() *Config {
		KeyLogWriter:                c.KeyLogWriter,
		sessionTicketKeys:           c.sessionTicketKeys,
		autoSessionTicketKeys:       c.autoSessionTicketKeys,

		PreferSkipResumptionOnNilExtension: c.PreferSkipResumptionOnNilExtension, // [UTLS]
	}
}

+49 −11
Original line number Diff line number Diff line
@@ -38,7 +38,16 @@ func (csc *ClientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState
	}
}

func runResumptionCheck(helloID tls.ClientHelloID, serverAddr string, retry int, verbose bool) {
type ResumptionType int

const (
	noResumption     ResumptionType = 0
	pskResumption    ResumptionType = 1
	ticketResumption ResumptionType = 2
)

func runResumptionCheck(helloID tls.ClientHelloID, getCustomSpec func() *tls.ClientHelloSpec, expectResumption ResumptionType, serverAddr string, retry int, verbose bool) {
	fmt.Printf("checking: hello [%s], expectResumption [%v], serverAddr [%s]\n", helloID.Client, expectResumption, serverAddr)
	csc := NewClientSessionCache()
	tcpConn, err := net.Dial("tcp", serverAddr)
	if err != nil {
@@ -55,6 +64,10 @@ func runResumptionCheck(helloID tls.ClientHelloID, serverAddr string, retry int,
		OmitEmptyPsk:       true,
	}, helloID)

	if getCustomSpec != nil {
		tlsConn.ApplyPreset(getCustomSpec())
	}

	// HS
	err = tlsConn.Handshake()
	if err != nil {
@@ -96,6 +109,7 @@ func runResumptionCheck(helloID tls.ClientHelloID, serverAddr string, retry int,
	}
	tlsConn.Close()

	resumption := noResumption
	for i := 0; i < retry; i++ {
		tcpConnPSK, err := net.Dial("tcp", serverAddr)
		if err != nil {
@@ -108,6 +122,10 @@ func runResumptionCheck(helloID tls.ClientHelloID, serverAddr string, retry int,
			OmitEmptyPsk:       true,
		}, helloID)

		if getCustomSpec != nil {
			tlsConnPSK.ApplyPreset(getCustomSpec())
		}

		// HS
		err = tlsConnPSK.Handshake()
		if verbose {
@@ -133,27 +151,47 @@ func runResumptionCheck(helloID tls.ClientHelloID, serverAddr string, retry int,

			if tlsVer == tls.VersionTLS13 && tlsConnPSK.HandshakeState.State13.UsingPSK {
				fmt.Println("[PSK used]")
				return
				resumption = pskResumption
				break
			} else if tlsVer == tls.VersionTLS12 && tlsConnPSK.DidTls12Resume() {
				fmt.Println("[session ticket used]")
				return
				resumption = ticketResumption
				break
			}
		}
		time.Sleep(700 * time.Millisecond)
	}
	panic(fmt.Sprintf("PSK or session ticket not used for a resumption session, server %s, helloID: %s", serverAddr, helloID.Client))

	if resumption != expectResumption {
		panic(fmt.Sprintf("Expecting resumption type: %v, actual %v; session, server %s, helloID: %s", expectResumption, resumption, serverAddr, helloID.Client))
	} else {
		fmt.Println("[expected]")
	}
}

func main() {
	tls13Url := "www.microsoft.com:443"
	tls12Url1 := "spocs.getpocket.com:443"
	tls12Url2 := "marketplace.visualstudio.com:443"
	runResumptionCheck(tls.HelloChrome_100_PSK, tls13Url, 1, false) // psk + utls
	runResumptionCheck(tls.HelloGolang, tls13Url, 1, false)         // psk + crypto/tls

	runResumptionCheck(tls.HelloChrome_100_PSK, tls12Url1, 10, false) // session ticket + utls
	runResumptionCheck(tls.HelloGolang, tls12Url1, 10, false)         // session ticket + crypto/tls
	runResumptionCheck(tls.HelloChrome_100_PSK, tls12Url2, 10, false) // session ticket + utls
	runResumptionCheck(tls.HelloGolang, tls12Url2, 10, false)         // session ticket + crypto/tls
	runResumptionCheck(tls.HelloChrome_100, nil, noResumption, tls13Url, 3, false) // no-resumption + utls
	func() {
		defer func() {
			if err := recover(); err == nil {
				panic("must throw")
			}
		}()

		runResumptionCheck(tls.HelloCustom, func() *tls.ClientHelloSpec {
			spec, _ := tls.UTLSIdToSpec(tls.HelloChrome_100)
			return &spec
		}, noResumption, tls13Url, 3, false) // no-resumption + utls custom + no psk extension
	}()
	runResumptionCheck(tls.HelloChrome_100_PSK, nil, pskResumption, tls13Url, 1, false) // psk + utls
	runResumptionCheck(tls.HelloGolang, nil, pskResumption, tls13Url, 1, false)         // psk + crypto/tls

	runResumptionCheck(tls.HelloChrome_100_PSK, nil, ticketResumption, tls12Url1, 10, false) // session ticket + utls
	runResumptionCheck(tls.HelloGolang, nil, ticketResumption, tls12Url1, 10, false)         // session ticket + crypto/tls
	runResumptionCheck(tls.HelloChrome_100_PSK, nil, ticketResumption, tls12Url2, 10, false) // session ticket + utls
	runResumptionCheck(tls.HelloGolang, nil, ticketResumption, tls12Url2, 10, false)         // session ticket + crypto/tls

}
+1 −1
Original line number Diff line number Diff line
@@ -854,7 +854,7 @@ func TestCloneNonFuncFields(t *testing.T) {
			f.Set(reflect.ValueOf("b"))
		case "ClientAuth":
			f.Set(reflect.ValueOf(VerifyClientCertIfGiven))
		case "InsecureSkipVerify", "InsecureSkipTimeVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites", "OmitEmptyPsk":
		case "InsecureSkipVerify", "InsecureSkipTimeVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites", "OmitEmptyPsk", "PreferSkipResumptionOnNilExtension":
			f.Set(reflect.ValueOf(true))
		case "InsecureServerNameToVerify":
			f.Set(reflect.ValueOf("c"))
+6 −0
Original line number Diff line number Diff line
@@ -39,6 +39,11 @@ type UConn struct {

	omitSNIExtension bool

	// skipResumptionOnNilExtension is copied from `Config.PreferSkipResumptionOnNilExtension`.
	//
	// By default, if ClientHelloSpec is predefined or utls-generated (as opposed to HelloCustom), this flag will be updated to true.
	skipResumptionOnNilExtension bool

	// certCompressionAlgs represents the set of advertised certificate compression
	// algorithms, as specified in the ClientHello. This is only relevant client-side, for the
	// server certificate. All other forms of certificate compression are unsupported.
@@ -58,6 +63,7 @@ func UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID) *UConn
	uconn.handshakeFn = uconn.clientHandshake
	uconn.sessionController = newSessionController(&uconn)
	uconn.utls.sessionController = uconn.sessionController
	uconn.skipResumptionOnNilExtension = config.PreferSkipResumptionOnNilExtension || clientHelloID.Client != helloCustom
	return &uconn
}

+16 −3
Original line number Diff line number Diff line
@@ -121,6 +121,12 @@ func (s *sessionController) assertNotLocked(caller string) {
	}
}

func (s *sessionController) assertCanSkip(caller, extensionName string) {
	if !s.uconnRef.skipResumptionOnNilExtension {
		panic(fmt.Sprintf("tls: %s failed: session resumption is enabled, but there is no %s in the ClientHelloSpec; Please consider provide one in the ClientHelloSpec; If this is intentional, you may consider disable resumption by setting Config.SessionTicketsDisabled to true, or set Config.PreferSkipResumptionOnNilExtension to true to suppress this exception", caller, extensionName))
	}
}

// finalCheck performs a comprehensive check on the updated state to ensure the correctness of the changes.
// If the checks pass successfully, the sessionController's state will be locked.
// Any failure in passing the tests indicates incorrect implementations in the utls, which will result in triggering a panic.
@@ -141,7 +147,11 @@ func (s *sessionController) initSessionTicketExt(session *SessionState, ticket [
	s.assertNotLocked("initSessionTicketExt")
	s.assertHelloNotBuilt("initSessionTicketExt")
	s.assertControllerState("initSessionTicketExt", NoSession)
	panicOnNil("initSessionTicketExt", s.sessionTicketExt, session, ticket)
	panicOnNil("initSessionTicketExt", session, ticket)
	if s.sessionTicketExt == nil {
		s.assertCanSkip("initSessionTicketExt", "session ticket extension")
		return
	}
	initializationGuard(s.sessionTicketExt, func(e ISessionTicketExtension) {
		s.sessionTicketExt.InitializeByUtls(session, ticket)
	})
@@ -155,8 +165,11 @@ func (s *sessionController) initPskExt(session *SessionState, earlySecret []byte
	s.assertNotLocked("initPskExt")
	s.assertHelloNotBuilt("initPskExt")
	s.assertControllerState("initPskExt", NoSession)
	panicOnNil("initPskExt", s.pskExtension, session, earlySecret, pskIdentities)

	panicOnNil("initPskExt", session, earlySecret, pskIdentities)
	if s.pskExtension == nil {
		s.assertCanSkip("initPskExt", "pre-shared key extension")
		return
	}
	initializationGuard(s.pskExtension, func(e PreSharedKeyExtension) {
		publicPskIdentities := mapSlice(pskIdentities, func(private pskIdentity) PskIdentity {
			return PskIdentity{