Source code for epipack.mock_samplable_set

Contains a MockSamplableSet class that mimicks the behavior of but is less efficient.

import numpy as np

[docs]def choice(arr,p): """ Returns a random element from ``arr`` with probability given in array ``p``. If ``arr`` is not an iterable, the function returns the index of the chosen element. """ ndx = np.argmax(np.random.rand()<np.cumsum(p)) try: return arr[ndx] except TypeError as e: return ndx
[docs]class MockSamplableSet: """ A set of items that can be sampled with probability proportional to a corresponding item weight. Mimicks the behavior of without being as efficient. Works similar to Python's set, with ``__getitem__``, ``__setitem__``, ``__delitem__``, ``__iter__``, ``__len__``, ``__contains__``. Parameters ========== min_weight : float minimum possible weight max_weight : float maximum possible weight weighted_elements : list, default = [] list of 2-tuples, first entry an item, second entry a weight cpp_type : str, default = 'int' The type of the items. Attributes ========== min_weight : float minimum possible weight max_weight : float maximum possible weight items : numpy.ndarray list of items in this set weights : numpy.ndarray list of corresponding weights """ def __init__(self,min_weight,max_weight,weighted_elements=[],cpp_type='int'): self.min_weight = min_weight self.max_weight = max_weight if type(weighted_elements) == dict: weighted_elements = list(weighted_elements.items()) self.items = np.array([ e[0] for e in weighted_elements ],dtype=cpp_type) self.weights = np.array([ e[1] for e in weighted_elements ],dtype=float) sort_ndx = np.argsort(self.items) self.items = self.items[sort_ndx] self.weights = self.weights[sort_ndx] self._total_weight = self.weights.sum() if np.any(self.weights < self.min_weight): raise ValueError("There are weights below the limit.") if np.any(self.weights > self.max_weight): raise ValueError("There are weights above the limit.")
[docs] def sample(self): """ Random sample from the set, sampled with probability proportional to items' weight. Returns ======= item : cpp_type An item from the set weight : float The weight of the item """ #ndx = np.random.choice(len(self.items),p=self.weights/self._total_weight) #ndx = np.argwhere(np.random.rand()<np.cumsum(self.weights/self._total_weight))[0][0] ndx = choice(len(self.items),p=self.weights/self._total_weight) return self.items[ndx], self.weights[ndx]
def __getitem__(self,key): found_key, ndx = self._find_key(key) if not found_key: raise KeyError("`",key,"` is not in this set.") else: return self.weights[ndx] def __delitem__(self,key): found_key, ndx = self._find_key(key) if found_key: self.items = np.delete(self.items, ndx) self.weights = np.delete(self.weights, ndx) self._total_weight = self.weights.sum() def __setitem__(self,key,value): if value < self.min_weight or value > self.max_weight: raise ValueError('Inserting element-weight pair ' + str(key) +" "+ str(value)+" \n" +\ 'has weight value out of bounds of ' + str(self.min_weight) + " " + \ str(self.max_weight)) found_key, ndx = self._find_key(key) if not found_key: self.items = np.insert(self.items, ndx, key) self.weights = np.insert(self.weights, ndx, value) else: self.weights[ndx] = value self._total_weight = self.weights.sum() def _find_key(self,key): ndx = np.searchsorted(self.items, key) return ( not ((ndx == len(self.items) or self.items[ndx] != key)), ndx ) def __iter__(self): self._ndx = 0 return self def __next__(self): if self._ndx < len(self.items): i, w = self.items[self._ndx], self.weights[self._ndx] self._ndx += 1 return (i,w) else: raise StopIteration def __len__(self): return len(self.items) def __contains__(self,key): return self._find_key(key)[0]
[docs] def total_weight(self): """Obtain the total weight of the set""" return self._total_weight
[docs] def clear(self): """Reset the set. Not implemented yet.""" pass
if __name__ == "__main__": # pragma: no cover s = MockSamplableSet(1.0,2.0,{0:1.2,3:1.8,}) print('===========') print(s.items) print(s.weights) print(len(s)) print(s.total_weight()) np.random.seed(1) print('===========') for _ in range(5): print(s.sample()) print('===========') print(s[0]) print(s[0]) s[0] = 2 print('===========') print(s.items) print(s.weights) print(s.total_weight()) s[1] = 1.3 print('===========') print(s.items) print(s.weights) print(s.total_weight()) del s[3] print('===========') print(s.items) print(s.weights) print(s.total_weight()) print('===========') for item, weight in s: print(item, weight) print('===========') print(0 in s) print(45 in s)