Official

C - digitnum Editorial by en_translator


First, let’s take a look at the values of \(f(x)\).

  • \(f(1)=1,f(2)=2, \dots ,f(9)=9\)
  • \(f(10)=1,f(11)=2, \dots , f(99)=90\)
  • \(f(100)=1,f(101)=2, \dots , f(999)=900\)
  • \(\dots\)

Did you grasp the rule?
Now, let’s actually find \(f(1)+f(2)+\dots+f(N)\).
The following way based on the observation above would be simple.

  • If there are \(K\) integers between \(1\) and \(\min(10-1,N)\) (inclusive), add \(1+2+\dots+K\) to the answer.
  • If there are \(K\) integers between \(10\) and \(\min(100-1,N)\) (inclusive), add \(1+2+\dots+K\) to the answer.
  • If there are \(K\) integers between \(100\) and \(\min(1000-1,N)\) (inclusive), add \(1+2+\dots+K\) to the answer.
  • \(\dots\)

In general, we can compute as follows.

  • For each integer \(k\) from \(1\) through \(18\), if there are \(K\) integers between \(10^{k-1}\) and \(\min(10^k-1,N)\), then add \(1+2+\dots+K\) to the answer.

Finally, how can we find \(1+2+\dots+K\)? Note that \(K \le 9 \times 10^{17}\) and we have to find the answer modulo \(998244353\) (which will be denoted by \(M\)).
First of all, forget that we have to find the remainder; then we have \(1+2+\dots+K=\frac{K\times(K+1)}{2}\). First, the remainder of \(K \times (K+1)\) by \(M\) coincides to remainder by \(M\) of (remainder of \(K\) by \(M\)) \(\times\) (remainder of \((K+1)\) by \(M\)). Then, how can we achieve the operation of dividing by \(2\)?
In fact, in the world of modulo \(998244353\), dividing by \(2\) and multiplying by \(499122177\) are equivalent. (Specifically, \(499122177\) is the inverse of \(2\) modulo \(998244353\). For more information, refer to this article (in Japanese) or other articles related to modular multiplicative inverse.)

Similar problem:
ARC127-A

Sample code (C++):

#include<bits/stdc++.h>
#define mod 998244353
#define inv2 499122177

using namespace std;

long long triangular_number(long long x){
  x%=mod;
  long long res=x;
  res*=(x+1);res%=mod;
  res*=inv2;res%=mod;
  return res;
}

int main(){
  long long n;
  cin >> n;
  
  long long res=0;
  long long p10=10;
  for(int dg=1;dg<=18;dg++){
    long long l=p10/10;
    long long r=min(n,p10-1);
    if(l<=r){res+=triangular_number(r-l+1);res%=mod;}
    p10*=10;
  }
  
  cout << res << '\n';
  return 0;
}

posted:
last update: