From 312f52634536c9a865a82ee4cdf4cd682009c701 Mon Sep 17 00:00:00 2001 From: kageru Date: Fri, 3 Dec 2021 12:00:58 +0100 Subject: [PATCH] Make day 3 not awful --- 2021/src/bin/day03.rs | 102 ++++++++++++++++++++++++++---------------- 1 file changed, 64 insertions(+), 38 deletions(-) diff --git a/2021/src/bin/day03.rs b/2021/src/bin/day03.rs index ba75002..62aa8a2 100644 --- a/2021/src/bin/day03.rs +++ b/2021/src/bin/day03.rs @@ -1,60 +1,58 @@ +#![feature(int_log)] #![feature(test)] extern crate test; use aoc2021::common::*; -use itertools::Itertools; const DAY: usize = 03; -type Parsed = Vec>; +type Parsed = Vec; fn parse_input(raw: &str) -> Parsed { - raw.lines().map(|line| line.chars().map(|c| c == '1').collect()).collect() + raw.lines().map(|line| usize::from_str_radix(line, 2).unwrap()).collect() } -fn most_common(parsed: &Parsed) -> String { - parsed - .iter() - .skip(1) - .fold(parsed[0].iter().map(|&b| b as usize).collect_vec(), |acc, bits| { - acc.iter().zip(bits).map(|(a, &b)| a + (b as usize)).collect() - }) - .into_iter() - .map(|b| if b as f32 >= (parsed.len() as f32) / 2.0 { '1' } else { '0' }) - .collect() +fn bit_at(x: usize, n: usize) -> bool { + (x >> n) & 1 != 0 } -fn part1(parsed: &Parsed) -> usize { - let gamma = most_common(parsed); - let epsilon: String = gamma.chars().map(|c| if c == '0' { '1' } else { '0' }).collect(); - let gamma = usize::from_str_radix(&gamma, 2).unwrap(); - let epsilon = usize::from_str_radix(&epsilon, 2).unwrap(); +fn most_common(parsed: &Parsed, bits: usize) -> usize { + (0..bits).rev().map(|n| most_common_at(parsed, n)).fold(0, |acc, b| (acc | (b as usize)) << 1) >> 1 +} + +fn most_common_at(parsed: &Parsed, n: usize) -> bool { + parsed.iter().filter(|&&x| bit_at(x, n)).count() * 2 >= parsed.len() +} + +fn invert(n: usize) -> usize { + !n & ((1 << n.log2()) - 1) +} + +fn part1(parsed: &Parsed, bits: usize) -> usize { + let gamma = most_common(parsed, bits); + let epsilon = invert(gamma); gamma * epsilon } -fn part2(parsed: &Parsed) -> usize { +fn part2(parsed: &Parsed, bits: usize) -> usize { let mut matching_gamma = parsed.clone(); let mut matching_epsilon = parsed.clone(); - for i in 0..parsed[0].len() { - let gamma = most_common(&matching_gamma); - let epsilon = most_common(&matching_epsilon); - let epsilon: String = epsilon.chars().map(|c| if c == '0' { '1' } else { '0' }).collect(); - matching_gamma.retain(|n| n[i] == (gamma.chars().nth(i).unwrap() == '1')); + for i in (0..bits).rev() { + let gamma = most_common_at(&matching_gamma, i); + let epsilon = !most_common_at(&matching_epsilon, i); + matching_gamma.retain(|&n| bit_at(n, i) == gamma); if matching_epsilon.len() > 1 { - matching_epsilon.retain(|n| n[i] == (epsilon.chars().nth(i).unwrap() == '1')); + matching_epsilon.retain(|&n| bit_at(n, i) == epsilon); } } - let gamma: String = matching_gamma[0].iter().map(|&b| if b { '1' } else { '0' }).collect(); - let epsilon: String = matching_epsilon[0].iter().map(|&b| if b { '1' } else { '0' }).collect(); - let gamma = usize::from_str_radix(&gamma, 2).unwrap(); - let epsilon = usize::from_str_radix(&epsilon, 2).unwrap(); - gamma * epsilon + debug_assert_eq!(matching_gamma.len(), 1); + debug_assert_eq!(matching_epsilon.len(), 1); + matching_gamma[0] * matching_epsilon[0] } fn main() { - let raw = read_file(DAY); - let input = parse_input(&raw); - println!("Part 1: {}", part1(&input)); - println!("Part 2: {}", part2(&input)); + let input = parse_input(&read_file(DAY)); + println!("Part 1: {}", part1(&input, 12)); + println!("Part 2: {}", part2(&input, 12)); } #[cfg(test)] @@ -75,9 +73,37 @@ mod tests { 00010 01010"; - test!(part1() == 198); - test!(part2() == 230); - bench!(part1() == 3549854); - bench!(part2() == 3765399); + #[test] + fn most_common_test() { + let parsed = parse_input(TEST_INPUT); + assert_eq!(most_common(&parsed, 5), 0b10110) + } + + #[test] + fn invert_test() { + let gamma = 0b10110; + assert_eq!(invert(gamma), 0b01001); + } + + #[test] + fn most_common_at_test() { + let parsed = parse_input(TEST_INPUT); + assert_eq!(most_common_at(&parsed, 4), true); + } + + #[test] + fn bit_at_test() { + assert_eq!(bit_at(0b111, 0), true); + assert_eq!(bit_at(0b111, 1), true); + assert_eq!(bit_at(0b111, 2), true); + assert_eq!(bit_at(0b111, 3), false); + assert_eq!(bit_at(0b101, 1), false); + assert_eq!(bit_at(0b11101, 3), true); + } + + test!(part1(5) == 198); + test!(part2(5) == 230); + bench!(part1(12) == 3549854); + bench!(part2(12) == 3765399); bench_input!(len == 1000); }