grid search 超参数寻优
http://scikit-learn.org/stable/modules/grid_search.html
1. 超参数寻优方法 gridsearchCV 和 RandomizedSearchCV
2. 参数寻优的技巧进阶
2.1. Specifying an objective metric
By default, parameter search uses the score
function of the estimator to evaluate a parameter setting. These are thesklearn.metrics.accuracy_score
for classification and sklearn.metrics.r2_score
for regression.
2.2 Specifying multiple metrics for evaluation
Multimetric scoring can either be specified as a list of strings of predefined scores names or a dict mapping the scorer name to the scorer function and/or the predefined scorer name(s).
http://scikit-learn.org/stable/modules/model_evaluation.html#multimetric-scoring
2.3 Composite estimators and parameter spaces 。pipeline 方法
http://scikit-learn.org/stable/modules/pipeline.html#pipeline
>>> from sklearn.pipeline import Pipeline
>>> from sklearn.svm import SVC
>>> from sklearn.decomposition import PCA
>>> estimators = [('reduce_dim', PCA()), ('clf', SVC())]
>>> pipe = Pipeline(estimators)
>>> pipe # check pipe
Pipeline(memory=None,
steps=[('reduce_dim', PCA(copy=True,...)),
('clf', SVC(C=1.0,...))])
>>> from sklearn.pipeline import make_pipeline
>>> from sklearn.naive_bayes import MultinomialNB
>>> from sklearn.preprocessing import Binarizer
>>> make_pipeline(Binarizer(), MultinomialNB())
Pipeline(memory=None,
steps=[('binarizer', Binarizer(copy=True, threshold=0.0)),
('multinomialnb', MultinomialNB(alpha=1.0,
class_prior=None,
fit_prior=True))])
>>> pipe.set_params(clf__C=10) # 给clf 设定参数
>>> from sklearn.model_selection import GridSearchCV
>>> param_grid = dict(reduce_dim__n_components=[2, 5, 10],
... clf__C=[0.1, 10, 100])
>>> grid_search = GridSearchCV(pipe, param_grid=param_grid)
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Tue Sep 5 10:22:07 2017
@author: xinpingbao
"""
import numpy as np
from sklearn import datasets
from sklearn.linear_model import Ridge
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import make_scorer
# load the diabetes datasets
dataset = datasets.load_diabetes()
X = dataset.data
y = dataset.target
# prepare a range of alpha values to test
alphas = np.array([1,0.1,0.01,0.001,0.0001,0])
# create and fit a ridge regression model, testing each alpha
model = Ridge()
grid = GridSearchCV(estimator=model, param_grid=dict(alpha=alphas)) # defaulting: sklearn.metrics.r2_score
# grid = GridSearchCV(estimator=model, param_grid=dict(alpha=alphas), scoring = 'metrics.mean_squared_error') # defaulting: sklearn.metrics.r2_score
grid.fit(X, y)
print(grid)
# summarize the results of the grid search
print(grid.best_score_)
print(grid.best_estimator_.alpha)
############################ 自定义error score函数 ############################
model = Ridge()
alphas = np.array([1,0.1,0.01,0.001,0.0001,0])
param_grid1 = dict(alpha=alphas)
def my_mse_error(real, pred):
w_high = 1.0
w_low = 1.0
weight = w_high * (real - pred < 0.0) + w_low * (real - pred >= 0.0)
mse = (np.sum((real - pred)**2 * weight) / float(len(real)))
return mse
def my_r2_score(y_true, y_pred):
nume = sum((y_true - y_pred) ** 2)
deno= sum((y_true - np.average(y_true, axis=0)) ** 2)
r2_score = 1 - (nume/deno)
return r2_score
error_score1 = make_scorer(my_mse_error, greater_is_better=False) # error less is better.
error_score2 = make_scorer(my_r2_score, greater_is_better=True) # error less is better.
#custom_scoring = {'weighted_MSE' : salesError}
grid_search = GridSearchCV(model, param_grid = param_grid1, scoring= error_score2, n_jobs=-1) #neg_mean_absolute_error
grid_result = grid_search.fit(X,y)
# summarize results
print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_)) # learning_rate = 0.1
最新文章
- 在Android中自定义捕获Application全局异常,可以替换掉系统的强制退出对话框(很有参考价值与实用价值)
- 解决编译错误: 非法字符: &#39;\ufeff&#39; 解决方案|错误: 需要class, interface或enum
- java 网络编程(二)----UDP基础级的示例
- Swift类与结构体
- POJ1850——Code(组合数学)
- D题 - A+B for Input-Output Practice (III)
- 学习CSS一些事(上)
- xml to json
- JQ简单图片轮播
- Algorithm -->; 矩阵链乘法
- java学习笔记09-类与对象
- Xml的转义字符--约束-xml解析器
- VC 为程序创建快捷方式的详细讲解
- mogoDB工具选择及连接<;一>;
- ECCV 2016 paper list
- 笔记:载入viewcontroller的几种方式
- vue-cli 3.0 实现A-Z字母滑动选择城市列表
- shell基础--cat命令的使用
- JDBC NOTE
- Qml应用程序的性能考虑与建议
热门文章
- hibernate 多对多(many-to-many)
- Weex入门篇——Mac 安装Weex
- django配置静态文件
- 一分钟理解js闭包
- mysql之 explain、optimizer_trace 执行计划
- Windows下通过Composer安装Yii2 [ 2.0 版本 ]
- Note: log switch off, only log_main and log_events will have logs!
- Android中内容观察者的使用---- ContentObserver类详解 (转)
- @SessionAttributes和@ModelAttribute
- 三种实现Ajax的方式