#!/usr/bin/env python3

import argparse
import asyncio
import getpass
import json
import math
import pathlib
import sys
from tqdm import tqdm
from websockets import connect

def file_loader(files):
    with tqdm(desc="Total", total=sum(size for (path, size) in files), unit='B', unit_scale=True, leave=True, position=1) as total_progress:
        for (path, size) in files:
            with tqdm(desc=path.name, total=size, unit='B', unit_scale=True, leave=True, position=0) as file_progress:
                with path.open("rb") as f:
                    while f.tell() < size:
                        data = f.read(min(16384, size - f.tell()))
                        if data == "":
                            tqdm.write("file ended early!")
                            exit(1)
                        total_progress.update(len(data))
                        file_progress.update(len(data))
                        yield data

async def send(paths, uri, password, lifetime, collection_name=None):
    paths = [path for path in paths if path.is_file()]
    fileMetadata = [
        {
            "name": path.name,
            "size": path.stat().st_size,
            "modtime": math.floor(path.stat().st_mtime * 1000),
        } for path in paths
    ]
    manifest = {
        "files": fileMetadata,
        "lifetime": lifetime,
        "password": password,
    }
    if collection_name is not None:
        manifest["collection_name"] = collection_name

    async with connect(uri) as ws:
        await ws.send(json.dumps(manifest))
        resp = json.loads(await ws.recv())
        if resp["type"] != "ready":
            print("unexpected response: {}".format(resp))
            exit(1)
        print("Download code: {}".format(resp["code"]))
        loader = file_loader([(paths[i], fileMetadata[i]["size"]) for i in range(len(paths))])
        for data in loader:
            await ws.send(data)
            resp = await ws.recv()
            if resp != "ack":
                tqdm.write("unexpected response: {}".format(resp))
                exit(1)

parser = argparse.ArgumentParser(description="Upload files to transbeam")
parser.add_argument("-l", "--lifetime", type=int, default=7, help="Lifetime in days for files (default 7)")
parser.add_argument("-u", "--uri", type=str, default="wss://transbeam.link/upload", help="Websocket URI for transbeam (default wss://transbeam.link/upload)")
parser.add_argument("-n", "--collection-name", type=str, help="Name for a collection of multiple files")
parser.add_argument("files", type=pathlib.Path, nargs="+", help="Files to upload")

async def main():
    args = parser.parse_args()
    if len(args.files) == 1 and args.collection_name is not None:
        print("--collection-name is only applicable when multiple files are being uploaded")
        exit(1)
    password = getpass.getpass()
    await send(args.files, args.uri, password, args.lifetime, args.collection_name)

asyncio.run(main())