From 8aada7086f54ae8d8c52bbb7a801bb4c1bf5ce69 Mon Sep 17 00:00:00 2001 From: Al Date: Mon, 30 May 2016 21:50:45 -0400 Subject: [PATCH] [intersections] intersections training data --- scripts/geodata/osm/extract.py | 4 + scripts/geodata/osm/formatter.py | 106 +++++++++++++++++- .../geodata/osm/osm_address_training_data.py | 8 ++ 3 files changed, 116 insertions(+), 2 deletions(-) diff --git a/scripts/geodata/osm/extract.py b/scripts/geodata/osm/extract.py index 5cda49bc..45f6884e 100644 --- a/scripts/geodata/osm/extract.py +++ b/scripts/geodata/osm/extract.py @@ -44,6 +44,10 @@ OSM_NAME_TAGS = ( 'short_name', ) +OSM_BASE_NAME_TAGS = ( + 'tiger:name_base', +) + def parse_osm(filename, allowed_types=ALL_OSM_TAGS, dependencies=False): ''' diff --git a/scripts/geodata/osm/formatter.py b/scripts/geodata/osm/formatter.py index 0749ba57..c65d9897 100644 --- a/scripts/geodata/osm/formatter.py +++ b/scripts/geodata/osm/formatter.py @@ -6,6 +6,7 @@ import sys import yaml from collections import OrderedDict +from six.itertools import combinations, ifilter this_dir = os.path.realpath(os.path.dirname(__file__)) sys.path.append(os.path.realpath(os.path.join(os.pardir, os.pardir))) @@ -24,6 +25,7 @@ from geodata.configs.utils import nested_get from geodata.countries.country_names import * from geodata.language_id.disambiguation import * from geodata.i18n.languages import * +from geodata.intersections.query import Intersection, IntersectionQuery from geodata.address_formatting.formatter import AddressFormatter from geodata.osm.extract import * from geodata.polygons.language_polys import * @@ -39,6 +41,8 @@ OSM_PARSER_DATA_DEFAULT_CONFIG = os.path.join(this_dir, os.pardir, os.pardir, os ADDRESS_FORMAT_DATA_TAGGED_FILENAME = 'formatted_addresses_tagged.tsv' ADDRESS_FORMAT_DATA_FILENAME = 'formatted_addresses.tsv' ADDRESS_FORMAT_DATA_LANGUAGE_FILENAME = 'formatted_addresses_by_language.tsv' +INTERSECTIONS_FILENAME = 'intersections.tsv' +INTERSECTIONS_TAGGED_FILENAME = 'intersections_tagged.tsv' ALL_LANGUAGES = 'all' @@ -113,7 +117,7 @@ class OSMAddressFormatter(object): } } - def __init__(self, components, subdivisions_rtree, buildings_rtree): + def __init__(self, components, subdivisions_rtree=None, buildings_rtree=None): # Instance of AddressComponents, contains structures for reverse geocoding, etc. self.components = components self.language_rtree = components.language_rtree @@ -543,7 +547,7 @@ class OSMAddressFormatter(object): if tag_components: row = (language, country, formatted_address) else: - row = formatted_address + row = (formatted_address,) writer.writerow(row) @@ -551,6 +555,104 @@ class OSMAddressFormatter(object): if i % 1000 == 0 and i > 0: print('did {} formatted addresses'.format(i)) + def build_intersections_training_data(self, infile, out_dir, tag_components=True): + ''' + Intersection addresses like "4th & Main Street" are represented in OSM + by ways that share at least one node. + + This creates formatted strings using the name of each way (sometimes the base name + for US addresses thanks to Tiger tags). + + Example: + + en us 34th/road Street/road &/intersection 8th/road Ave/road + ''' + i = 0 + + if tag_components: + formatted_tagged_file = open(os.path.join(out_dir, INTERSECTIONS_TAGGED_FILENAME), 'w') + writer = csv.writer(formatted_tagged_file, 'tsv_no_quote') + else: + formatted_file = open(os.path.join(out_dir, INTERSECTIONS_FILENAME), 'w') + writer = csv.writer(formatted_file, 'tsv_no_quote') + + all_name_tags = set(OSM_NAME_TAGS) + base_name_tags = set(OSM_BASE_NAME_TAGS) + + replace_with_base_name_prob = float(nested_get(self.config, ('intersections', 'replace_with_base_name_probability'), default=0.0)) + + reader = OSMIntersectionsReader(infile) + for node_id, latitude, longitude, ways in reader.intersections(): + if not ways or len(ways) < 2: + continue + + tags = ways[0] + namespaced_language = None + + language_components = {} + + base_name_tags = [t for t in all_base_name_tags if t in tags] + if not base_names: + base_name_tag = None + else: + base_name_tag = base_name_tags[0] + + for tag in tags: + if tag.rsplit(':', 1)[0] in all_name_tags and all((tag in w for w in ways)): + way_names = [(w[tag], w.get(base_name_tag) if base_name_tag else None) for w in ways] + if ':' in tag: + namespaced_language = tag.rsplit(':')[-1] + + if namespaced_language not in language_components: + address_components, country, language = self.components.expanded({}, latitude, longitude, language=namespaced_language) + language_components[namespaced_language] = (address_components, country, language) + else: + address_components, country, language = language_components[namespaced_language] + + intersection_phrase = Intersection.phrase(language, country=country) + if not intersection_phrase: + continue + + formatted_intersections = [] + + for (w1, w1_base), (w2, w2_base) in combinations(way_names, 2): + intersection = IntersectionQuery(road1=w1, intersection_phrase=intersection_phrase, road2=w2) + formatted = self.formatter.format_intersection(intersection, address_components, country, language, tag_components=tag_components) + formatted_intersections.append(formatted) + + if w1_base and random.random() < replace_with_base_name_prob: + w1 = w1_base + + intersection = IntersectionQuery(road1=w1, intersection_phrase=intersection_phrase, road2=w2) + formatted = self.formatter.format_intersection(intersection, address_components, country, language, tag_components=tag_components) + formatted_intersections.append(formatted) + + if w2_base and random.random() < replace_with_base_name_prob: + w2 = w2_base + + intersection = IntersectionQuery(road1=w1, intersection_phrase=intersection_phrase, road2=w2) + formatted = self.formatter.format_intersection(intersection, address_components, country, language, tag_components=tag_components) + formatted_intersections.append(formatted) + + for formatted in formatted_intersections: + if not formatted or not formatted.strip(): + continue + + formatted = tsv_string(formatted) + if not formatted or not formatted.strip(): + continue + + if tag_components: + row = (language, country, formatted) + else: + row = (formatted,) + + writer.writerow(row) + + i += 1 + if i % 1000 == 0 and i > 0: + print('did {} intersections'.format(i)) + def build_limited_training_data(self, infile, out_dir): ''' Creates a special kind of formatted address training data from OSM's addr:* tags diff --git a/scripts/geodata/osm/osm_address_training_data.py b/scripts/geodata/osm/osm_address_training_data.py index 2469eae0..f0909587 100644 --- a/scripts/geodata/osm/osm_address_training_data.py +++ b/scripts/geodata/osm/osm_address_training_data.py @@ -444,6 +444,9 @@ if __name__ == '__main__': default=tempfile.gettempdir(), help='Temp directory to use') + parser.add_argument('-x', '--intersections-file', + help='Path to planet-ways-latlons.osm') + parser.add_argument('--language-rtree-dir', required=True, help='Language RTree directory') @@ -530,3 +533,8 @@ if __name__ == '__main__': osm_formatter.build_limited_training_data(args.address_file, args.out_dir) if args.venues_file: build_venue_training_data(language_rtree, args.venues_file, args.out_dir) + + if args.intersections_file and args.format: + components = AddressComponents(osm_rtree, language_rtree, neighborhoods_rtree, quattroshapes_rtree, geonames) + osm_formatter = OSMAddressFormatter(components, subdivisions_rtree, buildings_rtree) + osm_formatter.build_intersections_training_data(args.address_file, args.out_dir, tag_components=not args.untagged)