Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

problems/regularisationの追加 #19

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions backend/static/problems/reguralarisation/description.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
過学習に対する措置として正則化がある.正則化ではコスト関数に正則化項を加えることで過学習を抑えることができる.正則化項には有名なものとして重みのノルムを用いたL1正則化や,その2乗を用いたL2正則化がある.$\lambda$は正則化パラメータと呼ばれる.L2正則化ではコスト関数を$w$で微分した時に扱いやすいように,$\lambda/2$を正則化項とすることが多い.

予測値として$\hat{\bm{y}}\in \mathbb{R}^{t}$,実測値として$\bm{y} \in \mathbb{R}^{t}$が与えられているとする.この時$\bm{y}$はone-hotエンコンディング表現で$t\times 1$次元となっている.
(例)
$$
\hat{\bm{y}} = \begin{pmatrix}0.2\\0.1\\0.9\\\vdots\end{pmatrix}, \bm{y}=\begin{pmatrix}0\\0\\1\\\vdots\end{pmatrix}
$$

この時コスト関数が正則化項$\lambda\|\bm{w}\|^2/2$($\bm{w}$は重み)を加えて
$$
J(\bm{w})=-\left(-\sum_{j=1}^t y_j\log(\hat{y}_j)+ (1-y_j)\log(1-\hat{y}_j)\right)+\frac{\lambda}{2}\|\bm{w}\|^2
$$
であたえられるとする.

これを元に,$\bm{x}\in\mathbb{R}^m$,$\bm{y}\in \mathbb{R}^{t}$が与えられ,それに対する予測値が$\hat{\bm{y}}=\text{softmax}(\bm{w}\bm{x})$で定義される時,$\lambda=1$として正則化項を加えたコスト関数の最小値を求めよ.ただし$\bm{w}$の初期値は$\bm{0}$とする.

## 制約
- $1 \leq m \leq 20$
- $1 \leq t \leq 5$
- $-100 \leq x_i\leq 100$ ($1 \leq i \leq m$)
- $y_i \in \{0, 1\}$

## 入力
入力は以下の形式で標準入力から与えられます。

```plaintext
m t
x_1 x_2 ... x_m
y_1 y_2 ... y_t
```
従って以下のコードで入力を受け取れます.
```python3
m, t = map(int, input().split())
x = torch.tensor(list(map(float, input().split()))).reshape(-1,1)
y = torch.tensor(list(map(int, input().split()))).reshape(-1, 1)
```
## 出力
最小化された$J(\bm{w})$を出力してください.想定界との絶対誤差が$10^{-3}$以下の時に正解と判定されます.
```plaintext
J_w
```

## サンプル
### サンプル1
#### 入力
```plaintext
4 3
0 1 2 3
0 1 0
```

#### 出力
```plaintext
0.39896008372306824
```
### サンプル2
#### 入力
```plaintext
6 4
1 3 2 4 3 5
0 1 0 0
```
#### 出力
```plaintext
0.1829862892627716
```
3 changes: 3 additions & 0 deletions backend/static/problems/reguralarisation/in/01.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
10 5
1 3 1 3 1 6 7 2 4 2
0 1 0 0 0
3 changes: 3 additions & 0 deletions backend/static/problems/reguralarisation/in/02.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
7 5
2 3 1 6 1 3 0
0 1 0 0 0
3 changes: 3 additions & 0 deletions backend/static/problems/reguralarisation/in/03.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
6 4
1 1 0 2 7 3
0 0 0 1
3 changes: 3 additions & 0 deletions backend/static/problems/reguralarisation/in/04.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
7 3
1 2 0 1 2 4 9
0 0 1
3 changes: 3 additions & 0 deletions backend/static/problems/reguralarisation/in/05.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
8 2
1 2 1 0 4 7 2 1
0 1
3 changes: 3 additions & 0 deletions backend/static/problems/reguralarisation/in/06.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
8 4
0 1 1 3 4 5 6 3
0 0 1 0
1 change: 1 addition & 0 deletions backend/static/problems/reguralarisation/out/01.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0.12536248564720154
1 change: 1 addition & 0 deletions backend/static/problems/reguralarisation/out/02.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0.21797901391983032
1 change: 1 addition & 0 deletions backend/static/problems/reguralarisation/out/03.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0.18298643827438354
1 change: 1 addition & 0 deletions backend/static/problems/reguralarisation/out/04.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0.1042860746383667
1 change: 1 addition & 0 deletions backend/static/problems/reguralarisation/out/05.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0.08774268627166748
1 change: 1 addition & 0 deletions backend/static/problems/reguralarisation/out/06.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0.1365353912115097
18 changes: 18 additions & 0 deletions backend/static/problems/reguralarisation/problem.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# 概要
summary:
title: Regularisation
points: 100
section: 5

# 制約
constraints:
# 実行時間制限 (ミリ秒)
time: 2000
# メモリ制限 (MB)
memory: 256
# 誤差ジャッジをするか?
error_judge: true
# 許容絶対誤差
absolute_error: 1e-3
# 許容相対誤差
relative_error: 1e-3
21 changes: 21 additions & 0 deletions backend/static/problems/reguralarisation/solution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Solution
import torch

# 入力の受け取り
m, t = map(int, input().split())
x = torch.tensor(list(map(float, input().split()))).reshape(-1,1)
y = torch.tensor(list(map(int, input().split()))).reshape(-1, 1)

# パラメータの初期化
w = torch.zeros((t,m), requires_grad=True)

# 最小化
for i in range(100):
pred = torch.softmax(torch.matmul(w, x), dim=0)
J_w = -torch.sum(y * torch.log(pred) + (1-y) * torch.log(1-pred)) + 0.5 * torch.norm(w)**2
J_w.backward()
with torch.no_grad():
w -= 0.1 * w.grad
w.grad.zero_()
# 出力
print(J_w.item())