当前位置:   article > 正文

备战数学建模50-终结篇(攻坚站15)_数学建模降噪题目

数学建模降噪题目

今天应该数学建模的最后一篇博文了,我们好好梳理一下,对缺少的知识点做一个汇总,希望我们在国赛能取得一个好成绩,也希望看到这篇博客的同学都有好运,人生中的每一段旅程都有意义,希望我们都能享受过程并取得一个满意的结果,道阻且长,行则将至,向上吧,年轻人!

目录

一、互信息

1.1、互信息基本概念

1.2、互信息计算与matlab实现 

二、Mann-Kendall 检验

2.1、M-K检验的理论知识

2.2、M-K检验的matlab实现

三、小波分析

3.1、小波分析基本理论

3.2、小波分析去噪matlab实现

3.3、小波分析周期变化分析matlab实现

四、机器学习的常见树模型

 4.1、AdaBoost模型的基本思想和算法流程

4.2、AdaBoost模型的具体实现

4.3、GBDT(梯度提升决策树)基本理论

4.4、XGBoost的基本理论

4.5、XGBoost的实践(分类和回归)

五、支持向量机SVM

5.1、支持向量机基本理论

5.2、SVM实现持向量机回归(SVR)模型

5.3、SVM实现分类任务

5.4、SVM实现时间序列预测

六、基于prophet的时间序列预测

6.1、prophet理论与实现一

​6.2、 prophet理论与实现二

七、图像处理方法

7.1、图像的去噪、增强

7.2、图像特征提取


一、互信息

1.1、互信息基本概念

1.2、互信息计算与matlab实现 

可以选择与产品辛烷值(RON)非线性相关度较高的自变量,理论上互信息值较大的变量对于因变量的影响是较大的,是具有代表性的变量。但实际上,这些变量间相关性较高,即这些变量代表的可能是同一性质的操作,这一类性质的改变对于产品辛烷值的影响较大,所以它们的互信息量都处于较大的范围。为了使得选择的主要变量尽可能的具有代表性和独立性,所以需要在互信息分析的基础上去除掉自变量之间相关性高的部分,从而保证选择的变量不仅对因变量的相关度较高,而且各变量之间尽可能不相关,能够较为全面地代表操作变量整体。

我们这里使用matlab计算变量与产品辛烷值的互信息,选择互信息大的,在根据变量相关性,剔除相关性强的变量。

  1. %计算两列向量之间的互信息
  2. %u1:输入计算的向量1
  3. %u2:输入计算的向量2
  4. %wind_size:向量的长度
  5. function mi = calmi(u1, u2, wind_size)
  6. x = [u1, u2];
  7. n = wind_size;
  8. [xrow, xcol] = size(x);
  9. bin = zeros(xrow,xcol);
  10. pmf = zeros(n, 2);
  11. for i = 1:2
  12. minx = min(x(:,i));
  13. maxx = max(x(:,i));
  14. binwidth = (maxx - minx) / n;
  15. edges = minx + binwidth*(0:n);
  16. histcEdges = [-Inf edges(2:end-1) Inf];
  17. [occur,bin(:,i)] = histc(x(:,i),histcEdges,1); %通过直方图方式计算单个向量的直方图分布
  18. pmf(:,i) = occur(1:n)./xrow;
  19. end
  20. %计算u1和u2的联合概率密度
  21. jointOccur = accumarray(bin,1,[n,n]); %(xi,yi)两个数据同时落入n*n等分方格中的数量即为联合概率密度
  22. jointPmf = jointOccur./xrow;
  23. Hx = -(pmf(:,1))'*log2(pmf(:,1)+eps);
  24. Hy = -(pmf(:,2))'*log2(pmf(:,2)+eps);
  25. Hxy = -(jointPmf(:))'*log2(jointPmf(:)+eps);
  26. MI = Hx+Hy-Hxy;
  27. mi = MI/sqrt(Hx*Hy);
  1. clc
  2. clear
  3. load('origin.mat')
  4. wind_size = size(data,1)
  5. mi = zeros(354,1) ;
  6. for i = 1 : 354
  7. u1 = data(:,i);
  8. u2 = data(:,end);
  9. mi(i) = calmi(u1, u2, wind_size);
  10. end
  11. [res,index] = sort(mi,'descend')

二、Mann-Kendall 检验

2.1、M-K检验的理论知识

 Mann-Kendall 检验法(M-K 法)是用于提取序列变化趋势的有效工具,也是“应用气候学实习”课程中重要的授课内容。

M-K 检验法最初由曼(H. B. Mann)和肯德尔(M.G. Kendall)提出了原理并发展了这一方法,是世界气象组织推荐的用于提取序列变化趋势的有效工具  。M-K 检验法不受个别异常值的干扰,能够客观反映时间序列趋势,目前已经被广泛用于气候参数和水文序列的分析中。M-K 法可以根据输出的两个序列(UF和 UB)明确突变的时段和区域。

