import sys
import os

# An assembler for the Gameboy Z80.

# Used to track files we've included and prevent recursive includes.
included_names = None

EVAL_UNOPS={
    "-":(lambda a:(-a)),
    "~":(lambda a:(~a)),
    "!":(lambda a:(int(not a))),
    "<":(lambda a:(a & 0xFF)),
    ">":(lambda a:((a>>8) & 0xFF))
}

EVAL_ADDOPS={
    "+":(lambda a, b: (a+b)),
    "-":(lambda a, b: (a-b))
}

EVAL_MULOPS={
    "*":(lambda a, b: (a*b)),
    "/":(lambda a, b: (a/b)),
    "%":(lambda a, b: (a%b))
}

EVAL_SHIFTOPS={
    "<<":(lambda a, b: (a<<b)),
    ">>":(lambda a, b: (a>>b))
}

EVAL_BITOPS={
    "&":(lambda a, b: (a&b)),
    "|":(lambda a, b: (a|b)),
    "^":(lambda a, b: (a^b))
}

def FindBinop(expr, i, oplist):
    for name in oplist:
        if expr.startswith(name, i):
            return name
    return None

EVAL_ALLOPS=set(("(", ")", "{", "}"))
for ops in (EVAL_UNOPS, EVAL_ADDOPS, EVAL_MULOPS, EVAL_SHIFTOPS, EVAL_BITOPS):
    for op in ops:
        EVAL_ALLOPS.add(op)

class Block:
    def __init__(self, name):
        self.name = name
        self.current_label = None
        self.base = None
        self.bank = None
        self.has_error = False
        self.data = bytearray()
        self.symbols = {}
        self.relocs16 = {}
        self.relocs8h = {}
        self.relocs8l = {}
        self.relocsRel = {}
        self.ascii = []
        self.equ = {}
    
    def error(self, e):
        print(self.name + " ERROR: " + e)
        self.has_error = True
    
    def placeLabel(self, l):
        if l in self.symbols:
            self.error("Symbol " + l + " defined twice.")
        elif l[0] == '.':
            if self.current_label == None:
                self.error("Dot-label used without a previous label.")
            else:
                self.symbols["." + self.current_label + l] = len(self.data)
        else:
            self.current_label = l
            self.symbols[l] = len(self.data)
    
    def pushByte(self, b):
        self.data.append(b & 0xFF)
    
    def pushBlob(self, d):
        for b in d:
            self.data.append(int(b) & 0xFF)
    
    def pushShort(self, s):
        self.pushByte(s)
        self.pushByte(s >> 8)
    
    def pushCode(self, c):
        if c >= 0x100:
            self.pushByte(c >> 8)
        self.pushByte(c)
    
    def _pushReloc(self, l, relocs):
        if l[0] == '.':
            if self.current_label == None:
                self.error("Dot-label used without a previous label.")
            else:
                l = "." + self.current_label + l
        r = relocs.get(l, [])
        r.append(len(self.data))
        relocs[l] = r
    
    def pushReloc16(self, l):
        self._pushReloc(l, self.relocs16)
        self.pushShort(0)
    
    def pushReloc8h(self, l):
        self._pushReloc(l, self.relocs8h)
        self.pushByte(0)
    
    def pushReloc8l(self, l):
        self._pushReloc(l, self.relocs8l)
        self.pushByte(0)
    
    def pushRelocRel(self, l):
        self._pushReloc(l, self.relocsRel)
        self.pushByte(0)
    
    def pushASCII(self, s):
        self.ascii.append((len(self.data), len(s)))
        for c in s:
            self.pushByte(ord(c))
    
    def resolveReloc(self, l, addr, all_blocks):
        block = None
        if l[0] == '$' or l[0] == '.':
            blocks = [self]
        else:
            blocks = all_blocks
        for b in blocks:
            if l in b.symbols:
                if block != None:
                    self.error("Symbol " + l + " defined in two different modules")
                block = b
        if block == None:
            self.error("Symbol " + l + " was not defined")
            return 0
        return block.symbols[l] + block.base
    
    def resolveRelocs(self, all_blocks):
        for l in self.relocs16:
            for addr in self.relocs16[l]:
                dest = self.resolveReloc(l, addr, all_blocks)
                self.data[addr] = dest & 0xFF
                self.data[addr + 1] = (dest >> 8) & 0xFF
        for l in self.relocs8h:
            for addr in self.relocs8h[l]:
                dest = self.resolveReloc(l, addr, all_blocks)
                self.data[addr] = (dest >> 8) & 0xFF
        for l in self.relocs8l:
            for addr in self.relocs8l[l]:
                dest = self.resolveReloc(l, addr, all_blocks)
                self.data[addr] = dest & 0xFF
        for l in self.relocsRel:
            for addr in self.relocsRel[l]:
                dest = self.resolveReloc(l, addr, all_blocks) - 1
                # Validate the distance...
                abs_addr = self.base + addr
                if dest - 128 > abs_addr or dest + 127 < abs_addr:
                    self.error("Relative jump is too far")
                else:
                    self.data[addr] = (dest - abs_addr) & 0xFF
    
    def evalBinop(self, expr, i, oplist, func):
        i, val = func(self, expr, i)
        if i == None:
            return (None, val)
        name = FindBinop(expr, i, oplist)
        while name:
            i += len(name)
            if i == len(expr):
                self.error("Expected operand for " + name)
                return (None, 0)
            op = oplist[name]
            i, val2 = func(self, expr, i)
            val = op(val, val2)
            name = FindBinop(expr, i, oplist)
        return (i, val)
    
    def evalExpr(self, expr, i):
        return self.evalBinop(expr, i, EVAL_ADDOPS, Block.evalTerm)
    
    def evalTerm(self, expr, i):
        return self.evalBinop(expr, i, EVAL_MULOPS, Block.evalFactor)
    
    def evalFactor(self, expr, i):
        return self.evalBinop(expr, i, EVAL_SHIFTOPS, Block.evalShiftFactor)
    
    def evalShiftFactor(self, expr, i):
        return self.evalBinop(expr, i, EVAL_BITOPS, Block.evalValue)
    
    def evalValue(self, expr, i):
        unop = EVAL_UNOPS.get(expr[i], None)
        if unop:
            i, val = self.evalValue(expr, i+1)
            return (i, unop(val))
        elif expr[i] == '(':
            i, val = self.evalExpr(expr, i+1)
            if expr[i] == ')':
                i += 1
                if i == len(expr):
                    i = None
            else:
                self.error("Expected ')' after expression")
            return (i, val)
        else:
            start = i
            while i < len(expr) and expr[i] not in EVAL_ALLOPS:
                i += 1
            key = expr[start:i]
            while key in self.equ:
                key = self.equ[key]
            try:
                val = int(key, 0)
            except:
                self.error("Expected only integers in expression")
                val = 0
            if i == len(expr):
                return (None, val)
            else:
                return (i, val)
    
    def eval(self, expr):
        return self.evalExpr(expr, 0)[1]

