import tensorflow as tf
import numpy as np
from sklearn import metrics
from sklearn.datasets import load_svmlight_file
from sklearn.utils import shuffle # Define the placeholder
x = tf.placeholder("float", [None, 12568])
y_ = tf.placeholder("float", [None, 1]) # Define the variable of the model
W = tf.Variable(tf.random_uniform([1, 12568], -1.0, 1.0))
b = tf.Variable(tf.zeros([1]))
y = tf.sigmoid(tf.matmul(x, tf.transpose(W)) + b)
y_pred =tf.sigmoid(tf.matmul(x, tf.transpose(W)) + b)
# clipping y to avoid log(y) become infinite
y = tf.clip_by_value(y, 1e-10, 1-1e-10) # Minimize the negative log likelihood.
loss = (-tf.matmul(tf.transpose(y_), tf.log(y)) - tf.matmul(tf.transpose(1-y_), tf.log(1-y)))
optimizer = tf.train.FtrlOptimizer(0.03, l1_regularization_strength=0.01, l2_regularization_strength=0.01)
train = optimizer.minimize(loss)
auc = tf.metrics.auc(labels=y_,predictions=y)
# Before starting, initialize the variables. We will 'run' this first.
init = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer()) # # Launch the graph.
sess = tf.Session()
sess.run(init) x_train, y_train = load_svmlight_file("./train_data_process")
x_train_new, y_train_new = shuffle(x_train, y_train) for sample_index in range(x_train_new.shape[0]):
sess.run(train, {x:x_train_new[sample_index].toarray(), y_:np.array([y_train_new[sample_index]]).reshape([1,1])})
train_W = sess.run(W)
train_b = sess.run(b)
if sample_index % 200 == 0:
size = 1000
if sample_index+1000 < x_train_new.shape[0]:
print(sample_index,sess.run(loss / size, {x:x_train_new[sample_index:sample_index+1000].toarray(), y_:np.array([y_train_new[sample_index:sample_index+1000]]).reshape([1000,1])})) #End print the model and the training accuracy
print('W:', train_W)
print('b:', train_b) # saver = tf.train.Saver()
# ckpt = tf.train.get_checkpoint_state("./model")
# if ckpt and ckpt.model_checkpoint_path:
# print("Success to load %s." % ckpt.model_checkpoint_path)
# saver.restore(sess, ckpt.model_checkpoint_path)
#
x_data,y_data = load_svmlight_file("./test_data_process")
#
# train_W = sess.run(W)
# train_b = sess.run(b)
# print('W:', train_W)
# print('b:', train_b) y_pre = sess.run(y_pred,feed_dict={x:x_data.toarray(),y_:np.array(y_data).reshape([-1,1])})
auc = metrics.roc_auc_score(y_data.reshape([-1,1]), y_pre)
print(auc) # # #predict_accuracy(train_y, y_data)
使用的是公司的模型训练数据,抽取了 一部分,测试的AUC是0.91

最新文章

  1. 如何利用 JConsole观察分析Java程序的运行,进行排错调优
  2. 【opencv】轮廓相关
  3. java的I/O操作:文件的路径
  4. Android提高篇之自定义dialog实现processDialog“正在加载”效果、使用Animation实现图片旋转
  5. 纯CSS3实现的图片滑块程序,效果非常酷
  6. 【BZOJ】1053: [HAOI2007]反素数ant
  7. Oracle分页查询SQL实现
  8. FLEX 网格布局及响应式处理
  9. Android运用自己的标题栏
  10. 利用GeneratedKeyHolder获得新增数据主键值
  11. Netty 5.0源码分析之综述
  12. 时间处理之strtotime
  13. Linux显示一行显示列总计
  14. 证书,CSP与Openssl
  15. 使用chcache 缓存
  16. 没有内涵段子可以刷了,利用Python爬取段友之家贴吧图片和小视频(含源码)
  17. centos重启报错Umounting file systems:umount:/opt:device is busy
  18. [Android] 状态栏的一些认识
  19. ASP.net在网页上显示当前时间,利用AJAX不刷新网页
  20. 用ajax实现用户名的检测(JavaScript方法)

热门文章

  1. PAT 1144 The Missing Number[简单]
  2. 【基础算法】- 个人认为最快的 Fibonacci 程序
  3. day6-面向对象
  4. 网页采集利器 phpQuery
  5. web前端攻城狮整理的收藏夹
  6. 0728am thinkphp介绍
  7. hdp (ambari) 集成hue
  8. Web服务器端程序的实现
  9. PHP联接MySQL
  10. 【英语学习】How do I stop overthinking at night?