2.2、M-K检验的matlab实现

  1. Data = [1961,4.7;
  2. 1962,4.98;
  3. 1963,5.39;
  4. 1964,5.72;
  5. 1965,4.88;
  6. 1966,5.02;
  7. 1967,5.1;
  8. 1968,5.19;
  9. 1969,4.31;
  10. 1970,4.86;
  11. 1971,5.28;
  12. 1972,4.9;
  13. 1973,5.23;
  14. 1974,4.6;
  15. 1975,5.88;
  16. 1976,4.58;
  17. 1977,5.37;
  18. 1978,5.24;
  19. 1979,4.91;
  20. 1980,4.84;
  21. 1981,4.98;
  22. 1982,5.54;
  23. 1983,5.67;
  24. 1984,4.31;
  25. 1985,4.67;
  26. 1986,4.49;
  27. 1987,4.96;
  28. 1988,5.18;
  29. 1989,5.41;
  30. 1990,6.06;
  31. 1991,5.4;
  32. 1992,5.38;
  33. 1993,5.3;
  34. 1994,6.18;
  35. 1995,5.58;
  36. 1996,5.2;
  37. 1997,5.8;
  38. 1998,6.81;
  39. 1999,6.47;
  40. 2000,6;
  41. 2001,6.16;
  42. 2002,6.42;
  43. 2003,6.7;
  44. 2004,6.18;
  45. 2005,5.52;
  46. 2006,6.46;
  47. 2007,6.73;
  48. 2008,5.86;
  49. 2009,5.61;
  50. 2010,5.31;
  51. 2011,5.69;
  52. 2012,5.05;
  53. 2013,5.13;
  54. 2014,6.09;
  55. 2015,6.13
  56. ] ;
  57. y = Data(:,2);%平均温度序列
  58. Sk = zeros(size(y)); % 定义累计量序列 Sk,长度 = y,初始值 =0,Sk(1) =0
  59. UFk = zeros(size(y)); % 定义统计量 UFk,长度 = y,初始值 =0,UFk(1) =0
  60. s = 0; % 定义 Sk 序列的元素 s
  61. for i =2:length(y)
  62. for j =1:i
  63. if y(i) > y(j)
  64. s = s +1;
  65. else
  66. s = s +0;
  67. end
  68. end
  69. Sk(i) = s;
  70. E = i* (i - 1) /4; % Sk(i)的均值,见式(3)
  71. Var = i* (i - 1)* (2* i +5) /72; % Sk(i)方差,见式(3)
  72. UFk(i) = (Sk(i)- E) /sqrt(Var);% 正序列 UF 值,见式(2)
  73. end
  74. Sk2 = zeros(size(y)); % 定义逆序累计量序列 Sk2,长度= y,初始值 =0,Sk(2) =0
  75. UBk = zeros(size(y)); % 定义逆序统计量 UBk,长度 = y,初始值 =0,UBk(1) =0
  76. s =0;
  77. y2 = flipud(y); % 按时间序列逆转平均温度序列
  78. for i =2:length(y2)
  79. for j =1:i
  80. if y2(i) > y2(j)
  81. s = s +1;
  82. else
  83. s = s +0;
  84. end
  85. end
  86. Sk2(i) = s;
  87. E = i* (i - 1) /4; %均值
  88. Var = i* (i - 1)* (2* i +5) /72; %方差
  89. UBk(i) = 0 - (Sk2(i) - E) /sqrt(Var);
  90. end
  91. UBk2 = flipud(UBk); %逆序列 UB 值
  92. x = Data(:,1);%年份序列
  93. n = length(x);%年份序列的长度
  94. figure %做图
  95. plot(x,UFk,'r-','linewidth',1.5);%画 UF 线
  96. hold on
  97. plot(x,UBk2,'b-.','linewidth',1.5);%画 UB 线
  98. plot(x,1.96* ones(n,1),'k:','linewidth',1);
  99. axis([min(x),max(x),-5,5]);%设置 X 轴范围和间距
  100. legend('UF 统计量','UB 统计量','0.05 显著水平');% 设置图例
  101. xlabel('年 Year','FontName','TimesNewRoman','FontSize',10);%X 轴标题
  102. ylabel('统计量 MK Value','FontName','TimesNewRoman','Fontsize',10);%Y 轴标题
  103. hold on
  104. plot(x,-1.96 * ones(n,1),'k:','linewidth',1);
  105. plot(x,0 * ones(n,1),'k-. ','linewidth',1);% 图片绘制

由图可知,该地区 1961 ~2015 年气温呈显著上升趋势,UF 和 UB 统计量有交点
且交点在置信直线范围之间,表明气温在 1989 年前后发生了突变。

三、小波分析

3.1、小波分析基本理论

第一:去噪

小波分析的重要应用之一就是用于信号降噪。我们知道,一个含噪的一维信号模型可以表示为下图。其中s(k)为含噪信号,f(k)为有用信号,e(k)为噪声信号。这里我们认为e(k)是一个 1 级高斯白噪声,通常表现为高频信号,而实际工程中f(k)通常为低频信号或者是一些比较稳定的信号。

因此我们可按如下的方法进行降噪处理。首先对信号进行小波分解, 一般地,噪声信号多包含在具有较高频率的细节中,从而,可利用门限阈值等形式对所分解的小波系数进行处理,然后对信号进行小波重构即可达到对信号降噪的目的。对信号降噪实质上是抑制信号中的无用部分,恢复信号中有用部分的过程。

小波信号降噪一般分为以下三个步骤:

(1)确定小波分解的层数,对信号进行分解。

(2)确定各个分解层下细节信号的阈值,对细节信号进行阈值量化处理。

(3)利用阈值处理后的细节信号和逼近信号进行重构,得到降噪后的信号。

第二:周期变化分析

3.2、小波分析去噪matlab实现

  1. clc;
  2. clear all;
  3. % 载入信号
  4. % Load electrical signal and select a part of it.
  5. load leleccum;
  6. indx = 2600:3100;
  7. %装载采集的信号
  8. x = leleccum(indx);
  9. lx=length(x);
  10. t=[0:1:length(x)-1]';
  11. %% 绘制监测所得信号
  12. subplot(2,2,1);
  13. plot(t,x);
  14. title('原始信号');
  15. grid on
  16. %% 用db1小波对原始信号进行3层分解并提取小波系数
  17. [c,l]=wavedec(x,3,'db1');
  18. ca3=appcoef(c,l,'db1',3);
  19. cd3=detcoef(c,l,3);
  20. cd2=detcoef(c,l,2);
  21. cd1=detcoef(c,l,1);
  22. %% 对信号进行强制去噪处理并图示
  23. cdd3=zeros(1,length(cd3));
  24. cdd2=zeros(1,length(cd2));
  25. cdd1=zeros(1,length(cd1));
  26. c1=[ca3,cdd3,cdd2,cdd1];
  27. x1=waverec(c1,l,'db1');
  28. subplot(2,2,2);
  29. plot(x1);
  30. title('强制去噪后信号');
  31. grid on
  32. %% 默认阈值对信号去噪并图示%%
  33. %用ddencmp( )函数获得信号的默认阈值,使用wdencmp( )函数实现去噪过程
  34. [thr,sorh,keepapp]=ddencmp('den','wv',x);
  35. x2=wdencmp('gbl',c,l,'db1',3,thr,sorh,keepapp);
  36. subplot(2,2,3);
  37. plot(x2);
  38. title('默认阈值去噪后信号');
  39. grid on
  40. %% 给定的软阈值进行去噪处理并图示
  41. wname = 'db3'; lev = 5;
  42. [c,l] = wavedec(x,lev,wname);
  43. alpha = 1.5; m = l(1);
  44. [thr,nkeep] = wdcbm(c,l,alpha,m)
  45. [xd,cxd,lxd,perf0,perfl2] = wdencmp('lvd',c,l,wname,lev,thr,'h');
  46. subplot(2,2,4);
  47. plot(xd);
  48. title('给定软阈值去噪后信号');

