#!/usr/bin/env python3
"""
cpc_bas_tokenize.py - Amstrad CPC Locomotive BASIC tokenizer
Converts plain ASCII .bas listing to tokenized AMSDOS .bas (or .dsk).

Usage:
    python3 cpc_bas_tokenize.py input.bas output.bas
    python3 cpc_bas_tokenize.py input.bas output.dsk --dsk [--name PROGNAME]
"""
import struct, sys, os, re, math, argparse

# ── KEYWORD TABLE (token = index + 0x80) ──────────────────────────────────────
KEYWORDS = [
    "AFTER","AUTO","BORDER","CALL","CAT","CHAIN","CLEAR","CLG",          # 80-87
    "CLOSEIN","CLOSEOUT","CLS","CONT","DATA","DEF","DEFINT",              # 88-8E
    "DEFREAL","DEFSTR","DEG","DELETE","DIM","DRAW","DRAWR","EDIT",        # 8F-96
    "ELSE","END","ENT","ENV","ERASE","ERROR","EVERY","FOR",               # 97-9E
    "GOSUB","GOTO","IF","INK","INPUT","KEY","LET","LINE","LIST",          # 9F-A7
    "LOAD","LOCATE","MEMORY","MERGE","MID$","MODE","MOVE","MOVER",        # A8-AF
    "NEXT","NEW","ON",None,None,"SQ","OPENIN",           # B0-B6
    "OPENOUT","ORIGIN","OUT","PAPER","PEN","PLOT","PLOTR","POKE",         # B7-BE
    "PRINT","'","RAD","RANDOMIZE","READ","RELEASE","REM","RENUM",         # BF-C6
    "RESTORE","RESUME","RETURN","RUN","SAVE","SOUND","SPEED","STOP",      # C7-CE
    "SYMBOL","TAG","TAGOFF","TROFF","TRON","WAIT","WEND","WHILE",         # CF-D6
    "WIDTH","WINDOW","WRITE","ZONE","DI","EI","FILL","GRAPHICS",          # D7-DE
    "MASK","FRAME","CURSOR",None,"ERL","FN","SPC","STEP","SWAP",          # DF-E7
    None,None,"TAB","THEN","TO","USING",                                   # E8-ED
    ">","=",">=","<","<>","<=",                                             # EE-F3
    "+","-","*","/","^","\\","AND","MOD","OR","XOR","NOT",                 # F4-FE
    None,                                                                   # FF (function prefix)
]

# ── FUNCTION TABLE (token = 0xFF + index) ─────────────────────────────────────
# Verified against actual CPC tokenized output.
FUNCTIONS = {
    0x00:"ABS",   0x01:"ASC",    0x02:"ATN",    0x03:"CHR$",
    0x04:"CINT",  0x05:"COS",    0x06:"CREAL",  0x07:"EXP",
    0x08:"FIX",   0x09:"FRE",    0x0A:"INKEY",  0x0B:"INP",
    0x0C:"INT",   0x0D:"JOY",    0x0E:"LEN",    0x0F:"LOG",
    0x10:"LOG10", 0x11:"LOWER$", 0x12:"PEEK",   0x13:"REMAIN",
    0x14:"SGN",   0x15:"SIN",    0x16:"SPACE$", 0x17:"SQ",
    0x18:"SQR",   0x19:"STR$",   0x1A:"TAN",    0x1B:"UNT",
    0x1C:"UPPER$",0x1D:"VAL",
    0x40:"EOF",   0x41:"ERR",    0x42:"HIMEM",  0x43:"INKEY$",
    0x44:"PI",    0x45:"RND",    0x46:"TIME",   0x47:"XPOS",
    0x48:"YPOS",  0x49:"DERR",
    # Extended functions (verified against binary: INSTR=0x74, LEFT$=0x75, etc.)
    0x74:"INSTR", 0x75:"LEFT$",  0x76:"MAX",    0x77:"MIN",
    0x78:"POS",   0x79:"RIGHT$", 0x7A:"ROUND",  0x7B:"STRING$",
    0x7C:"TEST",  0x7D:"TESTR",  0x7E:"COPYCHR$",0x7F:"VPOS",
}

