This weekend I followed the wiki to implement the basic big integer multiplication. I use the Toom-3 algorithm to implement. But the time spends unexpectedly at the beginning is slower than long multiplication(grade-school multiplication) and gone forever. I hope the program can over the grade-school multiplication within 500 digits, How should I do, please?
I try to optimize, I reserve the vector capacity and remove the supernumerary code. But is not very effective.
And should I use the vector<long long>
to be my base digits?
The whole source code in Github:
typedef long long BigIntBase;
typedef vector<BigIntBase> BigIntDigits;
// ceil(numeric_limits<BigIntBase>::digits10 / 2.0) - 1;
static const int digit_base_len = 9;
// b
static const BigIntBase digit_base = 1000000000;
class BigInt {
public:
BigInt(int digits_capacity = 0, bool nega = false) {
negative = nega;
digits.reserve(digits_capacity);
}
BigInt(BigIntDigits _digits, bool nega = false) {
negative = nega;
digits = _digits;
}
BigInt(const span<const BigIntBase> &range, bool nega = false) {
negative = nega;
digits = BigIntDigits(range.begin(), range.end());
}
BigInt operator+(const BigInt &rhs) {
if ((*this).negative == rhs.negative)
return BigInt(plus((*this).digits, rhs.digits), (*this).negative);
if (greater((*this).digits, rhs.digits))
return BigInt(minus((*this).digits, rhs.digits), (*this).negative);
return BigInt(minus(rhs.digits, (*this).digits), rhs.negative);
}
BigInt operator-(const BigInt &rhs) { return *this + BigInt(rhs.digits, !rhs.negative); }
BigInt operator*(const BigInt &rhs) {
if ((*this).digits.empty() || rhs.digits.empty()) {
return BigInt();
} else if ((*this).digits.size() == 1 && rhs.digits.size() == 1) {
BigIntBase val = (*this).digits[0] * rhs.digits[0];
return BigInt(val < digit_base ? BigIntDigits{val} : BigIntDigits{val % digit_base, val / digit_base}, (*this).negative ^ rhs.negative);
} else if ((*this).digits.size() == 1)
return BigInt(multiply(rhs, (*this).digits[0]).digits, (*this).negative ^ rhs.negative);
else if (rhs.digits.size() == 1)
return BigInt(multiply((*this), rhs.digits[0]).digits, (*this).negative ^ rhs.negative);
return BigInt(toom3(span((*this).digits), span(rhs.digits)), (*this).negative ^ rhs.negative);
}
string to_string() {
if (this->digits.empty())
return "0";
stringstream ss;
if (this->negative)
ss << "-";
ss << std::to_string(this->digits.back());
for (auto it = this->digits.rbegin() + 1; it != this->digits.rend(); ++it)
ss << setw(digit_base_len) << setfill('0') << std::to_string(*it);
return ss.str();
}
BigInt from_string(string s) {
digits.clear();
negative = s[0] == '-';
for (int pos = max(0, (int)s.size() - digit_base_len); pos >= 0; pos -= digit_base_len)
digits.push_back(stoll(s.substr(pos, digit_base_len)));
if (s.size() % digit_base_len)
digits.push_back(stoll(s.substr(0, s.size() % digit_base_len)));
return *this;
}
private:
bool negative;
BigIntDigits digits;
const span<const BigIntBase> toom3_slice_num(const span<const BigIntBase> &num, const int &n, const int &i) {
int begin = n * i;
if (begin < num.size()) {
const span<const BigIntBase> result = num.subspan(begin, min((int)num.size() - begin, i));
return result;
}
return span<const BigIntBase>();
}
BigIntDigits toom3(const span<const BigIntBase> &num1, const span<const BigIntBase> &num2) {
int i = ceil(max(num1.size() / 3.0, num2.size() / 3.0));
const span<const BigIntBase> m0 = toom3_slice_num(num1, 0, i);
const span<const BigIntBase> m1 = toom3_slice_num(num1, 1, i);
const span<const BigIntBase> m2 = toom3_slice_num(num1, 2, i);
const span<const BigIntBase> n0 = toom3_slice_num(num2, 0, i);
const span<const BigIntBase> n1 = toom3_slice_num(num2, 1, i);
const span<const BigIntBase> n2 = toom3_slice_num(num2, 2, i);
BigInt pt0 = plus(m0, m2);
BigInt pp0 = m0;
BigInt pp1 = plus(pt0.digits, m1);
BigInt pn1 = pt0 - m1;
BigInt pn2 = multiply(pn1 + m2, 2) - m0;
BigInt pin = m2;
BigInt qt0 = plus(n0, n2);
BigInt qp0 = n0;
BigInt qp1 = plus(qt0.digits, n1);
BigInt qn1 = qt0 - n1;
BigInt qn2 = multiply(qn1 + n2, 2) - n0;
BigInt qin = n2;
BigInt rp0 = pp0 * qp0;
BigInt rp1 = pp1 * qp1;
BigInt rn1 = pn1 * qn1;
BigInt rn2 = pn2 * qn2;
BigInt rin = pin * qin;
BigInt r0 = rp0;
BigInt r4 = rin;
BigInt r3 = divide(rn2 - rp1, 3);
BigInt r1 = divide(rp1 - rn1, 2);
BigInt r2 = rn1 - rp0;
r3 = divide(r2 - r3, 2) + multiply(rin, 2);
r2 = r2 + r1 - r4;
r1 = r1 - r3;
BigIntDigits result = r0.digits;
if (!r1.digits.empty()) {
shift_left(r1.digits, i);
result = plus(result, r1.digits);
}
if (!r2.digits.empty()) {
shift_left(r2.digits, i << 1);
result = plus(result, r2.digits);
}
if (!r3.digits.empty()) {
shift_left(r3.digits, i * 3);
result = plus(result, r3.digits);
}
if (!r4.digits.empty()) {
shift_left(r4.digits, i << 2);
result = plus(result, r4.digits);
}
return result;
}
BigIntDigits plus(const span<const BigIntBase> &lhs, const span<const BigIntBase> &rhs) {
if (lhs.empty())
return BigIntDigits(rhs.begin(), rhs.end());
if (rhs.empty())
return BigIntDigits(lhs.begin(), lhs.end());
int max_length = max(lhs.size(), rhs.size());
BigIntDigits result;
result.reserve(max_length + 1);
for (int w = 0; w < max_length; ++w)
result.push_back((lhs.size() > w ? lhs[w] : 0) + (rhs.size() > w ? rhs[w] : 0));
for (int w = 0; w < result.size() - 1; ++w) {
result[w + 1] += result[w] / digit_base;
result[w] %= digit_base;
}
if (result.back() >= digit_base) {
result.push_back(result.back() / digit_base);
result[result.size() - 2] %= digit_base;
}
return result;
}
BigIntDigits minus(const span<const BigIntBase> &lhs, const span<const BigIntBase> &rhs) {
if (lhs.empty())
return BigIntDigits(rhs.begin(), rhs.end());
if (rhs.empty())
return BigIntDigits(lhs.begin(), lhs.end());
BigIntDigits result;
result.reserve(lhs.size() + 1);
for (int w = 0; w < lhs.size(); ++w)
result.push_back((lhs.size() > w ? lhs[w] : 0) - (rhs.size() > w ? rhs[w] : 0));
for (int w = 0; w < result.size() - 1; ++w)
if (result[w] < 0) {
result[w + 1] -= 1;
result[w] += digit_base;
}
while (!result.empty() && !result.back())
result.pop_back();
return result;
}
void shift_left(BigIntDigits &lhs, const int n) {
if (!lhs.empty()) {
BigIntDigits zeros(n, 0);
lhs.insert(lhs.begin(), zeros.begin(), zeros.end());
}
}
BigInt divide(const BigInt &lhs, const int divisor) {
BigIntDigits reminder(lhs.digits);
BigInt result(lhs.digits.capacity(), lhs.negative);
for (int w = reminder.size() - 1; w >= 0; --w) {
result.digits.insert(result.digits.begin(), reminder[w] / divisor);
reminder[w - 1] += (reminder[w] % divisor) * digit_base;
}
while (!result.digits.empty() && !result.digits.back())
result.digits.pop_back();
return result;
}
BigInt multiply(const BigInt &lhs, const int multiplier) {
BigInt result(lhs.digits, lhs.negative);
for (int w = 0; w < result.digits.size(); ++w)
result.digits[w] *= multiplier;
for (int w = 0; w < result.digits.size(); ++w)
if (result.digits[w] >= digit_base) {
if (w + 1 == result.digits.size())
result.digits.push_back(result.digits[w] / digit_base);
else
result.digits[w + 1] += result.digits[w] / digit_base;
result.digits[w] %= digit_base;
}
return result;
}
bool greater(const BigIntDigits &lhs, const BigIntDigits &rhs) {
if (lhs.size() == rhs.size()) {
int w = lhs.size() - 1;
while (w >= 0 && lhs[w] == rhs[w])
--w;
return w >= 0 && lhs[w] > rhs[w];
} else
return lhs.size() > rhs.size();
}
};