提出 #43651390


ソースコード 拡げる

#![allow(non_snake_case)]


fn main(){
    let input=Input::input();
    let node=Node{
        track_id:!0,
        state:[0;N],
        hash:0,
        refs:0,
        score:(0..N).map(|i|input.score[0][i][0]).sum(),
    };
    let mut solver=BeamSearch::new(node);
    let ans=solver.solve(&input);

    for n in ans{
        if n==1{
            println!("A");
        } else {
            println!("B");
        }
    }
}


const N:usize=20;
use proconio::{*,marker::*};


struct Input{
    cmds:[[usize;3];TURN],
    score:[[[i16;TURN];N];TURN+1],
    zob:[u64;N],
}
impl Input{
    fn input()->Input{
        input!{
            turn:usize,
            i_cmds:[[Usize1;3];TURN],
        }
        assert_eq!(turn,TURN);
        let mut cmds=[[!0;3];TURN];
        for i in 0..TURN{
            for j in 0..3{
                cmds[i][j]=i_cmds[i][j];
            }
        }
        
        let mut score=[[[0;TURN];N];TURN+1];

        for i in (0..TURN).rev(){
            for &j in &cmds[i]{
                score[i][j][0]=score[i+1][j][1];
                for k in 1..TURN{
                    score[i][j][k]=score[i+1][j][k-1]+(k==1) as i16;
                }
            }

            for j in 0..N{
                if cmds[i].contains(&j){
                    continue;
                }
                for k in 0..TURN{
                    score[i][j][k]=score[i+1][j][k]+(k==0) as i16;
                }
            }
        }

        let mut zob=[0;N];
        use rand::prelude::*;
        let mut rng=rand_pcg::Pcg64Mcg::new(0);
        for i in 0..N{
            zob[i]=rng.gen();
        }

        Input{cmds,score,zob}
    }
}


#[allow(non_camel_case_types)]
type uint=u32;


#[derive(Clone,Default)]
struct Node{
    track_id:uint,
    score:i16,
    hash:u64,
    state:[i8;N],
    refs:u8,
}
impl Node{
    fn new_node(&self,input:&Input,turn:usize,cand:&Cand)->Node{
        let mut ret=self.clone();
        ret.apply(input,turn,cand);
        ret
    }
    
    fn apply(&mut self,input:&Input,turn:usize,cand:&Cand){
        for &n in &input.cmds[turn]{
            self.state[n]+=cand.op;
        }
        self.score=cand.eval_score;
        self.hash=cand.hash;
    }
}


#[derive(Clone)]
struct Cand{
    op:i8,
    parent:uint,
    eval_score:i16,
    hash:u64,
}
impl Cand {
    fn raw_score(&self,input:&Input)->i16{
        self.eval_score
    }
}


const MAX_WIDTH:usize=100000;
const TURN:usize=100;


struct BeamSearch{
    track:Vec<(uint,i8)>,
    nodes:Vec<Node>,
    free:Vec<usize>,
    at:usize,
    cands:Vec<Cand>,
}
impl BeamSearch{
    fn new(node:Node)->BeamSearch{
        const MAX_NODES:usize=MAX_WIDTH*TURN;
        // assert!(MAX_NODES<uint::MAX as usize);
        let mut nodes=vec![Node::default();MAX_WIDTH*2];
        nodes[0]=node;
        
        BeamSearch{
            free:(0..nodes.len()).collect(),
            nodes,
            at:1,
            track:Vec::with_capacity(MAX_NODES),
            cands:Vec::with_capacity(MAX_WIDTH),
        }
    }
    
    fn enum_cands(&self,input:&Input,turn:usize,cands:&mut Vec<Cand>){
        for &i in &self.free[..self.at]{
            self.append_cands(input,turn,i,cands);
        }
    }
    
