博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Matlab图像识别/检索系列(7)-10行代码完成深度学习网络之取中间层数据作为特征...
阅读量:6945 次
发布时间:2019-06-27

本文共 5513 字,大约阅读时间需要 18 分钟。

现在,大家都意识到深度神经网络在图像特征提取方面具有很强的能力,尽管其解释性不强,尽管人们对它的内部原理不十分清楚。那么能不能取出网络中某层数据作为图像特征而进行自己定制的其它处理呢?答案当然是肯定的。在Matlab2017b中,从网络取数据主要有两种方法。一是使用Neural Network Toolboxactivations函数,一是导入网络后直接使用网络某层的名字。

1.使用activations函数

%exam1.m    load digitTrainSet;    %创建CNN网络    layers = [imageInputLayer([28 28 1],'Normalization','none');                        convolution2dLayer(5,20);                        reluLayer();                        maxPooling2dLayer(2,'Stride',2);                        convolution2dLayer(5,16);                        reluLayer();                        maxPooling2dLayer(2,'Stride',2);                        fullyConnectedLayer(256);                        reluLayer();                        fullyConnectedLayer(10);                        softmaxLayer();                        classificationLayer()];    opts = trainingOptions('sgdm');    %训练CNN网络    net = trainNetwork(XTrain,TTrain,layers,opts);    %提取输入X的第6层输出数据    trainFeatures = activations(net,XTrain,6);    %训练多分类模型    svm = fitcecoc(trainFeatures,TTrain);    load digitTestSet;    %提取测试数据的第6层输出数据    testFeatures = activations(net,XTest,6);    %预测测试数据所属类别    testPredictions = predict(svm,testFeatures);        %对测试数据标签进行one-hot编码    ttest = dummyvar(double(TTest))' ;      %对测试数据预测标签进行one-hot编码    tpredictions = dummyvar(double(testPredictions))';    %对混淆矩阵做图    plotconfusion(ttest,tpredictions);    %计算准确率,即实际标签和预测标签相同个数的和/测试数据总数    accuracy = sum(TTest == testPredictions)/numel(TTest);

函数activations的用法是:

features = activations(net,X,layer,Name,Value)

参数中net表示创建的网络,X表示输入数据,layer表示层数,NameValue用来设置参数的值。函数dummyvar来自Statistics and Machine Learning Toolbox,其作用是将每个类别标签转换为只含有0和1的向量,即one-hot编码。如类别1和9分别转换为[0 1 0 0 0 0 0 0 0 0]和[0 0 0 0 0 0 0 0 0 1],这里共有10个类,类标签为0~9,每个类用10个0或1的数字表示,第几类用对应位置数字为1其它为0表示。

2.直接使用层的名字

在Matlab2017b中预置了一些常用深度网络,可以函数的形式直接调用,如alexnet、vgg16、vgg19和googlenet,可在Neural Network ToolboxFunctions中查看。第一次使用需要在Matlab主页工具栏的附加功能中下载。调用形式很简单,代码如下。

%exam2.m    unzip('MerchData.zip');    %创建图像集    images = imageDatastore('MerchData',...            'IncludeSubfolders',true,...            'LabelSource','foldernames');    %划分图像集    [trainingImages,testImages] = splitEachLabel(images,0.7,'randomized');    %获取训练图像总数    numTrainImages = numel(trainingImages.Labels);    %在图像总数中随机取16个数    idx = randperm(numTrainImages,16);    %显示16幅图像    figure    for i = 1:16            subplot(4,4,i)            I = readimage(trainingImages,idx(i));            imshow(I)    end    %调用alexnet网络    net = alexnet;    %设值要用的层为第7个全连接层    layer = 'fc7';    %提取训练图像fc7层数据    trainingFeatures = activations(net,trainingImages,layer);    %提取测试图像fc7层数据    testFeatures = activations(net,testImages,layer);    %拟合训练图像多分类器    classifier = fitcecoc(trainingFeatures,trainingLabels, 'FitPosterior',1);    %预测测试图像的类别标签    predictedLabels = predict(classifier,testFeatures);    %[label,NegLoss,PBScore,Posterior] = predict(classifier,testFeatures);    idx = [1 5 10 15];    figure    for i = 1:numel(idx)            subplot(2,2,i)            I = readimage(testImages,idx(i));            label = predictedLabels(idx(i));            imshow(I);            title(char(label));    end    accuracy = mean(predictedLabels == testLabels);

