#!/usr/bin/env python3

# BASIC transplr. Transforms well-formed C64 BASIC V2 to something weird.
# https://github.com/c1570/Reblitz64

import csv
import os.path
import re
import sys

input_file = sys.argv[1] if len(sys.argv) > 1 else "input.prg"
output_file = "output.prg"
output_txt_file = "output.txt"


TOKENS = ["console.log('END')", "//FOR", "}", "DATA", "//INPUT", "//INPUT#", "//DIM", "READ",
          "LET", "GOTO", "RUN", "if(", "RESTORE", "",  # GOSUB
          "return", "//REM", "STOP", "ON", "WAIT", "LOAD", "SAVE", "VERIFY", "DEF", "POKE",
          "//PRINT#", "//console.log", "CONT", "LIST", "//CLR", "CMD", "//SYS", "//OPEN",
          "//CLOSE", "//GET", "NEW", "TAB(", "TO", "FN", "SPC(", "THEN", "NOT",
          "STEP", "+", "-", "*", "/", "^", "&", "|", ">", "=", "<",
          "SGN", "Math.floor", "Math.abs", "USR", "FRE", "POS", "SQR", "RND", "LOG",
          "EXP", "COS", "SIN", "TAN", "ATN", "PEEK", "LEN", "STR$", "c64_parse_float",
          "ASC", "String.fromCharCode", "LEFT$", "RIGHT$", "MID$", "GO"]
TOK_TO_STR = {}
val = 128
for tok in TOKENS:
    TOK_TO_STR[val] = tok
    val += 1


def prg_line_to_ascii(data: bytearray):
    def convert(c):
        return TOK_TO_STR[c] if c in TOK_TO_STR else chr(c)
    res = ""
    in_quotes = False
    for c in data:
        if c == 34:
            in_quotes = not in_quotes
        res += convert(c) if not in_quotes else chr(c)
    return res


class Element:
    def __init__(self, data):
        self.data = data

    def get_bytes(self):
        return self.data

    def __str__(self):
        return f"Element {self.data}"


def create_element(data):
    return Element(data)


all_variables = dict()