    fn update<I:Iterator<Item=Cand>>(&mut self,input:&Input,turn:usize,cands:I){
        self.cands.clear();
        for cand in cands{
            self.nodes[cand.parent as usize].refs+=1;
            self.cands.push(cand);
        }

        for i in (0..self.at).rev(){
            if self.nodes[self.free[i]].refs==0{
                self.at-=1;
                self.free.swap(i,self.at);
            }
        }

        for cand in &self.cands{
            let node=&mut self.nodes[cand.parent as usize];
            node.refs-=1;
            let prev=node.track_id;

            let new=if node.refs==0{
                node.apply(input,turn,cand);
                node
            }
            else{
                let mut new=node.new_node(input,turn,cand);
                new.refs=0;
                let idx=self.free[self.at];
                self.at+=1;
                self.nodes[idx]=new;
                &mut self.nodes[idx]
            };

            self.track.push((prev,cand.op));
            new.track_id=self.track.len() as uint-1;
        }
    }

    fn restore(&self,mut idx:uint)->Vec<i8>{
        idx=self.nodes[idx as usize].track_id;
        let mut ret=vec![];
        while idx!=!0{
            ret.push(self.track[idx as usize].1);
            idx=self.track[idx as usize].0;
        }
        ret.reverse();
        ret
    }

    fn append_cands(&self,input:&Input,turn:usize,idx:usize,cands:&mut Vec<Cand>){
        let node=&self.nodes[idx];

        let next=|add:i8|->(i16,u64){
            let mut score=node.score;
            let mut hash=node.hash;
            for &n in &input.cmds[turn]{
                score-=input.score[turn][n][node.state[n].abs() as usize];
                let new=node.state[n]+add;
                hash+=add as u64*input.zob[n];
                score+=input.score[turn+1][n][new.abs() as usize];
                score+=(new==0) as i16;
            }
            (score,hash)
        };

        let (eval_score,hash)=next(1);
        let cand=Cand{
            op:1,
            parent:idx as uint,
            eval_score,hash,
        };
        cands.push(cand);

        if turn!=0{
            let (eval_score,hash)=next(-1);
            let cand=Cand{
                op:-1,
                parent:idx as uint,
                eval_score,hash,
            };
            cands.push(cand);
        }
    }
    
    fn solve(&mut self,input:&Input)->Vec<i8>{
        use std::cmp::Reverse;
        let M=MAX_WIDTH;
        
        let mut cands=Vec::<Cand>::with_capacity(MAX_WIDTH);
        let mut set=NopHashSet::default();
        for t in 0..TURN{
            if t!=0{
                if cands.len()>M{
                    cands.select_nth_unstable_by_key(M,|a|Reverse(a.eval_score));
                    cands.truncate(M);
                }
                cands.sort_unstable_by_key(|a|Reverse(a.eval_score));
                set.clear();
                self.update(input,t-1,cands.iter().filter(|cand|
                    set.insert(cand.hash)
                ).take(M).cloned());
            }

            cands.clear();
            self.enum_cands(input,t,&mut cands);
            assert!(!cands.is_empty(),"次の合法手が存在しないよ");
        }

        let best=cands.iter().max_by_key(|a|a.raw_score(input)).unwrap();
        eprintln!("score = {}",best.raw_score(input));

        let mut ret=self.restore(best.parent);
        ret.push(best.op);

        ret
    }
}


#[allow(unused)]
mod nop_hash{
    use std::collections::{HashMap,HashSet};
    use core::hash::BuildHasherDefault;
    use core::hash::Hasher;
    
    #[derive(Default)]
    pub struct NopHasher{
        hash:u64,
    }
    impl Hasher for NopHasher{
        fn write(&mut self,_:&[u8]){
            panic!();
        }
    
        #[inline]
        fn write_u64(&mut self,n:u64){
            self.hash=n;
        }
    
        #[inline]
        fn finish(&self)->u64{
            self.hash
        }
    }
    
    pub type NopHashMap<K,V>=HashMap<K,V,BuildHasherDefault<NopHasher>>;
    pub type NopHashSet<V>=HashSet<V,BuildHasherDefault<NopHasher>>;
}
#[allow(unused)]
use nop_hash::*;



use std::cmp::{self, Ordering, Ordering::*};
use std::mem::{self, MaybeUninit};
use std::ptr;

struct CopyOnDrop<T> {
    src: *const T,
    dest: *mut T,
}

impl<T> Drop for CopyOnDrop<T> {
    fn drop(&mut self) {
        unsafe {
            ptr::copy_nonoverlapping(self.src, self.dest, 1);
        }
    }
}

fn shift_tail<T, F>(v: &mut [T], is_less: &mut F)
where
    F: FnMut(&T, &T) -> bool,
{
    let len = v.len();
    unsafe {
        if len >= 2 && is_less(v.get_unchecked(len - 1), v.get_unchecked(len - 2)) {
            let tmp = mem::ManuallyDrop::new(ptr::read(v.get_unchecked(len - 1)));
            let v = v.as_mut_ptr();
            let mut hole = CopyOnDrop { src: &*tmp, dest: v.add(len - 2) };
            ptr::copy_nonoverlapping(v.add(len - 2), v.add(len - 1), 1);

            for i in (0..len - 2).rev() {
                if !is_less(&*tmp, &*v.add(i)) {
                    break;
                }

                ptr::copy_nonoverlapping(v.add(i), v.add(i + 1), 1);
                hole.dest = v.add(i);
            }
        }
    }
}

fn insertion_sort<T, F>(v: &mut [T], is_less: &mut F)
where
    F: FnMut(&T, &T) -> bool,
{
    for i in 1..v.len() {
        shift_tail(&mut v[..i + 1], is_less);
    }
}

fn partition_in_blocks<T, F>(v: &mut [T], pivot: &T, is_less: &mut F) -> usize
where
    F: FnMut(&T, &T) -> bool,
{
    const BLOCK: usize = 128;

    let mut l = v.as_mut_ptr();
    let mut block_l = BLOCK;
    let mut start_l = ptr::null_mut();
    let mut end_l = ptr::null_mut();
    let mut offsets_l = [MaybeUninit::<u8>::uninit(); BLOCK];

    let mut r = unsafe { l.add(v.len()) };
    let mut block_r = BLOCK;
    let mut start_r = ptr::null_mut();
    let mut end_r = ptr::null_mut();
    let mut offsets_r = [MaybeUninit::<u8>::uninit(); BLOCK];

    fn width<T>(l: *mut T, r: *mut T) -> usize {
        assert!(mem::size_of::<T>() > 0);
        (r as usize - l as usize) / mem::size_of::<T>()
    }

    loop {
        let is_done = width(l, r) <= 2 * BLOCK;

        if is_done {
            let mut rem = width(l, r);
            if start_l < end_l || start_r < end_r {
                rem -= BLOCK;
            }

            if start_l < end_l {
                block_r = rem;
            } else if start_r < end_r {
                block_l = rem;
            } else {
                block_l = rem / 2;
                block_r = rem - block_l;
            }
        }

        if start_l == end_l {
            start_l = offsets_l.as_mut_ptr() as *mut _;
            end_l = start_l;
            let mut elem = l;

            for i in 0..block_l {
                unsafe {
                    *end_l = i as u8;
                    end_l = end_l.offset(!is_less(&*elem, pivot) as isize);
                    elem = elem.offset(1);
                }
            }
        }

        if start_r == end_r {
            start_r = offsets_r.as_mut_ptr() as *mut _;
            end_r = start_r;
            let mut elem = r;

            for i in 0..block_r {
                unsafe {
                    elem = elem.offset(-1);
                    *end_r = i as u8;
                    end_r = end_r.offset(is_less(&*elem, pivot) as isize);
                }
            }
        }

        let count = cmp::min(width(start_l, end_l), width(start_r, end_r));

        if count > 0 {
            macro_rules! left {
                () => {
                    l.offset(*start_l as isize)
                };
            }
            macro_rules! right {
                () => {
                    r.offset(-(*start_r as isize) - 1)
                };
            }

            unsafe {
                let tmp = ptr::read(left!());
                ptr::copy_nonoverlapping(right!(), left!(), 1);

                for _ in 1..count {
                    start_l = start_l.offset(1);
                    ptr::copy_nonoverlapping(left!(), right!(), 1);
                    start_r = start_r.offset(1);
                    ptr::copy_nonoverlapping(right!(), left!(), 1);
                }

                ptr::copy_nonoverlapping(&tmp, right!(), 1);
                mem::forget(tmp);
                start_l = start_l.offset(1);
                start_r = start_r.offset(1);
            }
        }

        if start_l == end_l {
            l = unsafe { l.offset(block_l as isize) };
        }

        if start_r == end_r {
            r = unsafe { r.offset(-(block_r as isize)) };
        }

        if is_done {
            break;
        }
    }

    if start_l < end_l {
        while start_l < end_l {
            unsafe {
                end_l = end_l.offset(-1);
                ptr::swap(l.offset(*end_l as isize), r.offset(-1));
                r = r.offset(-1);
            }
        }
        width(v.as_mut_ptr(), r)
    } else if start_r < end_r {
        while start_r < end_r {
            unsafe {
                end_r = end_r.offset(-1);
                ptr::swap(l, r.offset(-(*end_r as isize) - 1));
                l = l.offset(1);
            }
        }
        width(v.as_mut_ptr(), l)
    } else {
        width(v.as_mut_ptr(), l)
    }
}