_kw_map = {kw: 0x80+i for i,kw in enumerate(KEYWORDS) if kw is not None}
_fn_map = {fn: idx for idx,fn in FUNCTIONS.items()}

# Sort longest-first for greedy matching
_sorted_kw = sorted(_kw_map, key=lambda x: -len(x))
_sorted_fn = sorted(_fn_map, key=lambda x: -len(x))

# Operator characters (non-alpha tokens)
_OP_CHARS = set('><=+-*/^\\')

# Keywords that are followed by line number references
_LINENUM_KW = {
    _kw_map.get("GOTO"), _kw_map.get("GOSUB"), _kw_map.get("RESTORE"),
    _kw_map.get("LIST"), _kw_map.get("RUN"),
}
_THEN_TOK = _kw_map.get("THEN")
_ON_TOK   = _kw_map.get("ON")

# ── NUMBER ENCODING ───────────────────────────────────────────────────────────

def encode_number(s, as_lineref=False):
    """Encode numeric literal. as_lineref -> always 0x1E + 2 bytes."""
    s = s.strip()
    if re.match(r'^&[0-9A-Fa-f]+$', s):
        return bytes([0x1C]) + struct.pack('<H', int(s[1:],16) & 0xFFFF)
    if re.match(r'^&[Xx][01]+$', s):
        return bytes([0x1B]) + struct.pack('<H', int(s[2:],2) & 0xFFFF)
    try:
        if '.' in s or 'E' in s.upper():
            return encode_float(float(s))
        v = int(s)
        if as_lineref:
            return bytes([0x1E]) + struct.pack('<H', v)
        if 0 <= v <= 9:  return bytes([0x0E + v])       # 0..9: single byte
        if v <= 255:      return bytes([0x19, v])         # 10..255: 2 bytes
        if v <= 65535:    return bytes([0x1A]) + struct.pack('<H', v)  # 16-bit
        return encode_float(float(v))
    except ValueError:
        return s.encode('ascii')

def encode_float(f):
    if f == 0.0: return bytes([0x1F,0,0,0,0,0])
    sign = 1 if f < 0 else 0
    f = abs(f)
    exp = math.floor(math.log2(f)) + 1
    m = int((f / (2.0**exp)) * (2**32)) & 0x7FFFFFFF
    if sign: m |= 0x80000000
    return bytes([0x1F,m&0xFF,(m>>8)&0xFF,(m>>16)&0xFF,(m>>24)&0xFF,(exp+129)&0xFF])

# ── LINE TOKENIZER ────────────────────────────────────────────────────────────

