|
| 1 | +from itertools import cycle |
| 2 | + |
| 3 | + |
| 4 | +class Partitioner(object): |
| 5 | + """ |
| 6 | + Base class for a partitioner |
| 7 | + """ |
| 8 | + def __init__(self, partitions): |
| 9 | + """ |
| 10 | + Initialize the partitioner |
| 11 | +
|
| 12 | + partitions - A list of available partitions (during startup) |
| 13 | + """ |
| 14 | + self.partitions = partitions |
| 15 | + |
| 16 | + def partition(self, key, partitions): |
| 17 | + """ |
| 18 | + Takes a string key and num_partitions as argument and returns |
| 19 | + a partition to be used for the message |
| 20 | +
|
| 21 | + partitions - The list of partitions is passed in every call. This |
| 22 | + may look like an overhead, but it will be useful |
| 23 | + (in future) when we handle cases like rebalancing |
| 24 | + """ |
| 25 | + raise NotImplemented('partition function has to be implemented') |
| 26 | + |
| 27 | + |
| 28 | +class RoundRobinPartitioner(Partitioner): |
| 29 | + """ |
| 30 | + Implements a round robin partitioner which sends data to partitions |
| 31 | + in a round robin fashion |
| 32 | + """ |
| 33 | + def __init__(self, partitions): |
| 34 | + self._set_partitions(partitions) |
| 35 | + |
| 36 | + def _set_partitions(self, partitions): |
| 37 | + self.partitions = partitions |
| 38 | + self.iterpart = cycle(partitions) |
| 39 | + |
| 40 | + def partition(self, key, partitions): |
| 41 | + # Refresh the partition list if necessary |
| 42 | + if self.partitions != partitions: |
| 43 | + self._set_partitions(partitions) |
| 44 | + |
| 45 | + return self.iterpart.next() |
| 46 | + |
| 47 | + |
| 48 | +class HashedPartitioner(Partitioner): |
| 49 | + """ |
| 50 | + Implements a partitioner which selects the target partition based on |
| 51 | + the hash of the key |
| 52 | + """ |
| 53 | + def partition(self, key, partitions): |
| 54 | + size = len(partitions) |
| 55 | + idx = hash(key) % size |
| 56 | + return partitions[idx] |
0 commit comments