pytorch를 keras로 변환하기 위해 여러가지 패키지를 활용해보고 진행해보았으나 패키지 의존성문제도 심하고 CPU와 GPU 사이의 문제나, 특정 레이어를 지원하지 않거나 인코딩 문제가 발생(MaxPool)하거나, 특정 옵션에서 에러가 발생하기도하고(Conv2d Padding options), pytorch의 명명규칙이 keras와 다른 등의 이유로 오류가 많이 발생했다. 우선은 pytorch가 onnx와 호환이 좋지 못한 것 같았다. 우선은 이러한 호환문제를 최소화 하고자 pytorch모델과 keras 모델을 각각 생성한 뒤 weight를 교환하는 방식이 가장 이상적일 것이라 판단되어 진행해보았다. 결과는 다소 아쉬웠다. 생각보다는 오차가 조금 크게 나타났다.

Torch 모델 생성

모델은 가장 익숙할 것으로 생각되는 손글씨 숫자 이미지를 기반으로 0~9를 분류하는 예제인  MNIST 모델을 활용하였다.

학습을 위한 제너레이터 생성

import torch
import tensorflow as tf
from torchsummary import summary as summary_
from torchvision import datasets, transforms
from matplotlib import pyplot as plt
import numpy as np
device = torch.device(
	'cuda' if torch.cuda.is_available() else 'cpu')

train_data = datasets.MNIST(
	root = './data/02/',
	train=True,
    download=True,
    transform=transforms.ToTensor())
test_data = datasets.MNIST(
	root = './data/02/',
    train=False,
    download=True,
    transform=transforms.ToTensor())

batch_size = 50 ; learning_rate = 0.0001
epoch_num = 4

train_loader = torch.utils.data.DataLoader(
	dataset=train_data,
    batch_size = batch_size, shuffle = True)
test_loader = torch.utils.data.DataLoader(
	dataset=test_data,
    batch_size = batch_size, shuffle = True)

print('number of training data : ', len(train_data))
print('number of test data : ', len(test_data))

# sample data
image, label = train_data[0]
image=image.squeeze().numpy()
plt.imshow(image, cmap='gray')

Torch 기반 CNN 모델 구축

Torch는 자료를 기본적으로 Batch_size, Channel, Height, Width순으로 받으며 tensorflow는 Batch_size, Height, Weight, Channel 순으로 자료를 받는다. 이를 참고하여 본인의 모델을 만들도록 하자. 현재 1, 32, 3, 1의 의미는 1개의 채널을 입력으로 받아 32개의 채널을 돌려주며, 커널크기가 3이며, stride가 1이라는 의미이다.

