#!/usr/bin/env python3
"""
Tiny expression compiler.

Source language:
    let x = 1 + 2 * 3;
    print(x);

Targets:
    - tokens
    - AST
    - stack-machine IR
    - optional execution on a tiny VM
"""

from __future__ import annotations

import argparse
import json
import re
from dataclasses import dataclass
from typing import Any


TOKEN_SPEC = [
    ("LET", r"let\b"),
    ("PRINT", r"print\b"),
    ("INT", r"\d+"),
    ("IDENT", r"[A-Za-z_][A-Za-z0-9_]*"),
    ("ASSIGN", r"="),
    ("PLUS", r"\+"),
    ("STAR", r"\*"),
    ("LPAREN", r"\("),
    ("RPAREN", r"\)"),
    ("SEMI", r";"),
    ("SKIP", r"[ \t\r\n]+"),
    ("MISMATCH", r"."),
]

TOKEN_RE = re.compile("|".join(f"(?P<{name}>{pattern})" for name, pattern in TOKEN_SPEC))


@dataclass
class Token:
    kind: str
    text: str
    pos: int


def tokenize(source: str) -> list[Token]:
    tokens: list[Token] = []
    for match in TOKEN_RE.finditer(source):
        kind = match.lastgroup
        text = match.group()
        pos = match.start()
        if kind == "SKIP":
            continue
        if kind == "MISMATCH":
            raise SyntaxError(f"unexpected character {text!r} at {pos}")
        tokens.append(Token(kind, text, pos))
    tokens.append(Token("EOF", "", len(source)))
    return tokens


class Parser:
    def __init__(self, tokens: list[Token]) -> None:
        self.tokens = tokens
        self.pos = 0

    def peek(self) -> Token:
        return self.tokens[self.pos]

    def consume(self, kind: str) -> Token:
        token = self.peek()
        if token.kind != kind:
            raise SyntaxError(f"expected {kind}, got {token.kind} at {token.pos}")
        self.pos += 1
        return token

    def parse_program(self) -> dict[str, Any]:
        statements = []
        while self.peek().kind != "EOF":
            statements.append(self.parse_stmt())
        return {"type": "Program", "statements": statements}

    def parse_stmt(self) -> dict[str, Any]:
        token = self.peek()
        if token.kind == "LET":
            self.consume("LET")
            name = self.consume("IDENT").text
            self.consume("ASSIGN")
            expr = self.parse_expr()
            self.consume("SEMI")
            return {"type": "LetStmt", "name": name, "expr": expr}
        if token.kind == "PRINT":
            self.consume("PRINT")
            self.consume("LPAREN")
            expr = self.parse_expr()
            self.consume("RPAREN")
            self.consume("SEMI")
            return {"type": "PrintStmt", "expr": expr}
        raise SyntaxError(f"unexpected token {token.kind} at {token.pos}")

    def parse_expr(self) -> dict[str, Any]:
        node = self.parse_term()
        while self.peek().kind == "PLUS":
            self.consume("PLUS")
            rhs = self.parse_term()
            node = {"type": "Add", "left": node, "right": rhs}
        return node

    def parse_term(self) -> dict[str, Any]:
        node = self.parse_factor()
        while self.peek().kind == "STAR":
            self.consume("STAR")
            rhs = self.parse_factor()
            node = {"type": "Mul", "left": node, "right": rhs}
        return node

    def parse_factor(self) -> dict[str, Any]:
        token = self.peek()
        if token.kind == "INT":
            self.consume("INT")
            return {"type": "Int", "value": int(token.text)}
        if token.kind == "IDENT":
            self.consume("IDENT")
            return {"type": "Var", "name": token.text}
        if token.kind == "LPAREN":
            self.consume("LPAREN")
            node = self.parse_expr()
            self.consume("RPAREN")
            return node
        raise SyntaxError(f"unexpected token {token.kind} at {token.pos}")


