diff --git a/scripts/geodata/language_id/create_language_training_data.py b/scripts/geodata/language_id/create_language_training_data.py index 7975add4..70d69079 100644 --- a/scripts/geodata/language_id/create_language_training_data.py +++ b/scripts/geodata/language_id/create_language_training_data.py @@ -39,14 +39,19 @@ def create_language_training_data(osm_dir, split_data=True, train_split=0.8, cv_ if split_data: languages_test_path = os.path.join(osm_dir, LANGUAGES_TEST_FILE) - subprocess.check_call(['split -l $[ $(wc -l', languages_random_path, '| cut -d" " -f1) *', str(int(train_split * 100)), '/ 100 + 1 ]', languages_random_path]) + num_lines = sum((1 for line in open(languages_random_path))) + train_lines = int(train_split * num_lines) + + test_lines = num_lines - train_lines + cv_lines = test_lines * (cv_split / (1.0 - train_split)) + 1 + + subprocess.check_call(['split -l', str(train_lines), languages_random_path]) subprocess.check_call(['mv xaa', languages_train_path]) subprocess.check_call(['mv xab', languages_test_path]) - cv_split = cv_split / (1 - (cv_split + train_split)) languages_cv_path = os.path.join(osm_dir, LANGUAGES_CV_FILE) - subprocess.check_call(['split -l $[ $(wc -l', languages_test_path, '| cut -d" " -f1) *', str(int(cv_split * 100)), '/ 100 + 1 ]', languages_test_path]) + subprocess.check_call(['split', '-l', str(cv_lines), languages_test_path]) subprocess.check_call(['mv', 'xaa', languages_cv_path]) subprocess.check_call(['mv', 'xab', languages_test_path]) else: