[학부수업] Restricted Boltzman Machine 실습 (12-2)
1. Import Library
import numpy as np
import torch
import torch.utils.data
import torch.nn as nn # neural network
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image
2. Tool:Image show and Save
import matplotlib.pyplot as plt
# 이미지 출력하고 세이브해주는 함수
def show_adn_save(file_name,img):
npimg = np.transpose(img.numpy(),(1,2,0))
f = "./%s.png"%file_name
plt.imshow(npimg)
plt.imsave(f,npimg)
3. Define the class for RBM
# 출처 : https://blog.paperspace.com/beginners-guide-to-boltzmann-machines-pytorch/
class RBM(nn.Module):
def __init__(self,
n_vis=784,
n_hin=500,
k=5):
super(RBM, self).__init__()
self.W = nn.Parameter(torch.randn(n_hin,n_vis)*1e-2)
self.v_bias = nn.Parameter(torch.zeros(n_vis))
self.h_bias = nn.Parameter(torch.zeros(n_hin))
self.k = k
def sample_from_p(self,p):
return F.relu(torch.sign(p - Variable(torch.rand(p.size()))))
def v_to_h(self,v):
p_h = F.sigmoid(F.linear(v,self.W,self.h_bias))
sample_h = self.sample_from_p(p_h)
return p_h,sample_h
def h_to_v(self,h):
p_v = F.sigmoid(F.linear(h,self.W.t(),self.v_bias))
sample_v = self.sample_from_p(p_v)
return p_v,sample_v
def forward(self,v):
pre_h1,h1 = self.v_to_h(v)
h_ = h1
for _ in range(self.k):
pre_v_,v_ = self.h_to_v(h_)
pre_h_,h_ = self.v_to_h(v_)
return v,v_
def free_energy(self,v):
vbias_term = v.mv(self.v_bias)
wx_b = F.linear(v,self.W,self.h_bias)
hidden_term = wx_b.exp().add(1).log().sum(1)
return (-hidden_term - vbias_term).mean()
4. Load MNIST Dataset
batch_size = 64
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('./MNIST_data',train=True,download=True,
transform=transforms.Compose([transforms.ToTensor()])),
batch_size = batch_size)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('./MNIST_data',train = False,
transform = transforms.Compose([transforms.ToTensor()])),
batch_size = batch_size)
출력
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/9912422 [00:00<?, ?it/s]
Extracting ./MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0/28881 [00:00<?, ?it/s]
Extracting ./MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/1648877 [00:00<?, ?it/s]
Extracting ./MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/4542 [00:00<?, ?it/s]
Extracting ./MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST_data/MNIST/raw
5. Main - RBM Train and Test
rbm = RBM(k=1) # CD 반복하는 횟수
train_op = optim.Adam(rbm.parameters(),0.005) # learning rate = 0.005
for epoch in range(10):
loss_ = []
for _,(data,target) in enumerate(train_loader): # Stochastic Gradient Descent
data = Variable(data.view(-1,784))
sample_data = data.bernoulli()
v,v1 = rbm(sample_data)
loss = rbm.free_energy(v) - rbm.free_energy(v1)
loss_.append(loss.data.item())
train_op.zero_grad()
loss.backward()
train_op.step()
print(np.mean(loss_))
출력
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:1960: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
warnings.warn(“nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.”)
-4.478195336835979
-1.164984371616388
0.8052755002019756
1.7930132575126598
2.2253803505317995
2.5580441733158983
2.7660305779625864
2.891407460292011
3.0023812820662314
3.0375535391541177
6. Test
testset = datasets.MNIST('./MNIST_data', train=False, transform=transforms.Compose([transforms.ToTensor()]))
sample_data = testset.data[:32,:].view(-1,784) # 총 32개 Sample Load
sample_data = sample_data.type(torch.FloatTensor)/255.
v,v1 = rbm(sample_data)
show_adn_save("real_testdata",make_grid(v.view(32,1,28,28).data))
출력
show_adn_save("generated_testdata",make_grid(v1.view(32,1,28,28).data))
출력
Leave a comment