ZERO_ARGS = {
    "nop":0,
    "rcla":0x07,
    "rcra":0x0F,
    "rla":0x17,
    "rra":0x1F,
    "daa":0x27,
    "cpl":0x2F,
    "scf":0x37,
    "ccf":0x3F,
    "halt":0x76,
    "ret":0xC9,
    "retnz":0xC0,
    "retz":0xC8,
    "retnc":0xD0,
    "reti":0xD9,
    "retc":0xD8,
    "ei":0xFB,
    "di":0xF3,
    "stop":0x1000,
    "rst00":0xC7,
    "rst0":0xC7,
    "rst08":0xCF,
    "rst8":0xCF,
    "rst10":0xD7,
    "rst18":0xDF,
    "rst20":0xE7,
    "rst28":0xEF,
    "rst30":0xF7,
    "rst38":0xFF
}

REGISTERS = (
    "h",
    "l",
    "a",
    "f",
    "b",
    "c",
    "d",
    "e",
    "sp",
    "pc",
    "hl",
    "af",
    "bc",
    "de",
    "(bc)",
    "(de)",
    "(hl)",
    "(hl+)",
    "(hl-)",
    "(c)" # This is weird, but it does exist.
)

SOME_ARGS = {
    "jr":{ "r8":0x18 },
    "jrz":{ "r8":0x28 },
    "jrc":{ "r8":0x38 },
    "jrnz":{ "r8":0x20 },
    "jrnc":{ "r8":0x30 },
    "jp":{ "i16":0xC3, "hl":0xE9 },
    "jpz":{ "i16":0xCA },
    "jpc":{ "i16":0xDA },
    "jpnz":{ "i16":0xC2 },
    "jpnc":{ "i16":0xD2 },
    "call":{ "i16":0xCD },
    "callz":{ "i16":0xCC },
    "callc":{ "i16":0xDC },
    "callnz":{ "i16":0xC4 },
    "callnc":{ "i16":0xD4 },
    "add":{
        "b":0x80,
        "c":0x81,
        "d":0x82,
        "e":0x83,
        "h":0x84,
        "l":0x85,
        "(hl)":0x86,
        "a":0x87,
        "hl":{
            "bc":0x09,
            "de":0x19,
            "hl":0x29,
            "sp":0x39
        },
        "sp":{ "r8":0xE9 },
        "d8":0xC6
    },
    "adc":{
        "b":0x88,
        "c":0x89,
        "d":0x8A,
        "e":0x8B,
        "h":0x8C,
        "l":0x8D,
        "(hl)":0x8E,
        "a":0x8F
    },
    "sub":{
        "b":0x90,
        "c":0x91,
        "d":0x92,
        "e":0x93,
        "h":0x94,
        "l":0x95,
        "(hl)":0x96,
        "a":0x97,
        "d8":0xD6
    },
    "sbc":{
        "b":0x98,
        "c":0x99,
        "d":0x9A,
        "e":0x9B,
        "h":0x9C,
        "l":0x9D,
        "(hl)":0x9E,
        "a":0x9F
    },
    "and":{
        "b":0xA0,
        "c":0xA1,
        "d":0xA2,
        "e":0xA3,
        "h":0xA4,
        "l":0xA5,
        "(hl)":0xA6,
        "a":0xA7,
        "d8":0xE6
    },
    "xor":{
        "b":0xA8,
        "c":0xA9,
        "d":0xAA,
        "e":0xAB,
        "h":0xAC,
        "l":0xAD,
        "(hl)":0xAE,
        "a":0xAF
    },
    "or":{
        "b":0xB0,
        "c":0xB1,
        "d":0xB2,
        "e":0xB3,
        "h":0xB4,
        "l":0xB5,
        "(hl)":0xB6,
        "a":0xB7,
        "d8":0xF6
    },
    "cp":{
        "b":0xB8,
        "c":0xB9,
        "d":0xBA,
        "e":0xBB,
        "h":0xBC,
        "l":0xBD,
        "(hl)":0xBE,
        "a":0xBF
    },
    "inc":{
        "bc":0x03,
        "de":0x13,
        "hl":0x23,
        "sp":0x33,
        
        "b":0x04,
        "d":0x14,
        "h":0x24,
        "(hl)":0x34,
        
        "c":0x0C,
        "e":0x1C,
        "h":0x2C,
        "a":0x3C
    },
    "dec":{
        "bc":0x0B,
        "de":0x1B,
        "hl":0x2B,
        "sp":0x3B,
        
        "b":0x05,
        "d":0x15,
        "h":0x25,
        "(hl)":0x35,
        
        "c":0x0D,
        "e":0x1D,
        "h":0x2D,
        "a":0x3D,
    },
    "pop":{
        "bc":0xC1,
        "de":0xD1,
        "hl":0xE1,
        "af":0xF1
    },
    "push":{
        "bc":0xC5,
        "de":0xD5,
        "hl":0xE5,
        "af":0xF5
    },
    "ld":{
        "bc":{ "d16":0x01 },
        "(bc)":{ "a":0x02 },
        "de":{ "d16":0x11 },
        "(de)":{ "a":0x12 },
        "hl":{ "d16":0x21 },
        "(hl+)":{ "a":0x22 },
        "(hl-)":{ "a":0x32 },
        "(a16)":{ "a":0xEA },
        "sp":{ "d16":0x31, "hl":0xF9 }
    },
    "ldh":{
        "(a8)":{ "a": 0xE0 },
        "a":{ "(a8)": 0xF0, "(c)":0xF2 },
        "(c)":{ "a":0xE2 } # Weird, but true.
    }
}

