相关推荐recommended
【深度学习】实验10 使用Keras完成逻辑回归
作者:mmseoamin日期:2024-03-04

文章目录

  • 使用Keras完成逻辑回归
    • 1. 导入Keras库
    • 2. 生成数据集
    • 3. 构造神经网络模型
    • 4. 训练模型
    • 5. 测试模型
    • 6. 分析模型
    • 附:系列文章

      使用Keras完成逻辑回归

      Keras是一个开源的深度学习框架,能够高效地实现神经网络和深度学习模型。它由纽约大学的Francois Chollet开发,旨在提供一个简单易用的高层次API,以便开发人员能够快速搭建模型,从而节省时间和精力。Keras能够兼容各种底层深度学习框架,如TensorFlow、Theano和CNTK等。它已经成为深度学习领域中最受欢迎的框架之一,因为它既容易上手又具有灵活性。

      Keras的设计初衷是让深度学习变得更容易,更快速地实现从数据到模型的过程。在使用Keras进行深度学习时,您无需编写多行代码来定义神经层、激活函数、优化器和损失函数等超参数,只需一行代码即可。此外,Keras还提供了丰富的预训练模型,可用于处理图像分类、自然语言处理、文本分类和序列分析等任务,从而大大减少了深度学习模型的开发和训练时间。

      Keras还具有以下特点:

      1. 简单易用:Keras使用Python编写,提供了简单的API接口,让用户更加关注模型的设计和调整。

      2. 易于扩展:Keras可以兼容多种深度学习框架,如TensorFlow、Theano和CNTK等,能够利用它们的计算能力进行高效的训练和推理。

      3. 快速实现:Keras提供了多种预训练模型,无需从头开始开发模型,快速构建出高质量的深度学习模型。

      4. 支持多种语言:Keras不仅支持Python编程语言,还支持R和Java等其他编程语言。

      5. 开源社区:Keras在GitHub上有庞大的开源社区,拥有丰富的教程和示例,以便开发人员更好地学习和使用。

      总之,Keras是一个简单易用、高效实现深度学习模型的框架,能够大大提升深度学习模型的开发和实现效率。

      1. 导入Keras库

      # 导入相关库
      from keras.models import Sequential
      from keras.layers import Dense, Dropout, Activation, Flatten
      from sklearn import datasets
      import numpy as np
      import matplotlib.pyplot as plt
      import warnings
      warnings.filterwarnings('ignore')
      
      Using TensorFlow backend.
      

      2. 生成数据集

      # 生成样本数据集,两个特征列,两个分类二分类不需要onehot编码,直接将类别转换为0和1,分别代表正样本的概率。
      X, y = datasets.make_classification(n_samples = 200, n_features = 2, n_informative = 2, n_redundant = 0,
                                         n_repeated = 0, n_classes = 2, n_clusters_per_class = 1)
      X, y
      
         (array([[ 0.26364611,  0.77250816],
                 [ 0.91698377,  0.9802208 ],
                 [ 0.82634329,  0.9821341 ],
                 [-0.83833456,  0.88223515],
                 [ 1.11509338,  0.98632275],
                 [ 1.04196821,  0.97892474],
                 [ 0.77695264,  1.06320914],
                 [-2.16804253,  0.15267335],
                 [-1.96973867,  0.99244728],
                 [-1.35368845,  1.25840447],
                 [-0.52455148,  2.2351536 ],
                 [ 1.08554563,  1.03795405],
                 [ 0.88261697,  0.97793289],
                 [-1.03718795,  0.53830131],
                 [ 0.94628633,  0.96289949],
                 [ 1.16190683,  1.01806263],
                 [-2.07795249,  0.32376505],
                 [ 0.9370119 ,  1.01060097],
                 [ 0.92750449,  0.98713143],
                 [-0.35800128,  1.4498587 ],
                 [-0.96709704,  1.77632874],
                 [-0.55995817,  1.58782776],
                 [ 0.88919948,  1.00133032],
                 [ 1.16465115,  1.05117935],
                 [-1.6969619 ,  1.80088135],
                 [ 1.06292602,  1.04594288],
                 [-0.07792111,  0.98391779],
                 [-1.05188451,  1.26871626],
                 [-0.83494005,  0.93958161],
                 [ 1.10371115,  1.03558148],
                 [ 0.98674372,  1.04567265],
                 [-1.08345028,  1.18601788],
                 [-2.06487683,  0.17118219],
                 [ 1.02734931,  0.99326938],
                 [-0.11345441,  1.08515199],
                 [ 0.97705823,  1.01751506],
                 [-0.10872522,  0.91580496],
                 [-1.27087508, -0.19146954],
                 [ 0.87616438,  0.97685435],
                 [ 0.89526079,  0.98651642],
                 [ 0.96521071,  1.0206381 ],
                 [ 1.0530243 ,  0.93365071],
                 [ 0.994778  ,  0.99724912],
                 [ 0.98176246,  1.03168734],
                 [ 0.74458014,  0.97066564],
                 [ 0.91748012,  0.9524803 ],
                 [-1.92749946,  0.07784549],
                 [ 0.7790389 ,  0.95517882],
                 [ 0.11824333,  1.81065221],
                 [ 0.97490265,  0.95326328],
                 [ 1.00355225,  0.96521073],
                 [ 1.08398178,  0.97814922],
                 [-1.0749128 ,  1.77825305],
                 [ 0.74886096,  1.39448605],
                 [-0.1950267 ,  1.57178284],
                 [ 1.069671  ,  0.97202065],
                 [ 0.85757149,  1.01910676],
                 [-1.02014343,  1.14016873],
                 [-1.25252256,  0.02906454],
                 [ 0.93948239,  1.44153932],
                 [ 1.28777891,  1.00133477],
                 [-1.7010408 ,  0.0821629 ],
                 [ 0.8390028 ,  0.97712472],
                 [ 0.99480479,  1.05717262],
                 [ 1.20707509,  0.97462669],
                 [-2.18786288,  1.4515569 ],
                 [ 1.16027197,  1.09086817],
                 [ 1.02771087,  0.9907291 ],
                 [ 0.71829704,  0.98817911],
                 [ 0.88605935,  0.99158972],
                 [ 1.03589316,  0.99557438],
                 [ 1.15489923,  0.95378093],
                 [ 1.0668616 ,  0.99316509],
                 [ 1.04848333,  1.09471239],
                 [-1.05108888, -0.071106  ],
                 [-1.19977682,  1.49257613],
                 [ 1.12232276,  0.99293853],
                 [-0.36977293,  1.59581   ],
                 [-0.27363841,  1.46272407],
                 [ 1.18075342,  0.95907983],
                 [ 1.01486256,  0.97501177],
                 [-0.41533403,  1.72366429],
                 [-0.18337732,  2.26674615],
                 [ 1.06777804,  1.00982417],
                 [ 1.17411206,  0.98088369],
                 [ 0.95355889,  1.05238272],
                 [-0.39459255,  1.97600217],
                 [ 0.90103447,  0.94080238],
                 [ 0.87268023,  1.00348657],
                 [-1.93323667,  1.04826094],
                 [ 0.10460058,  1.16348717],
                 [-1.85815599,  1.32669461],
                 [ 0.90426972,  0.97521677],
                 [-0.58409513,  0.9870014 ],
                 [-1.74011619, -0.21416096],
                 [-1.51931589,  0.34938829],
                 [ 1.02631005,  0.99378866],
                 [ 1.02869184,  0.99995857],
                 [ 0.79862419,  1.00291807],
                 [-1.34714457,  0.78937109],
                 [-2.54273315,  0.96748855],
                 [-1.86729291,  0.37250653],
                 [-0.89843699,  0.43898384],
                 [-1.83077543,  0.43636701],
                 [-0.89141966,  1.57275938],
                 [-0.96662858,  0.8196104 ],
                 [ 0.87417528,  1.00989496],
                 [ 0.93997582,  0.95616278],
                 [-1.85338565,  1.00940185],
                 [ 0.89565224,  0.95460192],
                 [-0.76327569,  0.93526008],
                 [-1.78345269,  1.53378105],
                 [ 0.77408528,  1.01387371],
                 [-1.47669576,  1.43472266],
                 [ 1.19417792,  1.0440538 ],
                 [ 1.15595665,  0.96823244],
                 [ 0.84068935,  1.01792225],
                 [ 1.11747629,  1.05722511],
                 [ 0.23722569,  1.54396395],
                 [-1.24609914,  0.30094681],
                 [-0.18745572,  1.04657197],
                 [ 0.90607352,  0.96120285],
                 [-2.02612   ,  0.44082817],
                 [ 0.8762596 ,  1.00607109],
                 [ 0.98791921,  1.02441508],
                 [-0.65307666,  1.22493946],
                 [ 0.94162298,  1.28044258],
                 [ 0.8622878 ,  0.99707326],
                 [-0.27590245,  1.1547649 ],
                 [ 0.99268975,  1.02885589],
                 [ 1.0635428 ,  1.03445117],
                 [-2.1378345 ,  0.62797163],
                 [-1.40559883,  0.26079323],
                 [ 1.07732353,  1.01373432],
                 [-1.74785838,  1.25425571],
                 [-0.51461996,  1.2583831 ],
                 [ 1.02632384,  1.00203908],
                 [ 0.84413823,  2.99872324],
                 [ 1.10319604,  0.9615482 ],
                 [ 0.95870127,  1.0461775 ],
                 [-1.61872726,  0.55348188],
                 [ 1.22219183,  1.00893646],
                 [-0.04807925,  1.69061295],
                 [-3.86851327, -0.36829707],
                 [-0.84318558,  0.71791949],
                 [ 0.95549697,  1.02457587],
                 [ 0.15484069,  0.80992914],
                 [ 1.1947279 ,  1.02301068],
                 [-0.88323476,  1.52212056],
                 [ 0.82715121,  0.99856576],
                 [-0.97808876,  2.01262021],
                 [-1.66906556,  0.70668215],
                 [ 1.29672679,  0.64929896],
                 [-0.45096669,  1.88364922],
                 [-2.70110985,  0.36698604],
                 [ 1.0795718 ,  1.02443886],
                 [ 0.99150574,  0.98348741],
                 [-0.65205587,  1.86131659],
                 [-0.56754302,  1.87827013],
                 [ 1.12356817,  1.06645171],
                 [-2.72752499,  0.43018586],
                 [-2.74061782, -0.08021407],
                 [-0.3200331 ,  1.09683115],
                 [ 1.0768664 ,  1.0085724 ],
                 [-3.6325113 ,  0.67221516],
                 [ 0.25830215,  0.79172286],
                 [ 1.07796662,  1.00493526],
                 [ 0.89606453,  0.98028498],
                 [-0.94518278,  1.52377526],
                 [ 0.90935946,  0.90695147],
                 [ 1.0148515 ,  1.06783713],
                 [ 1.16686534,  0.99312304],
                 [-1.31640844,  0.32636521],
                 [-1.39485695,  0.47605367],
                 [-0.50763796,  2.04039346],
                 [-0.58489137,  1.16215935],
                 [-1.21643673,  1.16555051],
                 [-2.9813908 , -0.02123246],
                 [ 1.05056765,  1.0129612 ],
                 [ 1.01961575,  1.03539024],
                 [ 1.01227271,  0.96751672],
                 [ 0.12444867,  1.38342266],
                 [ 0.99713663,  0.96095512],
                 [ 0.98185855,  0.9941474 ],
                 [ 0.92998157,  1.03644759],
                 [-0.18646788,  2.02399395],
                 [-1.79776907,  0.97067984],
                 [-3.23433111,  0.54897531],
                 [-2.18617596,  0.33414794],
                 [ 1.16844027,  1.01821873],
                 [ 1.0428281 ,  1.01154471],
                 [ 0.9159169 ,  1.02463567],
                 [-1.3578118 ,  0.67183832],
                 [-0.58824562,  1.08975919],
                 [ 1.01775857,  1.00733938],
                 [ 1.14847576,  1.01783862],
                 [-1.1115874 ,  0.42278247],
                 [ 0.84772713,  0.99733494],
                 [ 1.00417018,  0.93763177],
                 [ 0.56134549,  1.20390517]]),
          array([1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1,
                 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0,
                 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1,
                 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0,
                 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0,
                 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1,
                 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1,
                 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1,
                 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0,
                 0, 1]))
      

      3. 构造神经网络模型

      # 构建神经网络模型
      model = Sequential()
      model.add(Dense(input_dim = 2, units = 1))
      model.add(Activation('sigmoid'))
      
      # 选定loss函数和优化器
      model.compile(loss = 'binary_crossentropy', optimizer = 'sgd')
      
      WARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/nn_impl.py:180: add_dispatch_support..wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
      Instructions for updating:
      Use tf.where in 2.0, which has the same broadcast rule as np.where
      

      4. 训练模型

      # 训练过程
      print("Training ----------")
      for step in range(501):
          cost = model.train_on_batch(X, y)
          if step % 50 == 0:
              print("After %d trainings, the cost: %f" % (step, cost))
      
      Training ----------
      After 0 trainings, the cost: 0.370295
      After 50 trainings, the cost: 0.349558
      After 100 trainings, the cost: 0.331982
      After 150 trainings, the cost: 0.316872
      After 200 trainings, the cost: 0.303725
      After 250 trainings, the cost: 0.292170
      After 300 trainings, the cost: 0.281923
      After 350 trainings, the cost: 0.272767
      After 400 trainings, the cost: 0.264532
      After 450 trainings, the cost: 0.257079
      After 500 trainings, the cost: 0.250299
      

      5. 测试模型

      # 测试过程
      print("Testing ----------")
      cost = model.evaluate(X, y, batch_size = 40)
      print("test cost:", cost)
      W, b = model.layers[0].get_weights()
      print('Weights = ', W, '\nbiases = ', b)
      
      Testing ----------
      200/200 [==============================] - 0s 53us/step
      test cost: 0.25016908943653104
      Weights =  [[-1.7198342 ]
       [-0.18482684]] 
      biases =  [0.47288144]
      

      6. 分析模型

      # 将训练结果绘出
      Y_pred = model.predict(X)
      # 将概率转化为类标号,概率在0-0.5时,转为0,概率在0.5-1时转为1
      Y_pred = (Y_pred*2).astype('int')  
      
      # 绘制散点图 参数:x横轴 y纵轴
      plt.subplot(2,1,1).scatter(X[:,0], X[:,1], c=Y_pred[:,0])
      plt.subplot(2,1,2).scatter(X[:,0], X[:,1], c=y)
      plt.show()
      

      【深度学习】实验10 使用Keras完成逻辑回归,1,第1张

      附:系列文章

      序号文章目录直达链接
      1波士顿房价预测https://want595.blog.csdn.net/article/details/132181950
      2鸢尾花数据集分析https://want595.blog.csdn.net/article/details/132182057
      3特征处理https://want595.blog.csdn.net/article/details/132182165
      4交叉验证https://want595.blog.csdn.net/article/details/132182238
      5构造神经网络示例https://want595.blog.csdn.net/article/details/132182341
      6使用TensorFlow完成线性回归https://want595.blog.csdn.net/article/details/132182417
      7使用TensorFlow完成逻辑回归https://want595.blog.csdn.net/article/details/132182496
      8TensorBoard案例https://want595.blog.csdn.net/article/details/132182584
      9使用Keras完成线性回归https://want595.blog.csdn.net/article/details/132182723
      10使用Keras完成逻辑回归https://want595.blog.csdn.net/article/details/132182795
      11使用Keras预训练模型完成猫狗识别https://want595.blog.csdn.net/article/details/132243928
      12使用PyTorch训练模型https://want595.blog.csdn.net/article/details/132243989
      13使用Dropout抑制过拟合https://want595.blog.csdn.net/article/details/132244111
      14使用CNN完成MNIST手写体识别(TensorFlow)https://want595.blog.csdn.net/article/details/132244499
      15使用CNN完成MNIST手写体识别(Keras)https://want595.blog.csdn.net/article/details/132244552
      16使用CNN完成MNIST手写体识别(PyTorch)https://want595.blog.csdn.net/article/details/132244641
      17使用GAN生成手写数字样本https://want595.blog.csdn.net/article/details/132244764
      18自然语言处理https://want595.blog.csdn.net/article/details/132276591