Keras是一个开源的深度学习框架,能够高效地实现神经网络和深度学习模型。它由纽约大学的Francois Chollet开发,旨在提供一个简单易用的高层次API,以便开发人员能够快速搭建模型,从而节省时间和精力。Keras能够兼容各种底层深度学习框架,如TensorFlow、Theano和CNTK等。它已经成为深度学习领域中最受欢迎的框架之一,因为它既容易上手又具有灵活性。
Keras的设计初衷是让深度学习变得更容易,更快速地实现从数据到模型的过程。在使用Keras进行深度学习时,您无需编写多行代码来定义神经层、激活函数、优化器和损失函数等超参数,只需一行代码即可。此外,Keras还提供了丰富的预训练模型,可用于处理图像分类、自然语言处理、文本分类和序列分析等任务,从而大大减少了深度学习模型的开发和训练时间。
Keras还具有以下特点:
简单易用:Keras使用Python编写,提供了简单的API接口,让用户更加关注模型的设计和调整。
易于扩展:Keras可以兼容多种深度学习框架,如TensorFlow、Theano和CNTK等,能够利用它们的计算能力进行高效的训练和推理。
快速实现:Keras提供了多种预训练模型,无需从头开始开发模型,快速构建出高质量的深度学习模型。
支持多种语言:Keras不仅支持Python编程语言,还支持R和Java等其他编程语言。
开源社区:Keras在GitHub上有庞大的开源社区,拥有丰富的教程和示例,以便开发人员更好地学习和使用。
总之,Keras是一个简单易用、高效实现深度学习模型的框架,能够大大提升深度学习模型的开发和实现效率。
# 导入相关库
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from sklearn import datasets
import numpy as np
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
Using TensorFlow backend.
# 生成样本数据集,两个特征列,两个分类二分类不需要onehot编码,直接将类别转换为0和1,分别代表正样本的概率。
X, y = datasets.make_classification(n_samples = 200, n_features = 2, n_informative = 2, n_redundant = 0,
n_repeated = 0, n_classes = 2, n_clusters_per_class = 1)
X, y
(array([[ 0.26364611, 0.77250816],
[ 0.91698377, 0.9802208 ],
[ 0.82634329, 0.9821341 ],
[-0.83833456, 0.88223515],
[ 1.11509338, 0.98632275],
[ 1.04196821, 0.97892474],
[ 0.77695264, 1.06320914],
[-2.16804253, 0.15267335],
[-1.96973867, 0.99244728],
[-1.35368845, 1.25840447],
[-0.52455148, 2.2351536 ],
[ 1.08554563, 1.03795405],
[ 0.88261697, 0.97793289],
[-1.03718795, 0.53830131],
[ 0.94628633, 0.96289949],
[ 1.16190683, 1.01806263],
[-2.07795249, 0.32376505],
[ 0.9370119 , 1.01060097],
[ 0.92750449, 0.98713143],
[-0.35800128, 1.4498587 ],
[-0.96709704, 1.77632874],
[-0.55995817, 1.58782776],
[ 0.88919948, 1.00133032],
[ 1.16465115, 1.05117935],
[-1.6969619 , 1.80088135],
[ 1.06292602, 1.04594288],
[-0.07792111, 0.98391779],
[-1.05188451, 1.26871626],
[-0.83494005, 0.93958161],
[ 1.10371115, 1.03558148],
[ 0.98674372, 1.04567265],
[-1.08345028, 1.18601788],
[-2.06487683, 0.17118219],
[ 1.02734931, 0.99326938],
[-0.11345441, 1.08515199],
[ 0.97705823, 1.01751506],
[-0.10872522, 0.91580496],
[-1.27087508, -0.19146954],
[ 0.87616438, 0.97685435],
[ 0.89526079, 0.98651642],
[ 0.96521071, 1.0206381 ],
[ 1.0530243 , 0.93365071],
[ 0.994778 , 0.99724912],
[ 0.98176246, 1.03168734],
[ 0.74458014, 0.97066564],
[ 0.91748012, 0.9524803 ],
[-1.92749946, 0.07784549],
[ 0.7790389 , 0.95517882],
[ 0.11824333, 1.81065221],
[ 0.97490265, 0.95326328],
[ 1.00355225, 0.96521073],
[ 1.08398178, 0.97814922],
[-1.0749128 , 1.77825305],
[ 0.74886096, 1.39448605],
[-0.1950267 , 1.57178284],
[ 1.069671 , 0.97202065],
[ 0.85757149, 1.01910676],
[-1.02014343, 1.14016873],
[-1.25252256, 0.02906454],
[ 0.93948239, 1.44153932],
[ 1.28777891, 1.00133477],
[-1.7010408 , 0.0821629 ],
[ 0.8390028 , 0.97712472],
[ 0.99480479, 1.05717262],
[ 1.20707509, 0.97462669],
[-2.18786288, 1.4515569 ],
[ 1.16027197, 1.09086817],
[ 1.02771087, 0.9907291 ],
[ 0.71829704, 0.98817911],
[ 0.88605935, 0.99158972],
[ 1.03589316, 0.99557438],
[ 1.15489923, 0.95378093],
[ 1.0668616 , 0.99316509],
[ 1.04848333, 1.09471239],
[-1.05108888, -0.071106 ],
[-1.19977682, 1.49257613],
[ 1.12232276, 0.99293853],
[-0.36977293, 1.59581 ],
[-0.27363841, 1.46272407],
[ 1.18075342, 0.95907983],
[ 1.01486256, 0.97501177],
[-0.41533403, 1.72366429],
[-0.18337732, 2.26674615],
[ 1.06777804, 1.00982417],
[ 1.17411206, 0.98088369],
[ 0.95355889, 1.05238272],
[-0.39459255, 1.97600217],
[ 0.90103447, 0.94080238],
[ 0.87268023, 1.00348657],
[-1.93323667, 1.04826094],
[ 0.10460058, 1.16348717],
[-1.85815599, 1.32669461],
[ 0.90426972, 0.97521677],
[-0.58409513, 0.9870014 ],
[-1.74011619, -0.21416096],
[-1.51931589, 0.34938829],
[ 1.02631005, 0.99378866],
[ 1.02869184, 0.99995857],
[ 0.79862419, 1.00291807],
[-1.34714457, 0.78937109],
[-2.54273315, 0.96748855],
[-1.86729291, 0.37250653],
[-0.89843699, 0.43898384],
[-1.83077543, 0.43636701],
[-0.89141966, 1.57275938],
[-0.96662858, 0.8196104 ],
[ 0.87417528, 1.00989496],
[ 0.93997582, 0.95616278],
[-1.85338565, 1.00940185],
[ 0.89565224, 0.95460192],
[-0.76327569, 0.93526008],
[-1.78345269, 1.53378105],
[ 0.77408528, 1.01387371],
[-1.47669576, 1.43472266],
[ 1.19417792, 1.0440538 ],
[ 1.15595665, 0.96823244],
[ 0.84068935, 1.01792225],
[ 1.11747629, 1.05722511],
[ 0.23722569, 1.54396395],
[-1.24609914, 0.30094681],
[-0.18745572, 1.04657197],
[ 0.90607352, 0.96120285],
[-2.02612 , 0.44082817],
[ 0.8762596 , 1.00607109],
[ 0.98791921, 1.02441508],
[-0.65307666, 1.22493946],
[ 0.94162298, 1.28044258],
[ 0.8622878 , 0.99707326],
[-0.27590245, 1.1547649 ],
[ 0.99268975, 1.02885589],
[ 1.0635428 , 1.03445117],
[-2.1378345 , 0.62797163],
[-1.40559883, 0.26079323],
[ 1.07732353, 1.01373432],
[-1.74785838, 1.25425571],
[-0.51461996, 1.2583831 ],
[ 1.02632384, 1.00203908],
[ 0.84413823, 2.99872324],
[ 1.10319604, 0.9615482 ],
[ 0.95870127, 1.0461775 ],
[-1.61872726, 0.55348188],
[ 1.22219183, 1.00893646],
[-0.04807925, 1.69061295],
[-3.86851327, -0.36829707],
[-0.84318558, 0.71791949],
[ 0.95549697, 1.02457587],
[ 0.15484069, 0.80992914],
[ 1.1947279 , 1.02301068],
[-0.88323476, 1.52212056],
[ 0.82715121, 0.99856576],
[-0.97808876, 2.01262021],
[-1.66906556, 0.70668215],
[ 1.29672679, 0.64929896],
[-0.45096669, 1.88364922],
[-2.70110985, 0.36698604],
[ 1.0795718 , 1.02443886],
[ 0.99150574, 0.98348741],
[-0.65205587, 1.86131659],
[-0.56754302, 1.87827013],
[ 1.12356817, 1.06645171],
[-2.72752499, 0.43018586],
[-2.74061782, -0.08021407],
[-0.3200331 , 1.09683115],
[ 1.0768664 , 1.0085724 ],
[-3.6325113 , 0.67221516],
[ 0.25830215, 0.79172286],
[ 1.07796662, 1.00493526],
[ 0.89606453, 0.98028498],
[-0.94518278, 1.52377526],
[ 0.90935946, 0.90695147],
[ 1.0148515 , 1.06783713],
[ 1.16686534, 0.99312304],
[-1.31640844, 0.32636521],
[-1.39485695, 0.47605367],
[-0.50763796, 2.04039346],
[-0.58489137, 1.16215935],
[-1.21643673, 1.16555051],
[-2.9813908 , -0.02123246],
[ 1.05056765, 1.0129612 ],
[ 1.01961575, 1.03539024],
[ 1.01227271, 0.96751672],
[ 0.12444867, 1.38342266],
[ 0.99713663, 0.96095512],
[ 0.98185855, 0.9941474 ],
[ 0.92998157, 1.03644759],
[-0.18646788, 2.02399395],
[-1.79776907, 0.97067984],
[-3.23433111, 0.54897531],
[-2.18617596, 0.33414794],
[ 1.16844027, 1.01821873],
[ 1.0428281 , 1.01154471],
[ 0.9159169 , 1.02463567],
[-1.3578118 , 0.67183832],
[-0.58824562, 1.08975919],
[ 1.01775857, 1.00733938],
[ 1.14847576, 1.01783862],
[-1.1115874 , 0.42278247],
[ 0.84772713, 0.99733494],
[ 1.00417018, 0.93763177],
[ 0.56134549, 1.20390517]]),
array([1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1,
0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0,
0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0,
0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0,
1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1,
1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1,
1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1,
1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0,
0, 1]))
# 构建神经网络模型
model = Sequential()
model.add(Dense(input_dim = 2, units = 1))
model.add(Activation('sigmoid'))
# 选定loss函数和优化器 model.compile(loss = 'binary_crossentropy', optimizer = 'sgd')
WARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/nn_impl.py:180: add_dispatch_support..wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version. Instructions for updating: Use tf.where in 2.0, which has the same broadcast rule as np.where
# 训练过程
print("Training ----------")
for step in range(501):
cost = model.train_on_batch(X, y)
if step % 50 == 0:
print("After %d trainings, the cost: %f" % (step, cost))
Training ---------- After 0 trainings, the cost: 0.370295 After 50 trainings, the cost: 0.349558 After 100 trainings, the cost: 0.331982 After 150 trainings, the cost: 0.316872 After 200 trainings, the cost: 0.303725 After 250 trainings, the cost: 0.292170 After 300 trainings, the cost: 0.281923 After 350 trainings, the cost: 0.272767 After 400 trainings, the cost: 0.264532 After 450 trainings, the cost: 0.257079 After 500 trainings, the cost: 0.250299
# 测试过程
print("Testing ----------")
cost = model.evaluate(X, y, batch_size = 40)
print("test cost:", cost)
W, b = model.layers[0].get_weights()
print('Weights = ', W, '\nbiases = ', b)
Testing ---------- 200/200 [==============================] - 0s 53us/step test cost: 0.25016908943653104 Weights = [[-1.7198342 ] [-0.18482684]] biases = [0.47288144]
# 将训练结果绘出
Y_pred = model.predict(X)
# 将概率转化为类标号,概率在0-0.5时,转为0,概率在0.5-1时转为1
Y_pred = (Y_pred*2).astype('int')
# 绘制散点图 参数:x横轴 y纵轴 plt.subplot(2,1,1).scatter(X[:,0], X[:,1], c=Y_pred[:,0]) plt.subplot(2,1,2).scatter(X[:,0], X[:,1], c=y) plt.show()

