diff --git a/scripts/geodata/address_expansions/abbreviations.py b/scripts/geodata/address_expansions/abbreviations.py index e3567faf..00aa54b6 100644 --- a/scripts/geodata/address_expansions/abbreviations.py +++ b/scripts/geodata/address_expansions/abbreviations.py @@ -1,9 +1,11 @@ import random +import re import six from geodata.address_expansions.gazetteers import * from geodata.encoding import safe_decode, safe_encode from geodata.text.tokenize import tokenize_raw, token_types +from geodata.text.utils import non_breaking_dash_regex LOWER, UPPER, TITLE, MIXED = range(4) @@ -20,30 +22,45 @@ def token_capitalization(s): return MIXED -def recase_abbreviation(expansion, tokens): - expansion_tokens = expansion.split() +expansion_token_regex = re.compile('([^ \-\.]+)([\.\- ]+|$)') + + +def recase_abbreviation(expansion, tokens, space_token=six.u(' ')): + expansion_tokens = expansion_token_regex.findall(expansion) + + print expansion, expansion_tokens, tokens if len(tokens) > len(expansion_tokens) and all((token_capitalization(t) != LOWER for t, c in tokens)): - return expansion.upper() + expansion_tokens = tokenize(expansion) + is_acronym = len(expansion_tokens) == 1 and expansion_tokens[0][1] == token_types.ACRONYM + if len(expansion) <= 3 or is_acronym: + return expansion.upper() + else: + return expansion.title() elif len(tokens) == len(expansion_tokens): strings = [] - for (t, c), e in zip(tokens, expansion_tokens): + for (t, c), (e, suf) in zip(tokens, expansion_tokens): cap = token_capitalization(t) + if suf == six.u(' '): + suf = space_token if cap == LOWER: - strings.append(e.lower()) + strings.append(six.u('').join((e.lower(), suf))) elif cap == UPPER: - strings.append(e.upper()) + strings.append(six.u('').join((e.upper(), suf))) elif cap == TITLE: - strings.append(e.title()) + strings.append(six.u('').join((e.title(), suf))) elif t.lower() == e.lower(): strings.append(t) else: - strings.append(e.title()) - return six.u(' ').join(strings) + strings.append(six.u('').join((e.title(), suf))) + + if suf == six.u(' '): + strings.append(space_token) + return six.u('').join(strings) else: - return six.u(' ').join([t.title() for t in expansion_tokens]) + return space_token.join([t.title() for t in expansion_tokens]) -def abbreviate(gazetteer, s, language, abbreviate_prob=0.3, separate_prob=0.2): +def abbreviate(gazetteer, s, language, abbreviate_prob=0.3, separate_prob=0.2, add_period_hyphen_prob=0.3): ''' Abbreviations ------------- @@ -63,101 +80,141 @@ def abbreviate(gazetteer, s, language, abbreviate_prob=0.3, separate_prob=0.2): i = 0 - for t, c, length, data in gazetteer.filter(norm_tokens): - if c == token_types.PHRASE: - valid = [] - data = [d.split(six.b('|')) for d in data] + def abbreviated_tokens(i, tokens, t, c, length, data, space_token=six.u(' ')): + data = [d.split(six.b('|')) for d in data] - added = False + # local copy + abbreviated = [] - # Append the original tokens with whitespace if there is any - if random.random() > abbreviate_prob: - for j, (t_i, c_i) in enumerate(t): - abbreviated.append(tokens[i + j][0]) - if i + j < n - 1 and raw_tokens[i + j + 1][0] > sum(raw_tokens[i + j][:2]): - abbreviated.append(six.u(' ')) - i += len(t) + # Append the original tokens with whitespace if there is any + if random.random() > abbreviate_prob: + for j, (t_i, c_i) in enumerate(t): + abbreviated.append(tokens[i + j][0]) + if i + j < n - 1 and raw_tokens[i + j + 1][0] > sum(raw_tokens[i + j][:2]): + abbreviated.append(space_token) + return abbreviated + + for lang, dictionary, is_canonical, canonical in data: + if lang not in (language, 'all'): continue - for lang, dictionary, is_canonical, canonical in data: - if lang not in (language, 'all'): - continue + is_canonical = int(is_canonical) + is_stopword = dictionary == 'stopword' + is_prefix = dictionary.startswith('concatenated_prefixes') + is_suffix = dictionary.startswith('concatenated_suffixes') + is_separable = is_prefix or is_suffix and dictionary.endswith('_separable') and len(t[0][0]) > length - is_canonical = int(is_canonical) - is_stopword = dictionary == 'stopword' - is_prefix = dictionary.startswith('concatenated_prefixes') - is_suffix = dictionary.startswith('concatenated_suffixes') - is_separable = is_prefix or is_suffix and dictionary.endswith('_separable') and len(t[0][0]) > length + suffix = None + prefix = None - suffix = None - prefix = None + if not is_canonical: + continue - if not is_canonical: - continue + if not is_prefix and not is_suffix: + abbreviations = gazetteer.canonicals.get((canonical, lang, dictionary)) + token = random.choice(abbreviations) if abbreviations else canonical + token = recase_abbreviation(token, tokens[i:i + len(t)], space_token=space_token) + abbreviated.append(token) + if i + len(t) < n and raw_tokens[i + len(t)][0] > sum(raw_tokens[i + len(t) - 1][:2]): + abbreviated.append(space_token) + break + elif is_prefix: + token = tokens[i][0] + prefix, token = token[:length], token[length:] - if not is_prefix and not is_suffix: - abbreviations = gazetteer.canonicals.get((canonical, lang, dictionary)) - token = random.choice(abbreviations) if abbreviations else canonical - token = recase_abbreviation(token, tokens[i:i + len(t)]) + abbreviated.append(prefix) + if random.random() < separate_prob: + sub_tokens = tokenize(token) + if sub_tokens and sub_tokens[0][1] in (token_types.HYPHEN, token_types.DASH): + token = six.u('').join((t for t, c in sub_tokens[1:])) + + abbreviated.append(space_token) + if token.islower(): + abbreviated.append(token.title()) + else: abbreviated.append(token) - if i + len(t) < n and raw_tokens[i + len(t)][0] > sum(raw_tokens[i + len(t) - 1][:2]): - abbreviated.append(six.u(' ')) - break - elif is_prefix: - token = tokens[i][0] - prefix, token = token[:length], token[length:] - abbreviated.append(prefix) - if random.random() < separate_prob: - abbreviated.append(six.u(' ')) - if token.islower(): - abbreviated.append(token.title()) - else: - abbreviated.append(token) - abbreviated.append(six.u(' ')) - break - elif is_suffix: - token = tokens[i][0] + abbreviated.append(space_token) + break + elif is_suffix: + token = tokens[i][0] - token, suffix = token[:-length], token[-length:] + token, suffix = token[:-length], token[-length:] - concatenated_abbreviations = gazetteer.canonicals.get((canonical, lang, dictionary), []) + concatenated_abbreviations = gazetteer.canonicals.get((canonical, lang, dictionary), []) - separated_abbreviations = [] - phrase = gazetteer.trie.get(suffix.rstrip('.')) - suffix_data = [safe_decode(d).split(six.u('|')) for d in (phrase or [])] - for l, d, _, c in suffix_data: - if l == lang and c == canonical: - separated_abbreviations.extend(gazetteer.canonicals.get((canonical, lang, d))) + separated_abbreviations = [] + phrase = gazetteer.trie.get(suffix.rstrip('.')) + suffix_data = [safe_decode(d).split(six.u('|')) for d in (phrase or [])] + for l, d, _, c in suffix_data: + if l == lang and c == canonical: + separated_abbreviations.extend(gazetteer.canonicals.get((canonical, lang, d))) - separate = random.random() < separate_prob + separate = random.random() < separate_prob - if concatenated_abbreviations and not separate: - abbreviation = random.choice(concatenated_abbreviations) - elif separated_abbreviations: - abbreviation = random.choice(separated_abbreviations) - else: - abbreviation = canonical + if concatenated_abbreviations and not separate: + abbreviation = random.choice(concatenated_abbreviations) + elif separated_abbreviations: + abbreviation = random.choice(separated_abbreviations) + else: + abbreviation = canonical - abbreviated.append(token) - if separate: - abbreviated.append(six.u(' ')) - if suffix.isupper(): - abbreviated.append(abbreviation.upper()) - elif separate: - abbreviated.append(abbreviation.title()) - else: - abbreviated.append(abbreviation) - abbreviated.append(six.u(' ')) - break + if separate: + sub_tokens = tokenize(token) + if sub_tokens and sub_tokens[-1][1] in (token_types.HYPHEN, token_types.DASH): + token = six.u('').join((t for t, c in sub_tokens[:-1])) + + abbreviated.append(token) + if separate: + abbreviated.append(space_token) + if suffix.isupper(): + abbreviated.append(abbreviation.upper()) + elif separate: + abbreviated.append(abbreviation.title()) + else: + abbreviated.append(abbreviation) + abbreviated.append(space_token) + break else: for j, (t_i, c_i) in enumerate(t): abbreviated.append(tokens[i + j][0]) if i + j < n - 1 and raw_tokens[i + j + 1][0] > sum(raw_tokens[i + j][:2]): abbreviated.append(six.u(' ')) - i += len(t) + return abbreviated + return abbreviated + for t, c, length, data in gazetteer.filter(norm_tokens): + if c == token_types.PHRASE: + abbrev_tokens = abbreviated_tokens(i, tokens, t, c, length, data) + abbreviated.extend(abbrev_tokens) + i += len(t) else: - abbreviated.append(tokens[i][0]) + token = tokens[i][0] + if not non_breaking_dash_regex.search(token): + abbreviated.append(token) + else: + sub_tokens = tokenize(non_breaking_dash_regex.sub(six.u(' '), token)) + sub_tokens_norm = [(t.lower() if c in token_types.WORD_TOKEN_TYPES else t, c) for t, c in sub_tokens] + + sub_token_abbreviated = [] + sub_i = 0 + sub_n = len(sub_tokens) + for t, c, length, data in gazetteer.filter(sub_tokens_norm): + if c == token_types.PHRASE: + abbrev_tokens = abbreviated_tokens(sub_i, sub_tokens, t, c, length, data, space_token=six.u('-')) + sub_token_abbreviated.extend(abbrev_tokens) + sub_i += len(t) + if sub_i < sub_n: + if abbrev_tokens and random.random() < add_period_hyphen_prob and not abbrev_tokens[-1].endswith(six.u('.')) and not abbrev_tokens[-1].lower().endswith(sub_tokens_norm[sub_i - 1][0]): + sub_token_abbreviated.append(six.u('.')) + sub_token_abbreviated.append(six.u('-')) + else: + sub_token_abbreviated.append(sub_tokens[sub_i][0]) + sub_i += 1 + if sub_i < sub_n: + sub_token_abbreviated.append(six.u('-')) + + abbreviated.append(six.u('').join(sub_token_abbreviated)) + if i < n - 1 and raw_tokens[i + 1][0] > sum(raw_tokens[i][:2]): abbreviated.append(six.u(' ')) i += 1 diff --git a/scripts/geodata/text/utils.py b/scripts/geodata/text/utils.py index 8f73b4ad..7459c33b 100644 --- a/scripts/geodata/text/utils.py +++ b/scripts/geodata/text/utils.py @@ -1,6 +1,10 @@ +import re + from geodata.text.tokenize import tokenize from geodata.text.token_types import token_types +non_breaking_dash_regex = re.compile(u'[\-\u058a\u05be\u1400\u1806\u2010-\u2013\u2212\u2e17\u2e1a\ufe32\ufe63\uff0d]', re.UNICODE) + def is_numeric(s): tokens = tokenize(s)