fn partition<T, F>(v: &mut [T], pivot: usize, is_less: &mut F) -> (usize, bool)
where
    F: FnMut(&T, &T) -> bool,
{
    let (mid, was_partitioned) = {
        v.swap(0, pivot);
        let (pivot, v) = v.split_at_mut(1);
        let pivot = &mut pivot[0];

        let tmp = mem::ManuallyDrop::new(unsafe { ptr::read(pivot) });
        let _pivot_guard = CopyOnDrop { src: &*tmp, dest: pivot };
        let pivot = &*tmp;

        let mut l = 0;
        let mut r = v.len();

        unsafe {
            while l < r && is_less(v.get_unchecked(l), pivot) {
                l += 1;
            }

            while l < r && !is_less(v.get_unchecked(r - 1), pivot) {
                r -= 1;
            }
        }

        (l + partition_in_blocks(&mut v[l..r], pivot, is_less), l >= r)
    };

    v.swap(0, mid);

    (mid, was_partitioned)
}

fn partition_equal<T, F>(v: &mut [T], pivot: usize, is_less: &mut F) -> usize
where
    F: FnMut(&T, &T) -> bool,
{
    v.swap(0, pivot);
    let (pivot, v) = v.split_at_mut(1);
    let pivot = &mut pivot[0];

    let tmp = mem::ManuallyDrop::new(unsafe { ptr::read(pivot) });
    let _pivot_guard = CopyOnDrop { src: &*tmp, dest: pivot };
    let pivot = &*tmp;

    let mut l = 0;
    let mut r = v.len();
    loop {
        unsafe {
            while l < r && !is_less(pivot, v.get_unchecked(l)) {
                l += 1;
            }

            while l < r && is_less(pivot, v.get_unchecked(r - 1)) {
                r -= 1;
            }

            if l >= r {
                break;
            }

            r -= 1;
            let ptr = v.as_mut_ptr();
            ptr::swap(ptr.add(l), ptr.add(r));
            l += 1;
        }
    }

    l + 1
}

fn choose_pivot<T, F>(v: &mut [T], is_less: &mut F) -> (usize, bool)
where
    F: FnMut(&T, &T) -> bool,
{
    const SHORTEST_MEDIAN_OF_MEDIANS: usize = 50;
    const MAX_SWAPS: usize = 4 * 3;

    let len = v.len();

    let mut a = len / 4 * 1;
    let mut b = len / 4 * 2;
    let mut c = len / 4 * 3;

    let mut swaps = 0;

    if len >= 8 {
        let mut sort2 = |a: &mut usize, b: &mut usize| unsafe {
            if is_less(v.get_unchecked(*b), v.get_unchecked(*a)) {
                ptr::swap(a, b);
                swaps += 1;
            }
        };

        let mut sort3 = |a: &mut usize, b: &mut usize, c: &mut usize| {
            sort2(a, b);
            sort2(b, c);
            sort2(a, b);
        };

        if len >= SHORTEST_MEDIAN_OF_MEDIANS {
            let mut sort_adjacent = |a: &mut usize| {
                let tmp = *a;
                sort3(&mut (tmp - 1), a, &mut (tmp + 1));
            };

            sort_adjacent(&mut a);
            sort_adjacent(&mut b);
            sort_adjacent(&mut c);
        }

        sort3(&mut a, &mut b, &mut c);
    }

    if swaps < MAX_SWAPS {
        (b, swaps == 0)
    } else {
        v.reverse();
        (len - 1 - b, true)
    }
}