需要注意的是,要使用的网络层的名字可以在导入网络后,用调试模式查看net变量的值,进一步看网络每层的名字。如下图:

Matlab图像识别/检索系列(7)-10行代码完成深度学习网络之取中间层数据作为特征
然后,查看第20层,如下图:
Matlab图像识别/检索系列(7)-10行代码完成深度学习网络之取中间层数据作为特征
可见其层的名字为‘fc7’。
也可以查看Matlab帮助文档中alexnet的网络结构,或者在Matlab的命令行窗口输入

net = alexnet    net.Layers

结果显示如下:

ans =   25x1 Layer array with layers:     1   'data'     Image Input                   227x227x3 images with 'zerocenter' normalization     2   'conv1'    Convolution                   96 11x11x3 convolutions with stride [4  4] and padding [0  0]     3   'relu1'    ReLU                          ReLU     4   'norm1'    Cross Channel Normalization   cross channel normalization with 5 channels per element     5   'pool1'    Max Pooling                   3x3 max pooling with stride [2  2] and padding [0  0]     6   'conv2'    Convolution                   256 5x5x48 convolutions with stride [1  1] and padding [2  2]     7   'relu2'    ReLU                          ReLU     8   'norm2'    Cross Channel Normalization   cross channel normalization with 5 channels per element     9   'pool2'    Max Pooling                   3x3 max pooling with stride [2  2] and padding [0  0]    10   'conv3'    Convolution                   384 3x3x256 convolutions with stride [1  1] and padding [1  1]    11   'relu3'    ReLU                          ReLU    12   'conv4'    Convolution                   384 3x3x192 convolutions with stride [1  1] and padding [1  1]    13   'relu4'    ReLU                          ReLU    14   'conv5'    Convolution                   256 3x3x192 convolutions with stride [1  1] and padding [1  1]    15   'relu5'    ReLU                          ReLU    16   'pool5'    Max Pooling                   3x3 max pooling with stride [2  2] and padding [0  0]    17   'fc6'      Fully Connected               4096 fully connected layer    18   'relu6'    ReLU                          ReLU    19   'drop6'    Dropout                       50% dropout    20   'fc7'      Fully Connected               4096 fully connected layer    21   'relu7'    ReLU                          ReLU    22   'drop7'    Dropout                       50% dropout    23   'fc8'      Fully Connected               1000 fully connected layer    24   'prob'     Softmax                       softmax    25   'output'   Classification Output         crossentropyex with 'tench', 'goldfish', and 998 other classes

转载于:https://blog.51cto.com/8764888/2053964

你可能感兴趣的文章
我的友情链接
查看>>
自己写的一个javascript首页图片切换组件
查看>>
Linux系统-tcpdump常用抓包命令
查看>>
MySQL 5.7新特性:在线开启和关闭基于GTID的复制
查看>>
XSS研究3-来自内部的XSS***的防范
查看>>
LintCode刷题(First Day) A+B问题
查看>>
zabbix-3.4-快速入门
查看>>
学习笔记TF053:循环神经网络,TensorFlow Model Zoo,强化学习,深度森林,深度学习艺术...
查看>>
trunc函数的用法
查看>>
Python TCP编程 Errno 98: Address already in use
查看>>
python除了利用arrow计算时间之外,还可以用datetime计算
查看>>
Rpm+二进制包+源码包@聊聊
查看>>
数组拷贝之System.arraycopy学习
查看>>
Java实现的有道云笔记图片批量下载工具
查看>>
单例模式你会几种写法?
查看>>
配置Tomcat监听80端口 配置Tomcat虚拟主机 Tomcat日志
查看>>
用笨办法学习编程
查看>>
python 笔记 之 for循环 打印 9x9乘法表
查看>>
Spring Cloud云架构 - SSO单点登录之OAuth2.0登录认证(1)
查看>>
Netty 基础认识 (二)
查看>>