欢迎您访问365答案网,请分享给你的朋友!
生活常识 学习资料

统一常见机器学习模型的保存与加载

时间:2023-08-23
统一常见机器学习模型的保存与加载

包括如下:
thundersvm
lightgbm(gpu)
deepforest
xgboost(gpu)
sklearn

代码:

import deepforestimport joblibimport lightgbmimport sklearnfrom sklearn.datasets import make_classificationimport thundersvmimport xgboostimport numpy as npdef model_save(estimator, save_name, save_path): # from deepforest import lgbcadeForestClassifier # from lightgbm.sklearn import LGBMClassifier # from sklearn.ensemble import RandomForestClassifier # from thundersvm import SVC # from xgboost import XGBClassifier # from sklearn.linear_model import LogisticRegression # from sklearn.naive_bayes import BernoulliNB # from sklearn.neural_network import MLPClassifier save_path = save_path + "/" + save_name if isinstance(estimator, deepforest.CascadeForestClassifier): print("save deepforest") estimator.save(save_path) elif isinstance(estimator, lightgbm.sklearn.LGBMClassifier): print("save lightgbm") joblib.dump(estimator, save_path + ".pkl") elif isinstance(estimator, sklearn.ensemble.RandomForestClassifier): print("save randomforest") joblib.dump(estimator, save_path + ".pkl") elif isinstance(estimator, thundersvm.SVC): print("save svm") estimator.save_to_file(save_path + ".pkl") elif isinstance(estimator, xgboost.XGBClassifier): print("save xgboost") estimator.save_model(save_path + ".json") elif isinstance(estimator, sklearn.linear_model.LogisticRegression): print("save logistic") joblib.dump(estimator, save_path + ".pkl") elif isinstance(estimator, sklearn.neural_network.MLPClassifier): print("save mlp") joblib.dump(estimator, save_path + ".pkl") else: raise Exception("estimator model saving error!")train_x, train_y = make_classification(100, 5)test_x = np.array([[1, 3, 5, 10, 5], [7, 52, 14, 61, 2]])lgb = lightgbm.sklearn.LGBMClassifier()lgb.fit(train_x, train_y)model_save(lgb, "lgb", "./modeltest")test_y = lgb.predict_proba(test_x)print(test_y)# lgb.load("./modeltest/lgb")load_model = joblib.load("./modeltest/lgb.pkl")test_y = load_model.predict_proba(test_x)print(test_y)print("=================================")svm = thundersvm.SVC(probability=True)svm.fit(train_x, train_y)model_save(svm, "svm", "./modeltest")test_y = svm.predict_proba(test_x)print(test_y)load_model = thundersvm.SVC(probability=True)load_model.load_from_file("./modeltest/svm.pkl")test_y = load_model.predict_proba(test_x)print(test_y)print("=================================")xgb = xgboost.XGBClassifier(use_label_encoder=False)xgb.fit(train_x, train_y)model_save(xgb, "xgb", "./modeltest")test_y = xgb.predict_proba(test_x)print(test_y)# lgb.load("./modeltest/lgb")# load_model = joblib.load("./modeltest/svm.pkl")load_model = xgboost.XGBClassifier()load_model.load_model("./modeltest/xgb.json")test_y = load_model.predict_proba(test_x)print(test_y)print("=================================")log = sklearn.linear_model.LogisticRegression()log.fit(train_x, train_y)model_save(log, "log", "./modeltest")test_y = log.predict_proba(test_x)print(test_y)load_model = joblib.load("./modeltest/log.pkl")test_y = load_model.predict_proba(test_x)print(test_y)

Copyright © 2016-2020 www.365daan.com All Rights Reserved. 365答案网 版权所有 备案号:

部分内容来自互联网,版权归原作者所有,如有冒犯请联系我们,我们将在三个工作时内妥善处理。