Timing leaks and multi-threading

by Landon | Aug 24, 2021 | Leave a comment

What if the server that verified MACs took longer to verify a correct mac than an incorrect one? Or, perhaps put differently, what if you could tell the difference between a more correct guess than an obviously wrong one? If you can, you can break MAC authentication schemes, and that’s what the cryptopals authors are trying to get at in challenges 31 and 32.

Write a function, call it "insecure_compare", that implements the == operation by doing byte-at-a-time comparisons with early exit (ie, return false at the first non-matching byte).

In the loop for "insecure_compare", add a 50ms sleep (sleep 50ms after each byte).

Use your "insecure_compare" function to verify the HMACs on incoming requests, and test that the whole contraption works. Return a 500 if the MAC is invalid, and a 200 if it's OK.

This is simple enough. Remember that a hash value is simply a byte array 20 bytes long in hexadecimal string form. If you take a byte array of 20 bytes and start changing then start changing one byte at a time, you can pretty easily determine when a byte in your candidate is correct. How? Brute force each byte in the hash.

  • Stand up a vulnerable server. I used spring boot and jetty. You can see my vulnerable server code here.
  • Starting with a byte array of all zeros, make 256 requests against the server, rotating byte n=0 (the first byte) through all 256 possible values. Measure how long each request takes. The correct byte in position n=0 is the one that took the longest.
  • Save the byte in position n=0
  • Repeat for positions n=(1, 19) until you finally get a 200 out of the server.

Challenge 31 was simple because of how obvious the timing leak is at 50 ms. It got much more difficult with challenge 32 because of the statistical noise involved with sending requests with a 5 ms delay instead of the 50 ms delay. Any number of things on the server can cause delays of a millisecond or two, especially when you’re running a small web server on a not-powerful system. Every solution for this challenge I could find elsewhere involved either not setting up an explicit web server like I had done and just relied on class-to-class method calls (who could blame them?) and usually relied on simply averaging the lengths of the request and taking the group of requests that took the longest.

Ultimately in my solution, I did rely on averaging, but I also relied on a process of elimination where I limited the number of requests I made for all possible solutions and then make 3x more requests for the five candidates that took the longest on the first go round. Both for 31 and 32, I used multi-threading and used the same tool to crack a timing leak at 50 ms and a timing leak at 5 ms. For the 50ms timing leak, I used upwards of 30 threads. The noise that produced in response times was not enough to pollute the results. For the 5ms leak, I could only use two or three and successfully break the mac. Otherwise, the noise was too much. You can see my solution below.

Final note: I used multi-threading to try to speed up build times with mixed results. This crack takes a long time to run. After you find the first byte correctly, each successive request is also going to be delayed. This means finding the first byte takes a fraction of a second. Finding the second byte with a single thread takes (50ms * 256) + 50ms at minimum, which is roughly 13 seconds. The third byte takes (100ms * 256) + 50ms, which is roughly 26 seconds. This grows in linear fashion. If you can multi-thread, you can make concurrent requests to arrive at the solution more quickly. Even so, getting reliable results takes time. With my solution, each challenge takes 20 minutes or so to compute. Rather than sit around for 3/4 of an hour waiting for the crack to finish, I tagged these tests and ignore them by default. I only run to them when I push to a specific branch on a GitHub server. GitHub will tell me later if there’s a problem.

Here’s my timing leak exploiter:

/**
 * a class dedicated to exploiting timing leaks in order to complete
 * challenges 31 and 32
 */
@Slf4j
public class C31_32_TimingLeakExploiter {

    private final String file;
    private final int port;
    private final RestTemplate restTemplate;
    private final Executor ex;

    public C31_32_TimingLeakExploiter(String file, int port, RestTemplate restTemplate, int numOfThreads) {
        this.file = file;
        this.port = port;
        this.restTemplate = restTemplate;
        this.ex = Executors.newFixedThreadPool(numOfThreads);
    }

    @SneakyThrows
    public void exploitLeak(final byte[] forgedHash) {
        for (int i = 0; i < forgedHash.length; i++) {
            log.info("starting round one with full byte set");
            //round 1
            Set<Byte> byteSet = IntStream.range(Byte.MIN_VALUE, Byte.MAX_VALUE + 1).mapToObj(n -> (byte) n).collect(Collectors.toSet());
            final SortedSet<Pair<Byte, Double>> tree = gatherSortedData(forgedHash, i, byteSet, 5, 33);
            log.info("round one over. top five results: {}", tree);
            //round 2
            byteSet = tree.stream().map(Pair::getKey).collect(Collectors.toSet());
            final SortedSet<Pair<Byte, Double>> tree2 = gatherSortedData(forgedHash, i, byteSet, 1, 100);
            forgedHash[i] = tree2.first().getKey();
            log.info("found a byte ({}). now the hash is {}", tree2.first().getKey(), Hex.toHexString(forgedHash));

            if (HttpStatus.OK == makeRequest(forgedHash).get().getKey()) {
                log.info("the hash was {}", Hex.toHexString(forgedHash));
                break;
            }
        }
    }

    private SortedSet<Pair<Byte, Double>> gatherSortedData(final byte[] forgedHash, final int i, Collection<Byte> initialCandidates,
                                                           int limitOfNewCandidates, int limitOfRequest) {
        final var futuresMap = new HashMap<Byte, List<CompletableFuture<Pair<HttpStatus, Long>>>>();
        for (byte k : initialCandidates) {
            final List<byte[]> samples = new ArrayList<>();
            for (int j = 0; j < limitOfRequest; j++) {
                var clone = Arrays.clone(forgedHash);
                clone[i] = k;
                samples.add(clone);
            }
            final var futures = samples.stream().map(this::makeRequest).collect(Collectors.toList());
            futuresMap.put(k, futures);
        }

        SortedSet<Pair<Byte, Double>> candidates = new TreeSet<>(Comparator.comparing(Pair::getRight, Comparator.reverseOrder()));

        for (Map.Entry<Byte, List<CompletableFuture<Pair<HttpStatus, Long>>>> entry : futuresMap.entrySet()) {
            var futures = entry.getValue();
            var all = CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]));
            var totalTime = all.thenApply(v -> futures.stream().map(CompletableFuture::join).map(Pair::getRight).reduce(0L, Long::sum)).join();
            final double mean = totalTime.doubleValue() / (double) entry.getValue().size();
            if (candidates.size() < limitOfNewCandidates) {
                candidates.add(Pair.of(entry.getKey(), mean));
            } else if (candidates.stream().anyMatch(e -> e.getValue() < mean)) {
                candidates.add(Pair.of(entry.getKey(), mean));
                candidates.remove(candidates.last());
            }
        }
        return candidates;
    }

    public CompletableFuture<Pair<HttpStatus, Long>> makeRequest(final byte[] forgedHash) {
        return CompletableFuture.supplyAsync(() -> {
            final long startTime = System.currentTimeMillis();
            final String signature = Hex.toHexString(forgedHash);
            final URI uri = URI.create(String.format("http://localhost:%s/leak/test/%s?signature=%s",
                    port,
                    file,
                    signature
            ));
            final ResponseEntity<String> response = restTemplate.getForEntity(uri, String.class);
            final long responseTime = System.currentTimeMillis() - startTime;
            if (response.getStatusCode() == HttpStatus.BAD_REQUEST) {
                throw new AssertionError("Got a bad request response");
            }
            return Pair.of(response.getStatusCode(), responseTime);
        }, ex);
    }
}