Official

G - Ai + Bj + Ck = X (1 <= i, j, k <= N) Editorial by physics0523


\(N \le 10^6\) なので、 \(i\) を固定する全探索ができます。 このとき、 問題は \(Bj+Ck=Y\) ( \(Y=X-Ai\) ) に帰着されます。

この不定方程式はユークリッドの拡張互除法(以降 extgcd )を用いて解くことができます。

具体的には、 extgcd で \(Bj+Ck=\gcd(B,C)\) を満たす解がひとつ出て、これを定数倍することで右辺が \(Y\%B\) なる解をひとつ得ることが出来ます。( 逆に、これが得られないなら固定した \(i\) に対して解が無いことも分かります。 )
さらに、不足分の \(j\) を調節することで元の不定方程式の解をひとつ得ます。
ここで、この時の各変数の値を long long に収めながら処理できることに注意してください。(詳しくはコード中のリンクも参照してください)

一般解は \(j'=j+tC/\gcd(B,C), k'=k-tB/\gcd(B,C)\) ( \(t\) は整数 ) という形をしているので、これを元に各変数が \(1\) 以上 \(N\) 以下となる \(t\) の数を数え上げれば元の不定方程式の解の数も数え上げることが出来ます。

実装例 (C++):

#include<bits/stdc++.h>
 
using namespace std;
 
// https://math.stackexchange.com/questions/670405/does-the-extended-euclidean-algorithm-always-return-the-smallest-coefficients-of
// https://teratail.com/questions/176282
// https://ei1333.github.io/luzhiled/snippets/math/extgcd.html
template< typename T >
T extgcd(T a, T b, T &x, T &y) {
  T d = a;
  if(b != 0) {
    d = extgcd(b, a % b, y, x);
    y -= (a / b) * x;
  } else {
    x = 1;
    y = 0;
  }
  return d;
}
 
long long llceil(long long a,long long b){
  if(a%b==0){return a/b;}
 
  if(a>=0){return (a/b)+1;}
  else{return -((-a)/b);}
}
 
long long llfloor(long long a,long long b){
  if(a%b==0){return a/b;}
 
  if(a>=0){return (a/b);}
  else{return -((-a)/b)-1;}
}
 
using pl=pair<long long,long long>;
pl findseg(pl seg,long long ini,long long step){
  if(step>0){
    return {llceil(seg.first-ini,step), llfloor(seg.second-ini,step)};
  }
  else{
    step*=-1;
    return {llceil(ini-seg.second,step), llfloor(ini-seg.first,step)};
  }
}
 
int main(){
  long long n,a,b,c,x;
  cin >> n >> a >> b >> c >> x;
 
  // 0 <= i,j,k < n
  x -= (a+b+c);
  if(x<0){cout << "0\n";return 0;}
 
  long long res=0;
  long long gbc=gcd(b,c);
 
  long long stepj=(c/gbc);
  long long stepk=-(b/gbc);
 
  for(long long i=0;i<n;i++){
    long long rx = x-a*i;
    if(rx<0){break;}
 
    // b*j + c*k = rx
    if(rx%gbc){continue;}
    long long j,k;
    extgcd(b,c,j,k);
    j*=((rx%b)/gbc);
    k*=((rx%b)/gbc);
    long long jadd=(rx/b);
    j+=jadd;
 
    pl sj=findseg({0,n-1},j,stepj);
    pl sk=findseg({0,n-1},k,stepk);
    long long fl=max(sj.first,sk.first);
    long long fr=min(sj.second,sk.second);
    res+=max(0ll,fr-fl+1);
  }
  cout << res << "\n";
  return 0;
}

posted:
last update: