当前位置:   article > 正文

【tensorflow框架神经网络实现鸢尾花分类_Keras】

【tensorflow框架神经网络实现鸢尾花分类_Keras】

1、前言

【tensorflow框架神经网络实现鸢尾花分类】一文中使用自定义的方式,实现了鸢尾花数据集的分类工作。在这里使用tensorflow中的keras模块快速、极简实现鸢尾花分类任务。

2、鸢尾花分类

import tensorflow as tf
from sklearn import datasets
import numpy as np

# 加载数据集
np.random.seed(0)
iris = datasets.load_iris()
x_train, y_train = iris.data, iris.target
np.random.seed(0)
np.random.shuffle(x_train)
np.random.seed(0)
np.random.shuffle(y_train)

# 设置随机种子
tf.random.set_seed(0)

# 构建模型
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2(0.01))
])

# 编译模型
model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.1),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

# 训练模型
model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)

# 打印模型摘要
model.summary()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31

3、结果打印

在这里插入图片描述

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号