#!/usr/bin/env python3 import curses import curses.ascii import curses.textpad import numpy import numpy.linalg import os.path import pickle import random import sys import textwrap TARGET_LIST_LENGTH = 6**4 MAX_WORD_LENGTH = 25 SUGGESTION_DISTANCE_THRESHOLD = 1.2 def vectorize(word): return word.replace(' ', '_').replace('-', '_') vectors = {} distance_cache = {} def memoized_distance(w1, w2): w1 = vectorize(w1) w2 = vectorize(w2) if w1 not in vectors or w2 not in vectors: return None v1 = min(w1, w2) v2 = max(w1, w2) if v1 not in distance_cache: distance_cache[v1] = {} if v2 not in distance_cache[v1]: distance_cache[v1][v2] = numpy.linalg.norm(vectors[v1] - vectors[v2]) return distance_cache[v1][v2] def prefix(word): return ''.join(c for c in word.lower() if curses.ascii.islower(c))[:3] def main(stdscr): global vectors stdscr.clear() stdscr.leaveok(False) filename = sys.argv[1] list_pad = curses.newpad(TARGET_LIST_LENGTH*2, MAX_WORD_LENGTH) status_line = curses.newwin(1, curses.COLS, curses.LINES - 1, 0) info_box = curses.newwin(curses.LINES - 1, MAX_WORD_LENGTH*2, 0, MAX_WORD_LENGTH) loading_message = "Loading conceptnet numberbatch vectors..." status_line.addstr(0, 0, loading_message) with open('numberbatch.pkl', 'rb') as f: vectors = pickle.load(f) status_line.clear() status_line.addstr(0, 0, "Vectors loaded!") status_line.refresh() words = {} if os.path.isfile(filename): with open(filename) as f: for line in f: word = line.strip().split(maxsplit=1)[-1] if len(prefix(word)) < 3: continue words[prefix(word)] = word pairs = {(w1, w2) for w1 in words.values() for w2 in words.values() if w2 > w1} loading_message = "Precomputing word vector distances: " status_line.clear() status_line.addstr(0, 0, loading_message) for (i, (w1, w2)) in enumerate(pairs): if i % 1000 == 0: status_line.addstr(0, len(loading_message), "{}/{} pairs".format(i, len(pairs))) status_line.refresh() memoized_distance(w1, w2) suggestion_candidates = list(vectors.keys()) random.shuffle(suggestion_candidates) suggestion = "" pos = 0 unsaved = False while True: status_line.clear() status_line.addstr(0, 0, "[a]dd/[d]elete/[j]down/[k]up/[/]search/[s]uggest/[w]write/[q]uit{}".format('*' if unsaved else '')) status_line.noutrefresh() list_pad.clear() for (i, w) in enumerate(sorted(words.values())): dice = "{}{}{}{}".format(i // 6**3 + 1, (i // 6**2)% 6 + 1, (i // 6) % 6 + 1, i % 6 + 1) entry = dice + " " + w list_pad.addstr(i, 0, entry) scroll_pos = max(0, min(len(words)-(curses.LINES-1), pos - curses.LINES//2)) list_pad.noutrefresh(scroll_pos,0, 0,0, curses.LINES-2,MAX_WORD_LENGTH-1) unknown_cn_words = {w for w in words.values() if vectorize(w) not in vectors} worst_distances = [] for w1 in words.values(): for w2 in words.values(): if w1 >= w2: continue d = memoized_distance(w1, w2) if d is not None and (len(worst_distances) == 0 or d < worst_distances[-1][2]): worst_distances.append((w1, w2, d)) worst_distances.sort(key=lambda x: x[2]) worst_distances = worst_distances[:8] current_word = None if len(words) == 0 else list(sorted(words.values()))[pos] worst_current_distances = [] for w in words.values(): if w == current_word: continue d = memoized_distance(current_word, w) if d is not None and (len(worst_current_distances) == 0 or d < worst_current_distances[-1][1]): worst_current_distances.append((w, d)) worst_current_distances.sort(key=lambda x: x[1]) worst_current_distances = worst_current_distances[:8] info_box.clear() info_box.addstr(0, 0, """{count}/{target} words; Worst overall distances: {worst} Worst distances from current word: {worstc} Unknown (to ConceptNet) words: {unk_c} Suggestion: {sug}""" .format( count=len(words), target=TARGET_LIST_LENGTH, worst='\n'.join(' {} to {}, {:.2}'.format(*x) for x in worst_distances), worstc='\n'.join(' {}, {:.2}'.format(*x) for x in worst_current_distances), unk_c='\n '.join(unknown_cn_words if len(unknown_cn_words) <= 3 else list(unknown_cn_words)[:2] + ["..."]), sug=suggestion ) ) info_box.noutrefresh() curses.doupdate() stdscr.move(pos - scroll_pos, 0) ch = stdscr.getch() if ch == ord('-') or ch == ord('d'): if current_word: del words[prefix(current_word)] pos = min(max(0, len(words)-1), pos) unsaved = True elif ch == ord('+') or ch == ord('a'): status_line.clear() status_line.refresh() input_box = curses.textpad.Textbox(status_line) input_box.edit() word = input_box.gather().strip() if len(prefix(word)) >= 3: old = words.get(prefix(word), None) if old: status_line.clear() status_line.addstr(0,0, "Replace {}? [y/n]".format(old)) status_line.refresh() if not old or stdscr.getch() == ord('y'): words[prefix(word)] = word pos = sorted(words.values()).index(word) unsaved = True elif ch == curses.KEY_DOWN or ch == ord('j'): pos = min(max(0, len(words)-1), pos+1) elif ch == curses.KEY_UP or ch == ord('k'): pos = max(0, pos-1) elif ch == ord('/'): status_line.clear() status_line.refresh() input_box = curses.textpad.Textbox(status_line) input_box.edit() word = input_box.gather() word = ''.join(c for c in word.lower() if curses.ascii.islower(c)) if len(prefix(word)) >= 3 and prefix(word) in words: pos = sorted(words.values()).index(words[prefix(word)]) elif ch == ord('s'): while True: candidate = suggestion_candidates.pop() if len(prefix(candidate)) < 3 or prefix(candidate) in words: continue min_dist = None for word in words.values(): d = memoized_distance(word, candidate) if d is not None and (min_dist is None or d < min_dist): min_dist = d if min_dist is None or min_dist > SUGGESTION_DISTANCE_THRESHOLD: suggestion = candidate break elif ch == ord('w'): with open(filename, "w") as f: for (i, w) in enumerate(sorted(words.values())): dice = "{}{}{}{}".format(i // 6**3 + 1, (i // 6**2)% 6 + 1, (i // 6) % 6 + 1, i % 6 + 1) entry = dice + " " + w f.write(entry + "\n") unsaved = False elif ch == ord('q'): if unsaved: status_line.clear() status_line.addstr(0,0, "Quit with unsaved data? [y/n]") status_line.refresh() if not unsaved or stdscr.getch() == ord('y'): return if len(sys.argv) != 2: print("usage: {} ".format(sys.argv[0]), file=sys.stderr) exit(1) curses.wrapper(main)