diff --git a/scripts/geodata/osm/intersections.py b/scripts/geodata/osm/intersections.py index 831539da..f5978451 100644 --- a/scripts/geodata/osm/intersections.py +++ b/scripts/geodata/osm/intersections.py @@ -1,26 +1,37 @@ import array import logging +import numpy import six +import ujson as json from bisect import bisect_left from collections import defaultdict, OrderedDict -from itertools import izip, combinations +from leveldb import LevelDB +from itertools import izip, groupby from geodata.coordinates.conversion import latlon_to_decimal +from geodata.file_utils import ensure_dir from geodata.osm.extract import * +from geodata.encoding import safe_decode, safe_encode class OSMIntersectionReader(object): - def __init__(self, filename): + def __init__(self, filename, db_dir): self.filename = filename self.node_ids = array.array('l') self.node_coordinates = array.array('d') - # Store these in memory, could be LevelDB if needed - self.way_props = {} - self.intersections_graph = defaultdict(list) + self.logger = logging.getLogger('osm.intersections') + # Store these in memory, could be LevelDB if needed + ensure_dir(db_dir) + ways_dir = os.path.join(db_dir, 'ways') + ensure_dir(ways_dir) + self.way_props = LevelDB(ways_dir) + # These form a graph and should always have the same length + self.intersection_edges_nodes = array.array('l') + self.intersection_edges_ways = array.array('l') def binary_search(self, a, x): '''Locate the leftmost value exactly equal to x''' @@ -62,7 +73,7 @@ class OSMIntersectionReader(object): node_counts[node_index] += 1 if i % 1000 == 0 and i > 0: - print('doing {}s, at {}'.format(element_id.split(':')[0], i)) + self.logger.info('doing {}s, at {}'.format(element_id.split(':')[0], i)) i += 1 for i, count in enumerate(node_counts): @@ -76,6 +87,7 @@ class OSMIntersectionReader(object): for element_id, props, deps in parse_osm(self.filename, dependencies=True): if element_id.startswith('node'): + node_id = long(element_id.split(':')[-1]) node_index = self.binary_search(self.node_ids, node_id) if node_index is not None: lat = props.get('lat') @@ -89,14 +101,45 @@ class OSMIntersectionReader(object): for node_id in deps: node_index = self.binary_search(self.node_ids, node_id) if node_index is not None: - self.intersections_graph[node_index].append(way_id) - self.way_props[way_id] = props + way_ids.append(way_id) + self.intersection_edges_nodes.append(node_id) + self.intersection_edges_ways.append(way_id) + self.way_props.Put(safe_encode(way_id), json.dumps(props)) if i % 1000 == 0 and i > 0: - print('second pass, doing {}s, at {}'.format(element_id.split(':')[0], i)) + self.logger.info('second pass, doing {}s, at {}'.format(element_id.split(':')[0], i)) i += 1 - for node_index, way_indices in six.iteritems(self.intersections_graph): - lat, lon = self.node_coordinates[node_index * 2], self.node_coordinates[node_index * 2 + 1] - ways = [self.way_props[w] for w in way_indices] - yield self.node_ids[node_index], lat, lon, ways + i = 0 + + indices = numpy.argsort(self.intersection_edges_nodes) + self.intersection_edges_nodes = numpy.fromiter((self.intersection_edges_nodes[i] for i in indices), dtype=numpy.uint64) + self.intersection_edges_ways = numpy.fromiter((self.intersection_edges_ways[i] for i in indices), dtype=numpy.uint64) + del indices + + idx = 0 + + # Need to make a copy here otherwise will change dictionary size during iteration + for node_id, g in groupby(self.intersection_edges_nodes): + group_len = sum((1 for j in g)) + + way_indices = self.intersection_edges_ways[idx:idx + group_len] + all_ways = [json.loads(reader.way_props.Get(safe_encode(w))) for w in way_indices] + way_names = set() + ways = [] + for way in all_ways: + if way['name'] in way_names: + continue + ways.append(way) + way_names.add(way['name']) + + idx += group_len + + if i % 1000 == 0 and i > 0: + self.logger.info('checking intersections, did {}'.format(i)) + i += 1 + + if len(ways) > 1: + node_index = self.binary_search(self.node_ids, node_id) + lat, lon = self.node_coordinates[node_index * 2], self.node_coordinates[node_index * 2 + 1] + yield self.node_ids[node_index], lat, lon, ways