def get_var_branch_info(data: bytearray, branch_targets: list, jump_sources: list, var_table: set, close_block_before: list, curly_stack: list) -> str:
    if len(data) <= 2:
        return ""
    in_data = False
    in_quotes = False
    in_rem = False
    in_branch = False
    in_gosub = False
    line_if_clauses = 0
    in_if_condition = False
    in_array_index = False
    starting_token = -1
    line_contains_unconditional_branch = False
    line_number = data[0] + (data[1] << 8)
    var_name = bytearray()
    branch_target = bytearray()
    j_sources = ", ".join([f"GOTO from {j}" if j >= 0 else f"GOSUB from {abs(j)}" for j in jump_sources]) if jump_sources is not None else ""
    is_subroutine = min(jump_sources) < 0 if jump_sources is not None and len(jump_sources) > 0 else False
    txt = ""
    pre_txts = []
    txts = []
    while line_number in close_block_before:
        pre_txts.append(" }")
        close_block_before.remove(line_number)
        reason = curly_stack.pop()
        #print(f"Closing IF/ELSE brace of line {reason[2]}")
        assert reason[0] in ["IF", "ELSE"] and reason[1] == line_number, f"{line_number}: Closing {reason[0]} brace mismatch, would close {reason[0]} opened in line {reason[2]}"
    if j_sources != "" and min(jump_sources) < 0:
        assert len(curly_stack) == 0 or curly_stack.pop()[0] == "SUB"
        pre_txts.insert(0, "}")
        target_s = labels["sub"].get(str(line_number), f"sub_{str(line_number)}")
        pre_txts.append(f"\nfunction {target_s}(){{")
        curly_stack.append(["SUB", line_number, line_number])
    curly_stack_len = len(curly_stack)

    def get_var_id(var_name):
        return "var_" + var_name.replace(" ", "_").replace("$", "str").replace("%", "int").replace("(", "arr")
    def add_var_name_to_txt():
        nonlocal txt, in_array_index
        if len(var_name) > 0:
            var_label = var_name.decode('ascii')
            #print(f"Variable: {var_label}")
            var_name.clear()
            global labels, all_variables
            var_id = get_var_id(var_label)
            if var_id not in all_variables:
                all_variables[var_id] = ""
            var_title = labels["var"].get(var_label, "")
            var_table.add(var_label)
            if var_id.endswith("arr"):
                txt = txt + f" {var_id}["
                in_array_index = True
            else:
                txt = txt + f" {var_id} "
    def add_branch_target_to_txt():
        nonlocal txt, in_gosub
        if len(branch_target) > 0:
            target = int(branch_target.decode('ascii'))
            #print(f"Branch target: {target}")
            if in_gosub:
                target_s = labels["sub"].get(str(target), f"sub_{str(target)}")
                txt = txt + f" {target_s}() "
                target = -target
            else:
                txt = txt + f"{target} "
            branch_targets.append(target)
            branch_target.clear()
    def add_char_to_txt(c: int, detokenize: bool):
        nonlocal txt, in_array_index
        if c == 41 and in_array_index:
            txt = txt + "]"  # wrong in case the index uses functions but who cares
            in_array_index = False
            return
        if 31 < c < 91:
            txt = txt + chr(c)
        elif detokenize:
            if c in TOK_TO_STR:
                token = TOK_TO_STR[c]
                if len(token) > 1:
                    txt = txt + f" {token} "
                else:
                    txt = txt + token
            # else:
            #     txt = txt + f"&#xEE{c:2x};"
        # else:
        #     sc = ([0x80, 0x20, 0x00, 0x40, 0xC0, 0x60, 0x40, 0x60][c >> 5] + (c & 0x1f)) if c != 255 else 94
        #     char = f"&#xEE{sc:2x};"
        #     txt = txt + char

    for c in data[2:]:
        if in_branch and c > 127:
            add_branch_target_to_txt()
            in_branch = False
            in_gosub = False
        if not in_quotes and not in_data and not in_rem and c > 127:
            if c == 137 and not starting_token == 145:  # GOTO but not ON...GOTO
                line_contains_unconditional_branch = True
            if c == 128:  # END
                line_contains_unconditional_branch = True
            if c == 142:  # RETURN
                line_contains_unconditional_branch = True
            if starting_token == -1:
                starting_token = c

        if in_quotes or c == 34:  # "
            if c == 34:
                add_var_name_to_txt()
                in_quotes = not in_quotes
            add_char_to_txt(c, False)
        elif in_data:
            starting_token = -1
            if c == 58:  # :
                in_data = False
                txts.append(txt)
                txt = ""
                continue
            add_char_to_txt(c, False)
        elif in_rem:
            add_char_to_txt(c, True)
            continue
        elif c == 131:  # DATA
            add_var_name_to_txt()
            add_char_to_txt(c, True)
            in_data = True
        elif c == 143:  # REM
            add_var_name_to_txt()
            add_char_to_txt(c, True)
            in_rem = True
        elif c in [137, 138, 141, 155, 167]:  # GOTO, RUN, GOSUB, LIST, THEN
            add_var_name_to_txt()
            in_gosub = c == 141
            in_branch = True
            if in_if_condition:
                txts.append(txt + ") {")
                txt = ""
                line_if_clauses += 1
                in_if_condition = False
                curly_stack.append(["IF", None, line_number])
                #print(f"{line_number}: opening if, curly_stack_len: {curly_stack_len}, stack: {','.join([e[0] for e in curly_stack])}")
            if c != 167:
                add_char_to_txt(c, True)
        elif c == 40:  # (
            if len(var_name) > 0:
                var_name.append(c)
                add_var_name_to_txt()
            else:
                add_char_to_txt(c, True)
        elif (64 < c < 91) and len(var_name) == 0:  # A..Z
            add_branch_target_to_txt()
            in_branch = False
            in_gosub = False
            var_name.append(c)
        elif (64 < c < 91) or (47 < c < 58) and len(var_name) > 0:  # A..Z0..9
            if var_name[-1] in [36, 37]:  # %, $
                add_var_name_to_txt()
            var_name.append(c)
        elif (47 < c < 58) and in_branch:
            branch_target.append(c)
        elif (c in [36, 37]) and len(var_name) > 0:  # %, $
            var_name.append(c)
            add_branch_target_to_txt()
        elif c == 32:  # spaces are not significant but possibly _important_ ("F N")
            if len(var_name) > 0:
                var_name.append(c)
            continue
        else:
            if in_branch:
                if c == 44:  # ,
                    add_branch_target_to_txt()
                elif c == 58:  # :
                    add_branch_target_to_txt()
                    in_branch = False
                    in_gosub = False
            add_var_name_to_txt()
            if in_if_condition:
                if c == 178 and not txt.endswith("<") and not txt.endswith(">"):  # = token
                    txt = txt + "=="
                    continue
                if c == 177:  # > token
                    if txt.endswith("<"):
                        txt = txt[:-1] + "!="
                        continue
                if c == 176:  # OR token
                    txt = txt + "||"
                    continue
                if c == 175:  # AND token
                    txt = txt + "&&"
                    continue
            if c == 58:  # :
                starting_token = -1
                txts.append(txt)
                txt = ""
                continue
            elif c == 139:  # IF
                in_if_condition = True
            if c == 129:  # FOR
                curly_stack.append(["FOR", None, line_number])
            if c == 130:  # NEXT
                reason = curly_stack.pop()
                assert reason[0] == "FOR", f"{line_number}: NEXT brace mismatch, would close {reason[0]} opened in line {reason[2]}"
            add_char_to_txt(c, True)

    # flush everything
    add_var_name_to_txt()
    add_branch_target_to_txt()

    # handling annotations
    if 'REM "' in txt:
        flags = re.search('REM "(.*)', txt).group(1).split(",")
        if 'ELSEIF' in flags:
            assert "GOTO" in txts.pop()
            reason = curly_stack.pop()
            assert reason[0] == "IF", f"{line_number}: Misplaced ELSEIF would close non-IF block opened by {reason[0]} in line {reason[2]}"
            txts.append(" } else")
            txt = ""
            assert line_if_clauses == 1
            line_if_clauses = 0
            line_contains_unconditional_branch = False
        if 'ELSE' in flags:
            txt = ""
            goto_line = txts.pop()
            assert "GOTO" in goto_line
            reason = curly_stack.pop()
            assert reason[0] == "IF", f"{line_number}: Misplaced ELSE would close non-IF block opened by {reason[0]} in line {reason[2]}"
            goto_target = int(re.search('GOTO (\d*)', goto_line).group(1))
            if len(txts) > 0 and txts[-1].startswith(" if"):
                # special case: IF...GOTOXXXX:REM"ELSE - this is actually just a multiline THEN with inverted condition
                if_line = txts.pop()
                if_line = f" if(!{if_line[3:].replace(') {', ')) {')}"
                txts.append(if_line)
                curly_stack.append(["IF", goto_target, line_number])
            else:
                txts.append(" } else {")
                curly_stack.append(["ELSE", goto_target, line_number])
            line_if_clauses -= 1
            line_contains_unconditional_branch = False
            close_block_before.append(goto_target)
        if 'IFBEGIN' in flags:
            goto_line = txts.pop()
            assert "GOTO" in goto_line
            if_line = txts.pop()
            assert if_line.startswith(" if"), f"{line_number}: IFBEGIN only supports the IF...GOTO pattern"
            txt = ""
            if_line = f" if(!{if_line[3:].replace(') {', ')) {')}"
            #print(f"IFBEGIN: {if_line}")
            assert line_if_clauses == 1
            line_if_clauses = 0
            txts.append(if_line)
            # curly_braces is already taken care of in IF token handling
        if 'BREAK' in flags:
            assert "GOTO" in txts.pop()
            txts.append(" break")
            txt = ""
        if 'CONTINUE' in flags:
            assert "GOTO" in txts.pop()
            txts.append(" continue")
            txt = ""
        if 'DO' in flags:
            txts.insert(0, " do {")
            txt = ""
            curly_stack.insert(curly_stack_len, ["DO", None, line_number])
            j_sources = ""  # disable mixed GOTO/GOSUB warning
            #print(f"{line_number}: do, curly_stack_len: {curly_stack_len}, stack: {','.join([e[0] for e in curly_stack])}")
        if 'WHILE(TRUE)' in flags:
            assert "GOTO" in txts.pop()
            reason = curly_stack.pop()
            assert reason[0] == "DO", f"{line_number}: Misplaced WHILE would close non-DO block opened by {reason[0]} in line {reason[2]}"
            txts.append(" } while(true)")
            txt = ""
    if txt != "":
        txts.append(txt)

    def it_is_integer(s):
        s = s.strip()
        res = s.isdigit() or (s.startswith("-") and s[1:].isdigit()) or (" " not in s and s.startswith("var_") and s.endswith("int"))
        #print(f"{s} is integer: {res}")
        return res

    # handle DIM/FOR/int var assignments
    for idx, txt in enumerate(txts):
        if "//DIM" in txt and "arr[" in txt:
            global all_variables
            m = re.match(r".*DIM\s*(.*)\[(.*)\]", txt)
            if m:
                all_variables[m.group(1)] = int(m.group(2)) + 1
        elif "//FOR" in txt:
            assert "STEP" not in txt, f"{line_number}: FOR...STEP is not supported"
            m = re.match(r".*FOR\s*(.*)\=\s*(.*)\s*TO\s*(.*)\s*", txt)
            assert m.group(3).strip() != m.group(1).strip(), f"{line_number}: FOR stop must be constant"
            var_name = m.group(1)
            start_val = m.group(2)
            if var_name.endswith("int"):
                if not it_is_integer(start_val):
                    start_val = f"Math.floor({start_val})"
            txts[idx] = f" for({var_name}={start_val}; {var_name}<={m.group(3)}; {var_name}++) {{"
        elif "int = " in txt:
            m = re.match(r".*(var_[A-Z0-9]{1,2}int) \= (.*)\s*", txt)
            if m and not it_is_integer(m.group(2)):
                txts[idx] = f" {m.group(1)} = Math.floor({m.group(2)})"

    #print(f"Line {line_number}, closing {line_if_clauses} if braces, stack: {','.join([e[0] for e in curly_stack])}")
    while line_if_clauses > 0:
        reason = curly_stack.pop()
        assert reason[0] in ["IF", "ELSE"], f"{line_number}: Implicit {reason[0]} closing brace would close block opened by {reason[0]} in line {reason[2]}"
        txts.append(" }")
        line_if_clauses -= 1
    #else:
    #    if line_contains_unconditional_branch:
    #        txts.append("")
    if "GOTO" in j_sources and "GOSUB" in j_sources:
        print(f"line {line_number} is target of both GOTO and GOSUB. Check/fix! {j_sources}")
    return pre_txts + txts