3.3、小波分析周期变化分析matlab实现

  1. %1.xiaozao函数,是需要对标准化的序列进行消除数据噪音分析;
  2. %2.Db3函数,是对数列进行Db3趋势分析;
  3. %3.period函数,是求得时间序列的实部和模的平方。
  4. %其中周期变化图是实部的等值线图
  5. %而小波方差是模的平方的算数平均。
  6. clc
  7. clear
  8. close all;
  9. load 暴雨量.mat
  10. start_year=1958
  11. a=s(:,1);
  12. b=zscore(a);
  13. scales=[1:1:32];
  14. %进行连续小波变换得到小波系数矩阵,选择复morlet小波函数
  15. wf=cwt(b,scales,'cmor1-1'); %计算小波系数
  16. shibu=real(wf);% 求得系数的实部
  17. mo=abs(wf); %计算小波系数模的绝对值
  18. mofang=mo.^2; %计算小波系数的模方
  19. fangcha=mean(mofang,2); %计算小波方差,小波方差是模的平方的算数平均
  20. %**********画小波实部*************
  21. figure(1);
  22. j = j + 1;
  23. % subplot(121);
  24. % axis([1961,2015,0,50]);
  25. width=713;%宽度,像素数
  26. height=493;%高度
  27. left=300;%距屏幕左下角水平距离
  28. bottem=200;%距屏幕左下角垂直距离
  29. set(gcf,'position',[left,bottem,width,height])
  30. contourf(shibu,10,'-');
  31. colormap('Jet');
  32. colorbar;
  33. hold on
  34. set(gca,'FontSize',13,'Fontname', 'Times New Roman','Fontweight','bold');
  35. xlabel('Year/a','FontName','Times new roman','FontSize',16,'Fontweight','bold');
  36. ylabel('Scales/year','FontName','Times new roman','FontSize',16,'Fontweight','bold');
  37. %set(gca,'XTick',1965:5:2017);
  38. %set(gca,'XTicklabel', 1962:1:2017); %更新XTickLabel
  39. set(gca,'xlim',[1 length(s)],'XTick',1:roundn(length(s)/5,0):length(s),'XTickLabel',start_year:roundn(length(s)/5,0):(start_year+length(s)-1))%修改横坐标的范围
  40. title('(a)','Fontname','Times new roman','FontSize',18,'Fontweight','bold','position',[-4,52]);
  41. %saveas(gca,[path_out5,num2str(j)],'png');
  42. %close;
  43. %********小波方差*************%
  44. figure(2);
  45. j = j + 1;
  46. width=713;%宽度,像素数
  47. height=493;%高度
  48. left=300;%距屏幕左下角水平距离
  49. bottem=200;%距屏幕左下角垂直距离
  50. set(gcf,'position',[left,bottem,width,height])
  51. % subplot(122);
  52. plot(fangcha,'k-','linewidth',1.5);
  53. set(gca,'FontName','Times new roman','FontSize',16,'Fontweight','bold');
  54. xlabel('Scales/year','FontName','Times new roman','FontSize',16,'Fontweight','bold');
  55. ylabel('Variance','FontName','Times new roman','FontSize',16,'Fontweight','bold');
  56. title('(b)','Fontname','Times new roman','FontSize',18,'Fontweight','bold','position',[-3,1.8]);
  57. set(gca,'XTick',0:5:31);
  58. axis([1 32 0 2]);
  59. grid on;
  60. %saveas(gca,[path_out5,num2str(j)],'png');
  61. %close;
  62. %********小波模**************%
  63. figure(3);
  64. j = j + 1;
  65. %subplot(122);
  66. width=713;%宽度,像素数
  67. height=493;%高度
  68. left=300;%距屏幕左下角水平距离
  69. bottem=200;%距屏幕左下角垂直距离
  70. set(gcf,'position',[left,bottem,width,height])
  71. contourf(mo,10,'-');
  72. colormap('Jet');
  73. colorbar;
  74. hold on
  75. set(gca,'FontName','Times new roman','FontSize',13,'Fontweight','bold');
  76. xlabel('Year','FontName','Times new roman','FontSize',16,'Fontweight','bold');
  77. ylabel('Scales/year','FontName','Times new roman','FontSize',16,'Fontweight','bold');
  78. title('(c)','Fontname','Times new roman','FontSize',18,'Fontweight','bold','position',[-4,52]);
  79. set(gca,'xlim',[1 length(s)],'XTick',1:roundn(length(s)/5,0):length(s),'XTickLabel',start_year:roundn(length(s)/5,0):(start_year+length(s)-1))
  80. %saveas(gca,[path_out5,num2str(j)],'png');
  81. %close;
  82. %********小波模方**************%
  83. figure(4);
  84. j = j + 1;
  85. %subplot(122);
  86. width=713;%宽度,像素数
  87. height=493;%高度
  88. left=300;%距屏幕左下角水平距离
  89. bottem=200;%距屏幕左下角垂直距离
  90. set(gcf,'position',[left,bottem,width,height])
  91. contourf(mofang,10,'-');
  92. colormap('Jet');
  93. colorbar;
  94. hold on
  95. set(gca,'FontName','Times new roman','FontSize',13,'Fontweight','bold');
  96. xlabel('Year','FontName','Times new roman','FontSize',16,'Fontweight','bold');
  97. ylabel('Scales/year','FontName','Times new roman','FontSize',16,'Fontweight','bold');
  98. set(gca,'xlim',[1 length(s)],'XTick',1:roundn(length(s)/5,0):length(s),'XTickLabel',start_year:roundn(length(s)/5,0):(start_year+length(s)-1))
  99. %saveas(gca,[path_out5,num2str(j)],'png');

小波实部等值线图如下:

 小波方差图如下:

小波模等值线图如下:

小波模方的等值线图:

四、机器学习的常见树模型

由于之前的博客介绍了决策树和随机森林,这次主要介绍AdaBoost,GDBT,XGBoost,LigthGBM四种模型的理论及实现过程。

 4.1、AdaBoost模型的基本思想和算法流程

Adaboost是一种迭代算法,其核心思想是针对同一个训练集训练不同的分类器(弱分类器),然后把这些弱分类器集合起来,构成一个更强的最终分类器(强分类器)。

我们看一下Adaboost的模型,就是给分类误差小的分类器分配更多的权值,给分类误差大的分类器分配更大的权值。

 我们看一下Adaboost的具体实现流程,首先输入训练样本x和y,然后初始化训练样本的权值分布,具体如下:

 接下来进行遍历得到所有的弱分类器和所有的权值,具体如下:

 最后得到最终的分类器如下:

该算法其实是一个简单的弱分类算法提升过程,这个过程通过不断的训练,可以提高对数据的分类能力。整个过程如下所示:

1. 先通过对N个训练样本的学习得到第一个弱分类器

2. 将分错的样本和其他的新数据一起构成一个新的N个的训练样本,通过对这个样本的学习得到第二个弱分类器 ;

3. 将1和2都分错了的样本加上其他的新样本构成另一个新的N个的训练样本,通过对这个样本的学习得到第三个弱分类器;

4. 最终经过提升的强分类器。即某个数据被分为哪一类要由各分类器权值决定。

4.2、AdaBoost模型的具体实现

下面使用python实现该模型的算法,完成一个二分类任务,我们使用Sklearn中的AdaBoost接口进行实践,具体如下:
 

  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. from sklearn.ensemble import AdaBoostClassifier
  4. from sklearn.tree import DecisionTreeClassifier
  5. from sklearn.datasets import make_gaussian_quantiles
  6. # 用make_gaussian_quantiles生成多组多维正态分布的数据
  7. # 生成2维正态分布,设定样本数1000,协方差2
  8. #其中x1是200行2列的数据,y1是200个输出样本表示分类结果
  9. x1, y1 = make_gaussian_quantiles(
  10. cov=2., n_samples=200, n_features=2, n_classes=2, shuffle=True, random_state=1)
  11. # 为了增加样本分布的复杂度,再生成一个数据分布
  12. #x2是300行2列的数据,y2是300个输出样本表示分类结果
  13. x2, y2 = make_gaussian_quantiles(mean=(
  14. 3, 3), cov=1.5, n_samples=300, n_features=2, n_classes=2, shuffle=True, random_state=1)
  15. #合并X水平合并,y竖直合并,然后按0和1不同颜色绘制散点图
  16. X=np.vstack((x1,x2))
  17. y=np.hstack((y1,1-y2))
  18. # 绘制生成数据
  19. plt.scatter(X[:,0],X[:,1],c=y)
  20. plt.show()
  21. #设定弱分类器CART
  22. weakClassifier=DecisionTreeClassifier(max_depth=2)
  23. #构建模型并进行训练
  24. clf=AdaBoostClassifier(base_estimator=weakClassifier,algorithm='SAMME',n_estimators=300,learning_rate=0.8)
  25. clf.fit(X, y)
  26. # 模型预测
  27. x1_min=X[:,0].min()-1
  28. x1_max=X[:,0].max()+1
  29. x2_min=X[:,1].min()-1
  30. x2_max=X[:,1].max()+1
  31. x1_,x2_=np.meshgrid(np.arange(x1_min,x1_max,0.02),np.arange(x2_min,x2_max,0.02))
  32. y_=clf.predict(np.c_[x1_.ravel(),x2_.ravel()])
  33. print(y)
  34. # 结果绘制
  35. #绘制分类效果
  36. y_=y_.reshape(x1_.shape)
  37. plt.contourf(x1_,x2_,y_,cmap=plt.cm.Paired)
  38. plt.scatter(X[:,0],X[:,1],c=y)
  39. plt.show()

原始的散点图与分类后的效果图如下:

4.3、GBDT(梯度提升决策树)基本理论

我们看一下GDBT模型,就是梯度提升+决策树,利用损失函数的负梯度尽心你和学习器。

 我们具体看一下为什么可以在GDBT模型中使用负梯度作为残差进行拟合,具体如下:

 我们看一下这个GDBT的梯度提升的流程,具体如下:

4.4、XGBoost的基本理论

我们看一下XGBoost是GBBT模型的一种,XGBoost提供并行树提升(也称为GBDT,GBM),可以快速准确地解决许多数据科学问题。

 我们先回顾一下决策树的概念,就是将不用的类别映射到叶子节点的概率进行分类和回归。

 使用单个树进行集成学习的能力优先,一般考虑使用多棵树进行集成学习,就是随机森林或者提升树。

 对于XGBoost的模型形式如下:是利用向前分布算法,学习到包含K棵树的加法模型。

4.5、XGBoost的实践(分类和回归)

分类问题,根据输入特征进行学习生成多个弱学习器,将多个弱学习器组合成一个强学习器,通过强学习器进行预测,输入的多组数据一共包含四个特征,输入的分类一共为3类:

  1. import xgboost as xgb
  2. from xgboost import plot_importance,plot_tree
  3. from sklearn.datasets import load_iris
  4. from sklearn.model_selection import train_test_split
  5. from sklearn.metrics import accuracy_score
  6. import matplotlib.pyplot as plt
  7. # 加载样本数据集
  8. #X有四个特征,y有三个类别:0,1,2
  9. iris = load_iris()
  10. X,y = iris.data,iris.target
  11. # 获取特征名称:四个名称
  12. feature_name = iris.feature_names
  13. # 数据分割
  14. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=3)
  15. # 模型训练
  16. model = xgb.XGBClassifier(max_depth=5, n_estimators=50, silent=True, objective='multi:softmax',feature_names=feature_name)
  17. model.fit(X_train, y_train)
  18. # 预测
  19. y_pred = model.predict(X_test)
  20. print(y_pred)
  21. # 计算准确率
  22. accuracy = accuracy_score(y_test,y_pred)
  23. print("accuarcy: %.2f%%" % (accuracy*100.0))
  24. # 显示重要特征
  25. plot_importance(model)
  26. plot_tree(model,num_trees=5)
  27. plt.show()

测试集预测的结果如下,一共分为三类,即0,1,2.

四个特征的重要性排名如下:

绘制的决策树如下:

 

回归问题:

根据输入特征和输出特征进行回归,输入的多组数据的特征数目是9个,对结果进行预测,代码如下:

  1. import xgboost as xgb
  2. from xgboost import plot_importance,plot_tree
  3. from sklearn.model_selection import train_test_split
  4. from sklearn.datasets import load_boston
  5. import matplotlib.pyplot as plt
  6. # 获取数据
  7. boston = load_boston()
  8. X,y = boston.data,boston.target
  9. # 数据集划分
  10. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
  11. # 模型训练
  12. model = xgb.XGBRegressor(max_depth=5, learning_rate=0.1, n_estimators=50, silent=True, objective='reg:gamma')
  13. model.fit(X_train, y_train)
  14. # 预测
  15. y_pred = model.predict(X_test)
  16. print(y_pred)
  17. # 显示重要特征
  18. plot_importance(model)
  19. # 可视化树的生成情况,num_trees是树的索引
  20. plot_tree(model, num_trees=17)
  21. plt.show()

回归的预测结果如下:

对9个特征的重要性进行排名如下:

绘制的决策树如下所示:

五、支持向量机SVM

5.1、支持向量机基本理论

支持向量机(Support Vector Machine, SVM)是一类按监督学习(supervised learning)方式对数据进行二元分类的广义线性分类器(generalized linear classifier),其决策边界是对学习样本求解的最大边距超平面(maximum-margin hyperplane) 

支持向量机可以用于分类,回归预测和时间序列预测。

5.2、SVM实现持向量机回归(SVR)模型

