diff --git a/src/python/m5/util/sorteddict.py b/src/python/m5/util/sorteddict.py index c91bd943d..ef32be3af 100644 --- a/src/python/m5/util/sorteddict.py +++ b/src/python/m5/util/sorteddict.py @@ -24,6 +24,8 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from bisect import bisect_left, bisect_right + class SortedDict(dict): def _get_sorted(self): return getattr(self, '_sorted', sorted) @@ -41,6 +43,42 @@ class SortedDict(dict): self._sorted_keys = _sorted_keys return _sorted_keys + def _left_eq(self, key): + index = self._left_ge(self, key) + if self._keys[index] != key: + raise KeyError(key) + return index + + def _right_eq(self, key): + index = self._right_le(self, key) + if self._keys[index] != key: + raise KeyError(key) + return index + + def _right_lt(self, key): + index = bisect_left(self._keys, key) + if index: + return index - 1 + raise KeyError(key) + + def _right_le(self, key): + index = bisect_right(self._keys, key) + if index: + return index - 1 + raise KeyError(key) + + def _left_gt(self, key): + index = bisect_right(self._keys, key) + if index != len(self._keys): + return index + raise KeyError(key) + + def _left_ge(self, key): + index = bisect_left(self._keys, key) + if index != len(self._keys): + return index + raise KeyError(key) + def _del_keys(self): try: del self._sorted_keys @@ -86,6 +124,26 @@ class SortedDict(dict): for k in self._keys: yield k, self[k] + def keyrange(self, start=None, end=None, inclusive=False): + if start is not None: + start = self._left_ge(start) + + if end is not None: + if inclusive: + end = self._right_le(end) + else: + end = self._right_lt(end) + + return iter(self._keys[start:end+1]) + + def valuerange(self, *args, **kwargs): + for k in self.keyrange(*args, **kwargs): + yield self[k] + + def itemrange(self, *args, **kwargs): + for k in self.keyrange(*args, **kwargs): + yield k, self[k] + def update(self, *args, **kwargs): dict.update(self, *args, **kwargs) self._del_keys() @@ -157,3 +215,6 @@ if __name__ == '__main__': print `d` print d.copy() + + for k,v in d.itemrange('d', 'z', inclusive=True): + print k,v