class Line:
    def __init__(self, data):
        self.number = data[0] + (data[1] << 8)
        self.elements = []
        data = data[2:]
        #print(f"{self.number} {prg_line_to_ascii(data)} --- {data}")
        off = 0
        in_quotes = False
        in_data = False
        while off < len(data):
            if data[off] == 34:  # "
                in_quotes = not in_quotes
            elif data[off] == 131:  # DATA
                in_data = True
            elif data[off] == 58:  # :
                if not in_quotes:
                    self.elements.append(create_element(data[:off]))
                    self.elements.append(create_element(data[off:off + 1]))  # TODO should include whitespace here
                    data = data[off + 1:]
                    in_data = False
                    off = -1
            elif data[off] == 167:  # THEN
                if not in_quotes and not in_data:
                    self.elements.append(create_element(data[:off + 1]))
                    data = data[off + 1:]
                    off = -1
            off += 1
        if len(data) > 0:
            self.elements.append(create_element(data))

    def get_bytes(self, branch_table=[]):
        out = bytearray([self.number & 255, (self.number >> 8) & 255])
        empty_line = True
        for el in self.elements:
            el_bytes = bytes(el.get_bytes())
            if el_bytes != b':' and el_bytes != b'':
                empty_line = False
            out.extend(el_bytes)
        if empty_line:
            if self.number in branch_table and len(branch_table[self.number]) > 0:
              print(f"{self.number} is in branch_table! len(out) is {len(out)}")
              if len(out) == 2:
                  out.extend(b'\x8f')  # all elements got removed but line is target of a branch: add REM
            else:
              return bytearray()  # empty line that's not target of a branch: just ignore
        out.extend(b'\00')
        return out

    def __str__(self):
        return f"{self.number} {prg_line_to_ascii(self.get_bytes()[2:-1])} --- {len(self.elements)} elements"