def tokenize_line(text):
    """Tokenize one BASIC line body (no line number prefix)."""
    out = bytearray()
    i = 0; in_str = False; in_rem = False
    upper = text.upper()
    # State for line number context
    expect_linenum = False     # next number should use 0x1E
    after_on_goto  = False     # in ON..GOTO/GOSUB number list (commas also give line refs)
    last_tok       = None      # last token emitted (for context tracking)

    while i < len(text):
        c = text[i]

        if in_str:
            out.append(ord(c)); 
            if c == '"': in_str = False
            i += 1; continue

        if in_rem:
            out.append(ord(c)); i += 1; continue

        if c == '"':
            out.append(ord(c)); in_str = True; i += 1; continue

        if c == ':':
            out.append(0x01)
            expect_linenum = False; after_on_goto = False
            i += 1; continue

        if c == ',':
            out.append(0x2C)
            # In ON..GOTO/GOSUB list, commas separate more line refs
            # (expect_linenum stays True if we're already in that mode)
            i += 1; continue

        # RSX call
        if c == '|':
            out.append(0x7C); i += 1
            while i < len(text) and text[i] == ' ': i += 1
            j = i
            while j < len(text) and (text[j].isalnum() or text[j] in '_.'): j += 1
            name = text[i:j].upper()
            if name:
                out.append(0x00)
                nb = bytearray(name.encode('ascii')); nb[-1] |= 0x80; out += bytes(nb)
            i = j; expect_linenum = False; continue

        # Numbers
        if c.isdigit() or (c == '&' and i+1 < len(text) and text[i+1].upper() in '0123456789ABCDEFXx'):
            use_lineref = expect_linenum
            if c == '&':
                j = i+1
                if j < len(text) and text[j].upper() == 'X':
                    j += 1
                    while j < len(text) and text[j] in '01': j += 1
                else:
                    while j < len(text) and text[j].upper() in '0123456789ABCDEF': j += 1
                out += encode_number(text[i:j], as_lineref=False)  # hex/bin never lineref
            else:
                j = i
                while j < len(text) and (text[j].isdigit() or text[j] == '.'): j += 1
                if j < len(text) and text[j].upper() == 'E':
                    j += 1
                    if j < len(text) and text[j] in '+-': j += 1
                    while j < len(text) and text[j].isdigit(): j += 1
                is_float = '.' in text[i:j] or 'E' in text[i:j].upper()
                out += encode_number(text[i:j], as_lineref=(use_lineref and not is_float))
            # After ON..GOTO line numbers, commas keep expect_linenum True
            # but after any other context, reset
            if not after_on_goto:
                expect_linenum = False
            i = j; continue

        # Space: pass through, don't reset linenum state
        if c == ' ':
            out.append(0x20); i += 1; continue

        # Try keywords (greedy, longest first)
        if c.isalpha() or c == "'":
            ru = upper[i:]
            matched_kw = None
            for kw in _sorted_kw:
                if ru.startswith(kw):
                    ep = i + len(kw)
                    lc = kw[-1]
                    if lc.isalpha() or lc == '$':
                        if ep < len(text) and (text[ep].isalnum() or text[ep] in '$%!_'):
                            continue
                    matched_kw = kw; break

            if matched_kw:
                tok = _kw_map[matched_kw]
                out.append(tok)
                i += len(matched_kw)
                if matched_kw in ("REM", "'"): in_rem = True; continue
                # Update line-number-context state
                expect_linenum = tok in _LINENUM_KW
                if tok == _THEN_TOK:
                    # THEN + number -> line ref; THEN + keyword -> no
                    j = i
                    while j < len(text) and text[j] == ' ': j += 1
                    if j < len(text) and text[j].isdigit():
                        expect_linenum = True
                    else:
                        expect_linenum = False
                after_on_goto = (tok in (_kw_map.get("GOTO"), _kw_map.get("GOSUB")) 
                                 and last_tok == _ON_TOK)
                last_tok = tok
                continue

            # Try functions
            matched_fn = None
            for fn in _sorted_fn:
                if ru.startswith(fn):
                    ep = i + len(fn)
                    lc = fn[-1]
                    if lc.isalpha():
                        if ep < len(text) and (text[ep].isalnum() or text[ep] in '$%!_'):
                            continue
                    matched_fn = fn; break
            if matched_fn:
                out.append(0xFF); out.append(_fn_map[matched_fn])
                i += len(matched_fn); expect_linenum = False; continue

            # Variable name
            j = i
            while j < len(text) and (text[j].isalnum() or text[j] == '_'): j += 1
            vname = text[i:j]
            if j < len(text) and text[j] in '$%!':
                vname += text[j]; j += 1
            if vname.endswith('$'):   tb, core = 0x03, vname[:-1]
            elif vname.endswith('%'): tb, core = 0x02, vname[:-1]
            elif vname.endswith('!'): tb, core = 0x04, vname[:-1]
            else:                     tb, core = 0x0D, vname
            nb = bytearray(core.encode('ascii')); nb[-1] |= 0x80
            out += bytes([tb, 0x00, 0x00]) + bytes(nb)
            expect_linenum = False; last_tok = None
            i = j; continue

        # Operator characters (tokenized)
        if c in _OP_CHARS:
            # Try 2-char operators first
            two = upper[i:i+2]
            one = upper[i:i+1]
            matched_op = None
            if two in _kw_map:
                matched_op = two
            elif one in _kw_map:
                matched_op = one
            if matched_op:
                out.append(_kw_map[matched_op])
                i += len(matched_op)
                expect_linenum = False; continue
            # Not a known token, pass through
            if 0x20 <= ord(c) <= 0x7E: out.append(ord(c))
            i += 1; continue

        # Other printable characters
        if 0x20 <= ord(c) <= 0x7E: out.append(ord(c))
        i += 1

    return bytes(out)

