1 minute read

1. Import Library

import numpy as np
import torch
import torch.utils.data
import torch.nn as nn # neural network
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image

2. Tool:Image show and Save

import matplotlib.pyplot as plt

# 이미지 출력하고 세이브해주는 함수
def show_adn_save(file_name,img):
  npimg = np.transpose(img.numpy(),(1,2,0))
  f = "./%s.png"%file_name
  plt.imshow(npimg)
  plt.imsave(f,npimg)

3. Define the class for RBM

# 출처 : https://blog.paperspace.com/beginners-guide-to-boltzmann-machines-pytorch/
class RBM(nn.Module):
   def __init__(self,
               n_vis=784,
               n_hin=500,
               k=5):
        super(RBM, self).__init__()
        self.W = nn.Parameter(torch.randn(n_hin,n_vis)*1e-2)
        self.v_bias = nn.Parameter(torch.zeros(n_vis))
        self.h_bias = nn.Parameter(torch.zeros(n_hin))
        self.k = k
    
   def sample_from_p(self,p):
       return F.relu(torch.sign(p - Variable(torch.rand(p.size()))))
    
   def v_to_h(self,v):
        p_h = F.sigmoid(F.linear(v,self.W,self.h_bias))
        sample_h = self.sample_from_p(p_h)
        return p_h,sample_h
    
   def h_to_v(self,h):
        p_v = F.sigmoid(F.linear(h,self.W.t(),self.v_bias))
        sample_v = self.sample_from_p(p_v)
        return p_v,sample_v
        
   def forward(self,v):
        pre_h1,h1 = self.v_to_h(v)
        
        h_ = h1
        for _ in range(self.k):
            pre_v_,v_ = self.h_to_v(h_)
            pre_h_,h_ = self.v_to_h(v_)
        
        return v,v_
    
   def free_energy(self,v):
        vbias_term = v.mv(self.v_bias)
        wx_b = F.linear(v,self.W,self.h_bias)
        hidden_term = wx_b.exp().add(1).log().sum(1)
        return (-hidden_term - vbias_term).mean()


4. Load MNIST Dataset

batch_size = 64
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./MNIST_data',train=True,download=True,
                   transform=transforms.Compose([transforms.ToTensor()])),
    batch_size = batch_size)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./MNIST_data',train = False,
                   transform = transforms.Compose([transforms.ToTensor()])),
    batch_size = batch_size)

출력

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz



  0%|          | 0/9912422 [00:00<?, ?it/s]


Extracting ./MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz



  0%|          | 0/28881 [00:00<?, ?it/s]


Extracting ./MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz



  0%|          | 0/1648877 [00:00<?, ?it/s]


Extracting ./MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz



  0%|          | 0/4542 [00:00<?, ?it/s]


Extracting ./MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST_data/MNIST/raw

5. Main - RBM Train and Test

rbm = RBM(k=1) # CD 반복하는 횟수
train_op = optim.Adam(rbm.parameters(),0.005) # learning rate = 0.005

for epoch in range(10):
  loss_ = []
  for _,(data,target) in enumerate(train_loader): # Stochastic Gradient Descent
    data = Variable(data.view(-1,784))
    sample_data = data.bernoulli()

    v,v1 = rbm(sample_data)
    loss = rbm.free_energy(v) - rbm.free_energy(v1)
    loss_.append(loss.data.item())
    train_op.zero_grad()
    loss.backward()
    train_op.step()
  print(np.mean(loss_))

출력
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:1960: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead. warnings.warn(“nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.”)

-4.478195336835979
-1.164984371616388
0.8052755002019756
1.7930132575126598
2.2253803505317995
2.5580441733158983
2.7660305779625864
2.891407460292011
3.0023812820662314
3.0375535391541177

6. Test

testset = datasets.MNIST('./MNIST_data', train=False, transform=transforms.Compose([transforms.ToTensor()]))

sample_data = testset.data[:32,:].view(-1,784) # 총 32개 Sample Load
sample_data = sample_data.type(torch.FloatTensor)/255.

v,v1 = rbm(sample_data)
show_adn_save("real_testdata",make_grid(v.view(32,1,28,28).data))

출력

output_12_0

show_adn_save("generated_testdata",make_grid(v1.view(32,1,28,28).data))

출력

output_13_0

Leave a comment