2013-05-31 18:15:37Morris

[ZJ][擴充歐幾里德] d374. 6. X^2 ≡ 1 (mod M)

內容 :

給你一個式子X^2 ≡ 1 (mod M) 和M ( 0<X≦M ),請找出所有符合這個式子的X。
例如:M=5,則X 可以等於1 或4;M=8 時,X 可以等於1, 3, 5, 7。
請你寫一個程式,針對每一個M,輸出能滿足這個式子的X。

輸入說明 :

每一個測試檔裡有一個整數即為M,你可以假設M 不會大於2147483647。

輸出說明 :

第一行為一個整數n,代表共有多少組解。
接下來的n 行則為所有滿足此式子的X,並由小到大輸出。

範例輸入 :

15

範例輸出 :

4
1
4
11
14

提示 :

出處 :

97全國能力縣賽 (管理:pcshic)


推公式的流程:

x*x%m = 1
=> x*x = 1 + n*m
=>(x+1)(x-1) = n*m
令 n = np * nq, m = mp * mq
=> x+1 = mp * np, x-1 = mq * nq
=> mpnp - mqnq = 2
try all mp, mq, use extended gcd

窮舉 mp, mq, 然後用參數式窮舉所有可能 np, nq 的值。

先從擴充歐幾理德中得到 a' n1 + b' n2 = gcd(n1, n2)

然後將其右邊變成 n,同乘 n/gcd(n1, n2) => a * n1 + b * n2 = n

a, b 的參數式 a = a + lcm(n1, n2)/n1 * k, b = b + lcm(n1, n2)/n2 * k


#include <stdio.h>
#include <math.h>
#include <set>
using namespace std;
set<long long> ret;
long long exgcd(long long x, long long y, long long &a, long long &b) {
    int flag = 0;// ax + by = gcd(x,y)
    long long t, la = 1, lb = 0, ra = 0, rb = 1;
    while(x%y) {
        if(flag == 0)
            la -= x/y*ra, lb -= x/y*rb;
        else
            ra -= x/y*la, rb -= x/y*lb;
        t = x, x = y, y = t%y;
        flag = 1 - flag;
    }
    if(flag == 0)
        a = ra, b = rb;
    else
        a = la, b = lb;
    return y;
}
void sol(long long mp, long long mq) {
    long long np, nq, g;
    g = exgcd(mp, mq, np, nq);// mp np + mq nq = gcd(mp, mq)
    if(2%g) return;
    long long k = 2/g, k1, k2;
    np *= k, nq *= k; // mp np + mq nq = 2
    k1 = mp/g*mq/mp, k2 = mp/g*mq/mq;
    if(np <= 0) { // adjust a >= 0
        k = -(np/k1) + (np%k1 != 0);
        np += k*k1, nq -= k*k2;
    }
    // maximize np, minimize nq
    k = (-nq)/k2+1;
    np += k*k1;
    nq -= k*k2;
    while(np > 0 && nq <= 0) {
        long long x = mp*np-1;
        if(mq*(-nq)+1 == x)
            ret.insert(x);
        np -= k1;
        nq += k2;
    }
}
int main() {
    int m, i;
    while(scanf("%d", &m) == 1) {
        ret.clear();
        long long mp, mq;
        long long sq = sqrt(m);
        for(i = 1; i <= sq; i++) {
            if(m%i == 0) {
                mp = i, mq = m/i;
                sol(mp, mq);
                mp = m/i, mq = i;
                sol(mp, mq);
            }
        }
        set<long long> ans;
        for(set<long long>::iterator it = ret.begin();
            it != ret.end(); it++) {
            if(*it >= 1 && *it < m) {
                ans.insert(*it);
                if(*it != m)
                    ans.insert(m-(*it));
            }
        }
        printf("%d\n", ans.size());
        for(set<long long>::iterator it = ans.begin();
            it != ans.end(); it++)
            printf("%lld\n", *it);
    }
    return 0;
}

/*
x*x%m = 1
x*x = 1 + n*m
(x+1)(x-1) = n*m
n = np * nq, m = mp * mq
x+1 = mp * np, x-1 = mq * nq
=> 2 = mpnp - mqnq
try all mp, mq, use extended gcd
*/