# ── FULL PROGRAM TOKENIZER ────────────────────────────────────────────────────

def tokenize_basic(text):
    """
    Tokenize full BASIC program.
    Line record: [2-byte length][2-byte linenum][content][0x00]
    where length = total bytes including the 2-byte length field itself.
    End: 0x00 0x00
    """
    parsed = []
    for line in text.splitlines():
        line = line.strip()
        if not line: continue
        m = re.match(r'^(\d+)\s?(.*)', line)
        if m: parsed.append((int(m.group(1)), m.group(2)))
    parsed.sort()
    body = bytearray()
    for linenum, rest in parsed:
        content = tokenize_line(rest)
        rec = struct.pack('<H', 2+2+len(content)+1) + struct.pack('<H', linenum) + content + b'\x00'
        body += rec
    body += b'\x00\x00'
    return bytes(body)

# ── AMSDOS HEADER ─────────────────────────────────────────────────────────────

def make_amsdos_header(filename, data_len):
    h = bytearray(128)
    base = os.path.splitext(filename)[0][:8].upper()
    ext  = (os.path.splitext(filename)[1][1:4] if os.path.splitext(filename)[1] else 'BAS').upper()
    h[0x00] = 0x00
    h[0x01:0x09] = (base+'        ')[:8].encode('ascii')
    h[0x09:0x0C] = (ext+'   ')[:3].encode('ascii')
    h[0x0C]=0; h[0x0D]=0; h[0x0E]=0xFF; h[0x12]=0
    struct.pack_into('<H', h, 0x18, data_len)
    struct.pack_into('<H', h, 0x1A, 0x0000)
    h[0x1F]=0xFF
    struct.pack_into('<H', h, 0x40, data_len)
    struct.pack_into('<H', h, 0x42, 0x0000)
    h[0x43]=data_len&0xFF; h[0x44]=(data_len>>8)&0xFF; h[0x45]=(data_len>>16)&0xFF
    chk=sum(h[:67])&0xFFFF; h[0x43]=chk&0xFF; h[0x44]=(chk>>8)&0xFF
    return bytes(h)

# ── DSK CREATION ─────────────────────────────────────────────────────────────

TRACKS=40; SIDES=1; SECTORS=9; SEC_SZ=512
TRACK_SZ=SECTORS*SEC_SZ; TRACK_REC=256+TRACK_SZ  # 4864 bytes per track record
BLOCK_SZ=1024; DIR_BLOCKS=2
# Interleaved sector IDs (standard CPC format 1, 2:1 interleave)
_SECTOR_IDS=[0xC1,0xC6,0xC2,0xC7,0xC3,0xC8,0xC4,0xC9,0xC5]

def make_empty_dsk():
    # Extended CPC DSK format (as used by CPCemu, iDSK, etc.)
    hdr = bytearray(256)
    hdr[:34] = b"EXTENDED CPC DSK File\r\nDisk-Info\r\n"
    hdr[0x22:0x30] = b"cpc_bas_tok   " # creator (14 bytes)
    hdr[0x30]=TRACKS; hdr[0x31]=SIDES; hdr[0x32]=0; hdr[0x33]=0
    for t in range(TRACKS): hdr[0x34+t]=TRACK_REC//256  # 0x13 = 19
    dsk = bytearray(hdr)
    for t in range(TRACKS):
        th = bytearray(256)
        th[:12] = b"Track-Info\r\n"
        th[0x10]=t; th[0x11]=0; th[0x14]=2; th[0x15]=SECTORS; th[0x16]=0x4E; th[0x17]=0xE5
        for s in range(SECTORS):
            b=0x18+s*8
            th[b]=t; th[b+1]=0; th[b+2]=_SECTOR_IDS[s]; th[b+3]=2; th[b+4]=0; th[b+5]=0
            struct.pack_into('<H',th,b+6,SEC_SZ)
        dsk += bytes(th)+bytes([0xE5]*TRACK_SZ)
    return bytearray(dsk)

