135 lines
3.8 KiB
Python
Executable File
135 lines
3.8 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
from typing import List
|
|
from pwn import process, context
|
|
|
|
context.log_level = "error"
|
|
|
|
cases_to_run = ["1", "2", "3", "4", "5", "6", "7"]
|
|
|
|
TIMEOUT_SECONDS = 0.01
|
|
|
|
|
|
def wrap_recvrepeat(r):
|
|
if r.poll() is not None:
|
|
return b""
|
|
return r.recvrepeat(TIMEOUT_SECONDS)
|
|
|
|
|
|
def recvrepeats(r):
|
|
output = wrap_recvrepeat(r)
|
|
while output == b"":
|
|
if r.poll() is not None:
|
|
break
|
|
output = wrap_recvrepeat(r)
|
|
|
|
ret = b""
|
|
|
|
while output != b"":
|
|
ret += output
|
|
output = wrap_recvrepeat(r)
|
|
|
|
return ret
|
|
|
|
|
|
def execute_process(
|
|
case: str, command: List[str], stdin: List[str]
|
|
) -> tuple[int, bytes]:
|
|
"""Returns the exit code and output of the process (including stdin and stderr)"""
|
|
print(f"Running case {case} with command: {command}")
|
|
try:
|
|
r = process(command, shell=False)
|
|
output = b""
|
|
for line in stdin:
|
|
ret = recvrepeats(r)
|
|
output += ret
|
|
output += line.encode("utf-8")
|
|
if r.poll() is None: # Only send if the process is still running
|
|
r.send(line.encode("utf-8"))
|
|
output += recvrepeats(r)
|
|
r.close()
|
|
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|
|
return 1, b""
|
|
|
|
return 0, output
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Clean up the diff file
|
|
with open("diff.txt", "w") as f:
|
|
f.write("")
|
|
|
|
for case in cases_to_run:
|
|
|
|
with open(f"{case}.in", "r") as f:
|
|
lines = f.readlines()
|
|
run_command: List[str] = lines[0].split()
|
|
input = lines[1:]
|
|
|
|
_, output = execute_process(case, run_command, input)
|
|
|
|
# Remove the last prompt
|
|
if output.endswith(b"(sdb) "):
|
|
output = output[:-6]
|
|
|
|
# Remove null bytes
|
|
output = output.replace(b"\x00", b"")
|
|
|
|
# Write the output to a file
|
|
with open(f"{case}.out", "wb") as f:
|
|
f.write(output)
|
|
|
|
diff_command = f"diff -w -B -u {case}.out {case}.ans"
|
|
diff_process = process(diff_command, shell=True)
|
|
diff_output = diff_process.recvall()
|
|
diff_process.close()
|
|
|
|
diff_lines = diff_output.decode("utf-8").split("\n")
|
|
diff_lines = [
|
|
line for line in diff_lines if line.startswith("-") or line.startswith("+")
|
|
]
|
|
diff_lines = [line for line in diff_lines if not line.startswith("---")]
|
|
diff_lines = [line for line in diff_lines if not line.startswith("+++")]
|
|
|
|
i = 0
|
|
while True:
|
|
if i + 1 >= len(diff_lines):
|
|
break
|
|
|
|
if "-$rbp" in diff_lines[i] and "+$rbp" in diff_lines[i + 1]:
|
|
output_line = diff_lines.pop(i)[1:].split()
|
|
expected_line = diff_lines.pop(i)[1:].split()
|
|
|
|
if len(output_line) != 6:
|
|
diff_lines.append(f"error")
|
|
break
|
|
|
|
output_rbp = int(output_line[1], 16)
|
|
output_rsp = int(output_line[3], 16)
|
|
output_r8 = int(output_line[5], 16)
|
|
expected_rbp = int(expected_line[1], 16)
|
|
expected_rsp = int(expected_line[3], 16)
|
|
expected_r8 = int(expected_line[5], 16)
|
|
|
|
if (
|
|
output_rbp - output_rsp != expected_rbp - expected_rsp
|
|
or output_r8 != expected_r8
|
|
):
|
|
diff_lines.append(f"error")
|
|
break
|
|
|
|
continue
|
|
|
|
i += 1
|
|
|
|
# Print the diff output if there is a difference
|
|
print(f"Case {case}: {'PASS' if len(diff_lines) == 0 else 'FAIL'}", end="\n\n")
|
|
|
|
# Print the diff output to `diff.txt`
|
|
if len(diff_lines) > 0:
|
|
with open("diff.txt", "a") as f:
|
|
f.write(diff_output.decode("utf-8"))
|
|
f.write("\n\n")
|