diceware/diceware-editor.py

215 lines
7.8 KiB
Python
Executable File

#!/usr/bin/env python3
import curses
import curses.ascii
import curses.textpad
import numpy
import numpy.linalg
import os.path
import pickle
import random
import sys
import textwrap
TARGET_LIST_LENGTH = 6**4
MAX_WORD_LENGTH = 25
SUGGESTION_DISTANCE_THRESHOLD = 1.2
def vectorize(word):
return word.replace(' ', '_').replace('-', '_')
vectors = {}
distance_cache = {}
def memoized_distance(w1, w2):
w1 = vectorize(w1)
w2 = vectorize(w2)
if w1 not in vectors or w2 not in vectors:
return None
v1 = min(w1, w2)
v2 = max(w1, w2)
if v1 not in distance_cache:
distance_cache[v1] = {}
if v2 not in distance_cache[v1]:
distance_cache[v1][v2] = numpy.linalg.norm(vectors[v1] - vectors[v2])
return distance_cache[v1][v2]
def prefix(word):
return ''.join(c for c in word.lower() if curses.ascii.islower(c))[:3]
def main(stdscr):
global vectors
stdscr.clear()
stdscr.leaveok(False)
filename = sys.argv[1]
list_pad = curses.newpad(TARGET_LIST_LENGTH*2, MAX_WORD_LENGTH)
status_line = curses.newwin(1, curses.COLS, curses.LINES - 1, 0)
info_box = curses.newwin(curses.LINES - 1, MAX_WORD_LENGTH*2, 0, MAX_WORD_LENGTH)
loading_message = "Loading conceptnet numberbatch vectors..."
status_line.addstr(0, 0, loading_message)
with open('numberbatch.pkl', 'rb') as f:
vectors = pickle.load(f)
status_line.clear()
status_line.addstr(0, 0, "Vectors loaded!")
status_line.refresh()
words = {}
if os.path.isfile(filename):
with open(filename) as f:
for line in f:
word = line.strip().split(maxsplit=1)[-1]
if len(prefix(word)) < 3:
continue
words[prefix(word)] = word
pairs = {(w1, w2) for w1 in words.values() for w2 in words.values() if w2 > w1}
loading_message = "Precomputing word vector distances: "
status_line.clear()
status_line.addstr(0, 0, loading_message)
for (i, (w1, w2)) in enumerate(pairs):
if i % 1000 == 0:
status_line.addstr(0, len(loading_message), "{}/{} pairs".format(i, len(pairs)))
status_line.refresh()
memoized_distance(w1, w2)
suggestion_candidates = list(vectors.keys())
random.shuffle(suggestion_candidates)
suggestion = ""
pos = 0
unsaved = False
while True:
status_line.clear()
status_line.addstr(0, 0, "[a]dd/[d]elete/[j]down/[k]up/[/]search/[s]uggest/[w]write/[q]uit{}".format('*' if unsaved else ''))
status_line.noutrefresh()
list_pad.clear()
for (i, w) in enumerate(sorted(words.values())):
dice = "{}{}{}{}".format(i // 6**3 + 1, (i // 6**2)% 6 + 1, (i // 6) % 6 + 1, i % 6 + 1)
entry = dice + " " + w
list_pad.addstr(i, 0, entry)
scroll_pos = max(0, min(len(words)-(curses.LINES-1), pos - curses.LINES//2))
list_pad.noutrefresh(scroll_pos,0, 0,0, curses.LINES-2,MAX_WORD_LENGTH-1)
unknown_cn_words = {w for w in words.values() if vectorize(w) not in vectors}
worst_distances = []
for w1 in words.values():
for w2 in words.values():
if w1 >= w2:
continue
d = memoized_distance(w1, w2)
if d is not None and (len(worst_distances) == 0 or d < worst_distances[-1][2]):
worst_distances.append((w1, w2, d))
worst_distances.sort(key=lambda x: x[2])
worst_distances = worst_distances[:8]
current_word = None if len(words) == 0 else list(sorted(words.values()))[pos]
worst_current_distances = []
for w in words.values():
if w == current_word:
continue
d = memoized_distance(current_word, w)
if d is not None and (len(worst_current_distances) == 0 or d < worst_current_distances[-1][1]):
worst_current_distances.append((w, d))
worst_current_distances.sort(key=lambda x: x[1])
worst_current_distances = worst_current_distances[:8]
info_box.clear()
info_box.addstr(0, 0, """{count}/{target} words;
Worst overall distances:
{worst}
Worst distances from current word:
{worstc}
Unknown (to ConceptNet) words:
{unk_c}
Suggestion:
{sug}"""
.format(
count=len(words),
target=TARGET_LIST_LENGTH,
worst='\n'.join(' {} to {}, {:.2}'.format(*x) for x in worst_distances),
worstc='\n'.join(' {}, {:.2}'.format(*x) for x in worst_current_distances),
unk_c='\n '.join(unknown_cn_words if len(unknown_cn_words) <= 3 else list(unknown_cn_words)[:2] + ["..."]),
sug=suggestion
)
)
info_box.noutrefresh()
curses.doupdate()
stdscr.move(pos - scroll_pos, 0)
ch = stdscr.getch()
if ch == ord('-') or ch == ord('d'):
if current_word:
del words[prefix(current_word)]
pos = min(max(0, len(words)-1), pos)
unsaved = True
elif ch == ord('+') or ch == ord('a'):
status_line.clear()
status_line.refresh()
input_box = curses.textpad.Textbox(status_line)
input_box.edit()
word = input_box.gather().strip()
if len(prefix(word)) >= 3:
old = words.get(prefix(word), None)
if old:
status_line.clear()
status_line.addstr(0,0, "Replace {}? [y/n]".format(old))
status_line.refresh()
if not old or stdscr.getch() == ord('y'):
words[prefix(word)] = word
pos = sorted(words.values()).index(word)
unsaved = True
elif ch == curses.KEY_DOWN or ch == ord('j'):
pos = min(max(0, len(words)-1), pos+1)
elif ch == curses.KEY_UP or ch == ord('k'):
pos = max(0, pos-1)
elif ch == ord('/'):
status_line.clear()
status_line.refresh()
input_box = curses.textpad.Textbox(status_line)
input_box.edit()
word = input_box.gather()
word = ''.join(c for c in word.lower() if curses.ascii.islower(c))
if len(prefix(word)) >= 3 and prefix(word) in words:
pos = sorted(words.values()).index(words[prefix(word)])
elif ch == ord('s'):
while True:
candidate = suggestion_candidates.pop()
if len(prefix(candidate)) < 3 or prefix(candidate) in words:
continue
min_dist = None
for word in words.values():
d = memoized_distance(word, candidate)
if d is not None and (min_dist is None or d < min_dist):
min_dist = d
if min_dist is None or min_dist > SUGGESTION_DISTANCE_THRESHOLD:
suggestion = candidate
break
elif ch == ord('w'):
with open(filename, "w") as f:
for (i, w) in enumerate(sorted(words.values())):
dice = "{}{}{}{}".format(i // 6**3 + 1, (i // 6**2)% 6 + 1, (i // 6) % 6 + 1, i % 6 + 1)
entry = dice + " " + w
f.write(entry + "\n")
unsaved = False
elif ch == ord('q'):
if unsaved:
status_line.clear()
status_line.addstr(0,0, "Quit with unsaved data? [y/n]")
status_line.refresh()
if not unsaved or stdscr.getch() == ord('y'):
return
if len(sys.argv) != 2:
print("usage: {} <filename>".format(sys.argv[0]), file=sys.stderr)
exit(1)
curses.wrapper(main)