TensorFlow样例一
2024-09-07 00:27:54
假设原函数为 f(x) = 5x^2 + 3,为了估计出这个函数,定义参数未知的函数g(x, w) = w0 x^2 + w1 x + w2,现要找出适合的w使g(x, w) ≈ f(x)。将这个问题转化为求解参数w使得损失函数L(w) = ∑ (f(x) - g(x, w))^2最小,求解过程使用了随机梯度下降(Stochastic Gradient Descent)。求解问题的代码如下:
import numpy as np import tensorflow as tf # Placeholders are used to feed values from python to TensorFlow ops. We define # two placeholders, one for input feature x, and one for output y. x = tf.placeholder(tf.float32) y = tf.placeholder(tf.float32) # Assuming we know that the desired function is a polynomial of 2nd degree, we # allocate a vector of size 3 to hold the coefficients. The variable will be # automatically initialized with random noise. w = tf.get_variable("w", shape=[3, 1]) # We define yhat to be our estimate of y. f = tf.stack([tf.square(x), x, tf.ones_like(x)], 1) yhat = tf.squeeze(tf.matmul(f, w), 1) # The loss is defined to be the l2 distance between our estimate of y and its # true value. We also added a shrinkage term, to ensure the resulting weights # would be small. loss = tf.nn.l2_loss(yhat - y) + 0.1 * tf.nn.l2_loss(w) # We use the Adam optimizer with learning rate set to 0.1 to minimize the loss. train_op = tf.train.AdamOptimizer(0.1).minimize(loss) def generate_data(): x_val = np.random.uniform(-10.0, 10.0, size=100) y_val = 5 * np.square(x_val) + 3 return x_val, y_val sess = tf.Session() # Since we are using variables we first need to initialize them. sess.run(tf.global_variables_initializer()) for _ in range(1000): x_val, y_val = generate_data() _, loss_val = sess.run([train_op, loss], {x: x_val, y: y_val}) print(loss_val) print(sess.run([w]))
求解过程如下:
4380421.0 3147655.5 4625718.5 3493661.0 3061016.0 3057624.5 3104206.2 …… 103.7392 98.461266 113.29772 104.56809 89.75495 …… 17.354445 17.66056 17.716873 18.782757 16.015532 [array([[4.9863739e+00], [6.9120852e-04], [3.8031762e+00]], dtype=float32)]
最新文章
- SpringMvc相关配置的作用
- 1、linux网络服务实验 用PuTTY连接Linux
- struts2学习记录
- 更便捷的Android多渠道打包方式
- webpack 教程资源收集
- Android dex分包方案
- Python中__init__方法/__name__系统变量讲解
- 【ZZ】MySql语句大全:创建、授权、查询、修改等
- Gazebo Ros入门
- e2e 自动化集成测试 架构 实例 WebStorm Node.js Mocha WebDriverIO Selenium Step by step (三) SqlServer数据库的访问
- mapreduce实现全局排序
- What is Windows Clustering
- warning: shared library text segment is not shareable
- myeclipse2015复制项目需要修改的地方
- Cesium基础使用介绍
- 2>;&;1的意思
- kubernetes集群搭建(3):master节点安装
- Codeforces 757 C Felicity is Coming!
- An error occurred (500 Error)
- 详解UILabel的adjustsFontSizeToFitWidth值