GitLab is used only for code review, issue tracking and project management. Canonical locations for source code are still https://gitweb.torproject.org/ https://git.torproject.org/ and git-rw.torproject.org.

state.py 3.33 KB
Newer Older
1 2 3 4
from sbws.util.filelock import FileLock
import os
import json

juga  's avatar
juga committed
5 6
from .json import CustomDecoder, CustomEncoder

7 8

class State:
9 10 11
    """
    `json` wrapper to read a json file every time it gets a key and to write
    to the file every time a key is set.
Matt Traudt's avatar
Matt Traudt committed
12

13 14
    Every time a key is got or set, the file is locked, to atomically access
    and update the file across threads and across processes.
Matt Traudt's avatar
Matt Traudt committed
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40

    >>> state = State('foo.state')
    >>> # state == {}

    >>> state['linux'] = True
    >>> # 'foo.state' now exists on disk with the JSON for {'linux': True}

    >>> # We read 'foo.state' from disk in order to get the most up-to-date
    >>> #     state info. Pretend another process has updated 'linux' to be
    >>> #     False
    >>> state['linux']
    >>> # returns False

    >>> # Pretend another process has added the user's age to the state file.
    >>> #     As before, we read the state file from disk for the most
    >>> #     up-to-date info.
    >>> state['age']
    >>> # Returns 14

    >>> # We now set their name. We read the state file first, set the option,
    >>> #     and then write it out.
    >>> state['name'] = 'John'

    >>> # We can do many of the same things with a State object as with a dict
    >>> for key in state: print(key)
    >>> # Prints 'linux', 'age', and 'name'
41 42

    """
43 44 45 46 47 48

    def __init__(self, fname):
        self._fname = fname
        self._state = self._read()

    def _read(self):
49 50
        if not os.path.exists(self._fname):
            return {}
51 52
        with FileLock(self._fname):
            with open(self._fname, 'rt') as fd:
juga  's avatar
juga committed
53
                return json.load(fd, cls=CustomDecoder)
54 55 56 57

    def _write(self):
        with FileLock(self._fname):
            with open(self._fname, 'wt') as fd:
juga  's avatar
juga committed
58
                return json.dump(self._state, fd, indent=4, cls=CustomEncoder)
59 60 61 62 63

    def __len__(self):
        self._state = self._read()
        return self._state.__len__()

64 65 66 67 68
    def get(self, key, d=None):
        """
        Implements a dictionary ``get`` method reading and locking
        a json file.
        """
juga  's avatar
juga committed
69
        self._state = self._read()
70
        return self._state.get(key, d)
juga  's avatar
juga committed
71

72 73 74 75 76 77 78 79 80 81
    def __getitem__(self, key):
        self._state = self._read()
        return self._state.__getitem__(key)

    def __delitem__(self, key):
        self._state = self._read()
        self._state.__delitem__(key)
        self._write()

    def __setitem__(self, key, value):
82 83 84
        # NOTE: important, read the file before setting the key,
        # otherwise if other instances are creating other keys, they're lost.
        self._state = self._read()
85 86 87 88 89 90 91 92 93 94
        self._state.__setitem__(key, value)
        self._write()

    def __iter__(self):
        self._state = self._read()
        return self._state.__iter__()

    def __contains__(self, item):
        self._state = self._read()
        return self._state.__contains__(item)
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109

    def count(self, k):
        """
        Returns the length if the key value is a list
        or the sum of number if the key value is a list of list
        or the key value
        or None if the state doesn't have the key.
        """
        if self.get(k):
            if isinstance(self._state[k], list):
                if isinstance(self._state[k][0], list):
                    return sum(map(lambda x: x[1], self._state[k]))
                return len(self._state[k])
            return self.get(k)
        return None