Push It To The Limit - WACON 2023 Prequal
Challenge Files
from Crypto.Util.number import bytes_to_long, getStrongPrime
with open("flag.txt", "rb") as f:
m = bytes_to_long(f.read())
SIZE = 1024
p = getStrongPrime(SIZE)
q = getStrongPrime(SIZE)
n = p * q
e = 0x10001
c = pow(m, e, n)
p_msb = p - p % (2 ** (SIZE // 2))
print(f"{n = }")
print(f"{c = }")
print(f"{p_msb = }")
n = 24712135189687942739677490021030751776088469214818275631687482073531676912880823269667196936095460153002434759403063429337125873794523587731746689517070810687221399532024093572951282737818446579992570629531618780373767724789390101166147862982539311016801595612323156816999866783427829783286164172896802725820761659256555627406518829192800217880692359914672894220547306033679060066475600137205045054015651689487444267401130160872050085589597109014374199731072611044277806027332254214020499883131062627540945260814416104971893858787291926267157394988131329441246648393933117451348643609850156730059817506513924523851733
c = 19285290054358264594160191119053363484661054622854927550086540936229836207751905061897299540539735528766803248513199392889410922209106513019275525361297785136742517684745274089253401778969310170805452788203125136583847273167894915706201708268160138117578035286292385848441833691098676192230945185815890266453215404593242520989429750775723053435372661531195966551199012453469748764989624596296116016310586535749198878013241527430239006604194528859329192316989103910514620735894760979900228995139208829267762309798970482895132300580481270883276800390489213520429816698576642899381455153039281329012831320123165127378159
p_msb = 161405912451824860188834725646055524173328544131300133372580621368926433914138476338787007253318242142454894032713487340762003643551953941809023233323836630063065828499586237941251339865726273353740523275987884928619323490566227483094269770052935277592758770273832919929071652425379016974435907024060290170880
Solve
Didn’t play this ctf btw, just upsolving.
It’s very similar to this one.
https://connor-mccartney.github.io/cryptography/small-roots/corrupt-key-1-picoMini
Half of the upper bits of p are given, the only difference is now p is 1024 bits not 512 bits.
This makes it a lot slower, so I edited my previous code to include 3 optimisations from maple3142.
Optimisation 1
Just a minor optimisation: since p must be odd, we can reduce the bound by 1 bit by changing
f = p_high * 2**(p_bits-p_high_bits) + x
x = small_roots(f, X=2**(p_bits-p_high_bits), beta=0.5, m=m)
to
f = p_high * 2**(p_bits-p_high_bits) + 2*x + 1
x = small_roots(f, X=2**(p_bits-p_high_bits-1), beta=0.5, m=m)
Optimisation 2
Using flatter for faster LLL.
sudo pacman -S eigen --noconfirm
cd ~/Documents
git clone https://github.com/keeganryan/flatter
cd flatter
cmake .
make -j4
sudo ln -s ~/Documents/flatter/bin/flatter /usr/local/bin/flatter
from subprocess import check_output
from re import findall
def flatter(M):
z = "[[" + "]\n[".join(" ".join(map(str, row)) for row in M) + "]]"
ret = check_output(["flatter"], input=z.encode())
return matrix(M.nrows(), M.ncols(), map(int, findall(b"-?\\d+", ret)))
And now we can do B = flatter(B)
instead of B = B.LLL()
Experiment
Here’s the full value of p:
p = 161405912451824860188834725646055524173328544131300133372580621368926433914138476338787007253318242142454894032713487340762003643551953941809023233323836632396674586164821404065443903169766781702197174899338334027128103867874700640036605974611327518250687560220955598412727224450293311080620976484498655311739
Now let’s analyse how long bruting a different number of bits takes:
from Crypto.Util.number import *
import time
from subprocess import check_output
from re import findall
def flatter(M):
z = "[[" + "]\n[".join(" ".join(map(str, row)) for row in M) + "]]"
ret = check_output(["flatter"], input=z.encode())
return matrix(M.nrows(), M.ncols(), map(int, findall(b"-?\\d+", ret)))
def small_roots(f, X, beta=1.0, m=None):
N = f.parent().characteristic()
delta = f.degree()
if m is None:
epsilon = RR(beta^2/f.degree() - log(2*X, N))
m = max(beta**2/(delta * epsilon), 7*beta/delta).ceil()
t = int((delta*m*(1/beta - 1)).floor())
f = f.monic().change_ring(ZZ)
P,(x,) = f.parent().objgens()
g = [x**j * N**(m-i) * f**i for i in range(m) for j in range(delta)]
g.extend([x**i * f**m for i in range(t)])
B = Matrix(ZZ, len(g), delta*m + max(delta,t))
for i in range(B.nrows()):
for j in range(g[i].degree()+1):
B[i,j] = g[i][j]*X**j
B = flatter(B)
f = sum([ZZ(B[0,i]//X**i)*x**i for i in range(B.ncols())])
roots = set([f.base_ring()(r) for r,m in f.roots() if abs(r) <= X])
return [root for root in roots if N.gcd(ZZ(f(root))) >= N**beta]
def recover(p_high, n, m):
p_bits = (len(bin(n))-2)//2
p_high_bits = len(bin(p_high)) - 2
PR.<x> = PolynomialRing(Zmod(n))
f = p_high * 2**(p_bits-p_high_bits) + 2*x + 1
x = small_roots(f, X=2**(p_bits-p_high_bits-1), beta=0.5, m=m)
if x == []:
return None
p = int(f(x[0]))
return p
n = 24712135189687942739677490021030751776088469214818275631687482073531676912880823269667196936095460153002434759403063429337125873794523587731746689517070810687221399532024093572951282737818446579992570629531618780373767724789390101166147862982539311016801595612323156816999866783427829783286164172896802725820761659256555627406518829192800217880692359914672894220547306033679060066475600137205045054015651689487444267401130160872050085589597109014374199731072611044277806027332254214020499883131062627540945260814416104971893858787291926267157394988131329441246648393933117451348643609850156730059817506513924523851733
p = 161405912451824860188834725646055524173328544131300133372580621368926433914138476338787007253318242142454894032713487340762003643551953941809023233323836632396674586164821404065443903169766781702197174899338334027128103867874700640036605974611327518250687560220955598412727224450293311080620976484498655311739
m = 1
for bits in range(15, 3, -1):
p_high = p >> (512 - bits)
while True:
starttime = time.time()
p = recover(p_high, n, m=m)
t = time.time() - starttime
if is_prime(p):
print(f"bruting {bits} bits with m={m} will take {round(2**bits * t / 3600, 2)} hours (single-threaded)")
break
m += 1
bruting 15 bits with m=17 will take 59.11 hours (single-threaded)
bruting 14 bits with m=18 will take 37.49 hours (single-threaded)
bruting 13 bits with m=19 will take 23.5 hours (single-threaded)
bruting 12 bits with m=21 will take 17.42 hours (single-threaded)
bruting 11 bits with m=23 will take 12.47 hours (single-threaded)
bruting 10 bits with m=25 will take 8.77 hours (single-threaded)
bruting 9 bits with m=27 will take 5.67 hours (single-threaded)
bruting 8 bits with m=30 will take 4.42 hours (single-threaded)
bruting 7 bits with m=33 will take 3.32 hours (single-threaded)
bruting 6 bits with m=38 will take 3.28 hours (single-threaded)
bruting 5 bits with m=44 will take 2.66 hours (single-threaded)
bruting 4 bits with m=53 will take 2.53 hours (single-threaded)
Optimisation 3
Parallelism:
You can tweak the values for bits, m and threads, different values will probably be better
depending on your machine. Here’s what worked well for me: bits=6, m=38, threads=6.
It finds the flag in about 25 minutes.
from Crypto.Util.number import *
import time
from subprocess import check_output
from re import findall
from concurrent.futures import ProcessPoolExecutor
import os
def flatter(M):
z = "[[" + "]\n[".join(" ".join(map(str, row)) for row in M) + "]]"
ret = check_output(["flatter"], input=z.encode())
return matrix(M.nrows(), M.ncols(), map(int, findall(b"-?\\d+", ret)))
def small_roots(f, X, beta=1.0, m=None):
N = f.parent().characteristic()
delta = f.degree()
if m is None:
epsilon = RR(beta^2/f.degree() - log(2*X, N))
m = max(beta**2/(delta * epsilon), 7*beta/delta).ceil()
t = int((delta*m*(1/beta - 1)).floor())
f = f.monic().change_ring(ZZ)
P,(x,) = f.parent().objgens()
g = [x**j * N**(m-i) * f**i for i in range(m) for j in range(delta)]
g.extend([x**i * f**m for i in range(t)])
B = Matrix(ZZ, len(g), delta*m + max(delta,t))
for i in range(B.nrows()):
for j in range(g[i].degree()+1):
B[i,j] = g[i][j]*X**j
B = flatter(B)
f = sum([ZZ(B[0,i]//X**i)*x**i for i in range(B.ncols())])
roots = set([f.base_ring()(r) for r,m in f.roots() if abs(r) <= X])
return [root for root in roots if N.gcd(ZZ(f(root))) >= N**beta]
def recover(p_high, n, m):
p_bits = (len(bin(n))-2)//2
p_high_bits = len(bin(p_high)) - 2
PR.<x> = PolynomialRing(Zmod(n))
f = p_high * 2**(p_bits-p_high_bits) + 2*x + 1
x = small_roots(f, X=2**(p_bits-p_high_bits-1), beta=0.5, m=m)
if x == []:
return None
p = int(f(x[0]))
return p
n = 24712135189687942739677490021030751776088469214818275631687482073531676912880823269667196936095460153002434759403063429337125873794523587731746689517070810687221399532024093572951282737818446579992570629531618780373767724789390101166147862982539311016801595612323156816999866783427829783286164172896802725820761659256555627406518829192800217880692359914672894220547306033679060066475600137205045054015651689487444267401130160872050085589597109014374199731072611044277806027332254214020499883131062627540945260814416104971893858787291926267157394988131329441246648393933117451348643609850156730059817506513924523851733
_p_high = 161405912451824860188834725646055524173328544131300133372580621368926433914138476338787007253318242142454894032713487340762003643551953941809023233323836630063065828499586237941251339865726273353740523275987884928619323490566227483094269770052935277592758770273832919929071652425379016974435907024060290170880
c = 19285290054358264594160191119053363484661054622854927550086540936229836207751905061897299540539735528766803248513199392889410922209106513019275525361297785136742517684745274089253401778969310170805452788203125136583847273167894915706201708268160138117578035286292385848441833691098676192230945185815890266453215404593242520989429750775723053435372661531195966551199012453469748764989624596296116016310586535749198878013241527430239006604194528859329192316989103910514620735894760979900228995139208829267762309798970482895132300580481270883276800390489213520429816698576642899381455153039281329012831320123165127378159
bits = 6
m = 38
def solve(x):
print(x)
_p = _p_high + x * 2**(512-bits)
p_high = int(bin(_p)[:512+bits+2], 2)
p = recover(p_high, n, m=m)
if p is not None and is_prime(p):
q = n//p
d = pow(65537, -1, (p-1)*(q-1))
flag = int(pow(c, d, n))
print(x, long_to_bytes(flag))
os.system(f"kill -9 {os.getpid()}")
with ProcessPoolExecutor(max_workers=6) as executor:
search_space = range(2**bits)
executor.map(solve, search_space)