| 序号 | 文章目录 | 直达链接 |
|---|---|---|
| 1 | 波士顿房价预测 | https://want595.blog.csdn.net/article/details/132181950 |
| 2 | 鸢尾花数据集分析 | https://want595.blog.csdn.net/article/details/132182057 |
| 3 | 特征处理 | https://want595.blog.csdn.net/article/details/132182165 |
| 4 | 交叉验证 | https://want595.blog.csdn.net/article/details/132182238 |
| 5 | 构造神经网络示例 | https://want595.blog.csdn.net/article/details/132182341 |
| 6 | 使用TensorFlow完成线性回归 | https://want595.blog.csdn.net/article/details/132182417 |
| 7 | 使用TensorFlow完成逻辑回归 | https://want595.blog.csdn.net/article/details/132182496 |
| 8 | TensorBoard案例 | https://want595.blog.csdn.net/article/details/132182584 |
| 9 | 使用Keras完成线性回归 | https://want595.blog.csdn.net/article/details/132182723 |
| 10 | 使用Keras完成逻辑回归 | https://want595.blog.csdn.net/article/details/132182795 |
| 11 | 使用Keras预训练模型完成猫狗识别 | https://want595.blog.csdn.net/article/details/132243928 |
| 12 | 使用PyTorch训练模型 | https://want595.blog.csdn.net/article/details/132243989 |
| 13 | 使用Dropout抑制过拟合 | https://want595.blog.csdn.net/article/details/132244111 |
| 14 | 使用CNN完成MNIST手写体识别(TensorFlow) | https://want595.blog.csdn.net/article/details/132244499 |
| 15 | 使用CNN完成MNIST手写体识别(Keras) | https://want595.blog.csdn.net/article/details/132244552 |
| 16 | 使用CNN完成MNIST手写体识别(PyTorch) | https://want595.blog.csdn.net/article/details/132244641 |
| 17 | 使用GAN生成手写数字样本 | https://want595.blog.csdn.net/article/details/132244764 |
| 18 | 自然语言处理 | https://want595.blog.csdn.net/article/details/132276591 |
上一篇:【C语言】学生考勤管理系统