# hxp ctf 2022: whistler writeup

heyooo~! ^w^

soo this weekend was hxp ctf 2022 (yes, they name them a year
off, god knows why) and there was a super cute problem i wanted to share!!~ the problem is
"whistler" and had a single file, `vuln.py`

, that was running on their server.

# /// the challenge ///

here it is bit by bit (heh). the gist is that there's a ring-LWE KEM being performed, and we're allowed to perform a chosen ciphertext attack against it.

here's the start of `vuln.py`

:

1 #!/usr/bin/env python3
2 import struct, hashlib, random, os
3 from Crypto.Cipher import AES
4
5 n = 256
6 q = 11777
7 w = 8

^ we start with a few imports and setting three values:

`n`

, the degree of the polynomial ring we're working over (namely the "truncated" polynomial ring`R_q = (ZZ/qZZ)[x]/(x^n + 1)`

)`q`

, the order of the prime field the polynomials are defined over- and
`w`

, the width of the "centered binomial distirbution" (CBD, lol) we'll sample coefficients of our private key from.

11 sample = lambda rng: [bin(rng.getrandbits(w)).count('1') - w//2 for _ in range(n)]

^ this is the centered binomial distribution sampler. it just samples `w`

random bits
and counts the number of '1's (sampling a "regular" binomial distribution), then
subtracts `w/2`

to make the distribution centered around zero. this makes the
probability mass function look kinda like this:

```
#
# # #
# # #
# # #
# # # # #
# # # # #
. # # # # # # # .
-5 -4 -3 -2 -1 0 1 2 3 4 5
```

note that the support is entirely on `{-4, -3, -2, ..., 3, 4}`

. anywho,

13 add = lambda f,g: [(x + y) % q for x,y in zip(f,g)]
14
15 def mul(f,g):
16 r = [0]*n
17 for i,x in enumerate(f):
18 for j,y in enumerate(g):
19 s,k = divmod(i+j, n)
20 r[k] += (-1)**s * x*y
21 r[k] %= q
22 return r

^ here we have addition and multiplication of polynomials in `R_q`

. first, note that
every polynomial in `R_q`

can be equivalently represented as a degree 255 polynomial, as
in `R_q`

we have that `x^256 = -1`

. thus, polynomials are represented here as lists of
256 coefficients, i.e.

`a_255 x^255 + a_254 x^254 + ... + a_1 x + a_0`

is represented as the list

`[a_0, a_1, ..., a_254, a_255]`

.

the multiplication is performed as just "normal" schoolbook polynomial multiplication,
except when powers of `x`

at or above 256 appear, the exponent gets "wrapped around"
back to 0 and the resulting monomial is negated.

26 def genkey():
27 a = [random.randrange(q) for _ in range(n)]
28 rng = random.SystemRandom()
29 s,e = sample(rng), sample(rng)
30 b = add(mul(a,s), e)
31 return s, (a,b)

^ this part generates keys for the KEM. the private key `s`

is a polynomial in `R_q`

with coefficients drawn from our CBD. the public key consists of two polynomials `a`

and
`b`

:

`a`

is a polynomial in`R_q`

with uniform random coefficients.`b = as + e`

, with`e`

being a random "error" polynomial with coefficients drawn from our CBD distribution.

33 center = lambda v: min(v%q, v%q-q, key=abs)
34 extract = lambda r,d: [2*t//q for u,t in zip(r,d) if u]
35
36 ppoly = lambda g: struct.pack(f'<{n}H', *g).hex()
37 pbits = lambda g: ''.join(str(int(v)) for v in g)
38 hbits = lambda g: hashlib.sha256(pbits(g).encode()).digest()
39 mkaes = lambda bits: AES.new(hbits(bits), AES.MODE_CTR, nonce=b'')

^ we now randomly define a bunch of helper functions all at once ^^; uhh here's some summaries, skip past these if you're bored:

`center`

- takes in a value
`v`

mod`q`

- possibly subtracts q from it to make it fall in
`[-q/2, q/2]`

- takes in a value
`extract`

- takes in a list
`r`

of 256 bits and a polynomial`d`

- for every coefficient of
`d`

corresponding to a 1 in`r`

, outputs a bit indicating whether the coefficient is greater than`q / 2`

- takes in a list
`ppoly`

- takes in a polynomial
`g`

- turns it into packed bytes representing its coefficients as little-endian 16-bit values

- takes in a polynomial
`pbits`

- takes in a list of bits (either
`bool`

s or`int`

s)`g`

- turns it into a string of '0's and '1's

- takes in a list of bits (either
`hbits`

- takes in a list of bits
- applies
`pbits`

to it, and then hashes the bits using SHA-256

`mkaes`

- takes a list of bits
- hashes it with
`hbits`

and then creates an AES engine with that hash as the key

41 def encaps(pk):
42 seed = os.urandom(32)
43 rng = random.Random(seed)
44 a,b = pk
45 s,e = sample(rng), sample(rng)
46 c = add(mul(a,s), e)
47 d = add(mul(b,s), e)
48 r = [int(abs(center(2*v)) > q//7) for v in d]
49 bits = extract(r,d)
50 return bits, (c,r)

^ finally! the KEM encapsulation functionnnn~. this function

- takes in
`pk`

, the public key of the other party - creates
`bits`

which is the shared symmetric key both parties derive, plus the ciphertext`(c,r)`

that the other party needs to derive`bits`

too

the steps here are (adding `'`

to names to distinguish them from the `s`

and `e`

used
during the decapsulator's keygen, and the `d`

value the decapsulator will later derive):

- make
`s'`

and`e'`

random polynomials with "short" coefficients (i.e. drawn from CBD) - compute
`c = as' + e'`

- compute
`d' = as' + e'`

- calculate
`r`

as a set of bits indicating which coefficients of`d`

are further than a distance of`q / 7`

away from 0 `extract`

the coefficients of`d'`

corresponding to '1's in`r`

, thresholding them against`q / 2`

this is a fairly standard ring-LWE KEM (see Kyber.CPAPKE, section 1.2.3 of the Kyber
spec)
except for the part with `r`

selecting a subset of the bits.

as far as I can tell, this is just so that when the other party recomputes something
close to `d'`

in the next step, the error in their computation doesn't take a
coefficient from just below say `q`

to just above `0`

, flipping the bit outputted by
`extract`

.

anyway, spoiler alert, that `r`

addition turns out to be the downfall of this whole thing.

52 def decaps(sk, ct):
53 s = sk
54 c,r = ct
55 d = mul(c,s)
56 return extract(r,d)

^ now, the decapsulation process done by the other party is much more simple.

- take the ciphertext
`(c,r)`

and compute`d = cs`

(we'll see this is approximately equal to`d'`

) `extract`

the coefficients of`d`

corresponding to '1's in`r`

, thresholding them against`q / 2`

.

the trick is that `d`

here and `d'`

from the encapsulation should be about the same, so
when we call `extract`

we'll more than likely get the same bits. indeed:

`d = bs' + e' = (as + e)s' + e' = ass' + es' + e'`

and

`d' = (as' + e')s = ass' + e's`

where `e`

, `e'`

, and `s'`

have tiny coefficients, i.e. their contributions are unlikely
to nudge a value from being on one side of `q / 2`

to the other. now, without further ado...

60 if __name__ == '__main__':
61
62 while True:
63 sk, pk = genkey()
64 dh, ct = encaps(pk)
65 if decaps(sk, ct) == dh:
66 break
67
68 print('pk[0]:', ppoly(pk[0]))
69 print('pk[1]:', ppoly(pk[1]))
70
71 print('ct[0]:', ppoly(ct[0]))
72 print('ct[1]:', pbits(ct[1]))
73
74 flag = open('flag.txt').read().strip()
75 print('flag: ', mkaes([0]+dh).encrypt(flag.encode()).hex())
76
77 for _ in range(2048):
78 c = list(struct.unpack(f'<{n}H', bytes.fromhex(input())))
79 r = list(map('01'.index, input()))
80 if len(r) != n or sum(r) < n//2: exit('!!!')
81
82 bits = decaps(sk, (c,r))
83
84 print(mkaes([1]+bits).encrypt(b'hxp<3you').hex())

^ at last, the heart of the challenge! we're tossed a public key from one party (the "decapsulator") and a ciphertext for that public key from the other (the "encapsulator"). then,

- we're shown the flag, encrypted using AES with key
`[0] + bits`

- we're allowed to, 2048 times, send over a ciphertext, which the receiving party
decapsulates to compute the shared key
`bits'`

, and then replies to us with the AES encryption of`"hxp<3you"`

using the key`[1] + bits'`

there's also a pesky restriction that the values of `r`

we supply need to have at least
`n/2`

'1's, i.e. we can't try to "one-hot" out any particular coefficient of `d`

.

given that we have no access to the outputted value `bits'`

itself, and AES is, well,
hard to crack, it seems like there's not much we can do. unless...

# /// a first try ///

one piece of information we do have is whether two derived keys `bits'`

are the same,
because with overwhelming likelihood the encryptions of the message `"hxp<3you"`

will
be the same if and only if the keys used for encryption were the same.

the other key insight we need is that if we pass in the constant polynomial `1`

for `c`

, then `d = cs = s`

and we can threshold the coefficients of the private key
itself!

so, we can begin our valiant effort: if we send `c = 1`

with say,

`r = 10101010...101010`

(this has exactly `n/2`

'1's, so we're good!) followed by `c = 1`

with

`r = 01101010...101010`

which flips the first two bits of the previous `r`

, we can check the two outputted
encrypted messages. __here's the key part:__ if the outputted messages are the same,
then the values for `bits'`

derived in both cases were the same. the only thing that
changed is now the first bit of `bits'`

comes from the first coefficient of `s`

, `s[1]`

, rather
than from the zeroth coefficient `s[0]`

.

thus, if the outputs are the same, then both `s[0]`

and `s[1]`

lie both above or both
below `q/2`

(since this is what `extract`

checks for each coefficient). similarly, if
the encrypted messages weren't the same, then they'd have to be on opposite sides of
`q/2`

.

# /// getting fancy with r ///

here's the cool thing: we can repeat this trick!! try on these values for `r`

:

1 01010101...010101
2 10010101...010101
3 10100101...010101
4 10101001...010101
5 ...
6 10101010...101001
7 10101010...101010

^ comparing line 1 with line 2, we exchange `s[0]`

in for `s[1]`

in `extract`

. swapping
line 2 for line 3 exchanges `s[2]`

in for `s[3]`

, line 3 for 4 exchanges `s[4]`

for
`s[5]`

, and so on for every pair of lines. if we now add

1 00110101...010101
2 00101101...010101
3 00101011...010101
4 ...
5 00101010...101101
6 00101010...101011

^ then, swapping in the last `r`

of the last set with line 1 of this set exchanges
`s[1]`

for `s[2]`

. continuing down this set we swap `s[3]`

with `s[4]`

, etc.

in particular, comparing the outputs for all of these values of `r`

lets us check
whether every coefficient of `s`

is on the same side of or a different side of `q/2`

from its neighbors. up to guessing whether `s[0] < q/2`

or `s[0] > q/2`

, this tells us
exactly which "side" every coefficient of `s`

lies on!

but we need the *exact* coefficients of `s`

!! ://

this leads us to our next trick:

# /// short coefficients go spinny ///

turns out `c = 1`

isn't the only value we can cleverly try. indeed, we can set `c`

to
*any* integer mod `q`

, and it will multiply the coefficients of `s`

by that number.
turns out, you can identify any number in `{-4, -3, .., 4}`

by comparing multiples mod
`q`

of them to `q/2`

. here's a table!!

```
* 10630 * 2991 * 9085 * 5954 * 378 * 3770 * 1532 * 11293
-4 < q/2 > q/2 > q/2 > q/2 > q/2 > q/2 < q/2 < q/2
-3 < q/2 < q/2 > q/2 < q/2 > q/2 < q/2 > q/2 < q/2
-2 < q/2 < q/2 < q/2 > q/2 > q/2 < q/2 > q/2 < q/2
-1 < q/2 > q/2 < q/2 < q/2 > q/2 > q/2 > q/2 < q/2
0 < q/2 < q/2 < q/2 < q/2 < q/2 < q/2 < q/2 < q/2
1 > q/2 < q/2 > q/2 > q/2 < q/2 < q/2 < q/2 > q/2
2 > q/2 > q/2 > q/2 < q/2 < q/2 > q/2 < q/2 > q/2
3 > q/2 > q/2 < q/2 > q/2 < q/2 > q/2 < q/2 > q/2
4 > q/2 < q/2 < q/2 < q/2 < q/2 < q/2 > q/2 > q/2
```

how did I get these numbers in the top row, you ask? i just generated some random ones lmao. others work, idek.

anywho, yeah. that's the whole trick! here's the full attack. it gets a pretty box because i said so. \uwu/

stage 1: collecting the "side" data for each

`c`

- for each multiplier
`k`

above in the table:

- for each value of
`r`

above:

- let
`c`

be the constant polynomial`k`

- send
`c`

and`r`

and let the server compute`d = cs`

and hand us back our encrypted messages- assuming
`d[0] < q/2`

, do all the encrypted message comparisons to figure out which side of`q/2`

each coefficient is on (our assumption`d[0] < q/2`

could be wrong, but in that case we only need to flip which side we think each coefficient is on)- save these "side choices" for this value of
`c`

stage 2: enumerating the possibilities for

`d[0]`

and getting`s`

- since
`d[0] < q/2`

may be false for some values of`c`

, iterate over every possible combination of which`c`

's "side choices" need to be flipped versus which don't

- use these guesses to try to determine each coefficient of
`s`

with the big table above- if this works, go through the
`decaps`

steps and try to recover the message with this value for`s`

since we have 8 values for `c`

and 256 values of `r`

, we use up exactly 2048 server
queries. it turns out, this attack works perfectly in practice! ^^

# /// the solve script ///

uhh so my solve script during the competition was *super* messy. like, embarrassingly.
i've cleaned it up a bit here but if there's typos forgive me lol.

1 import struct
2 import itertools
3 from pwn import *
4 from vuln import *

^ yup, that last line's one you never want to see in a script.

6 ks = [10630, 2991, 9085, 5954, 378, 3770, 1532, 11293] # random values, lol
7 traces = {tuple([2 * ((v * k) % q) // q for k in ks]): v for v in range(-4, 5)}

^ we store each multiplier `k`

from the table above in `ks`

and then compute `traces`

,
which is a map from the rows of the table above back to the coefficient that generates
them.

for instance, the key `(0, 1, 1, 1, 1, 1, 0, 0)`

corresponding to the first row of the
above table gives the value `-4`

. the expression `2 * ((v * k) % q) // q`

just
determines which side of `q / 2`

each value `v = -4, -3, ..., 4`

is when we multiply it
by `k`

mod `q`

.

9 rs = []
10
11 for i in range(n // 2):
12 r = [0, 1] * i + [1, 0] * (n // 2 - i)
13 rs.append(r)
14
15 for i in range(n // 2):
16 r = [0] + [0, 1] * i + [1, 0] * (n // 2 - 1 - i) + [1]
17 rs.append(r)

^ next, we generate our list of `rs`

as above.

19 with remote("<server IP redacted>", 4421) as conn:
20 conn.recvline() # we don't need pk[0]
21 conn.recvline() # or pk[1] for that matter
22
23 ct = []
24
25 # parse c as a list of polynomial coefficients
26 c_bytes = bytes.fromhex(conn.recvline().decode().split(":")[1])
27 ct.append(struct.unpack(f'<{n}H', c_bytes))
28
29 # parse r as a list of 0 and 1 values
30 r_bytes = conn.recvline().decode().split(":")[1].strip()
31 ct.append(list(map('01'.index, r_bytes)))
32
33 # get encoded flag as bytes
34 flag_enc = bytes.fromhex(conn.recvline().decode().split(":")[1])

^ we parse the ciphertext and the encrypted flag from the server (i wasn't feeling super pythonic ok forgive me)

36 guesses = []
37
38 for k in ks:
39 encs = [] # encrypted messages received for this multiplier k
40 for r in rs:
41 c = [k] + [0] * (n - 1) # set c to a constant polynomial with constant term k
42 conn.sendline(ppoly(c)) # send it over
43 conn.sendline(pbits(r)) # send r too
44 encs.append(bytes.fromhex(conn.readline().decode())) # save the response
45
46 guess_bits = [0] # guesses for whether each bit is > q/2, we assume d[0] < q/2
47
48 for i in range(n // 2):
49 # infer bit 2 * i + 1 from bit 2 * i (i.e. 1 from 0)
50 guess = guess_bits[-1]
51 if encs[i] != encs[i + 1]:
52 guess = 1 if guess == 0 else 0
53 guess_bits.append(guess)
54
55 # infer bit 2 * i + 2 from bit 2 * i + 1 (i.e. 2 from 1)
56 if len(guess_bits) != n:
57 if encs[n // 2 + i] != encs[n // 2 + i + 1]:
58 guess = 1 if guess == 0 else 0
59 guess_bits.append(guess)
60
61 # save the guesses for this value of k
62 guesses.append(guess_bits)

^ this just does exactly what we described with guessing the "sides" of each coefficient
of `d`

. note that, as above, one set of `r`

values handles comparisons with even offsets
while the other handles comparisons with odd offsets.

then, lastly...

64 # iterate over which elements of `guesses` we might need to flip bitwise
65 for flips in itertools.product([True, False], repeat=8):
66 # perform the flipping
67 flipped_guesses = []
68 for flip, guess_bits in zip(flips, guesses):
69 if flip:
70 guess_bits = [1 if bit == 0 else 0 for bit in guess_bits]
71 flipped_guesses.append(guess_bits)
72
73 # after flipping, get the "sides" for each coefficient for each k
74 # i.e. s_traces[0] could be (0, 1, 1, 1, 1, 1, 0, 0), meaning s[0] is -4
75 s_traces = list(zip(*flipped_guesses))
76
77 # try to recover s
78 try:
79 s = [traces[t] for t in s_traces]
80 except KeyError:
81 continue
82
83 # get that flag
84 dh = decaps(s, ct)
85 flag = mkaes([0]+dh).decrypt(bytes.fromhex(enc_flag))
86 print(flag)

^ and this last step just does the flipping necessary because we might have misguessed
which side `d[0]`

fell on for some of our values of `k`

, and tries to recover `s`

for
each possible set of flips.

running this script gives (after several minutes, lol)... `hxp{e4zy_p34zY_p34nuT_Bu7t3r}`

and that's the challenge! ^^

# /// conclusion ///

welllll, that's it!! i hope you enjoyed, i thought this was a clever little puzzle and i'm happy to now have a writeup for it here :)