heyooo~! ^w^

soo this weekend was hxp ctf 2022 (yes, they name them a year off, god knows why) and there was a super cute problem i wanted to share!!~ the problem is "whistler" and had a single file, vuln.py, that was running on their server.


/// the challenge ///

here it is bit by bit (heh). the gist is that there's a ring-LWE KEM being performed, and we're allowed to perform a chosen ciphertext attack against it.

here's the start of vuln.py:

1#!/usr/bin/env python3
2import struct, hashlib, random, os
3from Crypto.Cipher import AES
4
5n = 256
6q = 11777
7w = 8

^ we start with a few imports and setting three values:

  • n, the degree of the polynomial ring we're working over (namely the "truncated" polynomial ring R_q = (ZZ/qZZ)[x]/(x^n + 1))
  • q, the order of the prime field the polynomials are defined over
  • and w, the width of the "centered binomial distirbution" (CBD, lol) we'll sample coefficients of our private key from.
11sample = lambda rng: [bin(rng.getrandbits(w)).count('1') - w//2 for _ in range(n)]

^ this is the centered binomial distribution sampler. it just samples w random bits and counts the number of '1's (sampling a "regular" binomial distribution), then subtracts w/2 to make the distribution centered around zero. this makes the probability mass function look kinda like this:

                              #
                         #    #    #
                         #    #    #
                         #    #    #
                    #    #    #    #    #
                    #    #    #    #    #
          .    #    #    #    #    #    #    #    .
    -5   -4   -3   -2   -1    0    1    2    3    4    5

note that the support is entirely on {-4, -3, -2, ..., 3, 4}. anywho,

13add = lambda f,g: [(x + y) % q for x,y in zip(f,g)]
14
15def mul(f,g):
16 r = [0]*n
17 for i,x in enumerate(f):
18 for j,y in enumerate(g):
19 s,k = divmod(i+j, n)
20 r[k] += (-1)**s * x*y
21 r[k] %= q
22 return r

^ here we have addition and multiplication of polynomials in R_q. first, note that every polynomial in R_q can be equivalently represented as a degree 255 polynomial, as in R_q we have that x^256 = -1. thus, polynomials are represented here as lists of 256 coefficients, i.e.

a_255 x^255 + a_254 x^254 + ... + a_1 x + a_0

is represented as the list

[a_0, a_1, ..., a_254, a_255].

the multiplication is performed as just "normal" schoolbook polynomial multiplication, except when powers of x at or above 256 appear, the exponent gets "wrapped around" back to 0 and the resulting monomial is negated.

26def genkey():
27 a = [random.randrange(q) for _ in range(n)]
28 rng = random.SystemRandom()
29 s,e = sample(rng), sample(rng)
30 b = add(mul(a,s), e)
31 return s, (a,b)

^ this part generates keys for the KEM. the private key s is a polynomial in R_q with coefficients drawn from our CBD. the public key consists of two polynomials a and b:

  • a is a polynomial in R_q with uniform random coefficients.
  • b = as + e, with e being a random "error" polynomial with coefficients drawn from our CBD distribution.
33center = lambda v: min(v%q, v%q-q, key=abs)
34extract = lambda r,d: [2*t//q for u,t in zip(r,d) if u]
35
36ppoly = lambda g: struct.pack(f'<{n}H', *g).hex()
37pbits = lambda g: ''.join(str(int(v)) for v in g)
38hbits = lambda g: hashlib.sha256(pbits(g).encode()).digest()
39mkaes = lambda bits: AES.new(hbits(bits), AES.MODE_CTR, nonce=b'')

^ we now randomly define a bunch of helper functions all at once ^^; uhh here's some summaries, skip past these if you're bored:

  • center
    • takes in a value v mod q
    • possibly subtracts q from it to make it fall in [-q/2, q/2]
  • extract
    • takes in a list r of 256 bits and a polynomial d
    • for every coefficient of d corresponding to a 1 in r, outputs a bit indicating whether the coefficient is greater than q / 2
  • ppoly
    • takes in a polynomial g
    • turns it into packed bytes representing its coefficients as little-endian 16-bit values
  • pbits
    • takes in a list of bits (either bools or ints) g
    • turns it into a string of '0's and '1's
  • hbits
    • takes in a list of bits
    • applies pbits to it, and then hashes the bits using SHA-256
  • mkaes
    • takes a list of bits
    • hashes it with hbits and then creates an AES engine with that hash as the key
41def encaps(pk):
42 seed = os.urandom(32)
43 rng = random.Random(seed)
44 a,b = pk
45 s,e = sample(rng), sample(rng)
46 c = add(mul(a,s), e)
47 d = add(mul(b,s), e)
48 r = [int(abs(center(2*v)) > q//7) for v in d]
49 bits = extract(r,d)
50 return bits, (c,r)

^ finally! the KEM encapsulation functionnnn~. this function

  • takes in pk, the public key of the other party
  • creates bits which is the shared symmetric key both parties derive, plus the ciphertext (c,r) that the other party needs to derive bits too

the steps here are (adding ' to names to distinguish them from the s and e used during the decapsulator's keygen, and the d value the decapsulator will later derive):

  1. make s' and e' random polynomials with "short" coefficients (i.e. drawn from CBD)
  2. compute c = as' + e'
  3. compute d' = as' + e'
  4. calculate r as a set of bits indicating which coefficients of d are further than a distance of q / 7 away from 0
  5. extract the coefficients of d' corresponding to '1's in r, thresholding them against q / 2

this is a fairly standard ring-LWE KEM (see Kyber.CPAPKE, section 1.2.3 of the Kyber spec) except for the part with r selecting a subset of the bits.

as far as I can tell, this is just so that when the other party recomputes something close to d' in the next step, the error in their computation doesn't take a coefficient from just below say q to just above 0, flipping the bit outputted by extract.

anyway, spoiler alert, that r addition turns out to be the downfall of this whole thing.

52def decaps(sk, ct):
53 s = sk
54 c,r = ct
55 d = mul(c,s)
56 return extract(r,d)

^ now, the decapsulation process done by the other party is much more simple.

  1. take the ciphertext (c,r) and compute d = cs (we'll see this is approximately equal to d')
  2. extract the coefficients of d corresponding to '1's in r, thresholding them against q / 2.

the trick is that d here and d' from the encapsulation should be about the same, so when we call extract we'll more than likely get the same bits. indeed:

d = bs' + e' = (as + e)s' + e' = ass' + es' + e'

and

d' = (as' + e')s = ass' + e's

where e, e', and s' have tiny coefficients, i.e. their contributions are unlikely to nudge a value from being on one side of q / 2 to the other. now, without further ado...

60if __name__ == '__main__':
61
62 while True:
63 sk, pk = genkey()
64 dh, ct = encaps(pk)
65 if decaps(sk, ct) == dh:
66 break
67
68 print('pk[0]:', ppoly(pk[0]))
69 print('pk[1]:', ppoly(pk[1]))
70
71 print('ct[0]:', ppoly(ct[0]))
72 print('ct[1]:', pbits(ct[1]))
73
74 flag = open('flag.txt').read().strip()
75 print('flag: ', mkaes([0]+dh).encrypt(flag.encode()).hex())
76
77 for _ in range(2048):
78 c = list(struct.unpack(f'<{n}H', bytes.fromhex(input())))
79 r = list(map('01'.index, input()))
80 if len(r) != n or sum(r) < n//2: exit('!!!')
81
82 bits = decaps(sk, (c,r))
83
84 print(mkaes([1]+bits).encrypt(b'hxp<3you').hex())

^ at last, the heart of the challenge! we're tossed a public key from one party (the "decapsulator") and a ciphertext for that public key from the other (the "encapsulator"). then,

  • we're shown the flag, encrypted using AES with key [0] + bits
  • we're allowed to, 2048 times, send over a ciphertext, which the receiving party decapsulates to compute the shared key bits', and then replies to us with the AES encryption of "hxp<3you" using the key [1] + bits'

there's also a pesky restriction that the values of r we supply need to have at least n/2 '1's, i.e. we can't try to "one-hot" out any particular coefficient of d.

given that we have no access to the outputted value bits' itself, and AES is, well, hard to crack, it seems like there's not much we can do. unless...


/// a first try ///

one piece of information we do have is whether two derived keys bits' are the same, because with overwhelming likelihood the encryptions of the message "hxp<3you" will be the same if and only if the keys used for encryption were the same.

the other key insight we need is that if we pass in the constant polynomial 1 for c, then d = cs = s and we can threshold the coefficients of the private key itself!

so, we can begin our valiant effort: if we send c = 1 with say,

r = 10101010...101010

(this has exactly n/2 '1's, so we're good!) followed by c = 1 with

r = 01101010...101010

which flips the first two bits of the previous r, we can check the two outputted encrypted messages. here's the key part: if the outputted messages are the same, then the values for bits' derived in both cases were the same. the only thing that changed is now the first bit of bits' comes from the first coefficient of s, s[1], rather than from the zeroth coefficient s[0].

thus, if the outputs are the same, then both s[0] and s[1] lie both above or both below q/2 (since this is what extract checks for each coefficient). similarly, if the encrypted messages weren't the same, then they'd have to be on opposite sides of q/2.


/// getting fancy with r ///

here's the cool thing: we can repeat this trick!! try on these values for r:

101010101...010101
210010101...010101
310100101...010101
410101001...010101
5...
610101010...101001
710101010...101010

^ comparing line 1 with line 2, we exchange s[0] in for s[1] in extract. swapping line 2 for line 3 exchanges s[2] in for s[3], line 3 for 4 exchanges s[4] for s[5], and so on for every pair of lines. if we now add

100110101...010101
200101101...010101
300101011...010101
4...
500101010...101101
600101010...101011

^ then, swapping in the last r of the last set with line 1 of this set exchanges s[1] for s[2]. continuing down this set we swap s[3] with s[4], etc.

in particular, comparing the outputs for all of these values of r lets us check whether every coefficient of s is on the same side of or a different side of q/2 from its neighbors. up to guessing whether s[0] < q/2 or s[0] > q/2, this tells us exactly which "side" every coefficient of s lies on!

but we need the exact coefficients of s!! ://

this leads us to our next trick:


/// short coefficients go spinny ///

turns out c = 1 isn't the only value we can cleverly try. indeed, we can set c to any integer mod q, and it will multiply the coefficients of s by that number. turns out, you can identify any number in {-4, -3, .., 4} by comparing multiples mod q of them to q/2. here's a table!!

    * 10630     * 2991      * 9085      * 5954      * 378       * 3770      * 1532      * 11293
-4    < q/2      > q/2       > q/2       > q/2      > q/2        > q/2       < q/2        < q/2
-3    < q/2      < q/2       > q/2       < q/2      > q/2        < q/2       > q/2        < q/2
-2    < q/2      < q/2       < q/2       > q/2      > q/2        < q/2       > q/2        < q/2
-1    < q/2      > q/2       < q/2       < q/2      > q/2        > q/2       > q/2        < q/2
 0    < q/2      < q/2       < q/2       < q/2      < q/2        < q/2       < q/2        < q/2
 1    > q/2      < q/2       > q/2       > q/2      < q/2        < q/2       < q/2        > q/2
 2    > q/2      > q/2       > q/2       < q/2      < q/2        > q/2       < q/2        > q/2
 3    > q/2      > q/2       < q/2       > q/2      < q/2        > q/2       < q/2        > q/2
 4    > q/2      < q/2       < q/2       < q/2      < q/2        < q/2       > q/2        > q/2

how did I get these numbers in the top row, you ask? i just generated some random ones lmao. others work, idek.

anywho, yeah. that's the whole trick! here's the full attack. it gets a pretty box because i said so. \uwu/

stage 1: collecting the "side" data for each c

  • for each multiplier k above in the table:
    • for each value of r above:
      • let c be the constant polynomial k
      • send c and r and let the server compute d = cs and hand us back our encrypted messages
    • assuming d[0] < q/2, do all the encrypted message comparisons to figure out which side of q/2 each coefficient is on (our assumption d[0] < q/2 could be wrong, but in that case we only need to flip which side we think each coefficient is on)
    • save these "side choices" for this value of c

stage 2: enumerating the possibilities for d[0] and getting s

  • since d[0] < q/2 may be false for some values of c, iterate over every possible combination of which c's "side choices" need to be flipped versus which don't
    • use these guesses to try to determine each coefficient of s with the big table above
    • if this works, go through the decaps steps and try to recover the message with this value for s

since we have 8 values for c and 256 values of r, we use up exactly 2048 server queries. it turns out, this attack works perfectly in practice! ^^


/// the solve script ///

uhh so my solve script during the competition was super messy. like, embarrassingly. i've cleaned it up a bit here but if there's typos forgive me lol.

1import struct
2import itertools
3from pwn import *
4from vuln import *

^ yup, that last line's one you never want to see in a script.

6ks = [10630, 2991, 9085, 5954, 378, 3770, 1532, 11293] # random values, lol
7traces = {tuple([2 * ((v * k) % q) // q for k in ks]): v for v in range(-4, 5)}

^ we store each multiplier k from the table above in ks and then compute traces, which is a map from the rows of the table above back to the coefficient that generates them.

for instance, the key (0, 1, 1, 1, 1, 1, 0, 0) corresponding to the first row of the above table gives the value -4. the expression 2 * ((v * k) % q) // q just determines which side of q / 2 each value v = -4, -3, ..., 4 is when we multiply it by k mod q.

9rs = []
10
11for i in range(n // 2):
12 r = [0, 1] * i + [1, 0] * (n // 2 - i)
13 rs.append(r)
14
15for i in range(n // 2):
16 r = [0] + [0, 1] * i + [1, 0] * (n // 2 - 1 - i) + [1]
17 rs.append(r)

^ next, we generate our list of rs as above.

19with remote("<server IP redacted>", 4421) as conn:
20 conn.recvline() # we don't need pk[0]
21 conn.recvline() # or pk[1] for that matter
22
23 ct = []
24
25 # parse c as a list of polynomial coefficients
26 c_bytes = bytes.fromhex(conn.recvline().decode().split(":")[1])
27 ct.append(struct.unpack(f'<{n}H', c_bytes))
28
29 # parse r as a list of 0 and 1 values
30 r_bytes = conn.recvline().decode().split(":")[1].strip()
31 ct.append(list(map('01'.index, r_bytes)))
32
33 # get encoded flag as bytes
34 flag_enc = bytes.fromhex(conn.recvline().decode().split(":")[1])

^ we parse the ciphertext and the encrypted flag from the server (i wasn't feeling super pythonic ok forgive me)

36 guesses = []
37
38 for k in ks:
39 encs = [] # encrypted messages received for this multiplier k
40 for r in rs:
41 c = [k] + [0] * (n - 1) # set c to a constant polynomial with constant term k
42 conn.sendline(ppoly(c)) # send it over
43 conn.sendline(pbits(r)) # send r too
44 encs.append(bytes.fromhex(conn.readline().decode())) # save the response
45
46 guess_bits = [0] # guesses for whether each bit is > q/2, we assume d[0] < q/2
47
48 for i in range(n // 2):
49 # infer bit 2 * i + 1 from bit 2 * i (i.e. 1 from 0)
50 guess = guess_bits[-1]
51 if encs[i] != encs[i + 1]:
52 guess = 1 if guess == 0 else 0
53 guess_bits.append(guess)
54
55 # infer bit 2 * i + 2 from bit 2 * i + 1 (i.e. 2 from 1)
56 if len(guess_bits) != n:
57 if encs[n // 2 + i] != encs[n // 2 + i + 1]:
58 guess = 1 if guess == 0 else 0
59 guess_bits.append(guess)
60
61 # save the guesses for this value of k
62 guesses.append(guess_bits)

^ this just does exactly what we described with guessing the "sides" of each coefficient of d. note that, as above, one set of r values handles comparisons with even offsets while the other handles comparisons with odd offsets.

then, lastly...

64# iterate over which elements of `guesses` we might need to flip bitwise
65for flips in itertools.product([True, False], repeat=8):
66 # perform the flipping
67 flipped_guesses = []
68 for flip, guess_bits in zip(flips, guesses):
69 if flip:
70 guess_bits = [1 if bit == 0 else 0 for bit in guess_bits]
71 flipped_guesses.append(guess_bits)
72
73 # after flipping, get the "sides" for each coefficient for each k
74 # i.e. s_traces[0] could be (0, 1, 1, 1, 1, 1, 0, 0), meaning s[0] is -4
75 s_traces = list(zip(*flipped_guesses))
76
77 # try to recover s
78 try:
79 s = [traces[t] for t in s_traces]
80 except KeyError:
81 continue
82
83 # get that flag
84 dh = decaps(s, ct)
85 flag = mkaes([0]+dh).decrypt(bytes.fromhex(enc_flag))
86 print(flag)

^ and this last step just does the flipping necessary because we might have misguessed which side d[0] fell on for some of our values of k, and tries to recover s for each possible set of flips.

running this script gives (after several minutes, lol)... hxp{e4zy_p34zY_p34nuT_Bu7t3r}

and that's the challenge! ^^


/// conclusion ///

welllll, that's it!! i hope you enjoyed, i thought this was a clever little puzzle and i'm happy to now have a writeup for it here :)