# Generate the LD definitions, since it's easy to automate.
LD_ARGS = (
    ("b", 0x40),
    ("c", 0x48),
    ("d", 0x50),
    ("e", 0x58),
    ("h", 0x60),
    ("l", 0x68),
    ("(hl)", 0x70),
    ("a", 0x78)
)
for o1, base in LD_ARGS:
    d = {}
    i = 0
    while i < len(LD_ARGS):
        d[LD_ARGS[i][0]] = base + i
        i += 1
    SOME_ARGS["ld"][o1] = d

CB_OPS = (
    "rlc",
    "rrc",
    "rl",
    "rr",
    "sla",
    "sra",
    "swap",
    "srl"
)

i = 0
for op in CB_OPS:
    ops = {}
    for reg, _ in LD_ARGS:
        ops[reg] = 0xCB00 + i
        i += 1
    SOME_ARGS[op] = ops

BIT_OPS = ("bit", "res", "set")
for op in BIT_OPS:
    for n in range(0, 8):
        bit_ops = {}
        for reg, _ in LD_ARGS:
            bit_ops[reg] = 0xCB00 + i
            i += 1
        SOME_ARGS[op + str(n)] = bit_ops

# A few oddballs that would conflict with full overwriting above.
SOME_ARGS["ld"]["a"]["(a16)"] = 0xFA

