[intersections] intersections training data
This commit is contained in:
@@ -44,6 +44,10 @@ OSM_NAME_TAGS = (
|
||||
'short_name',
|
||||
)
|
||||
|
||||
OSM_BASE_NAME_TAGS = (
|
||||
'tiger:name_base',
|
||||
)
|
||||
|
||||
|
||||
def parse_osm(filename, allowed_types=ALL_OSM_TAGS, dependencies=False):
|
||||
'''
|
||||
|
||||
@@ -6,6 +6,7 @@ import sys
|
||||
import yaml
|
||||
|
||||
from collections import OrderedDict
|
||||
from six.itertools import combinations, ifilter
|
||||
|
||||
this_dir = os.path.realpath(os.path.dirname(__file__))
|
||||
sys.path.append(os.path.realpath(os.path.join(os.pardir, os.pardir)))
|
||||
@@ -24,6 +25,7 @@ from geodata.configs.utils import nested_get
|
||||
from geodata.countries.country_names import *
|
||||
from geodata.language_id.disambiguation import *
|
||||
from geodata.i18n.languages import *
|
||||
from geodata.intersections.query import Intersection, IntersectionQuery
|
||||
from geodata.address_formatting.formatter import AddressFormatter
|
||||
from geodata.osm.extract import *
|
||||
from geodata.polygons.language_polys import *
|
||||
@@ -39,6 +41,8 @@ OSM_PARSER_DATA_DEFAULT_CONFIG = os.path.join(this_dir, os.pardir, os.pardir, os
|
||||
ADDRESS_FORMAT_DATA_TAGGED_FILENAME = 'formatted_addresses_tagged.tsv'
|
||||
ADDRESS_FORMAT_DATA_FILENAME = 'formatted_addresses.tsv'
|
||||
ADDRESS_FORMAT_DATA_LANGUAGE_FILENAME = 'formatted_addresses_by_language.tsv'
|
||||
INTERSECTIONS_FILENAME = 'intersections.tsv'
|
||||
INTERSECTIONS_TAGGED_FILENAME = 'intersections_tagged.tsv'
|
||||
|
||||
ALL_LANGUAGES = 'all'
|
||||
|
||||
@@ -113,7 +117,7 @@ class OSMAddressFormatter(object):
|
||||
}
|
||||
}
|
||||
|
||||
def __init__(self, components, subdivisions_rtree, buildings_rtree):
|
||||
def __init__(self, components, subdivisions_rtree=None, buildings_rtree=None):
|
||||
# Instance of AddressComponents, contains structures for reverse geocoding, etc.
|
||||
self.components = components
|
||||
self.language_rtree = components.language_rtree
|
||||
@@ -543,7 +547,7 @@ class OSMAddressFormatter(object):
|
||||
if tag_components:
|
||||
row = (language, country, formatted_address)
|
||||
else:
|
||||
row = formatted_address
|
||||
row = (formatted_address,)
|
||||
|
||||
writer.writerow(row)
|
||||
|
||||
@@ -551,6 +555,104 @@ class OSMAddressFormatter(object):
|
||||
if i % 1000 == 0 and i > 0:
|
||||
print('did {} formatted addresses'.format(i))
|
||||
|
||||
def build_intersections_training_data(self, infile, out_dir, tag_components=True):
|
||||
'''
|
||||
Intersection addresses like "4th & Main Street" are represented in OSM
|
||||
by ways that share at least one node.
|
||||
|
||||
This creates formatted strings using the name of each way (sometimes the base name
|
||||
for US addresses thanks to Tiger tags).
|
||||
|
||||
Example:
|
||||
|
||||
en us 34th/road Street/road &/intersection 8th/road Ave/road
|
||||
'''
|
||||
i = 0
|
||||
|
||||
if tag_components:
|
||||
formatted_tagged_file = open(os.path.join(out_dir, INTERSECTIONS_TAGGED_FILENAME), 'w')
|
||||
writer = csv.writer(formatted_tagged_file, 'tsv_no_quote')
|
||||
else:
|
||||
formatted_file = open(os.path.join(out_dir, INTERSECTIONS_FILENAME), 'w')
|
||||
writer = csv.writer(formatted_file, 'tsv_no_quote')
|
||||
|
||||
all_name_tags = set(OSM_NAME_TAGS)
|
||||
base_name_tags = set(OSM_BASE_NAME_TAGS)
|
||||
|
||||
replace_with_base_name_prob = float(nested_get(self.config, ('intersections', 'replace_with_base_name_probability'), default=0.0))
|
||||
|
||||
reader = OSMIntersectionsReader(infile)
|
||||
for node_id, latitude, longitude, ways in reader.intersections():
|
||||
if not ways or len(ways) < 2:
|
||||
continue
|
||||
|
||||
tags = ways[0]
|
||||
namespaced_language = None
|
||||
|
||||
language_components = {}
|
||||
|
||||
base_name_tags = [t for t in all_base_name_tags if t in tags]
|
||||
if not base_names:
|
||||
base_name_tag = None
|
||||
else:
|
||||
base_name_tag = base_name_tags[0]
|
||||
|
||||
for tag in tags:
|
||||
if tag.rsplit(':', 1)[0] in all_name_tags and all((tag in w for w in ways)):
|
||||
way_names = [(w[tag], w.get(base_name_tag) if base_name_tag else None) for w in ways]
|
||||
if ':' in tag:
|
||||
namespaced_language = tag.rsplit(':')[-1]
|
||||
|
||||
if namespaced_language not in language_components:
|
||||
address_components, country, language = self.components.expanded({}, latitude, longitude, language=namespaced_language)
|
||||
language_components[namespaced_language] = (address_components, country, language)
|
||||
else:
|
||||
address_components, country, language = language_components[namespaced_language]
|
||||
|
||||
intersection_phrase = Intersection.phrase(language, country=country)
|
||||
if not intersection_phrase:
|
||||
continue
|
||||
|
||||
formatted_intersections = []
|
||||
|
||||
for (w1, w1_base), (w2, w2_base) in combinations(way_names, 2):
|
||||
intersection = IntersectionQuery(road1=w1, intersection_phrase=intersection_phrase, road2=w2)
|
||||
formatted = self.formatter.format_intersection(intersection, address_components, country, language, tag_components=tag_components)
|
||||
formatted_intersections.append(formatted)
|
||||
|
||||
if w1_base and random.random() < replace_with_base_name_prob:
|
||||
w1 = w1_base
|
||||
|
||||
intersection = IntersectionQuery(road1=w1, intersection_phrase=intersection_phrase, road2=w2)
|
||||
formatted = self.formatter.format_intersection(intersection, address_components, country, language, tag_components=tag_components)
|
||||
formatted_intersections.append(formatted)
|
||||
|
||||
if w2_base and random.random() < replace_with_base_name_prob:
|
||||
w2 = w2_base
|
||||
|
||||
intersection = IntersectionQuery(road1=w1, intersection_phrase=intersection_phrase, road2=w2)
|
||||
formatted = self.formatter.format_intersection(intersection, address_components, country, language, tag_components=tag_components)
|
||||
formatted_intersections.append(formatted)
|
||||
|
||||
for formatted in formatted_intersections:
|
||||
if not formatted or not formatted.strip():
|
||||
continue
|
||||
|
||||
formatted = tsv_string(formatted)
|
||||
if not formatted or not formatted.strip():
|
||||
continue
|
||||
|
||||
if tag_components:
|
||||
row = (language, country, formatted)
|
||||
else:
|
||||
row = (formatted,)
|
||||
|
||||
writer.writerow(row)
|
||||
|
||||
i += 1
|
||||
if i % 1000 == 0 and i > 0:
|
||||
print('did {} intersections'.format(i))
|
||||
|
||||
def build_limited_training_data(self, infile, out_dir):
|
||||
'''
|
||||
Creates a special kind of formatted address training data from OSM's addr:* tags
|
||||
|
||||
@@ -444,6 +444,9 @@ if __name__ == '__main__':
|
||||
default=tempfile.gettempdir(),
|
||||
help='Temp directory to use')
|
||||
|
||||
parser.add_argument('-x', '--intersections-file',
|
||||
help='Path to planet-ways-latlons.osm')
|
||||
|
||||
parser.add_argument('--language-rtree-dir',
|
||||
required=True,
|
||||
help='Language RTree directory')
|
||||
@@ -530,3 +533,8 @@ if __name__ == '__main__':
|
||||
osm_formatter.build_limited_training_data(args.address_file, args.out_dir)
|
||||
if args.venues_file:
|
||||
build_venue_training_data(language_rtree, args.venues_file, args.out_dir)
|
||||
|
||||
if args.intersections_file and args.format:
|
||||
components = AddressComponents(osm_rtree, language_rtree, neighborhoods_rtree, quattroshapes_rtree, geonames)
|
||||
osm_formatter = OSMAddressFormatter(components, subdivisions_rtree, buildings_rtree)
|
||||
osm_formatter.build_intersections_training_data(args.address_file, args.out_dir, tag_components=not args.untagged)
|
||||
|
||||
Reference in New Issue
Block a user