[알고리즘]백준 10830: 행렬 제곱
알고리즘 유형: 수학
문제
크기가 N*N인 행렬 A가 주어진다. 이때, A의 B제곱을 구하는 프로그램을 작성하시오. 수가 매우 커질 수 있으니, A^B의 각 원소를 1,000으로 나눈 나머지를 출력한다.
입력
첫째 줄에 행렬의 크기 N과 B가 주어진다. (2 ≤ N ≤ 5, 1 ≤ B ≤ 100,000,000,000)
둘째 줄부터 N개의 줄에 행렬의 각 원소가 주어진다. 행렬의 각 원소는 1,000보다 작거나 같은 자연수 또는 0이다.
출력
첫째 줄부터 N개의 줄에 걸쳐 행렬 A를 B제곱한 결과를 출력한다.
풀이
필요한 부분은 세가지입니다
1. 행렬곱 구현
2. 나머지만 출력
3. 계산을 빠르게
첫번째로 행렬 곱을 구현하겠습니다. 여기서 나머지를 출력하는 것을 동시에 진행하겠습니다
A, B 행렬의 곱의 결과는 A행렬의 행벡터와 B행렬의 열벡터의 내적과 같습니다
C = A@B(@는 행렬의 곱을 표현하는 연산자라고 하겠습니다)
C의 1번째 행의 1번째 열의 성분을 예로 들어보겠습니다
$C_{11} = rowA_1 \cdot colB_1$입니다 ($rowA_n, colB_n$은 각각 A와 B의 열벡터, 행벡터를 의미합니다)
$C_{12} = rowA_1 \cdot colB_2$입니다
그렇다면 1행의 n번째 성분은
$C_{1n} = rowA_1 \cdot colB_n$이 됩니다
이번에는 n행의 1번째 성분을 보겠습니다
$C_{n1} = rowA_n \cdot colB_1$이 됩니다
뭔가 조금 보이시나요?
$C_{ij} = rowA_i \cdot colB_j$가 됩니다.
그런데 사실 중요한 부분은 우리가 손으로 끄적거릴게 아니기 때문에
코드로 어떻게 구현하는지 입니다
$C_{ij} = rowA_i \cdot colB_j$ 이 식의 내적을 전부 풀어보겠습니다 (편의상 3x3 행렬로 가정하겠습니다)
$C_{ij} = A_{i1}*B{1j}+A_{i2}*B{2j}+A_{i3}*B{3j} = \sum^{3}_{jk =1} A_{ik} * B_{jk}$ 입니다
이를 코드로 구현하면 i, j , k 세가지 변수를 활용한 삼중 for문으로 구현할 수 있습니다
def dot(a:list,b:list):
c =[[] for i in range(N)]
for i in range(N):
for j in range(N):
c_ij = 0
for k in range(N):
c_ij+=a[i][k]*b[k][j]
c[i].append(c_ij%1000)
return c
맨 아래에 %1000을 진행해주는 것을 확인할 수 있는데,
맨 마지막 단계에서 해주는 것이 아니라 중간중간에 이런식으로 나머지 연산을 해주면
계산결과가 달라지는게 아닌가..? 라는 생각이 들 수도 있습니다
그런데 나머지 연산(모듈러 연산)은
C =A*B에서 C mod N = A mod N * B mod N과 같습니다
그렇기 때문에 숫자의 크기를 계속 일정하게 유지하기 위해서 미리미리 모듈러 연산을 진행해주는 것이 좋습니다
마지막으로 연산량에 문제가 있습니다
매번 행렬곱을 반복하기에는 행렬곱 자체가 요구하는 연산량이 너무 많습니다
N x N 행렬의 경우 1개의 원소를 구하기 위해 $N$만큼의 곱셈을 수행해야 하기 때문에
전체로 보면 $N^3$만큼의 곱셉이 필요합니다.
그렇기 때문에 행렬곱을 수행하는 횟수를 가능한 줄여주는 편이 좋습니다
그래서 해야하는 제곱 횟수를 이진수로 변환하여 자릿수가 1인 경우에만 dot연산을 수행하도록 하였습니다
이 경우 최악의 경우 1000억번의 경우에도 연산을 1011101001000011101101110100000000000의 자
릿수에 해당하는 37번(*2)만 수행하면 완료됩니다
그리고 원인은 제대로 파악을 못했는데 [[1000,1000], [1000,1000]]같은 행렬이 들어오는 경우
[[0,0],[0,0]]이 아니라 원래 값이 그대로 출력하는 문제가 있어 마지막에 이를 걸러줄 수 있는 코드를
추가해주었습니다
import sys
N,M = map(int,sys.stdin.readline().strip().split())
A= []
A_zeros=[]
for i in range(N):
A_i= list(map(int, sys.stdin.readline().strip().split()))
A.append(A_i)
A_zeros = [0 for i in range(len(A_i))]
def dot(a:list,b:list):
c =[[] for i in range(N)]
for i in range(N):
for j in range(N):
c_ij = 0
for k in range(N):
c_ij+=a[i][k]*b[k][j]
c[i].append(c_ij%1000)
return c
A_sum =None
for degit in (bin(M)[2:][::-1]):
if degit!="0" and A_sum==None:
A_sum =A
elif degit !="0" and A_sum!=None:
A_sum=dot(A_sum,A)
A =dot(A,A)
for i in range(len(A_sum)):
for j in range(len(A_sum)):
if A_sum[i][j]==1000:
A_sum[i][j]=0
for row in A_sum:
print(*row, sep=" ")