我们将数据划分为训练集和测试集,训练集有354组数据和13个特征,测试集是152组数据和13个特征,对数据进行回归预测,matlab的代码如下:

  1. close all;
  2. clc
  3. clear
  4. %% 下载数据
  5. load('p_train.mat');
  6. load('p_test.mat');
  7. load('t_train.mat');
  8. load('t_test.mat');
  9. %% 数据归一化
  10. %输入样本归一化
  11. [pn_train,ps1] = mapminmax(p_train');
  12. pn_train = pn_train';
  13. pn_test = mapminmax('apply',p_test',ps1);
  14. pn_test = pn_test';
  15. %输出样本归一化
  16. [tn_train,ps2] = mapminmax(t_train');
  17. tn_train = tn_train';
  18. tn_test = mapminmax('apply',t_test',ps2);
  19. tn_test = tn_test';
  20. %% SVR模型创建/训练
  21. % 寻找最佳c参数/g参数——交叉验证方法
  22. % SVM模型有两个非常重要的参数C与gamma。
  23. % 其中 C是惩罚系数,即对误差的宽容度。
  24. % c越高,说明越不能容忍出现误差,容易过拟合。C越小,容易欠拟合。C过大或过小,泛化能力变差
  25. % gamma是选择RBF函数作为kernel后,该函数自带的一个参数。隐含地决定了数据映射到新的特征空间后的分布,
  26. % gamma越大,支持向量越少,gamma值越小,支持向量越多。支持向量的个数影响训练与预测的速度。
  27. [c,g] = meshgrid(-10:0.5:10,-10:0.5:10);
  28. [m,n] = size(c);
  29. cg = zeros(m,n);
  30. eps = 10^(-4);
  31. v = 5;
  32. bestc = 0;
  33. bestg = 0;
  34. error = Inf;
  35. for i = 1:m
  36. for j = 1:n
  37. cmd = ['-v ',num2str(v),' -t 2',' -c ',num2str(2^c(i,j)),' -g ',num2str(2^g(i,j) ),' -s 3 -p 0.1'];
  38. cg(i,j) = svmtrain(tn_train,pn_train,cmd);
  39. if cg(i,j) < error
  40. error = cg(i,j);
  41. bestc = 2^c(i,j);
  42. bestg = 2^g(i,j);
  43. end
  44. if abs(cg(i,j) - error) <= eps && bestc > 2^c(i,j)
  45. error = cg(i,j);
  46. bestc = 2^c(i,j);
  47. bestg = 2^g(i,j);
  48. end
  49. end
  50. end
  51. % 创建/训练SVR
  52. cmd = [' -t 2',' -c ',num2str(bestc),' -g ',num2str(bestg),' -s 3 -p 0.01'];
  53. model = svmtrain(tn_train,pn_train,cmd);
  54. %% SVR仿真预测
  55. % [Predict_1,error_1,dec_values_1] = svmpredict(tn_train,pn_train,model);
  56. [Predict_2,error_2,dec_values_2] = svmpredict(tn_test,pn_test,model);
  57. % 反归一化
  58. % predict_1 = mapminmax('reverse',Predict_1,ps2);
  59. predict_2 = mapminmax('reverse',Predict_2,ps2);
  60. %% 计算误差
  61. [len,~]=size(predict_2);
  62. error = t_test - predict_2;
  63. error = error';
  64. MAE1=sum(abs(error./t_test'))/len;
  65. MSE1=error*error'/len;
  66. RMSE1=MSE1^(1/2);
  67. R = corrcoef(t_test,predict_2);
  68. r = R(1,2);
  69. disp(['........支持向量回归误差计算................'])
  70. disp(['平均绝对误差MAE为:',num2str(MAE1)])
  71. disp(['均方误差为MSE:',num2str(MSE1)])
  72. disp(['均方根误差RMSE为:',num2str(RMSE1)])
  73. disp(['决定系数 R^2为:',num2str(r)])
  74. figure(1)
  75. plot(1:length(t_test),t_test,'r-*',1:length(t_test),predict_2,'b:o')
  76. grid on
  77. legend('真实值','预测值')
  78. xlabel('样本编号')
  79. ylabel('')
  80. string_2 = {'测试集预测结果对比'};
  81. title(string_2)

预测效果如下:

5.3、SVM实现分类任务

下面是使用SVM实现对红酒分类的预测,一共187组数据,13个特征,输出的类别为3类。

  1. %% Matlab神经网络43个案例分析
  2. % 基于SVM的数据分类预测——意大利葡萄酒种类识别
  3. %% 清空环境变量
  4. close all;
  5. clear;
  6. clc;
  7. format compact;
  8. %% 数据提取
  9. % 载入测试数据wine,其中包含的数据为classnumber = 3,wine:178*13的矩阵,wine_labes:178*1的列向量
  10. load chapter_WineClass.mat;
  11. % 画出测试数据的box可视化图
  12. figure;
  13. boxplot(wine,'orientation','horizontal','labels',categories);
  14. title('wine数据的box可视化图','FontSize',12);
  15. xlabel('属性值','FontSize',12);
  16. grid on;
  17. % 画出测试数据的分维可视化图
  18. figure
  19. subplot(3,5,1);
  20. hold on
  21. for run = 1:178
  22. plot(run,wine_labels(run),'*');
  23. end
  24. xlabel('样本','FontSize',10);
  25. ylabel('类别标签','FontSize',10);
  26. title('class','FontSize',10);
  27. for run = 2:14
  28. subplot(3,5,run);
  29. hold on;
  30. str = ['attrib ',num2str(run-1)];
  31. for i = 1:178
  32. plot(i,wine(i,run-1),'*');
  33. end
  34. xlabel('样本','FontSize',10);
  35. ylabel('属性值','FontSize',10);
  36. title(str,'FontSize',10);
  37. end
  38. % 选定训练集和测试集
  39. % 将第一类的1-30,第二类的60-95,第三类的131-153做为训练集
  40. train_wine = [wine(1:30,:);wine(60:95,:);wine(131:153,:)];
  41. % 相应的训练集的标签也要分离出来
  42. train_wine_labels = [wine_labels(1:30);wine_labels(60:95);wine_labels(131:153)];
  43. % 将第一类的31-59,第二类的96-130,第三类的154-178做为测试集
  44. test_wine = [wine(31:59,:);wine(96:130,:);wine(154:178,:)];
  45. % 相应的测试集的标签也要分离出来
  46. test_wine_labels = [wine_labels(31:59);wine_labels(96:130);wine_labels(154:178)];
  47. %% 数据预处理
  48. % 数据预处理,将训练集和测试集归一化到[0,1]区间
  49. [mtrain,ntrain] = size(train_wine);
  50. [mtest,ntest] = size(test_wine);
  51. dataset = [train_wine;test_wine];
  52. % mapminmax为MATLAB自带的归一化函数
  53. [dataset_scale,ps] = mapminmax(dataset',0,1);
  54. dataset_scale = dataset_scale';
  55. train_wine = dataset_scale(1:mtrain,:);
  56. test_wine = dataset_scale( (mtrain+1):(mtrain+mtest),: );
  57. %% SVM网络训练
  58. tic;
  59. model = svmtrain(train_wine_labels, train_wine, '-c 2 -g 1');
  60. toc;
  61. %% SVM网络预测
  62. tic;
  63. [predict_label, accuracy,dec_value1] = svmpredict(test_wine_labels, test_wine, model);
  64. toc;
  65. %% 结果分析
  66. % 测试集的实际分类和预测分类图
  67. % 通过图可以看出只有一个测试样本是被错分的
  68. figure;
  69. hold on;
  70. plot(test_wine_labels,'o');
  71. plot(predict_label,'r*');
  72. xlabel('测试集样本','FontSize',12);
  73. ylabel('类别标签','FontSize',12);
  74. legend('实际测试集分类','预测测试集分类');
  75. title('测试集的实际分类和预测分类图','FontSize',12);
  76. grid on;

5.4、SVM实现时间序列预测

我们看一下SVM对于时间序列的预测,数据如下,第1列为时间,后面3列为时间的序列的数据,根据svm模型对时间序列数据进行预测,首先绘制B,C,D三类的时间序列变化图。

绘制出时间序列变化图,具体的时间序列变化如下所示,三组都是400个时间序列的样本数据。

 模型训练与预测的代码如下,对测试集进行预测,绘制预测值和真实值的曲线,最后绘制预测误差的曲线。

  1. #time时间列,single1,信号值,取前多少个X_data预测下一个数据
  2. def time_slice(time,single,X_lag):
  3. sample = []
  4. label = []
  5. for k in range(len(time) - X_lag - 1):
  6. t = k + X_lag
  7. sample.append(single[k:t])
  8. label.append(single[t + 1])
  9. return sample,label
  10. sample,label = time_slice(time,single1,5)
  11. # 数据集划分
  12. X_train, X_test, y_train, y_test = train_test_split(sample, label, test_size=0.3, random_state=42)
  13. # 数据集掷乱
  14. random_seed = 13
  15. X_train, y_train = shuffle(X_train, y_train, random_state=random_seed)
  16. # 参数设置SVR准备
  17. parameters = {'kernel': ['rbf'], 'gamma': np.logspace(-5, 0, num=6, base=2.0),
  18. 'C': np.logspace(-5, 5, num=11, base=2.0)}
  19. # 网格搜索:选择十折交叉验证
  20. svr = svm.SVR()
  21. grid_search = GridSearchCV(svr, parameters, cv=10, n_jobs=4, scoring='neg_mean_squared_error')
  22. # SVR模型训练
  23. grid_search.fit(X_train, y_train)
  24. # 输出最终的参数
  25. print(grid_search.best_params_)
  26. # 模型的精度
  27. print(grid_search.best_score_)
  28. # SVR模型保存
  29. joblib.dump(grid_search, 'svr.pkl')
  30. # SVR模型加载
  31. svr = joblib.load('svr.pkl')
  32. # SVR模型测试
  33. y_hat = svr.predict(X_test)
  34. # 计算预测值与实际值的残差绝对值
  35. abs_vals = np.abs(y_hat - y_test)
  36. plt.subplot(1, 1, 1)
  37. plt.plot(y_test, c='k', label='data')
  38. plt.plot(y_hat, c='g', label='svr model')
  39. plt.xlabel('data')
  40. plt.ylabel('target')
  41. plt.title('Support Vector Regression')
  42. plt.legend()
  43. plt.show()
  44. plt.subplot(1, 1, 1)
  45. plt.plot(abs_vals)
  46. plt.show()

使用训练好的svm模型进行预测值和真实值的拟合曲线图如下:

预测的值和真实值之间的误差变化图,具体如下:

 

六、基于prophet的时间序列预测

6.1、prophet理论与实现一

我们可以看一下Prophet模型进行时间序列预测的基本模型,具体包括趋势项,周期项,节假日项,误差项,四个项目共同组成prophet模型。

我们根据原始时间序列数据去预测后30天的申购总额和赎回总额,原始数据是2013年7月到2014年的8月,我们使用当前的时间序列数据预测未来30天的,具体如下:

 python代码实现如下:

  1. import pandas as pd
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from prophet import Prophet
  5. from prophet.diagnostics import cross_validation
  6. from prophet.diagnostics import performance_metrics
  7. from prophet.plot import plot_cross_validation_metric
  8. import warnings
  9. warnings.filterwarnings('ignore')
  10. #读取数据
  11. data_user = pd.read_csv('user_balance_table.csv')
  12. #将第一列的时间数据转换成固定格式
  13. data_user['report_date'] = pd.to_datetime(data_user['report_date'], format='%Y%m%d')
  14. #输出前面的头部信息
  15. print(data_user.head())
  16. #取时间列和另外要预测的两列
  17. data_user_byday = data_user.groupby(['report_date'])['total_purchase_amt','total_redeem_amt'].sum().sort_values(['report_date']).reset_index()
  18. print(data_user_byday.head())
  19. # 定义模型
  20. def FB(data: pd.DataFrame) -> pd.DataFrame:
  21. df = pd.DataFrame({
  22. 'ds': data.report_date,
  23. 'y': data.total_purchase_amt,
  24. })
  25. #申购总额的最大值和最小值
  26. df['cap'] = data.total_purchase_amt.values.max()
  27. df['floor'] = data.total_purchase_amt.values.min()
  28. m = Prophet(
  29. changepoint_prior_scale=0.05,
  30. daily_seasonality=False,
  31. yearly_seasonality=True, # 年周期性
  32. weekly_seasonality=True, # 周周期性
  33. growth="logistic",
  34. )
  35. m.add_seasonality(name='monthly', period=30.5, fourier_order=5, prior_scale=0.1)#月周期性
  36. m.add_country_holidays(country_name='CN') # 中国所有的节假日
  37. m.fit(df)
  38. future = m.make_future_dataframe(periods=30, freq='D') # 预测时长
  39. #预测的申购总额的最大值和最小值
  40. future['cap'] = data.total_purchase_amt.values.max()
  41. future['floor'] = data.total_purchase_amt.values.min()
  42. forecast = m.predict(future)
  43. fig = m.plot_components(forecast)
  44. fig1 = m.plot(forecast)
  45. return forecast
  46. result_purchase = FB(data_user_byday)
  47. print(result_purchase)
  48. plt.show()

预测结果如下:

预测的周期性和趋势图等如下:

6.2、 prophet理论与实现二

数据部分,2013年7月到2014年8月的数据,对后30天的赎回总额进行预测,具体如下:

python代码如下:

  1. import pandas as pd
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from prophet import Prophet
  5. from prophet.diagnostics import cross_validation
  6. from prophet.diagnostics import performance_metrics
  7. from prophet.plot import plot_cross_validation_metric
  8. import warnings
  9. warnings.filterwarnings('ignore')
  10. #读取数据
  11. data_user = pd.read_csv('user_balance_table.csv')
  12. #将第一列的时间数据转换成固定格式
  13. data_user['report_date'] = pd.to_datetime(data_user['report_date'], format='%Y%m%d')
  14. #输出前面的头部信息
  15. print(data_user.head())
  16. #取时间列和另外要预测的两列
  17. data_user_byday = data_user.groupby(['report_date'])['total_purchase_amt','total_redeem_amt'].sum().sort_values(['report_date']).reset_index()
  18. print(data_user_byday.head())
  19. # 定义模型
  20. def FB(data: pd.DataFrame) -> pd.DataFrame:
  21. df = pd.DataFrame({
  22. 'ds': data.report_date,
  23. 'y': data.total_redeem_amt,
  24. })
  25. df['cap'] = data.total_redeem_amt.values.max()
  26. df['floor'] = data.total_redeem_amt.values.min()
  27. m = Prophet(
  28. changepoint_prior_scale=0.05,
  29. daily_seasonality=False,
  30. yearly_seasonality=True, # 年周期性
  31. weekly_seasonality=True, # 周周期性
  32. growth="logistic",
  33. )
  34. #365/12
  35. m.add_seasonality(name='monthly', period=30.5, fourier_order=5, prior_scale=0.1) # 月周期性
  36. m.add_country_holidays(country_name='CN' ) # 中国所有的节假日
  37. m.fit(df)
  38. future = m.make_future_dataframe(periods=30, freq='D' ) # 预测时长
  39. future['cap'] = data.total_redeem_amt.values.max()
  40. future['floor'] = data.total_redeem_amt.values.min()
  41. forecast = m.predict(future)
  42. fig = m.plot_components(forecast)
  43. fig1 = m.plot(forecast)
  44. return forecast
  45. result_redeem = FB(data_user_byday)
  46. print(result_redeem)
  47. plt.show()

预测结果如下:

周期性分析结果如下:

七、图像处理方法

7.1、图像的去噪、增强

1、去噪是采用滤波的方式,本文使用了三种滤波方式:均值滤波,中值滤波,高斯高通滤波,滤波的主要代码如下,滤波的主要效果展示在图片中:

  1. clear;
  2. clc
  3. g = imread('bird.jpg');
  4. gg = imnoise(g, 'gaussian'); %添加高斯噪声
  5. subplot(2,2,1);
  6. imshow(gg);
  7. title('高斯噪声');
  8. j = 2;
  9. for i = 3:4:11
  10. subplot(2,2,j);
  11. G = avefilter(gg, i);
  12. imshow(G);
  13. title([num2str(i), '\ast', num2str(i), '均值滤波']);
  14. j = j+1;
  15. end
  16. figure(2);
  17. g = imread('bird.png');
  18. gg = imnoise(g, 'salt & pepper', 0.05); %添加椒盐噪声
  19. subplot(2,2,1), imshow(gg);
  20. title('椒盐噪声');
  21. j = 2;
  22. for i = 3:4:11
  23. G = medianfilter(gg, i);
  24. subplot(2,2,j);
  25. imshow(G);
  26. title([num2str(i), '\ast', num2str(i), '中值滤波']);
  27. j = j+1;
  28. end
  29. figure(3);
  30. d0=50; %阈值
  31. image=imread('bird.jpg');
  32. [M,N,P] = size(image);
  33. img_f = fft2(double(image));%傅里叶变换得到频谱
  34. img_f=fftshift(img_f); %移到中间
  35. m_mid=floor(M/2);%中心点坐标
  36. n_mid=floor(N/2);
  37. h = zeros(M,N,P);%高斯低通滤波器构造
  38. for i = 1:M
  39. for j = 1:N
  40. d = ((i-m_mid)^2+(j-n_mid)^2);
  41. h(i,j) = exp(-d/(2*(d0^2)));
  42. end
  43. end
  44. img_lpf = h.*img_f;
  45. img_lpf=ifftshift(img_lpf); %中心平移回原来状态
  46. img_lpf=uint8(real(ifft2(img_lpf))); %反傅里叶变换,取实数部分
  47. subplot(1,2,1);
  48. imshow(image);
  49. title('原图');
  50. subplot(1,2,2);
  51. imshow(img_lpf);
  52. title('高斯低通滤波d=50');
  1. function G = avefilter(F, k)
  2. % F 是待处理的图像
  3. % k 是模版的大小,奇数
  4. [m,n,p] = size(F) ;
  5. % 转换数据类型,便于计算
  6. G = uint16(zeros(m, n));
  7. Ft = uint16(F);
  8. M = uint16(ones(k, k));
  9. h = (k+1)/2;
  10. for i = 1:m
  11. for j = 1:n
  12. if((i < h)|| (j < h)|| (i > m-h+1)|| (j > n-h+1)) %不能被模版处理的区域
  13. G(i, j) = Ft(i, j);
  14. continue; %像素值不变
  15. end
  16. %取同样大小的图像块,中间的像素是待处理的像素
  17. T = Ft(i-(k-1)/2: i+(k-1)/2, j-(k-1)/2: j+(k-1)/2);
  18. T = T.*M; %和模版相乘
  19. G(i, j) = sum(T(:))/k^2; %结果求和并计算平均值
  20. end
  21. end
  22. G = uint8(G); %结果转换成8-bit图像的数据类型
  1. function G = medianfilter(F, k)
  2. % F 是待处理的图像
  3. % k 是模版的大小,奇数
  4. [m,n,p] = size(F) ;
  5. % 转换数据类型,便于计算
  6. G = uint16(zeros(m, n)); Ft = uint16(F); M = uint16(ones(k, k));
  7. h = (k+1)/2;
  8. for i = 1:m
  9. for j = 1:n
  10. if((i < h)|| (j < h)|| (i > m-h+1)|| (j > n-h+1)) %不能被模版处理的区域
  11. G(i, j) = Ft(i, j);
  12. continue; %像素值不变
  13. end
  14. % 取同样大小的图像块,中间的像素是待处理的像素
  15. T = Ft(i-(k-1)/2: i+(k-1)/2, j-(k-1)/2: j+(k-1)/2);
  16. T = T(:); %将矩阵转换为一维向量
  17. G(i, j) = median(T); %求中值并赋值给中间像素
  18. end
  19. end
  20. G = uint8(G); %结果转换成8-bit图像的数据类型

滤波效果如下:

 

2-接下来对图像进行增强处理, 主要采用两种方式,一种是灰度线性变换,另外一种是直方图均衡变换。

matlab实现灰度线性变换和直方图均衡变换进行图像增强的代码如下:

  1. %% 获取灰色图像直方图
  2. I = imread('bird.jpg'); %读取图片
  3. I = rgb2gray(I); %把图片从rgb格式转为灰度图
  4. row = size(I, 1); %获取图片像素的行列数
  5. column = size(I, 2);
  6. N = zeros(1, 256); %一个空的容器,用来记录每个像素出现的次数
  7. % 两个循环用来遍历每一个像素
  8. for i = 1:row
  9. for j = 1:column
  10. k = I(i, j); % 获取该像素点的像素值
  11. N(k + 1) = N(k + 1) + 1; % 记录下该像素值出现的次数
  12. end
  13. end
  14. %展示图片
  15. figure(1);
  16. subplot(121);imshow(I);
  17. subplot(122);bar(N);
  18. %% 灰色线性变换进行图像增强
  19. I = imread('bird.jpg');
  20. I = rgb2gray(I);
  21. I = double(I);
  22. J = (I - 80) * 255 / 70;
  23. row = size(I, 1);
  24. column = size(I, 2);
  25. for i = 1:row
  26. for j = 1:column
  27. if J(i, j) < 0
  28. J(i, j) = 0;
  29. elseif J(i, j) > 255
  30. J(i, j) = 255;
  31. end
  32. end
  33. end
  34. figure(2);
  35. subplot(121);imshow(uint8(I));
  36. subplot(122);imshow(uint8(J));
  37. %% 直方图均衡变换进行图像增强
  38. %R,G,B直方图展示
  39. figure(3) ;
  40. I = imread('bird.jpg');
  41. subplot(221);imshow(I);
  42. subplot(222);imhist(I(:, :, 1));title('R');
  43. subplot(223);imhist(I(:, :, 2));title('G');
  44. subplot(224);imhist(I(:, :, 3));title('B');
  45. %均衡化方法
  46. I = imread('bird.jpg');
  47. G = rgb2gray(I) ;
  48. J = histeq(G);
  49. figure(4);
  50. subplot(221);imshow(G);
  51. subplot(222);imshow(J);
  52. subplot(223);imhist(G);
  53. subplot(224);imhist(J);

效果图如下:

原始图像和原始图像的灰色直方图如下:

对图像做了灰色线性变换后的图像对比如下:

原始图像的RGB直方图如下:

对原始图像做了直方图均衡变换的效果图:

7.2、图像特征提取

使用SITF算法进行图像特征提取,提取的特征位置如下,具体的matlab代码如下:

我们先看一下提取的效果:

主函数如下:

  1. clear;
  2. clc
  3. [image, descriptors, locs] = sift('deng.jpg');
  4. disp('descriptors如下:') ;
  5. disp(descriptors) ;
  6. image1 = imread('deng.jpg');
  7. showkeys(image1, locs)

sift算法函数如下:

  1. % [image, descriptors, locs] = sift(imageFile)
  2. %
  3. % This function reads an image and returns its SIFT keypoints.
  4. % Input parameters:
  5. % imageFile: the file name for the image.
  6. %
  7. % Returned:
  8. % image: the image array in double format
  9. % descriptors: a K-by-128 matrix, where each row gives an invariant
  10. % descriptor for one of the K keypoints. The descriptor is a vector
  11. % of 128 values normalized to unit length.
  12. % locs: K-by-4 matrix, in which each row has the 4 values for a
  13. % keypoint location (row, column, scale, orientation). The
  14. % orientation is in the range [-PI, PI] radians.
  15. %
  16. % Credits: Thanks for initial version of this program to D. Alvaro and
  17. % J.J. Guerrero, Universidad de Zaragoza (modified by D. Lowe)
  18. function [image, descriptors, locs] = sift(imageFile)
  19. % Load image
  20. image1 = imread(imageFile);
  21. image = rgb2gray(image1) ;
  22. % If you have the Image Processing Toolbox, you can uncomment the following
  23. % lines to allow input of color images, which will be converted to grayscale.
  24. % if isrgb(image)
  25. % image = rgb2gray(image);
  26. % end
  27. [rows, cols] = size(image);
  28. % Convert into PGM imagefile, readable by "keypoints" executable
  29. f = fopen('tmp.pgm', 'w');
  30. if f == -1
  31. error('Could not create file tmp.pgm.');
  32. end
  33. fprintf(f, 'P5\n%d\n%d\n255\n', cols, rows);
  34. fwrite(f, image', 'uint8');
  35. fclose(f);
  36. % Call keypoints executable
  37. if isunix
  38. command = '!./sift ';
  39. else
  40. command = '!siftWin32 ';
  41. end
  42. command = [command ' <tmp.pgm >tmp.key'];
  43. eval(command);
  44. % Open tmp.key and check its header
  45. g = fopen('tmp.key', 'r');
  46. if g == -1
  47. error('Could not open file tmp.key.');
  48. end
  49. [header, count] = fscanf(g, '%d %d', [1 2]);
  50. if count ~= 2
  51. error('Invalid keypoint file beginning.');
  52. end
  53. num = header(1);
  54. len = header(2);
  55. if len ~= 128
  56. error('Keypoint descriptor length invalid (should be 128).');
  57. end
  58. % Creates the two output matrices (use known size for efficiency)
  59. locs = double(zeros(num, 4));
  60. descriptors = double(zeros(num, 128));
  61. % Parse tmp.key
  62. for i = 1:num
  63. [vector, count] = fscanf(g, '%f %f %f %f', [1 4]); %row col scale ori
  64. if count ~= 4
  65. error('Invalid keypoint file format');
  66. end
  67. locs(i, :) = vector(1, :);
  68. [descrip, count] = fscanf(g, '%d', [1 len]);
  69. if (count ~= 128)
  70. error('Invalid keypoint file value.');
  71. end
  72. % Normalize each input vector to unit length
  73. descrip = descrip / sqrt(sum(descrip.^2));
  74. descriptors(i, :) = descrip(1, :);
  75. end
  76. fclose(g);

特征点的展示函数如下:

  1. % showkeys(image, locs)
  2. %
  3. % This function displays an image with SIFT keypoints overlayed.
  4. % Input parameters:
  5. % image: the file name for the image (grayscale)
  6. % locs: matrix in which each row gives a keypoint location (row,
  7. % column, scale, orientation)
  8. function showkeys(image, locs)
  9. disp('Drawing SIFT keypoints ...');
  10. % Draw image with keypoints
  11. figure('Position', [50 50 size(image,2) size(image,1)]);
  12. colormap('gray');
  13. imagesc(image);
  14. hold on;
  15. imsize = size(image);
  16. for i = 1: size(locs,1)
  17. % Draw an arrow, each line transformed according to keypoint parameters.
  18. TransformLine(imsize, locs(i,:), 0.0, 0.0, 1.0, 0.0);
  19. TransformLine(imsize, locs(i,:), 0.85, 0.1, 1.0, 0.0);
  20. TransformLine(imsize, locs(i,:), 0.85, -0.1, 1.0, 0.0);
  21. end
  22. hold off;
  1. % ------ Subroutine: TransformLine -------
  2. % Draw the given line in the image, but first translate, rotate, and
  3. % scale according to the keypoint parameters.
  4. %
  5. % Parameters:
  6. % Arrays:
  7. % imsize = [rows columns] of image
  8. % keypoint = [subpixel_row subpixel_column scale orientation]
  9. %
  10. % Scalars:
  11. % x1, y1; begining of vector
  12. % x2, y2; ending of vector
  13. function TransformLine(imsize, keypoint, x1, y1, x2, y2)
  14. % The scaling of the unit length arrow is set to approximately the radius
  15. % of the region used to compute the keypoint descriptor.
  16. len = 6 * keypoint(3);
  17. % Rotate the keypoints by 'ori' = keypoint(4)
  18. s = sin(keypoint(4));
  19. c = cos(keypoint(4));
  20. % Apply transform
  21. r1 = keypoint(1) - len * (c * y1 + s * x1);
  22. c1 = keypoint(2) + len * (- s * y1 + c * x1);
  23. r2 = keypoint(1) - len * (c * y2 + s * x2);
  24. c2 = keypoint(2) + len * (- s * y2 + c * x2);
  25. line([c1 c2], [r1 r2], 'Color', 'r');

道阻且长,行则将至,我们都是这条人生路途中的追梦人,大家加油吧,希望我们数模竞赛能取得一个好成绩,加油吧,少年!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/article/detail/55682?site
推荐阅读
相关标签
  

闽ICP备14008679号