73 lines
2.2 KiB
Python
73 lines
2.2 KiB
Python
import six
|
|
|
|
from collections import *
|
|
from marisa_trie import BytesTrie
|
|
from geodata.encoding import safe_encode, safe_decode
|
|
|
|
SENTINEL = None
|
|
|
|
|
|
class PhraseFilter(object):
|
|
def __init__(self, phrases):
|
|
self.trie = [(key, self.serialize(val)) for key, val in six.iteritems(phrases)]
|
|
|
|
serialize = safe_encode
|
|
deserialize = safe_decode
|
|
|
|
def filter(self, tokens):
|
|
def return_item(item):
|
|
return False, item, []
|
|
|
|
if not tokens:
|
|
return
|
|
|
|
ent = []
|
|
ent_tokens = []
|
|
|
|
queue = deque(tokens + [(SENTINEL,) * 2])
|
|
skip_until = 0
|
|
|
|
trie = self.trie
|
|
|
|
while queue:
|
|
item = queue.popleft()
|
|
t, c = item
|
|
|
|
if t is not SENTINEL and trie.has_keys_with_prefix(u' '.join(ent_tokens + [t])):
|
|
ent.append(item)
|
|
ent_tokens.append(item[0])
|
|
elif ent_tokens:
|
|
res = trie.get(u' '.join(ent_tokens)) or None
|
|
if res is not None:
|
|
yield (True, ent, map(self.deserialize, res))
|
|
queue.appendleft(item)
|
|
ent = []
|
|
ent_tokens = []
|
|
elif len(ent_tokens) == 1:
|
|
yield return_item(ent[0])
|
|
ent = []
|
|
ent_tokens = []
|
|
queue.appendleft(item)
|
|
else:
|
|
have_phrase = False
|
|
|
|
for i in xrange(len(ent) - 1, 0, -1):
|
|
remainder = ent[i:]
|
|
res = trie.get(u' '.join([e[0] for e in ent[:i]])) or None
|
|
if res is not None:
|
|
yield (True, ent[:i], map(self.deserialize, res))
|
|
have_phrase = True
|
|
break
|
|
|
|
if not have_phrase:
|
|
yield return_item(ent[0])
|
|
|
|
todos = list(remainder)
|
|
todos.append(item)
|
|
queue.extendleft(reversed(todos))
|
|
|
|
ent = []
|
|
ent_tokens = []
|
|
elif t is not SENTINEL:
|
|
yield return_item(item)
|