当前位置:   article > 正文

使用Tensorflow训练1DCNN网络_tensorflow karas 1d cnn

tensorflow karas 1d cnn

数据准备

需要准备训练集数据和标签、测试集数据和标签,上述数据都是csv文件

代码效果

能够读取相应的数据,完成数据拼接,并送入神经网络中训练并获得acc和loss曲线

代码

  1. import tensorflow as tf
  2. import pandas as pd
  3. import numpy as np
  4. from matplotlib import pyplot as plt
  5. from sklearn.model_selection import train_test_split
  6. def read_csv_file(train_data_file_path,train_label_file_path,test_data_file_path,test_label_file_path):
  7. """
  8. 读取csv文件并将文件进行拼接
  9. :param train_data_file_path: 训练数据路径
  10. :param train_label_file_path: 训练标签路径
  11. :param test_data_file_path: 测试数据路径
  12. :param test_label_file_path: 测试标签路径
  13. :return: 返回拼接完成后的路径
  14. """
  15. #从csv中读取数据
  16. train_data = pd.read_csv(train_data_file_path,header=None)
  17. train_label = pd.read_csv(train_label_file_path,header=None)
  18. test_data = pd.read_csv(test_data_file_path,header=None)
  19. test_label = pd.read_csv(test_label_file_path,header=None)
  20. ##########将数据集拼接起来
  21. #数据与标签拼接
  22. dataset_train = pd.concat([train_data,train_label],axis=1)
  23. dataset_test = pd.concat([test_data,test_label],axis=1)
  24. #训练集与测试集拼接
  25. dataset = pd.concat([dataset_train,dataset_test],axis=0).sample(frac=1,random_state=0).reset_index(drop=True)
  26. return dataset
  27. def get_train_test(dataset,data_ndim=1):
  28. """
  29. 划分训练数据和测试数据,并转变数据维数
  30. :param dataset: 数据拼接
  31. :param data_ndim: 数据的维数
  32. :return: 训练集和测试集的标签和数据
  33. """
  34. #获得训练数据和标签
  35. X = dataset.iloc[:,:-1]
  36. y = dataset.iloc[:,-1]
  37. #划分训练集和测试集
  38. X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.2,random_state=42)
  39. #改变数据维度让他符合(数量,长度,维度)的要求
  40. X_train = np.array(X_train).reshape(X_train.shape[0],X_train.shape[1],data_ndim)
  41. X_test = np.array(X_test).reshape(X_test.shape[0],X_test.shape[1],data_ndim)
  42. print("X Train shape: ", X_train.shape)
  43. print("X Test shape: ", X_test.shape)
  44. return X_train,X_test,y_train,y_test
  45. def bulid(X_train,y_train,X_test,y_test,batch_size=10,epochs=10):
  46. """
  47. 搭建网络结构完成训练
  48. :param X_train: 训练集数据
  49. :param y_train: 训练集标签
  50. :param X_test: 测试集数据
  51. :param y_test: 测试集标签
  52. :param batch_size: 批次大小
  53. :param epochs: 循环轮数
  54. :return: acc和loss曲线
  55. """
  56. model = tf.keras.models.Sequential([
  57. tf.keras.layers.Conv1D(filters=32, kernel_size=(3,), padding='same',
  58. activation=tf.keras.layers.LeakyReLU(alpha=0.001), input_shape=(X_train.shape[1], 1)),
  59. tf.keras.layers.Conv1D(filters=64, kernel_size=(3,), padding='same',
  60. activation=tf.keras.layers.LeakyReLU(alpha=0.001)),
  61. tf.keras.layers.Conv1D(filters=128, kernel_size=(3,), padding='same',
  62. activation=tf.keras.layers.LeakyReLU(alpha=0.001)),
  63. tf.keras.layers.MaxPool1D(pool_size=(3,), strides=2, padding='same'),
  64. tf.keras.layers.Dropout(0.5),
  65. tf.keras.layers.Flatten(),
  66. tf.keras.layers.Dense(units=256, activation=tf.keras.layers.LeakyReLU(alpha=0.001)),
  67. tf.keras.layers.Dense(units=512, activation=tf.keras.layers.LeakyReLU(alpha=0.001)),
  68. tf.keras.layers.Dense(units=10, activation='softmax'),
  69. ])
  70. model.compile(optimizer='adam',
  71. loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
  72. metrics=['sparse_categorical_accuracy'])
  73. history = model.fit(X_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(X_test, y_test))
  74. model.summary()
  75. # 获得训练集和测试集的acc和loss曲线
  76. acc = history.history['sparse_categorical_accuracy']
  77. val_acc = history.history['val_sparse_categorical_accuracy']
  78. loss = history.history['loss']
  79. val_loss = history.history['val_loss']
  80. # 绘制acc曲线
  81. plt.subplot(1, 2, 1)
  82. plt.plot(acc, label='Training Accuracy')
  83. plt.plot(val_acc, label='Validation Accuracy')
  84. plt.title('Training and Validation Accuracy')
  85. plt.legend()
  86. # 绘制loss曲线
  87. plt.subplot(1, 2, 2)
  88. plt.plot(loss, label='Training Loss')
  89. plt.plot(val_loss, label='Validation Loss')
  90. plt.title('Training and Validation Loss')
  91. plt.legend()
  92. plt.show()
  93. if __name__ == "__main__":
  94. x_test_csv_path = "D:/桌面文件夹/重采样后/800维信号数据(csv)/test/test_data.csv"
  95. y_test_csv_path = "D:/桌面文件夹/重采样后/800维信号数据(csv)/test/test_label.csv"
  96. x_train_csv_path = "D:/桌面文件夹/重采样后/800维信号数据(csv)/train/train_data.csv"
  97. y_train_csv_path = "D:/桌面文件夹/重采样后/800维信号数据(csv)/train/train_label.csv"
  98. dataset = read_csv_file(x_train_csv_path,y_train_csv_path,x_test_csv_path,y_test_csv_path)
  99. X_train,X_test,y_train,y_test = get_train_test(dataset=dataset,data_ndim=1)
  100. bulid(X_train,y_train,X_test,y_test)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/寸_铁/article/detail/747198
推荐阅读
相关标签
  

闽ICP备14008679号