fn partition_at_index_loop<'a, T, F>(
    mut v: &'a mut [T],
    mut index: usize,
    is_less: &mut F,
    mut pred: Option<&'a T>,
) where
    F: FnMut(&T, &T) -> bool,
{
    loop {
        const MAX_INSERTION: usize = 10;
        if v.len() <= MAX_INSERTION {
            insertion_sort(v, is_less);
            return;
        }

        let (pivot, _) = choose_pivot(v, is_less);

        if let Some(p) = pred {
            if !is_less(p, &v[pivot]) {
                let mid = partition_equal(v, pivot, is_less);

                if mid > index {
                    return;
                }

                v = &mut v[mid..];
                index = index - mid;
                pred = None;
                continue;
            }
        }

        let (mid, _) = partition(v, pivot, is_less);

        let (left, right) = v.split_at_mut(mid);
        let (pivot, right) = right.split_at_mut(1);
        let pivot = &pivot[0];

        if mid < index {
            v = right;
            index = index - mid - 1;
            pred = Some(pivot);
        } else if mid > index {
            v = left;
        } else {
            return;
        }
    }
}

fn partition_at_index<T, F>(
    v: &mut [T],
    index: usize,
    mut is_less: F,
) -> (&mut [T], &mut T, &mut [T])
where
    F: FnMut(&T, &T) -> bool,
{
    use Greater;
    use Less;

    if index >= v.len() {
        panic!("partition_at_index index {} greater than length of slice {}", index, v.len());
    }

    if mem::size_of::<T>() == 0 {
    } else if index == v.len() - 1 {
        let (max_index, _) = v
            .iter()
            .enumerate()
            .max_by(|&(_, x), &(_, y)| if is_less(x, y) { Less } else { Greater })
            .unwrap();
        v.swap(max_index, index);
    } else if index == 0 {
        let (min_index, _) = v
            .iter()
            .enumerate()
            .min_by(|&(_, x), &(_, y)| if is_less(x, y) { Less } else { Greater })
            .unwrap();
        v.swap(min_index, index);
    } else {
        partition_at_index_loop(v, index, &mut is_less, None);
    }

    let (left, right) = v.split_at_mut(index);
    let (pivot, right) = right.split_at_mut(1);
    let pivot = &mut pivot[0];
    (left, pivot, right)
}


pub trait Nth{
    type T;
    fn select_nth_unstable(&mut self, index: usize) where Self::T: Ord;
    fn select_nth_unstable_by<F: FnMut(&Self::T, &Self::T) -> Ordering>(&mut self, index: usize, compare: F);
    fn select_nth_unstable_by_key<K: Ord, F: FnMut(&Self::T) -> K>(&mut self, index: usize, f: F);
}

impl<T> Nth for [T]{
    type T = T;

    #[inline]
    fn select_nth_unstable(&mut self, index: usize) where T: Ord{
        partition_at_index(self, index, T::lt);
    }

    #[inline]
    fn select_nth_unstable_by<F: FnMut(&T, &T) -> Ordering>(&mut self, index: usize, mut compare: F){
        partition_at_index(self, index, |a: &T, b: &T| compare(a, b) == Less);
    }
    
    #[inline]
    fn select_nth_unstable_by_key<K: Ord, F: FnMut(&T) -> K>(&mut self, index:usize, mut f: F){
        partition_at_index(self, index, |a: &T, b: &T| f(a).lt(&f(b)));
    }
}

提出情報

提出日時
問題 A49 - Heuristic 2
ユーザ rhoo
言語 Rust (1.42.0)
得点 48822
コード長 19295 Byte
結果 AC
実行時間 948 ms
メモリ 81596 KiB

コンパイルエラー

warning: the item `Greater` is imported redundantly
   --> src/main.rs:679:9
    |
309 | use std::cmp::{self, Ordering, Ordering::*};
    |                                ----------- the item `Greater` is already imported here
