diff --git a/scripts/geodata/osm/osm_address_training_data.py b/scripts/geodata/osm/osm_address_training_data.py index 56ddc532..66b3c8d2 100644 --- a/scripts/geodata/osm/osm_address_training_data.py +++ b/scripts/geodata/osm/osm_address_training_data.py @@ -525,6 +525,13 @@ def osm_reverse_geocoded_components(admin_rtree, country, latitude, longitude): class OSMAddressFormatter(object): alpha3_codes = {c.alpha2: c.alpha3 for c in pycountry.countries} + rare_components = { + AddressFormatter.SUBURB, + AddressFormatter.CITY_DISTRICT, + AddressFormatter.STATE_DISTRICT, + AddressFormatter.COUNTRY + } + def __init__(self, admin_rtree, language_rtree, neighborhoods_rtree, quattroshapes_rtree, geonames, splitter=None): self.admin_rtree = admin_rtree self.language_rtree = language_rtree @@ -1126,7 +1133,7 @@ class OSMAddressFormatter(object): return address_components, country, language - def formatted_addresses(self, value, dropout_prob=0.5, tag_components=True): + def formatted_addresses(self, value, dropout_prob=0.5, rare_component_dropout_prob=0.7, tag_components=True): ''' Formatted addresses ------------------- @@ -1178,7 +1185,9 @@ class OSMAddressFormatter(object): component_set = component_bitset(address_components.keys()) for component in current_components: - if component_set ^ OSM_ADDRESS_COMPONENT_VALUES[component] in OSM_ADDRESS_COMPONENTS_VALID and random.random() < dropout_prob: + prob = rare_component_dropout_prob if component in self.rare_components else dropout_prob + + if component_set ^ OSM_ADDRESS_COMPONENT_VALUES[component] in OSM_ADDRESS_COMPONENTS_VALID and random.random() < prob: address_components.pop(component) component_set ^= OSM_ADDRESS_COMPONENT_VALUES[component] if not address_components: