From 56b2b36ad3212fb0c6546af86631f8cdab14f769 Mon Sep 17 00:00:00 2001 From: ts4f <17165475+ts4f@users.noreply.github.com> Date: Fri, 8 May 2026 16:09:57 +0200 Subject: [PATCH] claude refactor --- .gitignore | 26 + CLAUDE.md | 55 ++ pyproject.toml | 5 + starlet.py | 1494 +------------------------------------ starlet/__init__.py | 15 + starlet/__main__.py | 55 ++ starlet/c_backend.py | 35 + starlet/ir.py | 64 ++ starlet/lexer.py | 136 ++++ starlet/mips_backend.py | 270 +++++++ starlet/parser.py | 539 +++++++++++++ starlet/symtable.py | 210 ++++++ tests/__init__.py | 0 tests/test_integration.py | 77 ++ tests/test_lexer.py | 130 ++++ uv.lock | 8 + 16 files changed, 1627 insertions(+), 1492 deletions(-) create mode 100644 .gitignore create mode 100644 CLAUDE.md create mode 100644 pyproject.toml create mode 100644 starlet/__init__.py create mode 100644 starlet/__main__.py create mode 100644 starlet/c_backend.py create mode 100644 starlet/ir.py create mode 100644 starlet/lexer.py create mode 100644 starlet/mips_backend.py create mode 100644 starlet/parser.py create mode 100644 starlet/symtable.py create mode 100644 tests/__init__.py create mode 100644 tests/test_integration.py create mode 100644 tests/test_lexer.py create mode 100644 uv.lock diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ea40a57 --- /dev/null +++ b/.gitignore @@ -0,0 +1,26 @@ +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# Claude Code +.claude/ +memory/ + +# Generated compiler output +*.int +*.asm +*.c + +# Python +__pycache__/ +*.py[cod] +*.egg-info/ +dist/ +build/ + +# OS +.DS_Store +Thumbs.db + diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..4215a69 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,55 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Running the Compiler + +```bash +python3 starlet.py +``` + +This produces three output files alongside the input: +- `.int` — intermediate quad representation +- `.c` — C backend (goto-based) +- `.asm` — MIPS32 assembly backend + +Example files are in `Examples/` (`.stl` extension). + +## Architecture + +The entire compiler lives in `starlet.py` as a single file with heavy use of global state. The four stages flow linearly: lex → parse → IR → code gen. + +### Lexer (`lex()`) +FSM tokenizer. Reads `data` (the open source file) one character at a time. States 1–9 handle identifiers, integers, multi-char operators (`<=`, `:=`, etc.), and two comment styles (`//` line, `/* */` block). Returns the next token string; the parser calls `lex()` to advance and stores the result in the global `token`. + +### Parser +Recursive descent. `program()` is the entry point. All grammar rules are functions (`block`, `statements`, `statement`, `expression`, `term`, `factor`, etc.). The parser both validates syntax and drives IR generation in the same pass — no separate AST is built. + +### Intermediate Representation (IR) +Quads stored in `quadDict: {int → [op, x, y, z]}`, numbered by `nextLabel`. Key functions: +- `gen_quad(op, x, y, z)` — emits a quad +- `make_list(label)`, `merge(l1, l2)`, `backpatch(labellist, target)` — implement boolean short-circuit and control flow via forward-reference patching; jump destinations are filled in after the target quad is known + +Quad operations: `:=`, `+`, `-`, `*`, `/`, comparison ops (`=`, `<>`, `<`, `<=`, `>`, `>=`), `jump`, `par`/`call`/`retv`, `out`, `inp`, `halt`, `begin_block`, `end_block`. + +### Symbol Table +`scopes_list` is a global stack of `Scope` objects. Each `Scope` tracks its `nesting_level`, a reference to its `enclosing_scope`, and an `entities` list. Stack pointer offsets begin at 12 and grow by 4 per entity (`Scope.get_sp()`). + +Entity subclasses: `Variable` (VAR), `Function` (FUNC), `Parameter` (PAR), `TempVariable` (TMPVAR). Functions store their `arguments` list (parameter modes), `start_quad`, and `framelength`. + +Key lookup: `testing(name)` walks the static chain via `enclosing_scope` links. `search_entity(name, type)` scans `scopes_list` front-to-back. + +### MIPS Backend (`write_to_asm`) +Called once per quad during `block()` (not after the full parse). MIPS frame layout: +- `$sp` → saved `$ra` +- `-4($sp)` → static link (access link to enclosing frame) +- `-8($sp)` → return value address +- `-12($sp)` onward → parameters and locals + +`$s0` holds the base of the global (level-0) activation record. `gnvlcode(v)` walks the static chain through `-4($t0)` links to compute the address of a non-local variable. `loadvr(v, r)` / `storerv(r, v)` dispatch on entity type and nesting level to emit the correct load/store. + +### Parameter Passing +Three modes throughout the compiler: +- `in` / `CV` — call by value +- `inout` / `REF` — call by reference (address passed) +- `inandout` / `RET` — call by return (caller passes address of result variable) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ec40f47 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,5 @@ +[project] +name = "compiler" +version = "0.1.0" +requires-python = ">=3.13" +dependencies = [] diff --git a/starlet.py b/starlet.py index 12db84b..6038ce2 100644 --- a/starlet.py +++ b/starlet.py @@ -1,1495 +1,5 @@ #!/usr/bin/env python3 +from starlet.__main__ import main -import sys - -new_exit_list = None # Used for loop/exit function -nextLabel = 0 # Label counter -quadDict = {} # A dict with key: Line number, and value: a quad(list) -tCounter = 1 # Token counter -lineno = 1 # Current line number of input file. -token_captivated = ['program', 'endprogram', 'declare', 'if', 'then', 'else', - 'endif', 'while', 'endwhile', 'dowhile', 'enddowhile', 'loop', 'endloop', 'exit', - 'forcase', 'endforcase', 'incase', 'endincase', 'when', 'default', 'enddefault', - 'function', 'endfunction', 'return', 'in', 'inout', 'inandout', 'and', 'or', 'not', - 'input', 'print'] -scopes_list = list() # A list of scopes -loop_enabled = False -func_enabled = False -ret_enabled = False -asm_file = None -parlist = list() -lmain_flag = True - -######################################################### -# LEX # -######################################################### - - -# State: 1-> words 3-> '<' 5-> ':' -# 2-> digits 4-> '>' 6-> '/' -def lex(): - buffer = [] - state = 0 # Initial FSM state (starting point) - ok = -2 # Final State - getback = False # True if we need to reposition file pointer - - # Lexical analyzer's FSM implementation - while state != ok: - char = data.read(1) # Reading one character at a time - buffer.append(char) - if state == 0: - if char.isalpha(): - state = 1 - elif char.isdigit(): - state = 2 - elif char == '<': - state = 3 - elif char == '>': - state = 4 - elif char == ':': - state = 5 - elif char == '/': - state = 6 - elif char in ('+', '-', '*', '=', ',', ';', '(', ')', '[', ']'): - state = ok - elif char == '': - state = ok - elif char.isspace(): - state = 0 - else: - print("(lex) invalid character: " + char) - sys.exit() - - elif state == 1: - if not char.isalnum(): - getback = True - state = ok - elif state == 2: - if not char.isdigit(): - if char.isalpha(): - error(" (lex) not valid integer: ") - getback = True - state = ok - elif state == 3: - if char != '=' and char != '>': - getback = True - state = ok - elif state == 4: - if char != '=': - getback = True - state = ok - elif state == 5: - if char != '=': - getback = True - state = ok - elif state == 6: - if char == '/': - state = 7 - elif char == '*': - state = 8 - else: - getback = True - state = ok - elif state == 7: - if char == '\n': - del buffer[:] - state = 0 - - elif state == 8: - - if char == '*': - state = 9 - elif char == '': - print("No closing comment found") - sys.exit() - - elif state == 9: - - if char == '/': - del buffer[:] - state = 0 - else: - state = 8 - - # Ignoring spaces and counting lines - if char.isspace(): - if char == '\n': - global lineno - lineno += 1 - if len(buffer) != 0: - del buffer[-1] - getback = False - - # Repositioning file pointer - if getback: - data.seek(data.tell() - 1) - del buffer[-1] - - ret = ''.join(buffer) - - # Checking if digit is out of bounds - if ret.isdigit(): - if abs(int(ret)) > 32767: - error("(lex) digit is out of bounds! ") - - # Emptying the buffer - del buffer[:] - - # Returning the first 30 characters - return ret[:30] - - -################################################################## -# SYNTAX # -################################################################## - -def program(): - global token, program_name, scopes_list - token = lex() - - if token == 'program': - token = lex() - if is_valid_id(token): - program_name = name = token - token = lex() - # Creating and adding Scope to scopes_list(default nesting_level = 0) - scopes_list.append(Scope()) - - block(name) - - if token != 'endprogram': - error("expected endprogram, found: ") - - else: - error(" (program) expected an 'id' found: ") - else: - error("(program) expected 'program' found: ") - - -def block(name): - declarations() - subprograms() - - # Setting start_quad(LineNo of .int where the current function starts) for each Function(Entity) - start_quad = next_quad() - if name != program_name: - en = search_entity(name, 'FUNC')[0] - en.set_start_quad(start_quad) - - gen_quad('begin_block', name) - statements() - - if program_name == name: - gen_quad('halt') - else: - # Update framelength - f_entity = search_entity(name, 'FUNC')[0] - f_entity.framelength = scopes_list[len(scopes_list) - 1].get_sp() - - gen_quad('end_block', name) - - # Printing Scopes and entities before deleting it - print(scopes_list[-1]) - for en in scopes_list[-1].entities: - print(en) - print("--------------------------------------------") - - for i in range(start_quad, nextLabel): - write_to_asm(quadDict[i], name, i) - - # Block is done, deleting Scope - del scopes_list[-1] - - -def declarations(): - global token - while token == 'declare': - token = lex() - varlist() - if token != ';': - error(" (declarations) expected ';' found: ") - token = lex() - - -def varlist(): - global token - - if is_valid_id(token): - - # Adding (declarations)token as a Variable entity - add_var_entity(token) - - token = lex() - while token == ',': - token = lex() - if not is_valid_id(token): - error(" (varlist) expected 'id', found: ") - - # Adding token as a Variable entity - add_var_entity(token) - token = lex() - elif token != ';': - error(" (varlist) not valid id: ") - - # Storing the declarations of the current scope to the previous (Enclosing scope) - # if len(scopes_list) >= 2: - # scopes_list[-2].set_enclosing_scope(scopes_list[-1].entities) - - -def subprograms(): - global token, func_enabled - - while token == 'function': - func_enabled = True - token = lex() - subprogram() - - -def subprogram(): - global token, func_enabled, ret_enabled - - if is_valid_id(token): - name = token - # Creating and adding new scope (nesting level depends on the length of our scopes_list) - new_scope = Scope(scopes_list.__len__(), scopes_list[-1]) - - scopes_list.append(new_scope) - - token = lex() - - # Adding Func entity - scopes_list[-2].add_entity(Function(name, 0)) - - funcbody(name) - if token != 'endfunction': - error(" (subprogram) expected endfuction, found: ") - - if not ret_enabled: - error(" No return in function/") - func_enabled = False - ret_enabled = False - token = lex() - else: - error("(subprogram) expected 'id', found: ") - - -def funcbody(name): - formalpars(name) - block(name) - - -def formalpars(name): - global token - - if token == '(': - token = lex() - - formalparlist(name) - if token != ')': - error(" (formalpars) expected ')', found: ") - else: - error(" (formalpars) expected '(', found: ") - - token = lex() - - -def formalparlist(name): - global token - if token != ')': - formalparitem(name) - while token == ',': - token = lex() - formalparitem(name) - - -def formalparitem(name): - global token - - if token in ('in', 'inout', 'inandout'): - par_mode = token - token = lex() - if is_valid_id(token): - # Add Argument(in/inout/inandout) to func - add_arg_to_func(par_mode, name) - - # Add Parameter(Entity) to Scope - add_param_entity(token, par_mode) - - token = lex() - else: - error("(formalparitem) expected id, found: ") - else: - error("expected 'in'/'inout'/'inandout', found: ") - - -def statements(): - global token - statement() - - while token == ';': - token = lex() - statement() - - -def statement(): - global token, loop_enabled, ret_enabled - - if token == 'if': - token = lex() - if_stat() - - elif token == 'while': - token = lex() - while_stat() - - elif token == 'dowhile': - token = lex() - do_while_stat() - - elif token == 'loop': - loop_enabled = True - token = lex() - loop_stat() - elif token == 'exit': - token = lex() - if loop_enabled: - exit_stat() - else: - error(" 'exit' must be declared inside a loop ") - elif token == 'forcase': - token = lex() - forcase_stat() - elif token == 'incase': - token = lex() - incase_stat() - elif token == 'return': - ret_enabled = True - token = lex() - exp = expression() - - gen_quad('retv', exp, '_', '_') - elif token == 'print': - token = lex() - exp = expression() - gen_quad('out', exp) - elif token == 'input': - token = lex() - id_place = input_stat() - gen_quad('inp', id_place) - elif is_valid_id(token): - assignment_stat() - - -def assignment_stat(): - global token - - if token.isalnum(): - t1 = token - - if not exists(t1): - error(' Not declared: ') - - token = lex() - - if token == ':=': - op = token - token = lex() - - if not exists(token) and not token.isdigit(): - error(' Not declared: ') - - exp = expression() - - gen_quad(op, exp, '_', t1) - - else: - error("(assignment_stat) expected ':=' found ") - - -def if_stat(): - global token - - if token == '(': - token = lex() - else: - error("(if_stat) expected '(' found: ") - - (b_true, b_false) = condition() - - if token == ')': - token = lex() - - backpatch(b_true, next_quad()) - else: - error("(if_stat) expected ')' found: ") - - if token == 'then': - - token = lex() - statements() - skip = make_list(next_quad()) - gen_quad('jump') - backpatch(b_false, next_quad()) - elsepart() - backpatch(skip, next_quad()) - - if token != 'endif': - error("(if_stat) expected 'endif' found ") - - token = lex() - - else: - error("(if_stat) expected 'then' found ") - - -def elsepart(): - global token - - if token == 'else': - token = lex() - statements() - - -def while_stat(): - global token - quad = next_quad() - if token == '(': - token = lex() - else: - error("(while_stat) expected '(' found: ") - - (b_true, b_false) = condition() - - if token == ')': - token = lex() - else: - error("(while_stat) expected ')' found: ") - backpatch(b_true, next_quad()) - statements() - gen_quad('jump', '_', '_', str(quad)) - backpatch(b_false, next_quad()) - - if token != 'endwhile': - error(" (while_stat) expected 'endwhile', found: ") - token = lex() - - -def do_while_stat(): - global token - quad = next_quad() - statements() - - if token == 'enddowhile': - token = lex() - - if token == '(': - token = lex() - else: - error("(do_while_stat) expected '(' found: ") - - (b_true, b_false) = condition() - - if token == ')': - backpatch(b_true, quad) - n_quad = next_quad() - backpatch(b_false, n_quad) - token = lex() - else: - error("(do_while_stat) expected ')' found: ") - else: - error(" (do_while_stat) expected 'enddowhile', found: ") - - -def loop_stat(): - global token, new_exit_list, loop_enabled - quad = next_quad() - statements() - - gen_quad('jump', '_', '_', str(quad)) - if token != 'endloop': - error(" (loop_stat) expected 'endloop', found: ") - - loop_enabled = False - token = lex() - - if new_exit_list is not None: - backpatch(new_exit_list, next_quad()) - new_exit_list = None - - -def exit_stat(): - global token, new_exit_list - new_exit_list = make_list(next_quad()) - gen_quad('jump') - - -def forcase_stat(): - global token - flag_quad = next_quad() - - exit_list = empty_list() - - while token == 'when': - token = lex() - if token == '(': - token = lex() - else: - error("(forcase_stat) expected '(' found: ") - - (b_true, b_false) = condition() - - backpatch(b_true, next_quad()) - - if token == ')': - token = lex() - else: - error("(forcase_stat) expected ')' found: ") - - if token != ':': - error("(forcase_stat) expected ':', found: ") - token = lex() - statements() - t = make_list(next_quad()) - gen_quad('jump') - exit_list = merge(exit_list, t) - backpatch(b_false, next_quad()) - - if token != 'default': - error("(forcase_stat) expecred 'default',found: ") - token = lex() - if token != ':': - error("(forcase_stat) expected ':', found: ") - token = lex() - statements() - gen_quad('jump', '_', '_', str(flag_quad)) - backpatch(exit_list, next_quad()) - if token != 'enddefault': - error("(forcase_stat) expected 'enddefault',found: ") - token = lex() - - if token != 'endforcase': - error("(forcase_stat) expected 'endforcase',found: ") - token = lex() - - -def incase_stat(): - global token - t = new_temp() - flag_quad = next_quad() - gen_quad(':=', '0', '_', t) - - while token == 'when': - token = lex() - if token == '(': - token = lex() - else: - error("(incase_stat) expected '(' found: ") - - (b_true, b_false) = condition() - backpatch(b_true, next_quad()) - gen_quad(':=', '1', '_', t) - - if token == ')': - token = lex() - else: - error("(incase_stat) expected ')' found: ") - - if token != ':': - error("(incase_stat) expected ':', found: ") - token = lex() - statements() - backpatch(b_false, next_quad()) - - if token != 'endincase': - error("(incase_stat) expected 'endincase',found: ") - token = lex() - gen_quad('=', '1', t, str(flag_quad)) - - -def input_stat(): - global token - - if not is_valid_id(token): - error(" (input_stat) expected an 'id', found: ") - ret = token - token = lex() - return ret - - -def actualpars(): - global token - - if token == '(': - token = lex() - actualparlist() - - if token != ')': - error(" (actualpars) expected ')', found: ") - - token = lex() - return True - - -def actualparlist(): - global token - - actualparitem() - - while token == ',': - token = lex() - actualparitem() - - -def actualparitem(): - global token - - if token == 'in': - token = lex() - if not exists(token): - error(' Not declared: ') - exp = expression() - gen_quad('par', exp, 'CV') - elif token == 'inout': - token = lex() - if not exists(token): - error(' Not declared: ') - t1 = token - if not is_valid_id(token): - error("(actualparitem) expected 'id' found ") - token = lex() - gen_quad('par', t1, 'REF') - elif token == 'inandout': - token = lex() - if not exists(token): - error(' Not declared: ') - t2 = token - if not is_valid_id(token): - error("(actualparitem) expected 'id' found ") - token = lex() - gen_quad('par', t2, 'RET') - - -def condition(): - global token - - (b_true, b_false) = boolterm() - - while token == 'or': - backpatch(b_false, next_quad()) - token = lex() - - (c_true, c_false) = boolterm() - b_true = merge(b_true, c_true) - b_false = c_false - return b_true, b_false - - -def boolterm(): - global token - - (b_true, b_false) = boolfactor() - - while token == 'and': - backpatch(b_true, next_quad()) - token = lex() - (c_true, c_false) = boolfactor() - b_false = merge(b_false, c_false) - b_true = c_true - return b_true, b_false - - -def boolfactor(): - global token - - if token == 'not': - token = lex() - if token == '[': - token = lex() - - ret = condition() - - if token != ']': - error(" (boolfactor) expected ']', found: ") - - token = lex() - else: - error(" (boolfactor) expected '[', found: ") - - elif token == '[': - token = lex() - - ret = condition() - - if token != ']': - error(" (boolfactor) expected ']', found: ") - - token = lex() - - else: - - exp1 = expression() - op = relational_oper() - exp2 = expression() - - b_true = make_list(next_quad()) - gen_quad(op, exp1, exp2) - b_false = make_list(next_quad()) - gen_quad('jump') - ret = (b_true, b_false) - return ret - - -def expression(): - global token - - op = optional_sign() - t1 = term() - - while add_oper(): - op = token - token = lex() - - if not exists(token) and not token.isdigit(): - error(' Not declared: ') - - t2 = term() - - tmp = new_temp() - gen_quad(op, t1, t2, tmp) - t1 = tmp - return t1 - - -def term(): - global token - - f1 = factor() - while mul_oper(): - op = token - token = lex() - f2 = factor() - - tmp = new_temp() - gen_quad(op, f1, f2, tmp) - f1 = tmp - return f1 - - -def factor(): - global token - - if token.isdigit(): - ret = token - token = lex() - elif token == '(': - token = lex() - ret = expression() - if token != ')': - error("(factor) expected ')' found ") - token = lex() - - elif is_valid_id(token): - ret = token - token = lex() - tail = idtail() - - if tail is not None: - new = new_temp() - gen_quad('par', new, 'RET') - gen_quad('call', ret) - ret = new - else: - error("(factor) expected something found ") - - return ret - - -def idtail(): - global token - - if token == '(': - return actualpars() - - -def optional_sign(): - global token - - if add_oper(): - token = lex() - return token - - -def relational_oper(): - global token - - if token not in ('=', '<=', '>=', '>', '<', '<>'): - error(" (relational_oper) expected a relation sign ") - ret = token - token = lex() - return ret - - -def add_oper(): - global token - if token == '+' or token == '-': - # token = lex() tin vgalame kai tin kaloume meta tin add_oper() - return True - return False - - -def mul_oper(): - global token - - if token == '*' or token == '/': - # token = lex() gia ton idio logo me ad_oper() - return True - return False - - -################################################################# -# # -# INT # -# # -################################################################# - - -def next_quad(): - return nextLabel - - -def gen_quad(op=None, x='_', y='_', z='_'): - global nextLabel - currentlabel = nextLabel - nextLabel += 1 - quad = [op, x, y, z] - quadDict[currentlabel] = quad - - -def new_temp(): - global tCounter - - temp = "T_" + str(tCounter) - tCounter += 1 - - # Create/Add new TempVariable to the current list of entities - scopes_list[-1].add_entity(TempVariable(temp, scopes_list[-1].get_sp())) - - return temp - - -def empty_list(): - return list() - - -def make_list(x): - newlist = list() - newlist.append(x) - return newlist - - -def merge(list1, list2): - return list1 + list2 - - -def backpatch(labellist, z): - global quadDict - - for key in labellist: - if key in quadDict: - quadDict[key].pop() - quadDict[key].append(z) - - -################################################################## -# # -# SYMBOL TABLE # -# # -################################################################## - -class Entity: - - def __init__(self, name, entity_type): - self.name = name - self.entity_type = entity_type - self.next = None - - def __str__(self): - return self.name + ': ' + self.entity_type - - -# Variable inherits from Entity -class Variable(Entity): - - def __init__(self, name, offset=0): - # Calling 'fathers' constructor and setting offset - Entity.__init__(self, name, 'VAR') - self.offset = offset - - def __str__(self): - return Entity.__str__(self) + '\toffset: ' + str(self.offset) - - -# Function inherits from Entity -class Function(Entity): - - def __init__(self, name, ret_val, start_quad=-1): - Entity.__init__(self, name, 'FUNC') - self.ret_val = ret_val - self.start_quad = start_quad - self.arguments = list() - self.framelength = -1 - - def set_framelength(self, x): - self.framelength = x - - def set_start_quad(self, x): - self.start_quad = x - - def set_ret_val(self, x): - self.ret_val = x - - def __str__(self): - return Entity.__str__(self) + \ - ',\tStart_quad: ' + self.start_quad.__str__() + \ - ',\tFramelength:' + str(self.framelength) + \ - ',\tArgs:' + self.arguments.__str__() - - -# Parameter inherits from Entity -class Parameter(Entity): - - def __init__(self, name, par_mode, offset=0): - Entity.__init__(self, name, 'PAR') - if par_mode == 'in': - self.par_mode = 'cv' - elif par_mode == 'inout': - self.par_mode = 'ref' - else: - self.par_mode = 'ret' - self.offset = offset - - def __str__(self): - return Entity.__str__(self) + ',\tpar_mode: ' + self.par_mode + ',\toffset: ' + str(self.offset) - - -# TempVariable inherits from Entity -class TempVariable(Entity): - - def __init__(self, name, offset=0): - Entity.__init__(self, name, 'TMPVAR') - self.offset = offset - - def __str__(self): - return Entity.__str__(self) + '\toffset: ' + str(self.offset) - - -class Scope: - - def __init__(self, nesting_level=0, enclosing_scope=None): - self.entities = list() - self.nesting_level = nesting_level - self.enclosing_scope = enclosing_scope - self.sp = 12 - - def get_sp(self): - ret = self.sp - self.sp += 4 - return ret - - # Adding entity to list - def add_entity(self, ent): - self.entities.append(ent) - - def set_enclosing_scope(self, x): - self.enclosing_scope = x - - def __str__(self): - return self.__repr__() + \ - '\nNesting lvl: ' + self.nesting_level.__repr__() + \ - '\nEnclosing Scope: ' + self.enclosing_scope.__repr__() - - -class Argument: - - # Initializing an Argument --> par_mode:(ret/cv/ref), pointer next_argument, default int type - def __init__(self, par_mode, next_argument): - self.par_mode = par_mode - self.type = 'Int' - self.next_argument = next_argument - - -# Adding arguments to function -def add_arg_to_func(par_mode, f_name): - func_en = search_entity(f_name, 'FUNC')[0] - if func_en is None: - error(' No definition: ') - - func_en.arguments.append(par_mode) - - -# Adding A Var(Entity) to the current scope and -# checking if the given Var already exists in the current nesting level(as a Parameter etc.) -def add_var_entity(var_name): - var_lvl = scopes_list[-1].nesting_level - var_off = scopes_list[-1].get_sp() - - if not unique(var_name, 'VAR', var_lvl): - error(' var not declared') - - if exists_as_param(var_name, var_lvl): - error(' Symbol is already declared as a parameter: ') - - scopes_list[-1].add_entity(Variable(var_name, var_off)) - - -# Adding a Parameter(Entity) to the current scope -def add_param_entity(par_name, par_mode): - par_lvl = scopes_list[-1].nesting_level - par_off = scopes_list[-1].get_sp() - - if not unique(par_name, 'PAR', par_lvl): - error(' (unique)') - - scopes_list[-1].add_entity(Parameter(par_name, par_mode, par_off)) - - -# Check if var entity already exists as a parameter -def exists_as_param(name, level): - for entity in scopes_list[level].entities: - if entity.entity_type == 'PAR' and entity.name == name: - return True - return False - - -def unique(name, entity_type, nesting_level): - for i in range(len(scopes_list[nesting_level].entities)): - for j in range(len(scopes_list[nesting_level].entities)): - x = scopes_list[nesting_level].entities[i] - y = scopes_list[nesting_level].entities[j] - - if x.name == y.name and x.entity_type == y.entity_type and x.name == name and x.entity_type == entity_type: - return False - return True - - -def exists(name): - for scope in scopes_list: - for entity in scope.entities: - if name == entity.name: - return True - - return False - - -# Searching(in scope_list) a specific entity by its given name and type -# and returning the first one we find -def search_entity(name, entity_type): - for scope in scopes_list: - for entity in scope.entities: - if entity.entity_type == entity_type and entity.name == name: - return entity, scope.nesting_level - - -# search entity by name anapoda -def testing(name): - global scopes_list - - scope = scopes_list[-1] - while scope is not None: - for entity in scope.entities: - if entity.name == name: - return entity, scope.nesting_level - scope = scope.enclosing_scope - - return None - -################################################################## -# # -# FINAL CODE # -# # -################################################################## - - -def gnvlcode(v): - global asm_file, scopes_list - - en, lvl = testing(v) - - if en is None: - print('Undeclared variable ' + v) - sys.exit() - - if en.entity_type == 'FUNC': - print('Undeclared variable ' + v) - sys.exit() - - current_lvl = scopes_list[-1].nesting_level - diff_of_lvl = current_lvl - lvl - - asm_file.write('\tlw $t0, -4($sp)\n') - - while diff_of_lvl > 1: - asm_file.write('\tlw $t0, -4($t0)\n') - diff_of_lvl -= 1 - - asm_file.write('\tadd $t0, $t0, - %d\n' % en.offset) - - -# Data transfer (v) to register #tr -def loadvr(v, r): - global asm_file, scopes_list - - if str(v).isdigit(): - asm_file.write('\tli $t%s, %s\n' % (r, v)) - else: - en, lvl = testing(v) - - if en is None: - print('Undeclared variable ' + v) - sys.exit() - - current_lvl = scopes_list[-1].nesting_level - - if en.entity_type == 'VAR' and lvl == 0: - asm_file.write('\tlw $t%s, -%d($s0)\n' % (r, en.offset)) - elif (en.entity_type == 'VAR' and lvl == current_lvl) or \ - (en.entity_type == 'PAR' and lvl == current_lvl and en.par_mode == 'cv') or \ - (en.entity_type == 'TMPVAR'): - - asm_file.write('\tlw $t%s, -%d($sp)\n' % (r, en.offset)) - elif en.entity_type == 'PAR' and lvl == current_lvl and en.par_mode == 'ref': - asm_file.write('\tlw $t0, -%d($sp)\n' % en.offset) - asm_file.write('\tlw $t%s, ($t0)\n' % r) - elif (en.entity_type == 'VAR' and lvl < current_lvl) or \ - (en.entity_type == 'PAR' and lvl < current_lvl and en.par_mode == 'cv'): - - gnvlcode(v) - asm_file.write('\tlw $t%s, ($t0)\n' % r) - elif en.entity_type == 'PAR' and lvl < current_lvl and en.par_mode == 'ref': - - gnvlcode(v) - asm_file.write('\tlw $t0, ($t0)\n') - asm_file.write('\tlw $t%s, ($t0)\n' % r) - else: - - print("ERROR: (loadvr) couldn't transfer data to register...: " + v) - sys.exit() - - -# Transfer data from register ($tr) to memory (variable v) -def storerv(r, v): - global asm_file, scopes_list - - en, lvl = testing(v) - - if en is None: - print('Undeclared variable ' + v) - sys.exit() - - current_lvl = scopes_list[-1].nesting_level - - if en.entity_type == 'VAR' and lvl == 0: - asm_file.write('\tsw $t%s, -%d($s0)\n' % (r, en.offset)) - elif (en.entity_type == 'VAR' and lvl == current_lvl) or \ - (en.entity_type == 'PAR' and lvl == current_lvl and en.par_mode == 'cv') or \ - (en.entity_type == 'TMPVAR'): - - asm_file.write('\tsw $t%s, -%d($sp)\n' % (r, en.offset)) - elif en.entity_type == 'PAR' and lvl == current_lvl and en.par_mode == 'ref': - asm_file.write('\tlw $t0, -%d($sp)\n' % en.offset) - asm_file.write('\tsw $t%s, ($t0)\n' % r) - elif (en.entity_type == 'VAR' and lvl < current_lvl) or \ - (en.entity_type == 'PAR' and lvl < current_lvl and en.par_mode == 'cv'): - - gnvlcode(v) - asm_file.write('\tsw $t%s, ($t0)\n' % r) - elif en.entity_type == 'PAR' and lvl < current_lvl and en.par_mode == 'ref': - - gnvlcode(v) - asm_file.write('\tlw $t0, ($t0)\n') - asm_file.write('\tsw $t%s, ($t0)\n' % r) - else: - print("ERROR: (storerv) couldn't transfer data from register to memory: " + v) - sys.exit() - - -def write_to_asm(quad, name, labelno): - global asm_file, quadDict, program_name, parlist, lmain_flag - - if name == program_name and lmain_flag: - asm_file.write('Lmain:\n') - asm_file.write('\tadd $sp,$sp, %d\n' % scopes_list[0].sp) - asm_file.write('\tmove $s0,$sp\n') - lmain_flag = False - - asm_file.write('L_' + str(labelno) + ':\n') - - if quad[0] == 'jump': - asm_file.write('\tj L_%s\n' % quad[3]) - elif quad[0] == ':=': - loadvr(quad[1], '1') - storerv('1', quad[3]) - elif quad[0] in ('=', '<>', '<', '<=', '>', '>='): - loadvr(quad[1], '1') - loadvr(quad[2], '2') - if quad[0] == '=': - relop = 'beq' - elif quad[0] == '<>': - relop = 'bne' - elif quad[0] == '<': - relop = 'blt' - elif quad[0] == '<=': - relop = 'ble' - elif quad[0] == '>': - relop = 'bgt' - elif quad[0] == '>=': - relop = 'bge' - - asm_file.write('\t%s $t1, $t2, L_%s\n' % (relop, quad[3])) - elif quad[0] in ('+', '-', '*', '/'): - loadvr(quad[1], '1') - loadvr(quad[2], '2') - - if quad[0] == '+': - op = 'add' - elif quad[0] == '-': - op = 'sub' - elif quad[0] == '*': - op = 'mul' - elif quad[0] == '/': - op = 'div' - - asm_file.write('\t%s $t1, $t1, $t2\n' % op) - storerv('1', quad[3]) - elif quad[0] == 'out': - asm_file.write('\tli $v0, 1\n') - asm_file.write('\tli $a0, %s\n' % quad[3]) - asm_file.write('\tsyscall\n') - elif quad[0] == 'in': - asm_file.write('\tli $v0, 5\n') - asm_file.write('\tsyscall\n') - elif quad[0] == 'retv': - - loadvr(quad[1], '1') - asm_file.write('\tlw $t0, -8($sp)\n') - asm_file.write('\tsw $t1, ($t0)\n') - - elif quad[0] == 'par': - - if name == program_name: - en_lvl = 0 - framelength = scopes_list[0].sp - else: - en, en_lvl = search_entity(name, 'FUNC') - framelength = en.framelength - - if not parlist: - asm_file.write('\tadd $fp, $sp, %d\n' % framelength) - - parlist.append(quad) - par_offset = 12 + 4*parlist.index(quad) - if quad[2] == 'CV': - loadvr(quad[1], '0') - asm_file.write('\tsw $t0, -%d($fp)\n' % par_offset) - elif quad[2] == 'REF': - var_en, var_lvl = testing(quad[1]) - - if var_en is None: - error('Undeclared variable') - if en_lvl == var_lvl: - if var_en.entity_type == 'VAR' or (var_en.entity_type == 'PAR' and var_en.par_mode == 'cv'): - asm_file.write('\tadd $t0, $sp, -%d\n' % var_en.offset) - asm_file.write('\tsw $t0, -%d($fp)\n' % par_offset) - elif var_en.entity_type == 'PAR' and var_en.par_mode == 'ref': - asm_file.write('\t$t0, -%d($sp)\n' % var_en.offset) - asm_file.write('\tsw $t0, -%d($fp)\n' % par_offset) - else: - if var_en.entity_type == 'VAR' or (var_en.entity_type == 'PAR' and var_en.par_mode == 'cv'): - gnvlcode(quad[1]) - asm_file.write('\tsw $t0, -%d($fp)\n' % par_offset) - elif var_en.entity_type == 'PAR' and var_en.par_mode == 'ref': - gnvlcode(quad[1]) - asm_file.write('\tlw $t0, ($t0)\n') - asm_file.write('\tsw $t0, -%d($fp)\n' % par_offset) - elif quad[2] == 'RET': - var_en, var_lvl = testing(quad[1]) - - if var_en is None: - error('Undeclared variable') - asm_file.write('\tadd $t0, $sp, -%d\n' % var_en.offset) - asm_file.write('\tsw $t0, -8($fp)\n') - - elif quad[0] == 'call': - - if name == program_name: - en_lvl = 0 - framelength = scopes_list[0].sp - else: - en, en_lvl = search_entity(name, 'FUNC') - framelength = en.framelength - - cn, cn_lvl = search_entity(quad[1], 'FUNC') - - if cn is None: - error('Function not declared') - - for i in range(len(cn.arguments)): - if (cn.arguments[i] == 'in' and parlist[i][2] != 'CV') or \ - (cn.arguments[i] == 'inout' and parlist[i][2] != 'REF') or \ - (cn.arguments[i] == 'inandout' and parlist[i][2] != 'RET'): - - print("ERROR: False parameter types in a called function\n") - sys.exit() - - parlist = list() - - if cn_lvl == en_lvl: - asm_file.write('\tlw $t0, -4($sp)\n') - asm_file.write('\tsw $t0, -4($fp)\n') - else: - asm_file.write('\tsw $sp, -4($fp)\n') - - asm_file.write('\tadd $sp, $sp, %d\n' % framelength) - asm_file.write('\tjal L_%d\n' % cn.start_quad) - asm_file.write('\tadd $sp, $sp, -%d\n' % framelength) - elif quad[0] == 'begin_block': - asm_file.write('\tsw $ra,($sp)\n') - - elif quad[0] == 'end_block': - if name == program_name: - asm_file.write('\n') - else: - asm_file.write('\tlw $ra,($sp)\n') - asm_file.write('\tjr $ra\n') - elif quad[0] == 'halt': - asm_file.write('\tli $v0, 10\n') - asm_file.write('\tsyscall\n') - - -################################################################## -# # -# Other Functions # -# # -################################################################## - - -# Checking if token is valid -def is_valid_id(tk): - if tk not in token_captivated: - if not tk.isdigit(): - if tk.isalnum(): - return True - - return False - - -# An error function, outputting the type of the syntax problem -def error(x): - print("Line " + str(lineno)) - print("ERROR " + x + str(token)) - sys.exit() - - -# Open files -def open_files(input_filename, int_filename, c_filename, asm_filename): - global data, int_file, c_file, asm_file - - try: - data = open(input_filename, 'r') - int_file = open(int_filename, 'w') - c_file = open(c_filename, 'w') - asm_file = open(asm_filename, 'w') - asm_file.write('L:\n\tj Lmain\n') - - except (FileNotFoundError, IOError): - print("Couldn't read file, or file doesn't exist!") - sys.exit() - - -def write_int_to_file(): - - for key in quadDict: - int_file.write(str(key) + ': ' + str(quadDict[key][0]) + ',' + str(quadDict[key][1]) + ',' + str( - quadDict[key][2]) + ',' + str(quadDict[key][3]) + '\n') - - -def write_to_c(): - global quadDict, c_file - - c_file.write('#include \n\n') - c_file.write('int main()\n{\n') - - for key in quadDict: - info = to_c(key) - - semi = ';' if info is not '' else '\t' - - c_file.write( - '\tL_' + str(key) + ': ' + str(info) + semi + '\t//' + str(key) + ': ' + str(quadDict[key][0]) + ',' + str( - quadDict[key][1]) + ',' + str(quadDict[key][2]) + ',' + str(quadDict[key][3]) + '\n') - - c_file.write('}') - - -# Function to create C syntax based on quadDict -def to_c(key): - global quadDict - first = str(quadDict[key][0]) - second = str(quadDict[key][1]) - third = str(quadDict[key][2]) - fourth = str(quadDict[key][3]) - - if first in ('begin_block', 'end_block'): - return '' - elif first == ':=': - return fourth + '=' + second - elif first in ('+', '-', '*', '/'): - return fourth + '=' + second + first + third - elif first in ('=', '<>', '<', '<=', '>', '>='): - assignment = first - if first == '=': - assignment = '==' - elif first == '<>': - assignment = '!=' - - return 'if (' + second + assignment + third + ') goto L_' + fourth - elif first == 'jump': - return 'goto L_' + fourth - elif first == 'halt': - return 'return 0' - elif first == 'out': - return 'printf("%d\\n", ' + second + ')' - - return '' - - -def main(): - # Checking if file is passed - if len(sys.argv) < 2: - print("Please pass a file to be executed") - sys.exit() - - # Getting file name - filename = sys.argv[1][:-4] - int_filename = filename + '.int' - c_filename = filename + '.c' - asm_filename = filename + '.asm' - - open_files(sys.argv[1], int_filename, c_filename, asm_filename) - program() - - write_int_to_file() - write_to_c() - - # Closing files - data.close() - int_file.close() - c_file.close() - asm_file.close() - - print("Successful!\nFiles:\n\t" + int_filename + ",\n\t" + c_filename + ',\n\t' + asm_filename + - "\nwere created in your directory.") - - -if __name__ == "__main__": +if __name__ == '__main__': main() diff --git a/starlet/__init__.py b/starlet/__init__.py new file mode 100644 index 0000000..7b6731e --- /dev/null +++ b/starlet/__init__.py @@ -0,0 +1,15 @@ +from .lexer import Lexer, LexError +from .ir import IR +from .symtable import SymbolTable +from .parser import Parser, ParseError +from .mips_backend import MipsBackend +from .c_backend import write_to_c + +__all__ = [ + 'Lexer', 'LexError', + 'IR', + 'SymbolTable', + 'Parser', 'ParseError', + 'MipsBackend', + 'write_to_c', +] diff --git a/starlet/__main__.py b/starlet/__main__.py new file mode 100644 index 0000000..115e619 --- /dev/null +++ b/starlet/__main__.py @@ -0,0 +1,55 @@ +import sys + +from .lexer import Lexer, LexError +from .ir import IR +from .symtable import SymbolTable +from .parser import Parser, ParseError +from .mips_backend import MipsBackend +from .c_backend import write_to_c + + +def main(): + if len(sys.argv) < 2: + print('Usage: python3 starlet.py ') + sys.exit(1) + + input_path = sys.argv[1] + base = input_path[:-4] if input_path.endswith('.stl') else input_path + int_path = base + '.int' + c_path = base + '.c' + asm_path = base + '.asm' + + try: + src = open(input_path, 'r') + int_file = open(int_path, 'w') + c_file = open(c_path, 'w') + asm_file = open(asm_path, 'w') + except (FileNotFoundError, IOError) as e: + print(f"Could not open file: {e}") + sys.exit(1) + + try: + lexer = Lexer(src) + symtable = SymbolTable() + ir = IR() + mips = MipsBackend(asm_file, symtable) + parser = Parser(lexer, ir, symtable, mips) + parser.parse() + except (LexError, ParseError) as e: + print(f'ERROR: {e}') + sys.exit(1) + finally: + src.close() + + ir.write_int(int_file) + write_to_c(ir.quad_dict, c_file) + + int_file.close() + c_file.close() + asm_file.close() + + print(f'Successful!\nFiles:\n\t{int_path},\n\t{c_path},\n\t{asm_path}\nwere created in your directory.') + + +if __name__ == '__main__': + main() diff --git a/starlet/c_backend.py b/starlet/c_backend.py new file mode 100644 index 0000000..aa7518e --- /dev/null +++ b/starlet/c_backend.py @@ -0,0 +1,35 @@ +def write_to_c(quad_dict: dict, c_file): + c_file.write('#include \n\n') + c_file.write('int main()\n{\n') + + for key, quad in quad_dict.items(): + info = _to_c(key, quad) + semi = ';' if info != '' else '\t' + c_file.write( + f'\tL_{key}: {info}{semi}' + f'\t//{key}: {quad[0]},{quad[1]},{quad[2]},{quad[3]}\n' + ) + + c_file.write('}') + + +def _to_c(key: int, quad: list) -> str: + op, x, y, z = quad[0], str(quad[1]), str(quad[2]), str(quad[3]) + + if op in ('begin_block', 'end_block'): + return '' + elif op == ':=': + return z + '=' + x + elif op in ('+', '-', '*', '/'): + return z + '=' + x + op + y + elif op in ('=', '<>', '<', '<=', '>', '>='): + c_op = {'=': '==', '<>': '!='}.get(op, op) + return f'if ({x}{c_op}{y}) goto L_{z}' + elif op == 'jump': + return 'goto L_' + z + elif op == 'halt': + return 'return 0' + elif op == 'out': + return f'printf("%d\\n", {x})' + + return '' diff --git a/starlet/ir.py b/starlet/ir.py new file mode 100644 index 0000000..5a7a838 --- /dev/null +++ b/starlet/ir.py @@ -0,0 +1,64 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .symtable import SymbolTable + + +class IR: + + def __init__(self): + self.quad_dict: dict[int, list] = {} + self._next_label: int = 0 + self._t_counter: int = 1 + + # ------------------------------------------------------------------ + # Quad emission + # ------------------------------------------------------------------ + + def gen_quad(self, op=None, x='_', y='_', z='_'): + label = self._next_label + self._next_label += 1 + self.quad_dict[label] = [op, x, y, z] + + def next_quad(self) -> int: + return self._next_label + + # ------------------------------------------------------------------ + # Temporaries + # ------------------------------------------------------------------ + + def new_temp(self, symtable: 'SymbolTable') -> str: + name = 'T_' + str(self._t_counter) + self._t_counter += 1 + symtable.add_temp(name) + return name + + # ------------------------------------------------------------------ + # Backpatch helpers + # ------------------------------------------------------------------ + + @staticmethod + def empty_list() -> list: + return [] + + @staticmethod + def make_list(label: int) -> list: + return [label] + + @staticmethod + def merge(list1: list, list2: list) -> list: + return list1 + list2 + + def backpatch(self, labellist: list, target): + for key in labellist: + if key in self.quad_dict: + self.quad_dict[key][3] = target + + # ------------------------------------------------------------------ + # Serialisation + # ------------------------------------------------------------------ + + def write_int(self, f): + for key, quad in self.quad_dict.items(): + f.write(f'{key}: {quad[0]},{quad[1]},{quad[2]},{quad[3]}\n') diff --git a/starlet/lexer.py b/starlet/lexer.py new file mode 100644 index 0000000..5a17d6f --- /dev/null +++ b/starlet/lexer.py @@ -0,0 +1,136 @@ +import sys + + +KEYWORDS = frozenset({ + 'program', 'endprogram', 'declare', 'if', 'then', 'else', + 'endif', 'while', 'endwhile', 'dowhile', 'enddowhile', 'loop', 'endloop', 'exit', + 'forcase', 'endforcase', 'incase', 'endincase', 'when', 'default', 'enddefault', + 'function', 'endfunction', 'return', 'in', 'inout', 'inandout', 'and', 'or', 'not', + 'input', 'print', +}) + +_MAX_INT = 32767 +_MAX_ID_LEN = 30 + + +class LexError(Exception): + def __init__(self, msg: str, lineno: int): + super().__init__(f'Line {lineno}: {msg}') + self.lineno = lineno + + +class Lexer: + """FSM tokenizer. Call next_token() to advance.""" + + # State constants + _OK = -2 + + def __init__(self, file_handle): + self._data = file_handle + self.lineno: int = 1 + + def next_token(self) -> str: + buffer: list[str] = [] + state = 0 + getback = False + + while state != self._OK: + char = self._data.read(1) + buffer.append(char) + + if state == 0: + if char.isalpha(): + state = 1 + elif char.isdigit(): + state = 2 + elif char == '<': + state = 3 + elif char == '>': + state = 4 + elif char == ':': + state = 5 + elif char == '/': + state = 6 + elif char in ('+', '-', '*', '=', ',', ';', '(', ')', '[', ']'): + state = self._OK + elif char == '': + state = self._OK + elif char.isspace(): + state = 0 + else: + raise LexError(f'invalid character: {char!r}', self.lineno) + + elif state == 1: + if not char.isalnum(): + getback = True + state = self._OK + + elif state == 2: + if not char.isdigit(): + if char.isalpha(): + raise LexError(f'invalid integer literal', self.lineno) + getback = True + state = self._OK + + elif state == 3: + if char != '=' and char != '>': + getback = True + state = self._OK + + elif state == 4: + if char != '=': + getback = True + state = self._OK + + elif state == 5: + if char != '=': + getback = True + state = self._OK + + elif state == 6: + if char == '/': + state = 7 + elif char == '*': + state = 8 + else: + getback = True + state = self._OK + + elif state == 7: + if char == '\n': + buffer.clear() + state = 0 + + elif state == 8: + if char == '*': + state = 9 + elif char == '': + raise LexError('unclosed block comment', self.lineno) + + elif state == 9: + if char == '/': + buffer.clear() + state = 0 + else: + state = 8 + + if char.isspace(): + if char == '\n': + self.lineno += 1 + if buffer: + buffer.pop() + getback = False + + if getback: + # Don't seek back when we hit EOF — '' doesn't belong to a next token + if char != '': + self._data.seek(self._data.tell() - 1) + if buffer: + buffer.pop() + + token = ''.join(buffer) + + if token.isdigit() and abs(int(token)) > _MAX_INT: + raise LexError(f'integer {token} out of bounds (max {_MAX_INT})', self.lineno) + + return token[:_MAX_ID_LEN] diff --git a/starlet/mips_backend.py b/starlet/mips_backend.py new file mode 100644 index 0000000..4400e72 --- /dev/null +++ b/starlet/mips_backend.py @@ -0,0 +1,270 @@ +from __future__ import annotations +import sys +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .symtable import SymbolTable + + +class MipsBackend: + + def __init__(self, asm_file, symtable: 'SymbolTable'): + self._f = asm_file + self._st = symtable + self._parlist: list = [] + self._lmain_flag: bool = True + self._f.write('L:\n\tj Lmain\n') + + # ------------------------------------------------------------------ + # Public entry point + # ------------------------------------------------------------------ + + def emit_block(self, quad_dict: dict, start: int, end: int, name: str): + for i in range(start, end): + self._write_to_asm(quad_dict[i], name, i) + + # ------------------------------------------------------------------ + # Address helpers + # ------------------------------------------------------------------ + + def _gnvlcode(self, v: str): + result = self._st.testing(v) + if result is None: + print('Undeclared variable ' + v) + sys.exit() + en, lvl = result + if en.entity_type == 'FUNC': + print('Undeclared variable ' + v) + sys.exit() + + current_lvl = self._st.current_level + diff = current_lvl - lvl + + self._f.write('\tlw $t0, -4($sp)\n') + while diff > 1: + self._f.write('\tlw $t0, -4($t0)\n') + diff -= 1 + self._f.write('\tadd $t0, $t0, - %d\n' % en.offset) + + def _loadvr(self, v: str, r: str): + if str(v).isdigit(): + self._f.write(f'\tli $t{r}, {v}\n') + return + + result = self._st.testing(v) + if result is None: + print('Undeclared variable ' + v) + sys.exit() + en, lvl = result + current_lvl = self._st.current_level + + if en.entity_type == 'VAR' and lvl == 0: + self._f.write(f'\tlw $t{r}, -{en.offset}($s0)\n') + elif (en.entity_type == 'VAR' and lvl == current_lvl) or \ + (en.entity_type == 'PAR' and lvl == current_lvl and en.par_mode == 'cv') or \ + en.entity_type == 'TMPVAR': + self._f.write(f'\tlw $t{r}, -{en.offset}($sp)\n') + elif en.entity_type == 'PAR' and lvl == current_lvl and en.par_mode == 'ref': + self._f.write(f'\tlw $t0, -{en.offset}($sp)\n') + self._f.write(f'\tlw $t{r}, ($t0)\n') + elif (en.entity_type == 'VAR' and lvl < current_lvl) or \ + (en.entity_type == 'PAR' and lvl < current_lvl and en.par_mode == 'cv'): + self._gnvlcode(v) + self._f.write(f'\tlw $t{r}, ($t0)\n') + elif en.entity_type == 'PAR' and lvl < current_lvl and en.par_mode == 'ref': + self._gnvlcode(v) + self._f.write('\tlw $t0, ($t0)\n') + self._f.write(f'\tlw $t{r}, ($t0)\n') + else: + print(f'ERROR: (loadvr) could not load variable: {v}') + sys.exit() + + def _storerv(self, r: str, v: str): + result = self._st.testing(v) + if result is None: + print('Undeclared variable ' + v) + sys.exit() + en, lvl = result + current_lvl = self._st.current_level + + if en.entity_type == 'VAR' and lvl == 0: + self._f.write(f'\tsw $t{r}, -{en.offset}($s0)\n') + elif (en.entity_type == 'VAR' and lvl == current_lvl) or \ + (en.entity_type == 'PAR' and lvl == current_lvl and en.par_mode == 'cv') or \ + en.entity_type == 'TMPVAR': + self._f.write(f'\tsw $t{r}, -{en.offset}($sp)\n') + elif en.entity_type == 'PAR' and lvl == current_lvl and en.par_mode == 'ref': + self._f.write(f'\tlw $t0, -{en.offset}($sp)\n') + self._f.write(f'\tsw $t{r}, ($t0)\n') + elif (en.entity_type == 'VAR' and lvl < current_lvl) or \ + (en.entity_type == 'PAR' and lvl < current_lvl and en.par_mode == 'cv'): + self._gnvlcode(v) + self._f.write(f'\tsw $t{r}, ($t0)\n') + elif en.entity_type == 'PAR' and lvl < current_lvl and en.par_mode == 'ref': + self._gnvlcode(v) + self._f.write('\tlw $t0, ($t0)\n') + self._f.write(f'\tsw $t{r}, ($t0)\n') + else: + print(f'ERROR: (storerv) could not store to variable: {v}') + sys.exit() + + # ------------------------------------------------------------------ + # Main quad dispatcher + # ------------------------------------------------------------------ + + def _write_to_asm(self, quad: list, name: str, labelno: int): + st = self._st + + if name == st.program_name and self._lmain_flag: + self._f.write('Lmain:\n') + self._f.write('\tadd $sp,$sp, %d\n' % st.scopes[0].sp) + self._f.write('\tmove $s0,$sp\n') + self._lmain_flag = False + + self._f.write('L_' + str(labelno) + ':\n') + + op = quad[0] + + if op == 'jump': + self._f.write('\tj L_%s\n' % quad[3]) + + elif op == ':=': + self._loadvr(quad[1], '1') + self._storerv('1', quad[3]) + + elif op in ('=', '<>', '<', '<=', '>', '>='): + self._loadvr(quad[1], '1') + self._loadvr(quad[2], '2') + relop = {'=': 'beq', '<>': 'bne', '<': 'blt', '<=': 'ble', '>': 'bgt', '>=': 'bge'}[op] + self._f.write('\t%s $t1, $t2, L_%s\n' % (relop, quad[3])) + + elif op in ('+', '-', '*', '/'): + self._loadvr(quad[1], '1') + self._loadvr(quad[2], '2') + mips_op = {'+': 'add', '-': 'sub', '*': 'mul', '/': 'div'}[op] + self._f.write('\t%s $t1, $t1, $t2\n' % mips_op) + self._storerv('1', quad[3]) + + elif op == 'out': + self._f.write('\tli $v0, 1\n') + self._f.write('\tli $a0, %s\n' % quad[3]) + self._f.write('\tsyscall\n') + + elif op == 'in': + self._f.write('\tli $v0, 5\n') + self._f.write('\tsyscall\n') + + elif op == 'retv': + self._loadvr(quad[1], '1') + self._f.write('\tlw $t0, -8($sp)\n') + self._f.write('\tsw $t1, ($t0)\n') + + elif op == 'par': + self._emit_par(quad, name) + + elif op == 'call': + self._emit_call(quad, name) + + elif op == 'begin_block': + self._f.write('\tsw $ra,($sp)\n') + + elif op == 'end_block': + if name == st.program_name: + self._f.write('\n') + else: + self._f.write('\tlw $ra,($sp)\n') + self._f.write('\tjr $ra\n') + + elif op == 'halt': + self._f.write('\tli $v0, 10\n') + self._f.write('\tsyscall\n') + + # ------------------------------------------------------------------ + # par / call helpers + # ------------------------------------------------------------------ + + def _caller_framelength(self, name: str) -> tuple[int, int]: + """Returns (caller_nesting_level, framelength) for the current block.""" + st = self._st + if name == st.program_name: + return 0, st.scopes[0].sp + result = st.search_entity(name, 'FUNC') + if result is None: + print(f'Function not declared: {name}') + sys.exit() + en, en_lvl = result + return en_lvl, en.framelength + + def _emit_par(self, quad: list, name: str): + en_lvl, framelength = self._caller_framelength(name) + + if not self._parlist: + self._f.write('\tadd $fp, $sp, %d\n' % framelength) + + self._parlist.append(quad) + par_offset = 12 + 4 * (len(self._parlist) - 1) + + if quad[2] == 'CV': + self._loadvr(quad[1], '0') + self._f.write('\tsw $t0, -%d($fp)\n' % par_offset) + + elif quad[2] == 'REF': + result = self._st.testing(quad[1]) + if result is None: + print('Undeclared variable') + sys.exit() + var_en, var_lvl = result + if en_lvl == var_lvl: + if var_en.entity_type == 'VAR' or (var_en.entity_type == 'PAR' and var_en.par_mode == 'cv'): + self._f.write('\tadd $t0, $sp, -%d\n' % var_en.offset) + self._f.write('\tsw $t0, -%d($fp)\n' % par_offset) + elif var_en.entity_type == 'PAR' and var_en.par_mode == 'ref': + self._f.write('\t$t0, -%d($sp)\n' % var_en.offset) + self._f.write('\tsw $t0, -%d($fp)\n' % par_offset) + else: + if var_en.entity_type == 'VAR' or (var_en.entity_type == 'PAR' and var_en.par_mode == 'cv'): + self._gnvlcode(quad[1]) + self._f.write('\tsw $t0, -%d($fp)\n' % par_offset) + elif var_en.entity_type == 'PAR' and var_en.par_mode == 'ref': + self._gnvlcode(quad[1]) + self._f.write('\tlw $t0, ($t0)\n') + self._f.write('\tsw $t0, -%d($fp)\n' % par_offset) + + elif quad[2] == 'RET': + result = self._st.testing(quad[1]) + if result is None: + print('Undeclared variable') + sys.exit() + var_en, _ = result + self._f.write('\tadd $t0, $sp, -%d\n' % var_en.offset) + self._f.write('\tsw $t0, -8($fp)\n') + + def _emit_call(self, quad: list, name: str): + st = self._st + en_lvl, framelength = self._caller_framelength(name) + + result = st.search_entity(quad[1], 'FUNC') + if result is None: + print(f'Function not declared: {quad[1]}') + sys.exit() + cn, cn_lvl = result + + for i, mode in enumerate(cn.arguments): + actual_mode = self._parlist[i][2] + if (mode == 'in' and actual_mode != 'CV') or \ + (mode == 'inout' and actual_mode != 'REF') or \ + (mode == 'inandout' and actual_mode != 'RET'): + print('ERROR: parameter type mismatch in call to ' + quad[1]) + sys.exit() + + self._parlist = [] + + if cn_lvl == en_lvl: + self._f.write('\tlw $t0, -4($sp)\n') + self._f.write('\tsw $t0, -4($fp)\n') + else: + self._f.write('\tsw $sp, -4($fp)\n') + + self._f.write('\tadd $sp, $sp, %d\n' % framelength) + self._f.write('\tjal L_%d\n' % cn.start_quad) + self._f.write('\tadd $sp, $sp, -%d\n' % framelength) diff --git a/starlet/parser.py b/starlet/parser.py new file mode 100644 index 0000000..15847c6 --- /dev/null +++ b/starlet/parser.py @@ -0,0 +1,539 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + +from .lexer import KEYWORDS +from .symtable import CompilerError + +if TYPE_CHECKING: + from .lexer import Lexer + from .ir import IR + from .symtable import SymbolTable + from .mips_backend import MipsBackend + + +class ParseError(Exception): + pass + + +class Parser: + + def __init__(self, lexer: 'Lexer', ir: 'IR', symtable: 'SymbolTable', mips: 'MipsBackend'): + self._lex = lexer + self._ir = ir + self._st = symtable + self._mips = mips + + self.token: str = '' + self._loop_enabled: bool = False + self._func_enabled: bool = False + self._ret_enabled: bool = False + self._new_exit_list = None + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _advance(self): + self.token = self._lex.next_token() + + def _is_valid_id(self, tk: str) -> bool: + return tk not in KEYWORDS and not tk.isdigit() and tk.isalnum() + + def _error(self, msg: str): + raise ParseError(f'Line {self._lex.lineno}: {msg}, found: {self.token!r}') + + # ------------------------------------------------------------------ + # Entry point + # ------------------------------------------------------------------ + + def parse(self): + self._program() + + # ------------------------------------------------------------------ + # Grammar rules + # ------------------------------------------------------------------ + + def _program(self): + self._advance() + if self.token != 'program': + self._error("expected 'program'") + self._advance() + if not self._is_valid_id(self.token): + self._error("expected program name (id)") + + self._st.program_name = self.token + name = self.token + self._advance() + + self._st.push_scope() + self._block(name) + + if self.token != 'endprogram': + self._error("expected 'endprogram'") + + def _block(self, name: str): + self._declarations() + self._subprograms() + + start_quad = self._ir.next_quad() + if name != self._st.program_name: + result = self._st.search_entity(name, 'FUNC') + if result: + result[0].set_start_quad(start_quad) + + self._ir.gen_quad('begin_block', name) + self._statements() + + if name == self._st.program_name: + self._ir.gen_quad('halt') + else: + result = self._st.search_entity(name, 'FUNC') + if result: + result[0].framelength = self._st.current_scope.get_sp() + + self._ir.gen_quad('end_block', name) + + end_quad = self._ir.next_quad() + self._mips.emit_block(self._ir.quad_dict, start_quad, end_quad, name) + + self._st.pop_scope() + + def _declarations(self): + while self.token == 'declare': + self._advance() + self._varlist() + if self.token != ';': + self._error("expected ';'") + self._advance() + + def _varlist(self): + if self._is_valid_id(self.token): + self._st.add_var(self.token) + self._advance() + while self.token == ',': + self._advance() + if not self._is_valid_id(self.token): + self._error("expected identifier in varlist") + self._st.add_var(self.token) + self._advance() + elif self.token != ';': + self._error("expected identifier or ';' in varlist") + + def _subprograms(self): + while self.token == 'function': + self._func_enabled = True + self._advance() + self._subprogram() + + def _subprogram(self): + if not self._is_valid_id(self.token): + self._error("expected function name (id)") + + name = self.token + new_scope = self._st.current_scope + self._st.push_scope(len(self._st.scopes), new_scope) + self._st.add_func(name) + + self._advance() + self._funcbody(name) + + if self.token != 'endfunction': + self._error("expected 'endfunction'") + if not self._ret_enabled: + self._error("no return statement in function") + + self._func_enabled = False + self._ret_enabled = False + self._advance() + + def _funcbody(self, name: str): + self._formalpars(name) + self._block(name) + + def _formalpars(self, name: str): + if self.token != '(': + self._error("expected '('") + self._advance() + self._formalparlist(name) + if self.token != ')': + self._error("expected ')'") + self._advance() + + def _formalparlist(self, name: str): + if self.token != ')': + self._formalparitem(name) + while self.token == ',': + self._advance() + self._formalparitem(name) + + def _formalparitem(self, name: str): + if self.token not in ('in', 'inout', 'inandout'): + self._error("expected 'in', 'inout', or 'inandout'") + par_mode = self.token + self._advance() + if not self._is_valid_id(self.token): + self._error("expected parameter name (id)") + self._st.add_arg_to_func(par_mode, name) + self._st.add_param(self.token, par_mode) + self._advance() + + def _statements(self): + self._statement() + while self.token == ';': + self._advance() + self._statement() + + def _statement(self): + if self.token == 'if': + self._advance() + self._if_stat() + elif self.token == 'while': + self._advance() + self._while_stat() + elif self.token == 'dowhile': + self._advance() + self._do_while_stat() + elif self.token == 'loop': + self._loop_enabled = True + self._advance() + self._loop_stat() + elif self.token == 'exit': + self._advance() + if not self._loop_enabled: + self._error("'exit' used outside a loop") + self._exit_stat() + elif self.token == 'forcase': + self._advance() + self._forcase_stat() + elif self.token == 'incase': + self._advance() + self._incase_stat() + elif self.token == 'return': + self._ret_enabled = True + self._advance() + exp = self._expression() + self._ir.gen_quad('retv', exp, '_', '_') + elif self.token == 'print': + self._advance() + exp = self._expression() + self._ir.gen_quad('out', exp) + elif self.token == 'input': + self._advance() + id_place = self._input_stat() + self._ir.gen_quad('inp', id_place) + elif self._is_valid_id(self.token): + self._assignment_stat() + + def _assignment_stat(self): + t1 = self.token + if not self._st.exists(t1): + self._error(f'undeclared variable {t1!r}') + self._advance() + if self.token != ':=': + self._error("expected ':='") + self._advance() + if not self._st.exists(self.token) and not self.token.isdigit(): + self._error(f'undeclared variable {self.token!r}') + exp = self._expression() + self._ir.gen_quad(':=', exp, '_', t1) + + def _if_stat(self): + if self.token != '(': + self._error("expected '(' after 'if'") + self._advance() + b_true, b_false = self._condition() + if self.token != ')': + self._error("expected ')' after condition") + self._advance() + self._ir.backpatch(b_true, self._ir.next_quad()) + if self.token != 'then': + self._error("expected 'then'") + self._advance() + self._statements() + skip = self._ir.make_list(self._ir.next_quad()) + self._ir.gen_quad('jump') + self._ir.backpatch(b_false, self._ir.next_quad()) + self._elsepart() + self._ir.backpatch(skip, self._ir.next_quad()) + if self.token != 'endif': + self._error("expected 'endif'") + self._advance() + + def _elsepart(self): + if self.token == 'else': + self._advance() + self._statements() + + def _while_stat(self): + quad = self._ir.next_quad() + if self.token != '(': + self._error("expected '(' after 'while'") + self._advance() + b_true, b_false = self._condition() + if self.token != ')': + self._error("expected ')' after condition") + self._advance() + self._ir.backpatch(b_true, self._ir.next_quad()) + self._statements() + self._ir.gen_quad('jump', '_', '_', str(quad)) + self._ir.backpatch(b_false, self._ir.next_quad()) + if self.token != 'endwhile': + self._error("expected 'endwhile'") + self._advance() + + def _do_while_stat(self): + quad = self._ir.next_quad() + self._statements() + if self.token != 'enddowhile': + self._error("expected 'enddowhile'") + self._advance() + if self.token != '(': + self._error("expected '(' after 'enddowhile'") + self._advance() + b_true, b_false = self._condition() + if self.token != ')': + self._error("expected ')' after condition") + self._ir.backpatch(b_true, quad) + self._ir.backpatch(b_false, self._ir.next_quad()) + self._advance() + + def _loop_stat(self): + quad = self._ir.next_quad() + self._statements() + self._ir.gen_quad('jump', '_', '_', str(quad)) + if self.token != 'endloop': + self._error("expected 'endloop'") + self._loop_enabled = False + self._advance() + if self._new_exit_list is not None: + self._ir.backpatch(self._new_exit_list, self._ir.next_quad()) + self._new_exit_list = None + + def _exit_stat(self): + self._new_exit_list = self._ir.make_list(self._ir.next_quad()) + self._ir.gen_quad('jump') + + def _forcase_stat(self): + flag_quad = self._ir.next_quad() + exit_list = self._ir.empty_list() + + while self.token == 'when': + self._advance() + if self.token != '(': + self._error("expected '(' after 'when'") + self._advance() + b_true, b_false = self._condition() + self._ir.backpatch(b_true, self._ir.next_quad()) + if self.token != ')': + self._error("expected ')' after condition") + self._advance() + if self.token != ':': + self._error("expected ':'") + self._advance() + self._statements() + t = self._ir.make_list(self._ir.next_quad()) + self._ir.gen_quad('jump') + exit_list = self._ir.merge(exit_list, t) + self._ir.backpatch(b_false, self._ir.next_quad()) + + if self.token != 'default': + self._error("expected 'default'") + self._advance() + if self.token != ':': + self._error("expected ':'") + self._advance() + self._statements() + self._ir.gen_quad('jump', '_', '_', str(flag_quad)) + self._ir.backpatch(exit_list, self._ir.next_quad()) + + if self.token != 'enddefault': + self._error("expected 'enddefault'") + self._advance() + if self.token != 'endforcase': + self._error("expected 'endforcase'") + self._advance() + + def _incase_stat(self): + t = self._ir.new_temp(self._st) + flag_quad = self._ir.next_quad() + self._ir.gen_quad(':=', '0', '_', t) + + while self.token == 'when': + self._advance() + if self.token != '(': + self._error("expected '(' after 'when'") + self._advance() + b_true, b_false = self._condition() + self._ir.backpatch(b_true, self._ir.next_quad()) + self._ir.gen_quad(':=', '1', '_', t) + if self.token != ')': + self._error("expected ')' after condition") + self._advance() + if self.token != ':': + self._error("expected ':'") + self._advance() + self._statements() + self._ir.backpatch(b_false, self._ir.next_quad()) + + if self.token != 'endincase': + self._error("expected 'endincase'") + self._advance() + self._ir.gen_quad('=', '1', t, str(flag_quad)) + + def _input_stat(self) -> str: + if not self._is_valid_id(self.token): + self._error("expected identifier for 'input'") + ret = self.token + self._advance() + return ret + + def _actualpars(self) -> bool: + if self.token == '(': + self._advance() + self._actualparlist() + if self.token != ')': + self._error("expected ')'") + self._advance() + return True + return False + + def _actualparlist(self): + self._actualparitem() + while self.token == ',': + self._advance() + self._actualparitem() + + def _actualparitem(self): + if self.token == 'in': + self._advance() + if not self._st.exists(self.token): + self._error(f'undeclared variable {self.token!r}') + exp = self._expression() + self._ir.gen_quad('par', exp, 'CV') + elif self.token == 'inout': + self._advance() + if not self._st.exists(self.token): + self._error(f'undeclared variable {self.token!r}') + t1 = self.token + if not self._is_valid_id(self.token): + self._error("expected identifier for 'inout'") + self._advance() + self._ir.gen_quad('par', t1, 'REF') + elif self.token == 'inandout': + self._advance() + if not self._st.exists(self.token): + self._error(f'undeclared variable {self.token!r}') + t2 = self.token + if not self._is_valid_id(self.token): + self._error("expected identifier for 'inandout'") + self._advance() + self._ir.gen_quad('par', t2, 'RET') + + def _condition(self): + b_true, b_false = self._boolterm() + while self.token == 'or': + self._ir.backpatch(b_false, self._ir.next_quad()) + self._advance() + c_true, c_false = self._boolterm() + b_true = self._ir.merge(b_true, c_true) + b_false = c_false + return b_true, b_false + + def _boolterm(self): + b_true, b_false = self._boolfactor() + while self.token == 'and': + self._ir.backpatch(b_true, self._ir.next_quad()) + self._advance() + c_true, c_false = self._boolfactor() + b_false = self._ir.merge(b_false, c_false) + b_true = c_true + return b_true, b_false + + def _boolfactor(self): + if self.token == 'not': + self._advance() + if self.token != '[': + self._error("expected '[' after 'not'") + self._advance() + ret = self._condition() + if self.token != ']': + self._error("expected ']'") + self._advance() + return ret[1], ret[0] # swap true/false for NOT + + elif self.token == '[': + self._advance() + ret = self._condition() + if self.token != ']': + self._error("expected ']'") + self._advance() + return ret + + else: + exp1 = self._expression() + op = self._relational_oper() + exp2 = self._expression() + b_true = self._ir.make_list(self._ir.next_quad()) + self._ir.gen_quad(op, exp1, exp2) + b_false = self._ir.make_list(self._ir.next_quad()) + self._ir.gen_quad('jump') + return b_true, b_false + + def _expression(self) -> str: + self._optional_sign() + t1 = self._term() + while self.token in ('+', '-'): + op = self.token + self._advance() + if not self._st.exists(self.token) and not self.token.isdigit(): + self._error(f'undeclared variable {self.token!r}') + t2 = self._term() + tmp = self._ir.new_temp(self._st) + self._ir.gen_quad(op, t1, t2, tmp) + t1 = tmp + return t1 + + def _term(self) -> str: + f1 = self._factor() + while self.token in ('*', '/'): + op = self.token + self._advance() + f2 = self._factor() + tmp = self._ir.new_temp(self._st) + self._ir.gen_quad(op, f1, f2, tmp) + f1 = tmp + return f1 + + def _factor(self) -> str: + if self.token.isdigit(): + ret = self.token + self._advance() + elif self.token == '(': + self._advance() + ret = self._expression() + if self.token != ')': + self._error("expected ')'") + self._advance() + elif self._is_valid_id(self.token): + ret = self.token + self._advance() + if self.token == '(': + self._actualpars() + new = self._ir.new_temp(self._st) + self._ir.gen_quad('par', new, 'RET') + self._ir.gen_quad('call', ret) + ret = new + else: + self._error("expected number, '(', or identifier") + return ret + + def _optional_sign(self): + if self.token in ('+', '-'): + self._advance() + + def _relational_oper(self) -> str: + if self.token not in ('=', '<=', '>=', '>', '<', '<>'): + self._error("expected relational operator") + op = self.token + self._advance() + return op diff --git a/starlet/symtable.py b/starlet/symtable.py new file mode 100644 index 0000000..6374787 --- /dev/null +++ b/starlet/symtable.py @@ -0,0 +1,210 @@ +import sys + + +class CompilerError(Exception): + pass + + +class Entity: + + def __init__(self, name, entity_type): + self.name = name + self.entity_type = entity_type + + def __str__(self): + return self.name + ': ' + self.entity_type + + +class Variable(Entity): + + def __init__(self, name, offset=0): + super().__init__(name, 'VAR') + self.offset = offset + + def __str__(self): + return super().__str__() + '\toffset: ' + str(self.offset) + + +class Function(Entity): + + def __init__(self, name, ret_val, start_quad=-1): + super().__init__(name, 'FUNC') + self.ret_val = ret_val + self.start_quad = start_quad + self.arguments = [] + self.framelength = -1 + + def set_framelength(self, x): + self.framelength = x + + def set_start_quad(self, x): + self.start_quad = x + + def set_ret_val(self, x): + self.ret_val = x + + def __str__(self): + return (super().__str__() + + ',\tStart_quad: ' + str(self.start_quad) + + ',\tFramelength:' + str(self.framelength) + + ',\tArgs:' + str(self.arguments)) + + +class Parameter(Entity): + + _MODE_MAP = {'in': 'cv', 'inout': 'ref', 'inandout': 'ret'} + + def __init__(self, name, par_mode, offset=0): + super().__init__(name, 'PAR') + self.par_mode = self._MODE_MAP.get(par_mode, par_mode) + self.offset = offset + + def __str__(self): + return super().__str__() + ',\tpar_mode: ' + self.par_mode + ',\toffset: ' + str(self.offset) + + +class TempVariable(Entity): + + def __init__(self, name, offset=0): + super().__init__(name, 'TMPVAR') + self.offset = offset + + def __str__(self): + return super().__str__() + '\toffset: ' + str(self.offset) + + +class Scope: + + def __init__(self, nesting_level=0, enclosing_scope=None): + self.entities = [] + self.nesting_level = nesting_level + self.enclosing_scope = enclosing_scope + self.sp = 12 + + def get_sp(self): + ret = self.sp + self.sp += 4 + return ret + + def add_entity(self, ent): + self.entities.append(ent) + + def __str__(self): + return (repr(self) + + '\nNesting lvl: ' + repr(self.nesting_level) + + '\nEnclosing Scope: ' + repr(self.enclosing_scope)) + + +class SymbolTable: + + def __init__(self): + self._scopes: list[Scope] = [] + self.program_name: str = '' + + # ------------------------------------------------------------------ + # Scope management + # ------------------------------------------------------------------ + + def push_scope(self, nesting_level=0, enclosing_scope=None): + self._scopes.append(Scope(nesting_level, enclosing_scope)) + + def pop_scope(self): + scope = self._scopes.pop() + print(scope) + for en in scope.entities: + print(en) + print('--------------------------------------------') + return scope + + @property + def current_scope(self) -> Scope: + return self._scopes[-1] + + @property + def current_level(self) -> int: + return self._scopes[-1].nesting_level + + @property + def scopes(self) -> list[Scope]: + return self._scopes + + # ------------------------------------------------------------------ + # Entity addition + # ------------------------------------------------------------------ + + def add_var(self, name: str): + lvl = self.current_level + offset = self.current_scope.get_sp() + if self._already_declared(name, 'VAR', lvl): + raise CompilerError(f'Variable already declared: {name}') + if self._exists_as_param(name, lvl): + raise CompilerError(f'Symbol already declared as a parameter: {name}') + self.current_scope.add_entity(Variable(name, offset)) + + def add_func(self, name: str, ret_val=0): + self._scopes[-2].add_entity(Function(name, ret_val)) + + def add_param(self, name: str, par_mode: str): + lvl = self.current_level + offset = self.current_scope.get_sp() + if self._already_declared(name, 'PAR', lvl): + raise CompilerError(f'Parameter already declared: {name}') + self.current_scope.add_entity(Parameter(name, par_mode, offset)) + + def add_arg_to_func(self, par_mode: str, f_name: str): + result = self.search_entity(f_name, 'FUNC') + if result is None: + raise CompilerError(f'No function definition for: {f_name}') + func_en, _ = result + func_en.arguments.append(par_mode) + + def add_temp(self, name: str) -> TempVariable: + offset = self.current_scope.get_sp() + temp = TempVariable(name, offset) + self.current_scope.add_entity(temp) + return temp + + # ------------------------------------------------------------------ + # Lookup helpers + # ------------------------------------------------------------------ + + def exists(self, name: str) -> bool: + for scope in self._scopes: + for entity in scope.entities: + if entity.name == name: + return True + return False + + def search_entity(self, name: str, entity_type: str): + """Search front-to-back; returns (entity, nesting_level) or None.""" + for scope in self._scopes: + for entity in scope.entities: + if entity.entity_type == entity_type and entity.name == name: + return entity, scope.nesting_level + return None + + def testing(self, name: str): + """Walk the static chain from innermost scope; returns (entity, nesting_level) or None.""" + scope = self._scopes[-1] + while scope is not None: + for entity in scope.entities: + if entity.name == name: + return entity, scope.nesting_level + scope = scope.enclosing_scope + return None + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _already_declared(self, name: str, entity_type: str, nesting_level: int) -> bool: + for entity in self._scopes[nesting_level].entities: + if entity.name == name and entity.entity_type == entity_type: + return True + return False + + def _exists_as_param(self, name: str, level: int) -> bool: + for entity in self._scopes[level].entities: + if entity.entity_type == 'PAR' and entity.name == name: + return True + return False diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 0000000..2db5356 --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,77 @@ +import os +import tempfile +import shutil +import pytest + +from starlet.__main__ import main + + +EXAMPLES_DIR = os.path.join(os.path.dirname(__file__), '..', 'Examples') +EXAMPLE_FILES = sorted( + f for f in os.listdir(EXAMPLES_DIR) if f.endswith('.stl') +) + + +@pytest.fixture +def tmp_dir(): + d = tempfile.mkdtemp() + yield d + shutil.rmtree(d) + + +def compile_example(example_name: str, tmp_dir: str) -> tuple[str, str, str]: + """Copy the .stl to tmp_dir, compile it, return (int_path, c_path, asm_path).""" + src = os.path.join(EXAMPLES_DIR, example_name) + dst = os.path.join(tmp_dir, example_name) + shutil.copy(src, dst) + + import sys + old_argv = sys.argv + sys.argv = ['starlet', dst] + try: + main() + finally: + sys.argv = old_argv + + base = dst[:-4] + return base + '.int', base + '.c', base + '.asm' + + +@pytest.mark.parametrize('example', EXAMPLE_FILES) +def test_example_compiles(example, tmp_dir): + int_path, c_path, asm_path = compile_example(example, tmp_dir) + assert os.path.isfile(int_path), f'{example}: .int file missing' + assert os.path.isfile(c_path), f'{example}: .c file missing' + assert os.path.isfile(asm_path), f'{example}: .asm file missing' + + +@pytest.mark.parametrize('example', EXAMPLE_FILES) +def test_output_files_non_empty(example, tmp_dir): + int_path, c_path, asm_path = compile_example(example, tmp_dir) + assert os.path.getsize(int_path) > 0, f'{example}: .int file is empty' + assert os.path.getsize(c_path) > 0, f'{example}: .c file is empty' + assert os.path.getsize(asm_path) > 0, f'{example}: .asm file is empty' + + +@pytest.mark.parametrize('example', EXAMPLE_FILES) +def test_c_output_has_main(example, tmp_dir): + _, c_path, _ = compile_example(example, tmp_dir) + with open(c_path) as f: + content = f.read() + assert 'int main()' in content, f'{example}: .c output missing main()' + + +@pytest.mark.parametrize('example', EXAMPLE_FILES) +def test_asm_output_has_lmain(example, tmp_dir): + _, _, asm_path = compile_example(example, tmp_dir) + with open(asm_path) as f: + content = f.read() + assert 'Lmain:' in content, f'{example}: .asm output missing Lmain label' + + +@pytest.mark.parametrize('example', EXAMPLE_FILES) +def test_int_output_has_halt(example, tmp_dir): + int_path, _, _ = compile_example(example, tmp_dir) + with open(int_path) as f: + content = f.read() + assert 'halt' in content, f'{example}: .int output missing halt quad' diff --git a/tests/test_lexer.py b/tests/test_lexer.py new file mode 100644 index 0000000..84f194c --- /dev/null +++ b/tests/test_lexer.py @@ -0,0 +1,130 @@ +import io +import pytest +from starlet.lexer import Lexer, LexError, KEYWORDS + + +def tokens(source: str) -> list[str]: + lexer = Lexer(io.StringIO(source)) + result = [] + while True: + tok = lexer.next_token() + if tok == '': + break + result.append(tok) + return result + + +class TestKeywords: + def test_program_keyword(self): + assert tokens('program') == ['program'] + + def test_all_keywords_recognized(self): + for kw in KEYWORDS: + assert tokens(kw) == [kw] + + +class TestIdentifiers: + def test_simple_id(self): + assert tokens('myVar') == ['myVar'] + + def test_id_with_digits(self): + assert tokens('var1') == ['var1'] + + def test_id_truncated_at_30_chars(self): + long_id = 'a' * 40 + result = tokens(long_id) + assert result == ['a' * 30] + + def test_multiple_ids(self): + assert tokens('foo bar') == ['foo', 'bar'] + + +class TestIntegers: + def test_single_digit(self): + assert tokens('5') == ['5'] + + def test_multi_digit(self): + assert tokens('123') == ['123'] + + def test_max_int(self): + assert tokens('32767') == ['32767'] + + def test_over_max_int_raises(self): + with pytest.raises(LexError): + tokens('32768') + + def test_letter_after_digit_raises(self): + with pytest.raises(LexError): + tokens('1a') + + +class TestOperators: + def test_assignment(self): + assert tokens(':=') == [':='] + + def test_colon_alone(self): + assert tokens(':') == [':'] + + def test_le(self): + assert tokens('<=') == ['<='] + + def test_ge(self): + assert tokens('>=') == ['>='] + + def test_ne(self): + assert tokens('<>') == ['<>'] + + def test_lt_alone(self): + assert tokens('<') == ['<'] + + def test_gt_alone(self): + assert tokens('>') == ['>'] + + def test_single_char_ops(self): + for op in ('+', '-', '*', '=', ',', ';', '(', ')', '[', ']'): + assert tokens(op) == [op] + + +class TestComments: + def test_line_comment_ignored(self): + assert tokens('// hello world\nfoo') == ['foo'] + + def test_block_comment_ignored(self): + assert tokens('/* hello */ bar') == ['bar'] + + def test_multiline_block_comment(self): + assert tokens('/* line1\nline2 */ baz') == ['baz'] + + def test_unclosed_block_comment_raises(self): + with pytest.raises(LexError): + tokens('/* unclosed') + + def test_line_counter(self): + lexer = Lexer(io.StringIO('a\nb\nc')) + # The lexer consumes the newline as a lookahead to delimit the token, + # so lineno increments before the token is returned. + lexer.next_token() # 'a' — newline consumed, lineno advances + assert lexer.lineno == 2 + lexer.next_token() # 'b' — newline consumed, lineno advances + assert lexer.lineno == 3 + lexer.next_token() # 'c' — EOF as lookahead, no newline consumed + assert lexer.lineno == 3 + + +class TestWhitespace: + def test_spaces_skipped(self): + assert tokens(' a b ') == ['a', 'b'] + + def test_tabs_skipped(self): + assert tokens('\ta\tb') == ['a', 'b'] + + def test_newlines_skipped(self): + assert tokens('a\nb') == ['a', 'b'] + + +class TestEOF: + def test_empty_source(self): + assert tokens('') == [] + + def test_only_whitespace(self): + assert tokens(' \n\t ') == [] diff --git a/uv.lock b/uv.lock new file mode 100644 index 0000000..ab613af --- /dev/null +++ b/uv.lock @@ -0,0 +1,8 @@ +version = 1 +revision = 3 +requires-python = ">=3.13" + +[[package]] +name = "compiler" +version = "0.1.0" +source = { virtual = "." }