class Program:
    def __init__(self, data):
        self.startaddr = data[0] + (data[1] << 8)
        self.lines = []
        curoff = 2
        while curoff < len(data):
            if data[curoff] == 0 and data[curoff + 1] == 0:
                return
            line_data = data[curoff + 2:]
            line_data = line_data[0:line_data[2:].find(b'\x00') + 2]
            # print(line_data)
            self.lines.append(Line(line_data))
            curoff += len(line_data) + 3

    def get_bytes(self, branch_table=[]):
        outprg = bytearray([self.startaddr & 255, (self.startaddr >> 8) & 255])
        curadr = self.startaddr
        for line in self.lines:
            line_bytes = line.get_bytes(branch_table)
            if len(line_bytes) == 0:
                continue
            curadr += len(line_bytes) + 2
            outprg.append(curadr & 255)
            outprg.append((curadr >> 8) & 255)
            outprg.extend(line_bytes)
        outprg.extend([0, 0])
        return outprg

    def __str__(self):
        return f"Program with {len(self.lines)} lines"


def analyze_program(program: Program) -> (dict, str, set):
    line_number_to_jump_source = dict()
    var_table = set()
    curly_stack = list()
    close_block_before = list()
    for line in program.lines:
        branch_targets = []
        get_var_branch_info(
            line.get_bytes()[:-1],
            branch_targets=branch_targets,
            jump_sources=[],
            var_table=var_table,
            close_block_before=close_block_before,
            curly_stack=curly_stack
        )
        for target in branch_targets:
            sources = line_number_to_jump_source.setdefault(abs(target), [])
            sources.append(line.number if target >= 0 else -line.number)
    lines_as_txt = ["function main() {"]
    close_block_before = list()
    curly_stack = list()
    for line in program.lines:
        line_txt = get_var_branch_info(
            line.get_bytes()[:-1],
            branch_targets=[],
            jump_sources=line_number_to_jump_source.setdefault(line.number, []),
            var_table=set(),
            close_block_before=close_block_before,
            curly_stack=curly_stack
        )
        lines_as_txt.extend(line_txt)

    # build variable prefix
    prg_lines = []
    for var_label in sorted(all_variables.keys()):
        if all_variables[var_label] != "":
            if var_label.endswith("intarr"):
                prg_lines.append(f"let {var_label} = new Int16Array({all_variables[var_label]})")
            elif var_label.endswith("strarr"):
                prg_lines.append(f"let {var_label} = Array.apply('', Array({all_variables[var_label]})).map(()=>'')")
            elif var_label.endswith("arr"):
                prg_lines.append(f"let {var_label} = new Float32Array({all_variables[var_label]})")
            else:
                raise Exception(f"Bug: Array size given for plain variable {var_label}")
        else:
            var_init = "0"
            if "str" in var_label:
                var_init = '""'
            prg_lines.append(f"let {var_label} = {var_init}")

    prg_lines.extend(lines_as_txt)
    return line_number_to_jump_source, "\n".join(prg_lines), var_table


with open(input_file, "rb") as infile:
    inprg = bytes(infile.read())
labels = {"var": {}, "sub": {}}
if os.path.isfile("labels.csv"):
    with open("labels.csv", newline="") as csvfile:
        csvreader = csv.reader(csvfile, delimiter=',')
        for row in csvreader:
            if row[2] != "":
                labels[row[0]][row[1]] = row[2]
input_program = Program(inprg)
branch_table, output_txt, var_table = analyze_program(input_program)
#crank_the_print(input_program, branch_table)
#outprg = input_program.get_bytes(branch_table)
#with open(output_file, "wb") as outfile:
#    outfile.write(outprg)
with open(output_txt_file, "w") as outfile:
    outfile.write(output_txt+"\n}\n")
