新聞中心
這里有您想知道的互聯(lián)網(wǎng)營(yíng)銷(xiāo)解決方案
如何實(shí)現(xiàn)keras訓(xùn)練淺層卷積網(wǎng)絡(luò)并保存和加載模型-創(chuàng)新互聯(lián)
創(chuàng)新互聯(lián)www.cdcxhl.cn八線(xiàn)動(dòng)態(tài)BGP香港云服務(wù)器提供商,新人活動(dòng)買(mǎi)多久送多久,劃算不套路!
不懂如何實(shí)現(xiàn)keras訓(xùn)練淺層卷積網(wǎng)絡(luò)并保存和加載模型?其實(shí)想解決這個(gè)問(wèn)題也不難,下面讓小編帶著大家一起學(xué)習(xí)怎么去解決,希望大家閱讀完這篇文章后大所收獲。
這里我們使用keras定義簡(jiǎn)單的神經(jīng)網(wǎng)絡(luò)全連接層訓(xùn)練MNIST數(shù)據(jù)集和cifar10數(shù)據(jù)集:
keras_mnist.py
from sklearn.preprocessing import LabelBinarizer from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report from keras.models import Sequential from keras.layers.core import Dense from keras.optimizers import SGD from sklearn import datasets import matplotlib.pyplot as plt import numpy as np import argparse # 命令行參數(shù)運(yùn)行 ap = argparse.ArgumentParser() ap.add_argument("-o", "--output", required=True, help="path to the output loss/accuracy plot") args =vars(ap.parse_args()) # 加載數(shù)據(jù)MNIST,然后歸一化到【0,1】,同時(shí)使用75%做訓(xùn)練,25%做測(cè)試 print("[INFO] loading MNIST (full) dataset") dataset = datasets.fetch_mldata("MNIST Original", data_home="/home/king/test/python/train/pyimagesearch/nn/data/") data = dataset.data.astype("float") / 255.0 (trainX, testX, trainY, testY) = train_test_split(data, dataset.target, test_size=0.25) # 將label進(jìn)行one-hot編碼 lb = LabelBinarizer() trainY = lb.fit_transform(trainY) testY = lb.transform(testY) # keras定義網(wǎng)絡(luò)結(jié)構(gòu)784--256--128--10 model = Sequential() model.add(Dense(256, input_shape=(784,), activation="relu")) model.add(Dense(128, activation="relu")) model.add(Dense(10, activation="softmax")) # 開(kāi)始訓(xùn)練 print("[INFO] training network...") # 0.01的學(xué)習(xí)率 sgd = SGD(0.01) # 交叉驗(yàn)證 model.compile(loss="categorical_crossentropy", optimizer=sgd, metrics=['accuracy']) H = model.fit(trainX, trainY, validation_data=(testX, testY), epochs=100, batch_size=128) # 測(cè)試模型和評(píng)估 print("[INFO] evaluating network...") predictions = model.predict(testX, batch_size=128) print(classification_report(testY.argmax(axis=1), predictions.argmax(axis=1), target_names=[str(x) for x in lb.classes_])) # 保存可視化訓(xùn)練結(jié)果 plt.style.use("ggplot") plt.figure() plt.plot(np.arange(0, 100), H.history["loss"], label="train_loss") plt.plot(np.arange(0, 100), H.history["val_loss"], label="val_loss") plt.plot(np.arange(0, 100), H.history["acc"], label="train_acc") plt.plot(np.arange(0, 100), H.history["val_acc"], label="val_acc") plt.title("Training Loss and Accuracy") plt.xlabel("# Epoch") plt.ylabel("Loss/Accuracy") plt.legend() plt.savefig(args["output"])
當(dāng)前名稱(chēng):如何實(shí)現(xiàn)keras訓(xùn)練淺層卷積網(wǎng)絡(luò)并保存和加載模型-創(chuàng)新互聯(lián)
網(wǎng)頁(yè)網(wǎng)址:http://www.dlmjj.cn/article/coojhs.html