class Compiler:
    def __init__(self) -> None:
        self.instructions: list[tuple[str, Any | None]] = []

    def emit(self, op: str, arg: Any | None = None) -> None:
        self.instructions.append((op, arg))

    def compile_program(self, program: dict[str, Any]) -> list[tuple[str, Any | None]]:
        for stmt in program["statements"]:
            self.compile_stmt(stmt)
        return self.instructions

    def compile_stmt(self, stmt: dict[str, Any]) -> None:
        kind = stmt["type"]
        if kind == "LetStmt":
            self.compile_expr(stmt["expr"])
            self.emit("STORE", stmt["name"])
            return
        if kind == "PrintStmt":
            self.compile_expr(stmt["expr"])
            self.emit("PRINT")
            return
        raise ValueError(f"unknown statement type: {kind}")

    def compile_expr(self, expr: dict[str, Any]) -> None:
        kind = expr["type"]
        if kind == "Int":
            self.emit("PUSH_CONST", expr["value"])
            return
        if kind == "Var":
            self.emit("LOAD", expr["name"])
            return
        if kind == "Add":
            self.compile_expr(expr["left"])
            self.compile_expr(expr["right"])
            self.emit("ADD")
            return
        if kind == "Mul":
            self.compile_expr(expr["left"])
            self.compile_expr(expr["right"])
            self.emit("MUL")
            return
        raise ValueError(f"unknown expression type: {kind}")


class VM:
    def __init__(self, instructions: list[tuple[str, Any | None]]) -> None:
        self.instructions = instructions
        self.stack: list[int] = []
        self.env: dict[str, int] = {}
        self.output: list[str] = []

    def run(self) -> list[str]:
        for op, arg in self.instructions:
            if op == "PUSH_CONST":
                self.stack.append(int(arg))
            elif op == "LOAD":
                if arg not in self.env:
                    raise NameError(f"undefined variable: {arg}")
                self.stack.append(self.env[str(arg)])
            elif op == "STORE":
                self.env[str(arg)] = self.stack.pop()
            elif op == "ADD":
                rhs = self.stack.pop()
                lhs = self.stack.pop()
                self.stack.append(lhs + rhs)
            elif op == "MUL":
                rhs = self.stack.pop()
                lhs = self.stack.pop()
                self.stack.append(lhs * rhs)
            elif op == "PRINT":
                value = self.stack.pop()
                self.output.append(str(value))
            else:
                raise ValueError(f"unknown instruction: {op}")
        return self.output


def ast_to_pretty_json(ast: dict[str, Any]) -> str:
    return json.dumps(ast, ensure_ascii=False, indent=2)


def instructions_to_text(instructions: list[tuple[str, Any | None]]) -> str:
    lines = []
    for index, (op, arg) in enumerate(instructions):
        if arg is None:
            lines.append(f"{index:02d}: {op}")
        else:
            lines.append(f"{index:02d}: {op} {arg}")
    return "\n".join(lines)


def main() -> None:
    parser = argparse.ArgumentParser(description="Tiny expression compiler")
    parser.add_argument("source", nargs="?", help="source file path")
    parser.add_argument("--code", help="source code inline")
    parser.add_argument("--emit-tokens", action="store_true")
    parser.add_argument("--emit-ast", action="store_true")
    parser.add_argument("--emit-ir", action="store_true")
    parser.add_argument("--run", action="store_true")
    args = parser.parse_args()

    if args.code is not None:
        source = args.code
    elif args.source:
        with open(args.source, "r", encoding="utf-8") as fh:
            source = fh.read()
    else:
        raise SystemExit("pass a source file or --code")

    tokens = tokenize(source)
    ast = Parser(tokens).parse_program()
    instructions = Compiler().compile_program(ast)

    if args.emit_tokens:
        print([(t.kind, t.text) for t in tokens if t.kind != "EOF"])
    if args.emit_ast:
        print(ast_to_pretty_json(ast))
    if args.emit_ir:
        print(instructions_to_text(instructions))
    if args.run:
        output = VM(instructions).run()
        print("\n".join(output))

    if not any([args.emit_tokens, args.emit_ast, args.emit_ir, args.run]):
        print(instructions_to_text(instructions))


if __name__ == "__main__":
    main()