# Interleaved sector order: physical sector index for each logical sector 0..8
_LOG_ORDER=[0,2,4,6,8,1,3,5,7]  # phys sector for logical sector within track

def log_sec_offset(log_sec):
    t=log_sec//SECTORS; ls=log_sec%SECTORS; ps=_LOG_ORDER[ls]
    return 256+t*TRACK_REC+256+ps*SEC_SZ

def write_block(dsk, blk, data_1k):
    for h in range(2):
        ls=blk*2+h; off=log_sec_offset(ls)
        chunk=bytes(data_1k[h*SEC_SZ:(h+1)*SEC_SZ])+bytes(max(0,SEC_SZ-len(data_1k)+h*SEC_SZ))
        dsk[off:off+SEC_SZ]=chunk[:SEC_SZ]

def write_file_to_dsk(dsk, filename, data):
    dsk=bytearray(dsk); data=bytearray(data)
    while len(data)%128: data+=b'\x00'
    base=os.path.splitext(filename)[0][:8].upper()
    ext=(os.path.splitext(filename)[1][1:4] if os.path.splitext(filename)[1] else 'BAS').upper()
    for ls in range(DIR_BLOCKS*2):
        off=log_sec_offset(ls); dsk[off:off+SEC_SZ]=bytes([0xE5]*SEC_SZ)
    blk=DIR_BLOCKS; ext_num=0; pos=0; dir_idx=0
    while pos<len(data) and dir_idx<64:
        entry=bytearray(32)
        entry[0]=0
        entry[1:9]=(base+'        ')[:8].encode('ascii')
        entry[9:12]=(ext+'   ')[:3].encode('ascii')
        entry[12]=ext_num&0x1F; entry[13]=0; entry[14]=(ext_num>>5)&0x3F
        rem=len(data)-pos; recs=min(128,(rem+127)//128); entry[15]=recs
        blks=min(16,(recs*128+BLOCK_SZ-1)//BLOCK_SZ)
        for b in range(blks): entry[16+b]=blk+b
        de_off=log_sec_offset(0)+dir_idx*32
        dsk[de_off:de_off+32]=entry
        for b in range(blks):
            chunk=bytes(data[pos+b*BLOCK_SZ:pos+(b+1)*BLOCK_SZ])
            write_block(dsk,blk+b,chunk)
        blk+=blks; pos+=blks*BLOCK_SZ; ext_num+=1; dir_idx+=1
    return bytes(dsk)

# ── MAIN ──────────────────────────────────────────────────────────────────────

def main():
    ap=argparse.ArgumentParser(description='CPC Locomotive BASIC tokenizer')
    ap.add_argument('input'); ap.add_argument('output')
    ap.add_argument('--dsk',action='store_true',help='Create DSK image')
    ap.add_argument('--name',default=None,help='8-char filename for DSK/header')
    args=ap.parse_args()
    with open(args.input,'r',encoding='latin-1') as f: text=f.read()
    print(f"Tokenizing {args.input}...")
    body=tokenize_basic(text)
    print(f"  Body: {len(body)} bytes")
    out_name=(args.name or os.path.splitext(os.path.basename(args.input))[0]).upper()[:8]
    bas_fn=out_name+'.BAS'
    amsdos=make_amsdos_header(bas_fn,len(body))+body
    print(f"  AMSDOS file: {len(amsdos)} bytes")
    if args.dsk:
        print(f"Creating DSK: {args.output}")
        dsk=make_empty_dsk()
        dsk=write_file_to_dsk(dsk,bas_fn,amsdos)
        with open(args.output,'wb') as f: f.write(dsk)
        print(f"  Done: '{bas_fn}' in {args.output} ({len(dsk)} bytes)")
        print(f"\nOn CPC: RUN \"{out_name}\"")
    else:
        with open(args.output,'wb') as f: f.write(amsdos)
        print(f"  Done: {args.output}")
        print(f"\nAdd to DSK:  iDSK disk.dsk -i {args.output} -t 0")
        print(f"List result: iDSK disk.dsk -b {bas_fn}")

if __name__=='__main__': main()