SOME_ARGS["ld"]["a"]["(bc)"] = 0x0A
SOME_ARGS["ld"]["a"]["(de)"] = 0x1A
SOME_ARGS["ld"]["a"]["(hl+)"] = 0x2A
SOME_ARGS["ld"]["a"]["(hl-)"] = 0x3A

SOME_ARGS["ld"]["b"]["d8"] = 0x06
SOME_ARGS["ld"]["d"]["d8"] = 0x16
SOME_ARGS["ld"]["h"]["d8"] = 0x26
SOME_ARGS["ld"]["(hl)"]["d8"] = 0x36

SOME_ARGS["ld"]["c"]["d8"] = 0x0E
SOME_ARGS["ld"]["e"]["d8"] = 0x1E
SOME_ARGS["ld"]["l"]["d8"] = 0x2E
SOME_ARGS["ld"]["a"]["d8"] = 0x3E

def locateOperand(op, ptr, allow=(8,16,'r')):
    prefix_opts = "adi"
    if 'r' in allow:
        prefix_opts += 'r'
    for p in prefix_opts:
        opts = []
        if 16 in allow:
            opts.append(("16", Block.pushShort))
        if 8 in allow or p == 'r':
            opts.append(("8", Block.pushByte))
        for t in opts:
            s, func = t
            o = p + s
            if o in op:
                return (o, func)
            if ptr:
                o = "(" + o + ")"
                if o in op:
                    return (o, func)

    print("Error in " + str(op))
    return (None, None)

def writeOpcode(mnemonic, block, ops, i, operands):
    # Drills down to find an opcode based on operands.
    if i == len(operands):
        block.pushCode(ops)
        return
    assert i < len(operands)
    operand = operands[i]
    if operand.lower() not in REGISTERS:
        if operand[0] == '(' and operand[-1] == ')':
            ptr = True
            operand = operand[1:-1]
        else:
            ptr = False
        
        if operand[0] == '{' and operand[-1] == '}':
            # Evaluate expression.
            num = block.eval(operand[1:-1])
            # print("INFO: Evaluated " + operand[1:-1] + " as " + hex(num))
        else:
            if operand[0] in "<>=@":
                c = operand[0]
                operand = operand[1:]
                operand = c + block.equ.get(operand, operand)
            else:
                operand = block.equ.get(operand, operand)
            try:
                if operand[0] in "<>=@":
                    c = operand[0]
                    num = int(operand[1:], 0)
                    if c == '>':
                        num >>= 8
                else:
                    num = int(operand, 0)
            except:
                # Assume an address.
                num = None
        
        key, func = locateOperand(ops, ptr)
        if num != None:
            if key == None:
                block.error("(1) Invalid operand " + operand + " for " + mnemonic)
            else:
                writeOpcode(mnemonic, block, ops[key], i + 1, operands)
                func(block, num)
        else:
            # Must be an address.
            if operand[0] == '>':
                pushReloc = Block.pushReloc8h
                allow = [8]
            elif operand[0] == '<':
                pushReloc = Block.pushReloc8l
                allow = [8]
            elif operand[0] == '@':
                pushReloc = Block.pushRelocRel
                allow = ['r']
            else:
                pushReloc = Block.pushReloc16
                allow = [16]
            
            if operand[0] in "=<>@":
                operand = operand[1:]
            
            key, func = locateOperand(ops, ptr, allow)
            if key == None:
                block.error("(2) Invalid operand " + operand + " for " + mnemonic)
            else:
                writeOpcode(mnemonic, block, ops[key], i + 1, operands)
                pushReloc(block, operand)
    else:
        val = ops.get(operand.lower(), None)
        if val == None:
            block.error("(3) Invalid operand " + operand + " for " + mnemonic)
        writeOpcode(mnemonic, block, val, i + 1, operands)

