Commit 2296d907 authored by George Kadianakis's avatar George Kadianakis
Browse files

Use temporary files instead of "/tmp" in scramblesuit unittests.

Conflicts:
	obfsproxy/test/transports/test_scramblesuit.py
parent ee185b89
Loading
Loading
Loading
Loading
+34 −17
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@ import obfsproxy.network.buffer as obfs_buf
import obfsproxy.common.transport_config as transport_config
import obfsproxy.transports.base as base

import obfsproxy.transports.scramblesuit.state as state
import obfsproxy.transports.scramblesuit.util as util
import obfsproxy.transports.scramblesuit.const as const
import obfsproxy.transports.scramblesuit.mycrypto as mycrypto
@@ -222,18 +223,16 @@ class UtilTest( unittest.TestCase ):
    def test4_setStateLocation( self ):
        name = (const.TRANSPORT_NAME).lower()

        util.setStateLocation("/tmp")
        self.failUnless(const.STATE_LOCATION == "/tmp/%s/" % name)

        # Nothing should change if we pass "None".
        util.setStateLocation(None)
        self.failUnless(const.STATE_LOCATION == "/tmp/%s/" % name)

        # Check if function creates non-existant directories.
        d = tempfile.mkdtemp()
        util.setStateLocation(d)
        self.failUnless(const.STATE_LOCATION == "%s/%s/" % (d, name))
        self.failUnless(os.path.exists("%s/%s/" % (d, name)))

        # Nothing should change if we pass "None".
        util.setStateLocation(None)
        self.failUnless(const.STATE_LOCATION == "%s/%s/" % (d, name))

        shutil.rmtree(d)

    def test5_getEpoch( self ):
@@ -256,15 +255,14 @@ class UtilTest( unittest.TestCase ):
        self.failUnless(util.readFromFile("/etc/shadow") == None)

class StateTest( unittest.TestCase ):

    def setUp( self ):
        const.STATE_LOCATION = "/tmp/"
        self.stateFile = const.STATE_LOCATION + const.SERVER_STATE_FILE
        const.STATE_LOCATION = tempfile.mkdtemp()
        self.stateFile = os.path.join(const.STATE_LOCATION, const.SERVER_STATE_FILE)
        self.state = state.State()

    def tearDown( self ):
        try:
            os.unlink(self.stateFile)
            shutil.rmtree(const.STATE_LOCATION)
        except OSError:
            pass

@@ -322,6 +320,14 @@ class ScrambleSuitTransportTest( unittest.TestCase ):
        self.validSecret = base64.b32encode( 'A' * const.SHARED_SECRET_LENGTH )
        self.invalidSecret = 'a' * const.SHARED_SECRET_LENGTH

        self.statefile = tempfile.mkdtemp()

    def tearDown( self ):
        try:
            shutil.rmtree(self.statefile)
        except OSError:
            pass

    def test1_validateExternalModeCli( self ):
        """Test with valid scramblesuit args and valid obfsproxy args."""
        self.args.uniformDHSecret = self.validSecret
@@ -340,7 +346,10 @@ class ScrambleSuitTransportTest( unittest.TestCase ):
            self.suit.validate_external_mode_cli( self.args )

    def test3_get_public_server_options( self ):
        scramblesuit.ScrambleSuitTransport.setup(transport_config.TransportConfig())
        transCfg = transport_config.TransportConfig()
        transCfg.setStateLocation(self.statefile)

        scramblesuit.ScrambleSuitTransport.setup(transCfg)
        options = scramblesuit.ScrambleSuitTransport.get_public_server_options("")
        self.failUnless("password" in options)

@@ -388,13 +397,21 @@ class MessageTest( unittest.TestCase ):
                          message.ProtocolMessage, "1", paddingLen=const.MPU)

class TicketTest( unittest.TestCase ):
    def setUp( self ):
        const.STATE_LOCATION = tempfile.mkdtemp()
        self.stateFile = os.path.join(const.STATE_LOCATION, const.SERVER_STATE_FILE)
        self.state = state.State()
        self.state.genState()

    def test1_authentication( self ):
        srvState = state.State()
        srvState.genState()
    def tearDown( self ):
        try:
            shutil.rmtree(const.STATE_LOCATION)
        except OSError:
            pass

    def test1_authentication( self ):
        ss = scramblesuit.ScrambleSuitTransport()
        ss.srvState = srvState
        ss.srvState = self.state

        realEpoch = util.getEpoch

@@ -404,7 +421,7 @@ class TicketTest( unittest.TestCase ):
            util.getEpoch = lambda: epoch

            # Prepare ticket message.
            blurb = ticket.issueTicketAndKey(srvState)
            blurb = ticket.issueTicketAndKey(self.state)
            rawTicket = blurb[const.MASTER_KEY_LENGTH:]
            masterKey = blurb[:const.MASTER_KEY_LENGTH]
            ss.deriveSecrets(masterKey)