...
679 |     use Greater;
    |         ^^^^^^^
    |
    = note: `#[warn(unused_imports)]` on by default

warning: the item `Less` is imported redundantly
   --> src/main.rs:680:9
    |
309 | use std::cmp::{self, Ordering, Ordering::*};
    |                                ----------- the item `Less` is already imported here
...
680 |     use Less;
    |         ^^^^

warning: unused variable: `input`
   --> src/main.rs:118:24
    |
118 |     fn raw_score(&self,input:&Input)->i16{
    |                        ^^^^^ help: consider prefixing with an underscore: `_input`
    |
    = note: `#[warn(unused_variables)]` on by default

ジャッジ結果

セット名 All
得点 / 配点 48822 / 1000
結果
AC × 50
セット名 テストケース
All in01.txt, in02.txt, in03.txt, in04.txt, in05.txt, in06.txt, in07.txt, in08.txt, in09.txt, in10.txt, in11.txt, in12.txt, in13.txt, in14.txt, in15.txt, in16.txt, in17.txt, in18.txt, in19.txt, in20.txt, in21.txt, in22.txt, in23.txt, in24.txt, in25.txt, in26.txt, in27.txt, in28.txt, in29.txt, in30.txt, in31.txt, in32.txt, in33.txt, in34.txt, in35.txt, in36.txt, in37.txt, in38.txt, in39.txt, in40.txt, in41.txt, in42.txt, in43.txt, in44.txt, in45.txt, in46.txt, in47.txt, in48.txt, in49.txt, in50.txt
ケース名 結果 実行時間 メモリ
in01.txt AC 817 ms 80680 KiB
in02.txt AC 818 ms 81220 KiB
in03.txt AC 895 ms 81104 KiB
in04.txt AC 842 ms 81320 KiB
in05.txt AC 820 ms 80896 KiB
in06.txt AC 812 ms 80664 KiB
in07.txt AC 813 ms 81360 KiB
in08.txt AC 902 ms 81596 KiB
in09.txt AC 858 ms 80776 KiB
in10.txt AC 866 ms 81500 KiB
in11.txt AC 918 ms 80472 KiB
in12.txt AC 832 ms 81208 KiB
in13.txt AC 830 ms 81012 KiB
in14.txt AC 824 ms 81096 KiB
in15.txt AC 841 ms 81108 KiB
in16.txt AC 815 ms 80692 KiB
in17.txt AC 843 ms 80336 KiB
in18.txt AC 882 ms 80788 KiB
in19.txt AC 948 ms 81008 KiB
in20.txt AC 819 ms 80656 KiB
in21.txt AC 832 ms 80796 KiB
in22.txt AC 808 ms 80956 KiB
in23.txt AC 811 ms 81292 KiB
in24.txt AC 820 ms 80480 KiB
in25.txt AC 825 ms 81024 KiB
in26.txt AC 825 ms 81204 KiB
in27.txt AC 820 ms 80792 KiB
in28.txt AC 826 ms 80472 KiB
in29.txt AC 821 ms 80964 KiB
in30.txt AC 832 ms 80748 KiB
in31.txt AC 862 ms 81428 KiB
in32.txt AC 821 ms 80932 KiB
in33.txt AC 831 ms 81108 KiB
in34.txt AC 820 ms 81012 KiB
in35.txt AC 839 ms 81264 KiB
in36.txt AC 883 ms 81060 KiB
in37.txt AC 835 ms 81288 KiB
in38.txt AC 823 ms 81304 KiB
in39.txt AC 817 ms 81016 KiB
in40.txt AC 828 ms 80480 KiB
in41.txt AC 818 ms 81016 KiB
in42.txt AC 817 ms 80908 KiB
in43.txt AC 805 ms 80824 KiB
in44.txt AC 814 ms 81212 KiB
in45.txt AC 815 ms 81460 KiB
in46.txt AC 811 ms 80780 KiB
in47.txt AC 829 ms 80564 KiB
in48.txt AC 929 ms 81028 KiB
in49.txt AC 819 ms 81316 KiB
in50.txt AC 810 ms 81516 KiB
sample_01.txt RE 2 ms 2208 KiB