def assemble(block, mnemonic, operands=()):
    global included_names
    # print(str(mnemonic) + " " + str(operands))
    mnemonic = mnemonic.lower()
    
    if mnemonic.startswith("!bank"):
        if block.bank != None:
            block.error("!bank specified more than once")
            return
        
        if len(operands) != 1:
            block.error("!bank must have one integer argument")
            return
        
        try:
            bank = int(operands[0], 0)
        except:
            block.error("Invalid bank value " + operands[0])
            return
        block.bank = bank
    
    # TODO: We should really split blocks with origin, not jam it in here.
    if mnemonic.startswith("!orig"):
        if block.base != None:
            block.error(mnemonic + " specified more than once")
        
        if len(operands) != 1:
            block.error(mnemonic + " must have one integer argument")
            return
        
        try:
            base = int(operands[0], 0)
        except:
            block.error("Invalid base value " + operands[0])
            return
        block.base = base
        return
    
    if mnemonic == "!equ" or mnemonic == "!set":
        if operands[1][0] == '{' and operands[1][-1] == '}':
            block.equ[operands[0]] = str(block.eval(operands[1][1:-1]))
        else:
            block.equ[operands[0]] = operands[1]
        return
    
    try:
        if mnemonic == "!byte":
            for o in operands:
                block.pushByte(int(o, 0))
            return
        if mnemonic == "!short":
            for o in operands:
                block.pushShort(int(o, 0))
            return
    except:
        block.error("Invalid integer literal ")
        return
    
    if mnemonic.startswith("!alig"):
        if len(operands) == 0:
            block.error("!align requires an operator.")
            return
        try:
            n = int(operands[0], 0)
            try:
                pad = int(operands[1], 0)
            except:
                pad = 0
            while len(block.data) % n != 0:
                block.pushByte(pad)
        except:
            block.error("Invalid alignment " + operands[0])
        return
    
    if mnemonic.startswith("!check"):
        # Very lazy hack to automate ROM header checksums.
        checksum = 0
        for b in block.data[-25:]:
            checksum = checksum - b - 1
        block.pushByte(checksum)
        return
    
    if mnemonic == "!logo":
        # Very lazy directive to write the NINTENDO logo data.
        block.pushBlob((
            0xCE,0xED,0x66,0x66,0xCC,0x0D,0x00,0x0B,0x03,0x73,0x00,0x83,0x00,0x0C,0x00,0x0D,
            0x00,0x08,0x11,0x1F,0x88,0x89,0x00,0x0E,0xDC,0xCC,0x6E,0xE6,0xDD,0xDD,0xD9,0x99,
            0xBB,0xBB,0x67,0x63,0x6E,0x0E,0xEC,0xCC,0xDD,0xDC,0x99,0x9F,0xBB,0xB9,0x33,0x3E))
        return
        
    if mnemonic == "!ascii":
        s=""
        for o in operands:
            s+=o
        block.pushASCII(s)
        return
    
    if mnemonic == "!image" or mnemonic.startswith("!sprit"):
        for o in operands:
            if len(o) != 8:
                block.error("!image directive must be followed by 8 characters.")
                continue
            pixels = bytearray()
            PIXEL_CONSTS={ ".": 0, "0":0, "1":1, "2":2, "3":3, "x":3, "X":3 }
            for c in o:
                b = PIXEL_CONSTS.get(c, None)
                if b == None:
                    block.error("!image characters must be in 0-3")
                    b = 0
                pixels.append(b)
            byte = 0
            for pix in pixels:
                byte <<= 1
                byte |= pix & 1
            block.pushByte(byte)
            byte = 0
            for pix in pixels:
                byte <<= 1
                byte |= (pix >> 1) & 1
            block.pushByte(byte)
        return
    
    if mnemonic == "!include":
        for o in operands:
            if o in included_names:
                block.error("Recursive include of " + o)
            else:
                included_names.add(o)
                try:
                    assembleFile(block, o)
                except Error:
                    print("Error in " + o)
                    raise
                finally:
                    included_names.remove(o)
        return
    
    if mnemonic == "!incbin":
        for o in operands:
            f = open(o, "rb")
            for b in f:
                if type(b) == str:
                    for c in b:
                        block.pushByte(ord(c))
                else:
                    block.pushByte(b)
            f.close()
        return
    
    if mnemonic[0] == '!':
        block.error("Unknown directive " + mnemonic)
        return
    
    # Handle zero-operand mnemonics
    if len(operands) == 0:
        code = ZERO_ARGS.get(mnemonic, None)
        if code == None:
            block.error("Unknown zero-operand opcode " + mnemonic)
        else:
            block.pushCode(code)
    else:
        code = SOME_ARGS.get(mnemonic, None)
        if code == None:
            block.error("Unknown opcode " + mnemonic)
        else:
            writeOpcode(mnemonic, block, code, 0, operands)