class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # N, C, H, W = 
			# Batch_size, Channel, height, width
        # tensorflow = N ,H, W, C
        self.conv1 = torch.nn.Conv2d(
        	1, 32, 3, 1, padding='same')
        self.conv2 = torch.nn.Conv2d(
        	32, 64, 3, 1, padding='same')
        self.dropout = torch.nn.Dropout2d(0.25)
        # (입력 뉴런, 출력 뉴런)
        # 7 * 7 * 64 = 3136
        self.fc1 = torch.nn.Linear(3136, 1000)    
        self.fc2 = torch.nn.Linear(1000, 10)
    
    def forward(self, x):
        x = self.conv1(x)
        x = torch.nn.functional.relu(x)
        x = torch.nn.functional.max_pool2d(x, 2)
        x = self.conv2(x)
        x = torch.nn.functional.relu(x)
        x = torch.nn.functional.max_pool2d(x, 2)
        x = self.dropout(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = torch.nn.functional.relu(x)
        x = self.fc2(x)
        output = torch.nn.functional.log_softmax(x, dim=1)
        return output

torch_model = CNN().to(device)
optimizer = torch.optim.Adam(
	torch_model.parameters(), lr = learning_rate)
criterion = torch.nn.CrossEntropyLoss()
torch_model.eval()

Torch 기반 CNN 모델 학습

torch_model.train()
i = 1
for epoch in range(epoch_num):
    for data, target in train_loader:
        data = data.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        output = torch_model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if i % 1000 == 0:
            print(
            "Train Step : {}\tLoss : {:3f}".format(i, loss.item()))
        i += 1

torch_model.eval()

keras 기반 CNN 모델 구축

inputs=tf.keras.layers.Input((1,28,28))

conv1=tf.keras.layers.Conv2D(
    filters=32, kernel_size=3, strides=1, 
    padding="same", data_format='channels_first',
    activation='relu')(inputs)
# act1=tf.keras.layers.ReLU()(conv1)
pool1=tf.keras.layers.MaxPool2D((2,2), 
	data_format='channels_first')(conv1)
conv2=tf.keras.layers.Conv2D(
    filters=64, kernel_size=3, strides=1, 
    padding="same", data_format='channels_first',
    activation='relu')(pool1)
# act2=tf.keras.layers.ReLU()(conv2)
pool2=tf.keras.layers.MaxPool2D((2,2), 
	data_format='channels_first')(conv2)
drop1=tf.keras.layers.Dropout(.25)(pool2)
flat1=tf.keras.layers.Flatten()(drop1)
dense1=tf.keras.layers.Dense(1000,
	activation='relu')(flat1)
# act3=tf.keras.layers.ReLU()(dense1)
outputs=tf.keras.layers.Dense(10,
	activation='softmax')(dense1)
# outputs=tf.keras.layers.Softmax()(dense2)
keras_model= tf.keras.Model(inputs,outputs)

weight 이동(torch -> keras)

torch와 keras의 weight 의 shape이 다르므로, 동일하게 변경해주자.  

layer_info=torch_model.state_dict()
layer_names=[i for i in layer_info]

print(keras_model.layers[1],
		'\n',layer_names[0],layer_names[1])
        
keras_model.layers[1].set_weights(
	(
    	np.transpose(
    		layer_info[layer_names[0]].detach(
        	).cpu().numpy(),
        axes=(2,3,1,0)),
    	layer_info[layer_names[1]].detach(
        ).cpu().numpy()[::]
    )
)

feature Map을 통해 확인

오차는 0.000001 정도부터 발생하는 것으로 보인다. 소숫점 문제등으로 발생했으리라 추정되는데 다소 크게 느껴진다.

keras_input=image[np.newaxis,np.newaxis]
torch_input=torch.from_numpy(
	image[np.newaxis,np.newaxis]).to(device)

torch_pred=torch_model.conv1(
	torch_input)[0,0,:,:].detach().cpu().numpy()
keras_pred=keras_model.layers[1](keras_input)[0,0,:,:]

plt.imshow(torch_pred,clim=(0,1))

plt.imshow(keras_pred,clim=(0,1))

plt.imshow((torch_pred-keras_pred).numpy(),
	clim=(0,.000001))

채널을 last로 설정

inputs=tf.keras.layers.Input((28,28,1))
conv1=tf.keras.layers.Conv2D(
    filters=32, kernel_size=3, strides=1, 
    padding="same",activation='relu')(inputs)
# act1=tf.keras.layers.ReLU()(conv1)
pool1=tf.keras.layers.MaxPool2D((2,2))(conv1)
conv2=tf.keras.layers.Conv2D(
    filters=64, kernel_size=3, strides=1, 
    padding="same",activation='relu')(pool1)
# act2=tf.keras.layers.ReLU()(conv2)
pool2=tf.keras.layers.MaxPool2D((2,2))(conv2)
drop1=tf.keras.layers.Dropout(.25)(pool2)
flat1=tf.keras.layers.Flatten()(drop1)
dense1=tf.keras.layers.Dense(1000,activation='relu')(flat1)
# act3=tf.keras.layers.ReLU()(dense1)
outputs=tf.keras.layers.Dense(10,activation='softmax')(dense1)
# outputs=tf.keras.layers.Softmax()(dense2)
keras_model= tf.keras.Model(inputs,outputs)


print(keras_model.layers[1],'\n',
	layer_names[0],layer_names[1])
keras_model.layers[1].set_weights((
    np.transpose(layer_info[
    	layer_names[0]].detach().cpu().numpy(),
                 axes=(2,3,1,0)),
    layer_info[layer_names[1]].detach().cpu().numpy()[::]))
    
keras_input=image[np.newaxis,:,:,np.newaxis]
torch_input=torch.from_numpy(
	image[np.newaxis,np.newaxis]).to(device)

torch_pred=torch_model.conv1(
	torch_input)[0,0,:,:].detach().cpu().numpy()
keras_pred=keras_model.layers[1](keras_input)[0,:,:,0]

plt.imshow(torch_pred,clim=(0,1))
plt.imshow(keras_pred,clim=(0,1))