查看原文
其他

PyTorch 教程-针对风格转换的优化过程

点击关注👉 Python架构师 2024-02-11
整理:python架构师

我们现在已经有了三张图片,现在可以执行我们的优化过程。为了执行优化过程,我们需要按照以下步骤进行操作:

步骤1:

在第一步中,我们定义了一些基本参数,这些参数帮助我们可视化训练过程并简化训练过程。第一个参数在每次迭代时显示目标图像,以便我们可以检查优化过程。我们使用Adam优化器定义了目标图像,并设置了学习率。最后但同样重要的是,我们定义了训练过程应该进行的优化步骤数。

由于训练过程可能需要很长时间才能完成,因此我们需要在结果和时间效率之间保持平衡。因此,我们将定义我们的步骤,并在我们的情况下将步骤限制为2100。
show_every=300 optimizer=optim.Adam([target],lr=0.003) steps=2100

步骤2:

现在,我们实现了一些用于数据可视化的代码行。我们定义了一个图像数组,该数组将在训练过程中存储目标图像。在训练过程结束后,我们可以通过这些图像创建一个视频,以了解样式和内容图像如何结合优化目标图像。我们将解开目标图像的形状。
height,width,channels=im_convert(target).shape image_array=np.empty(shape=(300,height,width,channels))
我们将定义一个捕获帧,它帮助我们每次捕获一帧。最后,我们将定义一个计数器变量,它将跟踪数组索引。
capture_frame=steps/300 counter=0

资源分享

👉点击领取:最全Python资料合集

优化迭代过程

#Defining a loop statement from 1 to steps+1 for ii in range(1,steps+1): #To ensure that our loop runs for the defined number of steps # Extracting feature for our current target image target_features=get_features(target,vgg) #Calculating the content loss for the iteration content_loss=torch.mean((target_features['conv4_2']content_features['conv4_2'])**2) #Initializing style loss style_loss=0 #The style loss is the result of a combine loss from five different layer within our model. #For this reason we iterate through the five style features to get the error at each layer. for layer in style_weights: #Collecting the target feature for the specific layer from the target feature variable target_feature=target_features[layer] #Applying gram matrix function to our target feature target_gram=gram_matrix(target_feature) #Getting style_gram value for our style image from the style grams variable style_gram=style_grams[layer] #Calculating the layer style loss as content loss layer_style_loss=style_weights[layer]*torch.mean((target_gram-style_gram)**2) #Obtaining feature dimensions _,d,h,w=target_feature.shape #Calculating total style loss style_loss += layer_style_loss/(d*h*w) #Calculating total loss total_loss=content_weight*content_loss+style_weight*style_loss #Using the optimizer to update parameters within our target image optimizer.zero_grad() total_loss.backward() optimizer.step() #Process for visualization throughout the training process #Comparing the iteration variable with our show every if ii % show_every==0: #Printing total loss print('Total loss:',total_loss.item()) #Printing the iteration print('Iteration',ii) #Printting the target images plt.imshow(im_convert(target)) #Removing the axis on the image plt.axis('off') # Showing image plt.show() #Comparing the iteration variable with our capture frame variable if ii%capture_frame==0: # Capturing a frame at every 700 iteration #Storing the target image into the image_array image_array[counter]=im_convert(target) # Increment in the counter variable counter=counter+1

当我们运行代码时,它将给出预期的输出:

绘制内容、样式和最终目标图像

#Making a grid arrangement with a single row and three columns for our three images fig,(ax1,ax2,ax3)=plt.subplots(1,3,figsize=(20,10)) #Plotting content image ax1.imshow(im_convert(content)) ax1.axis('off') #Plotting style image ax2.imshow(im_convert(style)) ax2.axis('off') #Plotting target image ax3.imshow(im_convert(target)) ax3.axis('off')

完整代码

#Required Libraries import torch import torch.optim as optim from torchvision import transforms, models from PIL import Image import matplotlib.pyplot as plt import numpy as np #Creating Model vgg=models.vgg19(pretrained=True).features for param in vgg.parameters(): param.requires_grad_(False) #Add model to device device=torch.device("cuda" if torch.cuda.is_available() else "cpu") vgg.to(device) #Load Iamge def load_image(img_path,max_size=400,shape=None): image=Image.open(img_path).convert('RGB') if max(image.size)>max_size: size=max_size else: size=max(image.size) if shape is not None: size=shape in_transform=transforms.Compose([ transforms.Resize(size), transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]) image=in_transform(image).unsqueeze(0) return image content=load_image('ab.jpg').to(device) style=load_image('abc.jpg',shape=content.shape[-2:]).to(device) #Image Conversion def im_convert(tensor): image=tensor.cpu().clone().detach().numpy() image=image.squeeze() image=image.transpose(1,2,0) image=image*np.array((0.5,0.5,0.5))+np.array((0.5,0.5,0.5)) image=image.clip(0,1) return image #Plotting Images fig, (ax1,ax2)=plt.subplots(1,2,figsize=(20,10)) ax1.imshow(im_convert(content)) ax1.axis('off') ax2.imshow(im_convert(style)) ax2.axis('off') #Getting Features def get_features(image,model): layers={'0':'conv1_1', '5':'conv2_1', '10':'conv3_1', '19':'conv4_1', '21':'conv4_2', '28':'conv5_1',} features={} for name, layer in model._modules.items(): image=layer(image) if name in layers: features[layers[name]]=image return features #Making content and style features content_features=get_features(content,vgg) style_features=get_features(style, vgg) #Creating gram matrix def gram_matrix(tensor): _,d,h,w=tensor.size() tensor=tensor.view(d,h*w) gram=torch.mm(tensor,tensor.t()) return gram #Creating style grams style_grams={layer:gram_matrix(style_features[layer]) for layer in style_features} #Initializing style weights style_weights={'conv1_1':1., 'conv2_1':0.75, 'conv3_1':0.2, 'conv4_1':0.2, 'conv5_1':0.2} content_weight=1 style_weight=1e6 target=content.clone().requires_grad_(True).to(device) #Performing optimization show_every=300 optimizer=optim.Adam([target],lr=0.003) steps=2100 height,width,channels=im_convert(target).shape image_array=np.empty(shape=(300,height,width,channels)) capture_frame=steps/300 counter=0 for ii in range(1,steps+1): target_features=get_features(target,vgg) content_loss=torch.mean((target_features['conv4_2']-content_features['conv4_2'])**2) style_loss=0 for layer in style_weights: target_feature=target_features[layer] target_gram=gram_matrix(target_feature) style_gram=style_grams[layer] layer_style_loss=style_weights[layer]*torch.mean((target_gram-style_gram)**2) _,d,h,w=target_feature.shape style_loss += layer_style_loss/(d*h*w) total_loss=content_weight*content_loss+style_weight*style_loss optimizer.zero_grad() total_loss.backward() optimizer.step() #Plotting output images if ii % show_every==0: print('Total loss:',total_loss.item()) print('Iteration',ii) plt.imshow(im_convert(target)) plt.axis('off') plt.show() if ii%capture_frame==0: image_array[counter]=im_convert(target) counter=counter+1 #Plotting content, style and target images fig,(ax1,ax2,ax3)=plt.subplots(1,3,figsize=(20,10)) ax1.imshow(im_convert(content)) ax1.axis('off') ax2.imshow(im_convert(style)) ax2.axis('off') ax3.imshow(im_convert(target)) ax3.axis('off')

输出:


 
热门推荐

继续滑动看下一个

PyTorch 教程-针对风格转换的优化过程

点击关注👉 Python架构师
向上滑动看下一个

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存