LU分解をもとにした最小二乗法、エラーは出ないが求めてる結果と違う

python

1import japanize_matplotlib 2import matplotlib.pyplot as plt 3from decimal import Decimal, getcontext 4 5getcontext().prec = 56 7# CSVファイルの読み込みからのコード8csv_file_path = 'data.csv'9 10with open(csv_file_path, 'r') as file:11 lines = file.readlines()12 13header = lines[0].split()14columns = [col.strip() for col in header]15 16data = [line.split() for line in lines[1:]]17matrix = [[float(value) for value in row] for row in data]18 19# データを月ごとにグループ化して平均を取る20grouped_data = []21temp_sum = [0.0] * len(columns)22count = 023 24for row in matrix:25 for i in range(len(columns)):26 temp_sum[i] += row[i]27 28 count += 129 30 if count % 12 == 0:31 average_row = [value / 12 for value in temp_sum]32 grouped_data.append(average_row)33 temp_sum = [0.0] * len(columns)34 35grouped_data = grouped_data[:-1] # 余分な行を削除36 37odd_grouped_data = []38even_grouped_data = []39 40# 偶数行と奇数行に分ける41for i, row in enumerate(grouped_data):42 if i % 2 == 0:43 even_grouped_data.append(row)44 else:45 odd_grouped_data.append(row)46 47# LU分解48def Crout_func(matrix):49 n = len(matrix)50 L = [[0] * n for _ in range(n)]51 U = [[0] * n for _ in range(n)]52 53 epsilon = Decimal('1e-10') # ゼロに近い値として使用する小さな値54 55 for i in range(n):56 for k in range(i, n):57 sum_LU = Decimal(0)58 for j in range(i):59 sum_LU += L[i][j] * U[j][k]60 U[i][k] = matrix[i][k] - sum_LU 61 62 for k in range(i, n):63 if i == k:64 if abs(U[i][i]) < epsilon:65 U[i][i] = epsilon 66 L[i][i] = Decimal(1)67 else:68 sum_LU = Decimal(0)69 for j in range(i):70 sum_LU += L[k][j] * U[j][i]71 if abs(U[i][i]) < epsilon:72 U[i][i] = epsilon 73 L[k][i] = (matrix[k][i] - sum_LU) / U[i][i]74 75 return n, L, U 76 77# 前進消去78def Zenshin(n, L, B):79 y = [0] * n 80 for i in range(n):81 sum_y = 082 for j in range(i):83 sum_y += L[i][j] * y[j]84 y[i] = B[i] - sum_y 85 return y 86 87# 後退代入88def Koutai(n, U, y):89 x = [0] * n 90 for i in range(n-1, -1, -1):91 sum_x = 092 for j in range(i + 1, n):93 sum_x += U[i][j] * x[j]94 x[i] = (y[i] - sum_x) / U[i][i]95 return x 96 97# 多項式回帰の実行98def Takousiki(n_jisu, x_years, y_values):99 x_years_decimal = [Decimal(x) for x in x_years]100 101 A = []102 for i in range(n_jisu + 1):103 row = []104 for j in range(n_jisu + 1):105 sum_value = Decimal(0)106 for x in x_years_decimal:107 term = x ** (i + j)108 sum_value += term 109 row.append(sum_value)110 A.append(row)111 112 B = []113 for i in range(n_jisu + 1):114 sum_value = Decimal(0)115 for x, y in zip(x_years_decimal, y_values):116 term = x ** i 117 sum_value += Decimal(y) * term 118 B.append(sum_value)119 120 # 対角成分に小さな正則化項を追加121 for i in range(n_jisu + 1):122 A[i][i] += Decimal('1e-6')123 124 n, L, U = Crout_func(A)125 y = Zenshin(n, L, B)126 keisu = Koutai(n, U, y)127 128# 選択した次数を出力129 print(f'次数:{n_jisu}')130 131 # 誤差計算と出力132 error_squared = sum((sum(coeff * x ** i for i, coeff in enumerate(keisu)) - Decimal(y_val)) ** 2 for x, y_val in zip(x_years_decimal, y_values))133 print(f'二乗誤差: {error_squared}')134 135 # 予測結果のプロット136 min_x = int(min(x_years) * 10)137 max_x = int(max(x_years) * 10)138 x_ans = [i * 0.1 for i in range(min_x, max_x + 1)]139 y_ans = [sum(coeff * Decimal(x) ** i for i, coeff in enumerate(keisu)) for x in x_ans]140 141 plt.scatter(x_years, y_values, label='Data Points')142 plt.plot(x_ans, y_ans, label='多項式回帰', color='red')143 plt.ylim(-2000,30000)144 plt.xlabel('X 経常収支(億円)')145 plt.ylabel('Y 金融収支(億円)')146 plt.legend()147 plt.show()148 149# 奇数行と偶数行それぞれに対して多項式回帰を実行150def use_grouped_data(odd_grouped_data, even_grouped_data, n_jisu):151 # odd_grouped_dataを用いて多項式回帰を実行152 # 奇数行のデータを用いて多項式回帰を実行153 x_data_odd = [item[1] for item in odd_grouped_data]154 y_data_odd = [item[2] for item in odd_grouped_data]155 Takousiki(n_jisu, x_data_odd, y_data_odd)156 157 # even_grouped_dataを用いて多項式回帰を実行158 # 偶数行のデータを用いて多項式回帰を実行159 x_data_even = [item[1] for item in even_grouped_data]160 y_data_even = [item[2] for item in even_grouped_data]161 Takousiki(n_jisu, x_data_even, y_data_even)162 163 164# それぞれのデータに対して多項式回帰を実行165# フィッティングデータと検証データを入れ替えて,お互いの二乗誤差がそれぞれ一番小さくなる次数は21であった166use_grouped_data(odd_grouped_data, even_grouped_data,n_jisu=14)167 168 169 170 171

コメントを投稿

0 コメント