当前位置:   article > 正文

凯斯西储轴承数据CWRU数据集制作十分类_凯斯西储(cwru)数据集解读并将数据集划分为10分类(含代码)

凯斯西储(cwru)数据集解读并将数据集划分为10分类(含代码)

凯斯西储轴承数据CWRU数据集制作


问题描述

凯斯西储轴承数据CWRU数据集制作预处理代码。

解决办法

基于开源代码的改进。

import os
from scipy.io import loadmat
import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from datasets.SequenceDatasets import dataset
from datasets.sequence_aug import *
from tqdm import tqdm



def get_files(root, N):
    '''
    This function is used to generate the final training set and test set.
    root:The location of the data set
    '''
    data = []
    lab =[]
    for k in range(len(N)):
        for n in tqdm(range(len(dataname[N[k]]))):
            if n==0:
               path1 =os.path.join(root,datasetname[3], dataname[N[k]][n]).replace("\\", "/")
            else:
                path1 = os.path.join(root,datasetname[0], dataname[N[k]][n]).replace("\\", "/")
            data1, lab1 = data_load(path1,dataname[N[k]][n],label=label[n])
            data += data1
            lab +=lab1

    return [data, lab]


def data_load(filename, axisname, label):
    '''
    This function is mainly used to generate test data and training data.
    filename:Data location
    axisname:Select which channel's data,---->"_DE_time","_FE_time","_BA_time"
    '''
    datanumber = axisname.split(".")
    if eval(datanumber[0]) < 100:
        realaxis = "X0" + datanumber[0] + axis[0]
    else:
        realaxis = "X" + datanumber[0] + axis[0]
    fl = loadmat(filename)[realaxis]
    data = []
    lab = []
    start, end = 0, signal_size
    while end <= fl.shape[0]:
        data.append(fl[start:end])
        lab.append(label)
        start += signal_size
        end += signal_size

    return data, lab



def data_split(data_dir,transfer_task,normlizetype="0-1",transfer_learning=True):
    source_N = transfer_task[0]
    target_N = transfer_task[1]
    data_transforms = {
        'train': Compose([
            Reshape(),
            Normalize(normlizetype),
            # RandomAddGaussian(),
            # RandomScale(),
            # RandomStretch(),
            # RandomCrop(),
            Retype(),
            # Scale(1)
        ]),
        'val': Compose([
            Reshape(),
            Normalize(normlizetype),
            Retype(),
            # Scale(1)
        ])
    }
    if transfer_learning:
        # get source train and val
        list_data = get_files(data_dir, source_N)
        data_pd = pd.DataFrame({"data": list_data[0], "label": list_data[1]})
        train_pd, val_pd = train_test_split(data_pd, test_size=0.2, random_state=40, stratify=data_pd["label"])
        source_train = dataset(list_data=train_pd, transform=data_transforms['train'])
        source_val = dataset(list_data=val_pd, transform=data_transforms['val'])

        # get target train and val
        list_data = get_files(data_dir, target_N)
        data_pd = pd.DataFrame({"data": list_data[0], "label": list_data[1]})
        train_pd, val_pd = train_test_split(data_pd, test_size=0.2, random_state=40, stratify=data_pd["label"])
        target_train = dataset(list_data=train_pd, transform=data_transforms['train'])
        target_val = dataset(list_data=val_pd, transform=data_transforms['val'])
        return source_train, source_val   #, target_train, target_val
    else:
        #get source train and val
        list_data = get_files(data_dir, source_N)
        data_pd = pd.DataFrame({"data": list_data[0], "label": list_data[1]})
        trval_pd, test_pd = train_test_split(data_pd, test_size=0.2, random_state=40)  #, stratify=data_pd["label"]
        train_pd, val_pd = train_test_split(trval_pd, test_size=0.5, random_state=40)
        xtrain = train_pd['data'].values
        ytrain = train_pd['label'].values
        xval = val_pd['data'].values
        yval = val_pd['label'].values
        xtest = val_pd['data'].values
        ytest = val_pd['label'].values
        # source_train = dataset(list_data=train_pd, transform=data_transforms['train'])
        # source_val = dataset(list_data=val_pd, transform=data_transforms['val'])

        # # get target train and val
        # list_data = get_files(data_dir, target_N)
        # data_pd = pd.DataFrame({"data": list_data[0], "label": list_data[1]})
        # xtest = data_pd['data'].values
        # ytest = data_pd['label'].values
        # # target_val = dataset(list_data=data_pd, transform=data_transforms['val'])
        return  xtrain, ytrain , xval , yval , xtest , ytest #source_train, source_val, target_val

if __name__ == '__main__':
    #Digital data was collected at 12,000 samples per second
    signal_size = 1024
    dataname= {0:["97.mat","105.mat", "118.mat", "130.mat", "169.mat", "185.mat", "197.mat", "209.mat", "222.mat","234.mat"],  # 1797rpm
               1:["98.mat","106.mat", "119.mat", "131.mat", "170.mat", "186.mat", "198.mat", "210.mat", "223.mat","235.mat"],  # 1772rpm
               2:["99.mat","107.mat", "120.mat", "132.mat", "171.mat", "187.mat", "199.mat", "211.mat", "224.mat","236.mat"],  # 1750rpm
               3:["100.mat","108.mat", "121.mat","133.mat", "172.mat", "188.mat", "200.mat", "212.mat", "225.mat","237.mat"]}  # 1730rpm

    datasetname = ["12k Drive End Bearing Fault Data", "12k Fan End Bearing Fault Data", "48k Drive End Bearing Fault Data",
                   "Normal Baseline Data"]
    axis = ["_DE_time", "_FE_time", "_BA_time"]

    label = [i for i in range(0, 10)]

    data_dir = '../cwru'
    output_dir = '../../data/CWRU'
    transfer_task = [[0], [3]]
    normlizetype = 'mean - std'

    X_train,y_train,X_val,y_val,X_test,y_test = data_split(data_dir,transfer_task,normlizetype,transfer_learning=False)
    print(X_train)

    dat_dict = dict()
    # X_train = X_train.permute(0, 2, 1)
    dat_dict["samples"] = torch.tensor([item for item in X_train])
    dat_dict["samples"] = dat_dict["samples"].permute(0, 2, 1)
    dat_dict["labels"] = torch.from_numpy(y_train)
    torch.save(dat_dict, os.path.join(output_dir, "train.pt"))

    dat_dict = dict()
    dat_dict["samples"] = torch.tensor([item for item in X_val])
    dat_dict["samples"] = dat_dict["samples"].permute(0, 2, 1)
    dat_dict["labels"] = torch.from_numpy(y_val)
    torch.save(dat_dict, os.path.join(output_dir, "val.pt"))

    dat_dict = dict()
    dat_dict["samples"] = torch.tensor([item for item in X_test])
    dat_dict["samples"] = dat_dict["samples"].permute(0, 2, 1)
    dat_dict["labels"] = torch.from_numpy(y_test)
    torch.save(dat_dict, os.path.join(output_dir, "test.pt"))


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158

数据链接:https://pan.baidu.com/s/1ZKs3Ux_apfhyBL3RrpiEPQ
提取码:2f9j
–来自百度网盘超级会员V4的分享

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/776196
推荐阅读
相关标签
  

闽ICP备14008679号