LABEL_CHARS="!abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ.$@1234567890_"

# Output bytes, can grow later.
blocks = []

def assembleFile(block, path):
    f = open(path, "r")
    for line in f:
        for part in line.split(";", 1)[0].split("&"):
            part = part.strip()
            if len(part) == 0:
                continue
            # Lex out the mnemonic, or label.
            i = 0
            while i < len(part):
                if part[i] not in LABEL_CHARS:
                    break
                i += 1
            if i == len(part):
                mnemonic = part
            else:
                mnemonic = part[:i]
                if part[i] == ':':
                    block.placeLabel(mnemonic)
                    continue
            
            operands_raw = part[i+1:].split(',')
            if len(operands_raw) == 1 and operands_raw[0] == "":
                assemble(block, mnemonic)
            else:
                operands = []
                for o in operands_raw:
                    op = ""
                    in_quote = False
                    for c in o:
                        if in_quote:
                            if c == '"':
                                in_quote = False
                            else:
                                op += c
                        elif c == '"':
                            in_quote = True
                        elif c == '[':
                            op += "("
                        elif c == ']':
                            op += ")"
                        elif c not in " \t\v\r":
                            op += c
                    operands.append(op)
                assemble(block, mnemonic, operands)
    
    if len(block.data) & 1:
        block.pushByte(0xFF)

# Default rom size.
romsize = 0x8000
def setRomSize(n):
    romsize = int(n, 0)

num_banks = None
def setNumBanks(n):
    num_banks = int(n, 0)

i = 1
num_args = len(sys.argv)

OPTIONS=(
    ("-s", "--size", setRomSize, "Sets the rom size."),
    ("-b", "--banks", setNumBanks, "Sets the number of banks.")
)

def MaybeParseArgument(short, long, proc, _help):
    global i
    assert i < num_args
    opt = sys.argv[i]
    if opt == short or opt == long:
        i += 1
        if i >= num_args:
            val = None
        else:
            val = sys.argv[i]
    elif opt.startswith(short + "="):
       val = sys.argv[len(short)+1:]
    elif opt.startswith(long + "="):
       val = sys.argv[len(long)+1:]
    else:
        return False
    
    if not val:
        print(val + " requires an argument")
    else:
        try:
            proc(val)
        except:
            return False
    return True

def usage():
    print("Emily's Gameboy Assembler")
    print("USAGE: " + sys.argv[0] + " [options] <files>")
    print("OPTIONS:")
    print("  --help|-h\tDisplays this message then exits.")
    for opt in OPTIONS:
        print("  " + opt[0] + "|" + opt[1] + "\t" + opt[3])

while i < num_args:
    if sys.argv[i] == "--" or not sys.argv[i].startswith("-"):
        break
    
    arg = sys.argv[i]
    if arg in ("--help", "-h"):
        usage()
        sys.exit(0)
    
    found = False
    for opt in OPTIONS:
        found = MaybeParseArgument(*opt)
        if found:
            break
    
    if not found:
        print("Error in option '" + arg + "'")
        usage()
        sys.exit(1)
    
    i += 1

