[intersections] intersections training data

This commit is contained in:
Al
2016-05-30 21:50:45 -04:00
parent 5075128ada
commit 8aada7086f
3 changed files with 116 additions and 2 deletions

View File

@@ -44,6 +44,10 @@ OSM_NAME_TAGS = (
'short_name',
)
OSM_BASE_NAME_TAGS = (
'tiger:name_base',
)
def parse_osm(filename, allowed_types=ALL_OSM_TAGS, dependencies=False):
'''

View File

@@ -6,6 +6,7 @@ import sys
import yaml
from collections import OrderedDict
from six.itertools import combinations, ifilter
this_dir = os.path.realpath(os.path.dirname(__file__))
sys.path.append(os.path.realpath(os.path.join(os.pardir, os.pardir)))
@@ -24,6 +25,7 @@ from geodata.configs.utils import nested_get
from geodata.countries.country_names import *
from geodata.language_id.disambiguation import *
from geodata.i18n.languages import *
from geodata.intersections.query import Intersection, IntersectionQuery
from geodata.address_formatting.formatter import AddressFormatter
from geodata.osm.extract import *
from geodata.polygons.language_polys import *
@@ -39,6 +41,8 @@ OSM_PARSER_DATA_DEFAULT_CONFIG = os.path.join(this_dir, os.pardir, os.pardir, os
ADDRESS_FORMAT_DATA_TAGGED_FILENAME = 'formatted_addresses_tagged.tsv'
ADDRESS_FORMAT_DATA_FILENAME = 'formatted_addresses.tsv'
ADDRESS_FORMAT_DATA_LANGUAGE_FILENAME = 'formatted_addresses_by_language.tsv'
INTERSECTIONS_FILENAME = 'intersections.tsv'
INTERSECTIONS_TAGGED_FILENAME = 'intersections_tagged.tsv'
ALL_LANGUAGES = 'all'
@@ -113,7 +117,7 @@ class OSMAddressFormatter(object):
}
}
def __init__(self, components, subdivisions_rtree, buildings_rtree):
def __init__(self, components, subdivisions_rtree=None, buildings_rtree=None):
# Instance of AddressComponents, contains structures for reverse geocoding, etc.
self.components = components
self.language_rtree = components.language_rtree
@@ -543,7 +547,7 @@ class OSMAddressFormatter(object):
if tag_components:
row = (language, country, formatted_address)
else:
row = formatted_address
row = (formatted_address,)
writer.writerow(row)
@@ -551,6 +555,104 @@ class OSMAddressFormatter(object):
if i % 1000 == 0 and i > 0:
print('did {} formatted addresses'.format(i))
def build_intersections_training_data(self, infile, out_dir, tag_components=True):
'''
Intersection addresses like "4th & Main Street" are represented in OSM
by ways that share at least one node.
This creates formatted strings using the name of each way (sometimes the base name
for US addresses thanks to Tiger tags).
Example:
en us 34th/road Street/road &/intersection 8th/road Ave/road
'''
i = 0
if tag_components:
formatted_tagged_file = open(os.path.join(out_dir, INTERSECTIONS_TAGGED_FILENAME), 'w')
writer = csv.writer(formatted_tagged_file, 'tsv_no_quote')
else:
formatted_file = open(os.path.join(out_dir, INTERSECTIONS_FILENAME), 'w')
writer = csv.writer(formatted_file, 'tsv_no_quote')
all_name_tags = set(OSM_NAME_TAGS)
base_name_tags = set(OSM_BASE_NAME_TAGS)
replace_with_base_name_prob = float(nested_get(self.config, ('intersections', 'replace_with_base_name_probability'), default=0.0))
reader = OSMIntersectionsReader(infile)
for node_id, latitude, longitude, ways in reader.intersections():
if not ways or len(ways) < 2:
continue
tags = ways[0]
namespaced_language = None
language_components = {}
base_name_tags = [t for t in all_base_name_tags if t in tags]
if not base_names:
base_name_tag = None
else:
base_name_tag = base_name_tags[0]
for tag in tags:
if tag.rsplit(':', 1)[0] in all_name_tags and all((tag in w for w in ways)):
way_names = [(w[tag], w.get(base_name_tag) if base_name_tag else None) for w in ways]
if ':' in tag:
namespaced_language = tag.rsplit(':')[-1]
if namespaced_language not in language_components:
address_components, country, language = self.components.expanded({}, latitude, longitude, language=namespaced_language)
language_components[namespaced_language] = (address_components, country, language)
else:
address_components, country, language = language_components[namespaced_language]
intersection_phrase = Intersection.phrase(language, country=country)
if not intersection_phrase:
continue
formatted_intersections = []
for (w1, w1_base), (w2, w2_base) in combinations(way_names, 2):
intersection = IntersectionQuery(road1=w1, intersection_phrase=intersection_phrase, road2=w2)
formatted = self.formatter.format_intersection(intersection, address_components, country, language, tag_components=tag_components)
formatted_intersections.append(formatted)
if w1_base and random.random() < replace_with_base_name_prob:
w1 = w1_base
intersection = IntersectionQuery(road1=w1, intersection_phrase=intersection_phrase, road2=w2)
formatted = self.formatter.format_intersection(intersection, address_components, country, language, tag_components=tag_components)
formatted_intersections.append(formatted)
if w2_base and random.random() < replace_with_base_name_prob:
w2 = w2_base
intersection = IntersectionQuery(road1=w1, intersection_phrase=intersection_phrase, road2=w2)
formatted = self.formatter.format_intersection(intersection, address_components, country, language, tag_components=tag_components)
formatted_intersections.append(formatted)
for formatted in formatted_intersections:
if not formatted or not formatted.strip():
continue
formatted = tsv_string(formatted)
if not formatted or not formatted.strip():
continue
if tag_components:
row = (language, country, formatted)
else:
row = (formatted,)
writer.writerow(row)
i += 1
if i % 1000 == 0 and i > 0:
print('did {} intersections'.format(i))
def build_limited_training_data(self, infile, out_dir):
'''
Creates a special kind of formatted address training data from OSM's addr:* tags

View File

@@ -444,6 +444,9 @@ if __name__ == '__main__':
default=tempfile.gettempdir(),
help='Temp directory to use')
parser.add_argument('-x', '--intersections-file',
help='Path to planet-ways-latlons.osm')
parser.add_argument('--language-rtree-dir',
required=True,
help='Language RTree directory')
@@ -530,3 +533,8 @@ if __name__ == '__main__':
osm_formatter.build_limited_training_data(args.address_file, args.out_dir)
if args.venues_file:
build_venue_training_data(language_rtree, args.venues_file, args.out_dir)
if args.intersections_file and args.format:
components = AddressComponents(osm_rtree, language_rtree, neighborhoods_rtree, quattroshapes_rtree, geonames)
osm_formatter = OSMAddressFormatter(components, subdivisions_rtree, buildings_rtree)
osm_formatter.build_intersections_training_data(args.address_file, args.out_dir, tag_components=not args.untagged)