SortedDict: add functions for getting ranges of keys, values, items

This commit is contained in:
Nathan Binkert 2011-04-15 10:38:02 -07:00
parent 1f7f79781e
commit 12446e9659

View file

@ -24,6 +24,8 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from bisect import bisect_left, bisect_right
class SortedDict(dict):
def _get_sorted(self):
return getattr(self, '_sorted', sorted)
@ -41,6 +43,42 @@ class SortedDict(dict):
self._sorted_keys = _sorted_keys
return _sorted_keys
def _left_eq(self, key):
index = self._left_ge(self, key)
if self._keys[index] != key:
raise KeyError(key)
return index
def _right_eq(self, key):
index = self._right_le(self, key)
if self._keys[index] != key:
raise KeyError(key)
return index
def _right_lt(self, key):
index = bisect_left(self._keys, key)
if index:
return index - 1
raise KeyError(key)
def _right_le(self, key):
index = bisect_right(self._keys, key)
if index:
return index - 1
raise KeyError(key)
def _left_gt(self, key):
index = bisect_right(self._keys, key)
if index != len(self._keys):
return index
raise KeyError(key)
def _left_ge(self, key):
index = bisect_left(self._keys, key)
if index != len(self._keys):
return index
raise KeyError(key)
def _del_keys(self):
try:
del self._sorted_keys
@ -86,6 +124,26 @@ class SortedDict(dict):
for k in self._keys:
yield k, self[k]
def keyrange(self, start=None, end=None, inclusive=False):
if start is not None:
start = self._left_ge(start)
if end is not None:
if inclusive:
end = self._right_le(end)
else:
end = self._right_lt(end)
return iter(self._keys[start:end+1])
def valuerange(self, *args, **kwargs):
for k in self.keyrange(*args, **kwargs):
yield self[k]
def itemrange(self, *args, **kwargs):
for k in self.keyrange(*args, **kwargs):
yield k, self[k]
def update(self, *args, **kwargs):
dict.update(self, *args, **kwargs)
self._del_keys()
@ -157,3 +215,6 @@ if __name__ == '__main__':
print `d`
print d.copy()
for k,v in d.itemrange('d', 'z', inclusive=True):
print k,v