while i < num_args:
    p = sys.argv[i]
    included_names = set()
    b = Block(p)
    try:
        assembleFile(b, p)
        blocks.append(b)
    except:
        print("Error in " + p)
        raise
    i += 1

if len(blocks) == 0:
    sys.exit(0)

if num_banks == None:
    max_bank = 0
    for b in blocks:
        if b.bank == None:
            b.bank = 0
        else:
            if b.bank > max_bank:
                max_bank = b.bank
            if b.base == None:
                b.error("!bank specified, but no !orig")
            else:
                if b.bank == 0 and b.base >= 0x4000:
                    b.error("Invalid base for bank 0")
                elif b.bank == 0 and b.base + len(b.data) >= 0x4000:
                    b.error("Block is too large to fit into bank 0")
                if b.bank != 0 and (b.base < 0x4000 or b.base >= 0x8000):
                    b.error("Invalid base for bank " + str(b.bank))
                elif b.bank == 0 and b.base + len(b.data) >= 0x8000:
                    b.error("Block is too large to fit into bank " + str(b.bank))
    num_banks = max_bank

bank_spans=[[]] * (num_banks + 1)

# First, place all already originated blocks.
for b in blocks:
    if b.base == None:
        continue
    b_len = len(b.data)
    spans = bank_spans[b.bank]
    for s in spans:
        if s[0] <= b.base + b_len and s[0] + s[1] >= b.base:
            b.error("Conflicting origin")
        elif b.base <= s[0] + s[1] and b.base + b_len >= s[0]:
            b.error("Conflicting origin")
    if not b.has_error:
        spans.append((b.base, b_len))

for spans in bank_spans:
    if len(spans) == 0:
        blocks[0].base = 0
        spans.append((0, len(blocks[0].data)))
    else:
        spans.sort(key=lambda a:a[0])

blocks.sort(key=lambda b:len(b.data))
# Place all remaining blocks where there is room.
# This is the least intelligent fitting algorith ever, I'm sorry...to me. Me,
# the only use of this assembler. Ever.
for b in blocks:
    if b.has_error or b.base != None:
        continue
    b_len = len(b.data)
    i = 0
    spans = bank_spans[b.bank]
    while i + 1 < len(spans):
        if spans[i+1][0] - (spans[i][0] + spans[i][1]) > b_len:
            break
        i += 1
    b.base = spans[i][0] + spans[i][1]
    obj = (b.base, b_len)
    spans.append(obj)
    spans.sort(key=lambda s:s[0])

blocks.sort(key=lambda b: b.base + (b.bank << 16))
out = open("a.gb", "wb")
debug = open("a.sym", "w")
debug.write("; Emily's Assembler symbol info\n")
at = 0
zero = bytearray([0])
zero16 = bytearray([0] * 16)

# Resolve all relocs.
for b in blocks:
    if b.has_error:
        continue
    b.resolveRelocs(blocks)

# Adjust base addresses for banks.
for b in blocks:
    if b.has_error:
        continue
    base = b.base + (0x4000 * b.bank)
    
    # This is just to make writing a little faster
    while at + 16 <= base:
        out.write(zero16)
        at += 16
    while at < base:
        out.write(zero)
        at += 1
    
    # Write the binary info
    out.write(b.data)
    at += len(b.data)
    # Write the debug symbol info
    bank_name = hex(b.bank)[2:].upper().rjust(2, '0')
    for sym in b.symbols:
        if sym[0] == '.':
            continue
        elif sym[0] == '$':
            name = b.name + sym
        else:
            name = sym
        addr = hex(b.symbols[sym] + b.base)[2:].upper().rjust(4, '0')
        debug.write(bank_name + ":" + addr + " " + name.replace('.', '_') + "\n")
    
    for a in b.ascii:
        addr = hex(a[0] + b.base)[2:].upper().rjust(4, '0')
        asc = hex(a[1])[2:].upper().rjust(4, '0')
        debug.write(bank_name + ":" + addr + " .asc:" + asc + "\n")

print("ROM used " + str(at) + " bytes of " + str(romsize) + " (" + str(int(at / (romsize / 100.0))) + "%)")

while at < romsize - 16:
    out.write(zero16)
    at += 16

while at < romsize:
    out.write(zero)
    at += 1

out.close()
# debug.close()

for b in blocks:
    if b.has_error:
        sys.exit(2)
sys.exit(0)
