Keras是一个开源的深度学习框架,能够高效地实现神经网络和深度学习模型。它由纽约大学的Francois Chollet开发,旨在提供一个简单易用的高层次API,以便开发人员能够快速搭建模型,从而节省时间和精力。Keras能够兼容各种底层深度学习框架,如TensorFlow、Theano和CNTK等。它已经成为深度学习领域中最受欢迎的框架之一,因为它既容易上手又具有灵活性。
Keras的设计初衷是让深度学习变得更容易,更快速地实现从数据到模型的过程。在使用Keras进行深度学习时,您无需编写多行代码来定义神经层、激活函数、优化器和损失函数等超参数,只需一行代码即可。此外,Keras还提供了丰富的预训练模型,可用于处理图像分类、自然语言处理、文本分类和序列分析等任务,从而大大减少了深度学习模型的开发和训练时间。
Keras还具有以下特点:
简单易用:Keras使用Python编写,提供了简单的API接口,让用户更加关注模型的设计和调整。
易于扩展:Keras可以兼容多种深度学习框架,如TensorFlow、Theano和CNTK等,能够利用它们的计算能力进行高效的训练和推理。
快速实现:Keras提供了多种预训练模型,无需从头开始开发模型,快速构建出高质量的深度学习模型。
支持多种语言:Keras不仅支持Python编程语言,还支持R和Java等其他编程语言。
开源社区:Keras在GitHub上有庞大的开源社区,拥有丰富的教程和示例,以便开发人员更好地学习和使用。
总之,Keras是一个简单易用、高效实现深度学习模型的框架,能够大大提升深度学习模型的开发和实现效率。
# 导入相关库 from keras.models import Sequential from keras.layers import Dense, Dropout, Activation, Flatten from sklearn import datasets import numpy as np import matplotlib.pyplot as plt import warnings warnings.filterwarnings('ignore')
Using TensorFlow backend.
# 生成样本数据集,两个特征列,两个分类二分类不需要onehot编码,直接将类别转换为0和1,分别代表正样本的概率。 X, y = datasets.make_classification(n_samples = 200, n_features = 2, n_informative = 2, n_redundant = 0, n_repeated = 0, n_classes = 2, n_clusters_per_class = 1) X, y
(array([[ 0.26364611, 0.77250816], [ 0.91698377, 0.9802208 ], [ 0.82634329, 0.9821341 ], [-0.83833456, 0.88223515], [ 1.11509338, 0.98632275], [ 1.04196821, 0.97892474], [ 0.77695264, 1.06320914], [-2.16804253, 0.15267335], [-1.96973867, 0.99244728], [-1.35368845, 1.25840447], [-0.52455148, 2.2351536 ], [ 1.08554563, 1.03795405], [ 0.88261697, 0.97793289], [-1.03718795, 0.53830131], [ 0.94628633, 0.96289949], [ 1.16190683, 1.01806263], [-2.07795249, 0.32376505], [ 0.9370119 , 1.01060097], [ 0.92750449, 0.98713143], [-0.35800128, 1.4498587 ], [-0.96709704, 1.77632874], [-0.55995817, 1.58782776], [ 0.88919948, 1.00133032], [ 1.16465115, 1.05117935], [-1.6969619 , 1.80088135], [ 1.06292602, 1.04594288], [-0.07792111, 0.98391779], [-1.05188451, 1.26871626], [-0.83494005, 0.93958161], [ 1.10371115, 1.03558148], [ 0.98674372, 1.04567265], [-1.08345028, 1.18601788], [-2.06487683, 0.17118219], [ 1.02734931, 0.99326938], [-0.11345441, 1.08515199], [ 0.97705823, 1.01751506], [-0.10872522, 0.91580496], [-1.27087508, -0.19146954], [ 0.87616438, 0.97685435], [ 0.89526079, 0.98651642], [ 0.96521071, 1.0206381 ], [ 1.0530243 , 0.93365071], [ 0.994778 , 0.99724912], [ 0.98176246, 1.03168734], [ 0.74458014, 0.97066564], [ 0.91748012, 0.9524803 ], [-1.92749946, 0.07784549], [ 0.7790389 , 0.95517882], [ 0.11824333, 1.81065221], [ 0.97490265, 0.95326328], [ 1.00355225, 0.96521073], [ 1.08398178, 0.97814922], [-1.0749128 , 1.77825305], [ 0.74886096, 1.39448605], [-0.1950267 , 1.57178284], [ 1.069671 , 0.97202065], [ 0.85757149, 1.01910676], [-1.02014343, 1.14016873], [-1.25252256, 0.02906454], [ 0.93948239, 1.44153932], [ 1.28777891, 1.00133477], [-1.7010408 , 0.0821629 ], [ 0.8390028 , 0.97712472], [ 0.99480479, 1.05717262], [ 1.20707509, 0.97462669], [-2.18786288, 1.4515569 ], [ 1.16027197, 1.09086817], [ 1.02771087, 0.9907291 ], [ 0.71829704, 0.98817911], [ 0.88605935, 0.99158972], [ 1.03589316, 0.99557438], [ 1.15489923, 0.95378093], [ 1.0668616 , 0.99316509], [ 1.04848333, 1.09471239], [-1.05108888, -0.071106 ], [-1.19977682, 1.49257613], [ 1.12232276, 0.99293853], [-0.36977293, 1.59581 ], [-0.27363841, 1.46272407], [ 1.18075342, 0.95907983], [ 1.01486256, 0.97501177], [-0.41533403, 1.72366429], [-0.18337732, 2.26674615], [ 1.06777804, 1.00982417], [ 1.17411206, 0.98088369], [ 0.95355889, 1.05238272], [-0.39459255, 1.97600217], [ 0.90103447, 0.94080238], [ 0.87268023, 1.00348657], [-1.93323667, 1.04826094], [ 0.10460058, 1.16348717], [-1.85815599, 1.32669461], [ 0.90426972, 0.97521677], [-0.58409513, 0.9870014 ], [-1.74011619, -0.21416096], [-1.51931589, 0.34938829], [ 1.02631005, 0.99378866], [ 1.02869184, 0.99995857], [ 0.79862419, 1.00291807], [-1.34714457, 0.78937109], [-2.54273315, 0.96748855], [-1.86729291, 0.37250653], [-0.89843699, 0.43898384], [-1.83077543, 0.43636701], [-0.89141966, 1.57275938], [-0.96662858, 0.8196104 ], [ 0.87417528, 1.00989496], [ 0.93997582, 0.95616278], [-1.85338565, 1.00940185], [ 0.89565224, 0.95460192], [-0.76327569, 0.93526008], [-1.78345269, 1.53378105], [ 0.77408528, 1.01387371], [-1.47669576, 1.43472266], [ 1.19417792, 1.0440538 ], [ 1.15595665, 0.96823244], [ 0.84068935, 1.01792225], [ 1.11747629, 1.05722511], [ 0.23722569, 1.54396395], [-1.24609914, 0.30094681], [-0.18745572, 1.04657197], [ 0.90607352, 0.96120285], [-2.02612 , 0.44082817], [ 0.8762596 , 1.00607109], [ 0.98791921, 1.02441508], [-0.65307666, 1.22493946], [ 0.94162298, 1.28044258], [ 0.8622878 , 0.99707326], [-0.27590245, 1.1547649 ], [ 0.99268975, 1.02885589], [ 1.0635428 , 1.03445117], [-2.1378345 , 0.62797163], [-1.40559883, 0.26079323], [ 1.07732353, 1.01373432], [-1.74785838, 1.25425571], [-0.51461996, 1.2583831 ], [ 1.02632384, 1.00203908], [ 0.84413823, 2.99872324], [ 1.10319604, 0.9615482 ], [ 0.95870127, 1.0461775 ], [-1.61872726, 0.55348188], [ 1.22219183, 1.00893646], [-0.04807925, 1.69061295], [-3.86851327, -0.36829707], [-0.84318558, 0.71791949], [ 0.95549697, 1.02457587], [ 0.15484069, 0.80992914], [ 1.1947279 , 1.02301068], [-0.88323476, 1.52212056], [ 0.82715121, 0.99856576], [-0.97808876, 2.01262021], [-1.66906556, 0.70668215], [ 1.29672679, 0.64929896], [-0.45096669, 1.88364922], [-2.70110985, 0.36698604], [ 1.0795718 , 1.02443886], [ 0.99150574, 0.98348741], [-0.65205587, 1.86131659], [-0.56754302, 1.87827013], [ 1.12356817, 1.06645171], [-2.72752499, 0.43018586], [-2.74061782, -0.08021407], [-0.3200331 , 1.09683115], [ 1.0768664 , 1.0085724 ], [-3.6325113 , 0.67221516], [ 0.25830215, 0.79172286], [ 1.07796662, 1.00493526], [ 0.89606453, 0.98028498], [-0.94518278, 1.52377526], [ 0.90935946, 0.90695147], [ 1.0148515 , 1.06783713], [ 1.16686534, 0.99312304], [-1.31640844, 0.32636521], [-1.39485695, 0.47605367], [-0.50763796, 2.04039346], [-0.58489137, 1.16215935], [-1.21643673, 1.16555051], [-2.9813908 , -0.02123246], [ 1.05056765, 1.0129612 ], [ 1.01961575, 1.03539024], [ 1.01227271, 0.96751672], [ 0.12444867, 1.38342266], [ 0.99713663, 0.96095512], [ 0.98185855, 0.9941474 ], [ 0.92998157, 1.03644759], [-0.18646788, 2.02399395], [-1.79776907, 0.97067984], [-3.23433111, 0.54897531], [-2.18617596, 0.33414794], [ 1.16844027, 1.01821873], [ 1.0428281 , 1.01154471], [ 0.9159169 , 1.02463567], [-1.3578118 , 0.67183832], [-0.58824562, 1.08975919], [ 1.01775857, 1.00733938], [ 1.14847576, 1.01783862], [-1.1115874 , 0.42278247], [ 0.84772713, 0.99733494], [ 1.00417018, 0.93763177], [ 0.56134549, 1.20390517]]), array([1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1]))
# 构建神经网络模型 model = Sequential() model.add(Dense(input_dim = 2, units = 1)) model.add(Activation('sigmoid'))
# 选定loss函数和优化器 model.compile(loss = 'binary_crossentropy', optimizer = 'sgd')
WARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/nn_impl.py:180: add_dispatch_support..wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version. Instructions for updating: Use tf.where in 2.0, which has the same broadcast rule as np.where
# 训练过程 print("Training ----------") for step in range(501): cost = model.train_on_batch(X, y) if step % 50 == 0: print("After %d trainings, the cost: %f" % (step, cost))
Training ---------- After 0 trainings, the cost: 0.370295 After 50 trainings, the cost: 0.349558 After 100 trainings, the cost: 0.331982 After 150 trainings, the cost: 0.316872 After 200 trainings, the cost: 0.303725 After 250 trainings, the cost: 0.292170 After 300 trainings, the cost: 0.281923 After 350 trainings, the cost: 0.272767 After 400 trainings, the cost: 0.264532 After 450 trainings, the cost: 0.257079 After 500 trainings, the cost: 0.250299
# 测试过程 print("Testing ----------") cost = model.evaluate(X, y, batch_size = 40) print("test cost:", cost) W, b = model.layers[0].get_weights() print('Weights = ', W, '\nbiases = ', b)
Testing ---------- 200/200 [==============================] - 0s 53us/step test cost: 0.25016908943653104 Weights = [[-1.7198342 ] [-0.18482684]] biases = [0.47288144]
# 将训练结果绘出 Y_pred = model.predict(X) # 将概率转化为类标号,概率在0-0.5时,转为0,概率在0.5-1时转为1 Y_pred = (Y_pred*2).astype('int')
# 绘制散点图 参数:x横轴 y纵轴 plt.subplot(2,1,1).scatter(X[:,0], X[:,1], c=Y_pred[:,0]) plt.subplot(2,1,2).scatter(X[:,0], X[:,1], c=y) plt.show()
序号 | 文章目录 | 直达链接 |
---|---|---|
1 | 波士顿房价预测 | https://want595.blog.csdn.net/article/details/132181950 |
2 | 鸢尾花数据集分析 | https://want595.blog.csdn.net/article/details/132182057 |
3 | 特征处理 | https://want595.blog.csdn.net/article/details/132182165 |
4 | 交叉验证 | https://want595.blog.csdn.net/article/details/132182238 |
5 | 构造神经网络示例 | https://want595.blog.csdn.net/article/details/132182341 |
6 | 使用TensorFlow完成线性回归 | https://want595.blog.csdn.net/article/details/132182417 |
7 | 使用TensorFlow完成逻辑回归 | https://want595.blog.csdn.net/article/details/132182496 |
8 | TensorBoard案例 | https://want595.blog.csdn.net/article/details/132182584 |
9 | 使用Keras完成线性回归 | https://want595.blog.csdn.net/article/details/132182723 |
10 | 使用Keras完成逻辑回归 | https://want595.blog.csdn.net/article/details/132182795 |
11 | 使用Keras预训练模型完成猫狗识别 | https://want595.blog.csdn.net/article/details/132243928 |
12 | 使用PyTorch训练模型 | https://want595.blog.csdn.net/article/details/132243989 |
13 | 使用Dropout抑制过拟合 | https://want595.blog.csdn.net/article/details/132244111 |
14 | 使用CNN完成MNIST手写体识别(TensorFlow) | https://want595.blog.csdn.net/article/details/132244499 |
15 | 使用CNN完成MNIST手写体识别(Keras) | https://want595.blog.csdn.net/article/details/132244552 |
16 | 使用CNN完成MNIST手写体识别(PyTorch) | https://want595.blog.csdn.net/article/details/132244641 |
17 | 使用GAN生成手写数字样本 | https://want595.blog.csdn.net/article/details/132244764 |
18 | 自然语言处理 | https://want595.blog.csdn.net/article/details/132276591 |
上一篇:【C语言】学生考勤管理系统