TensorFlow.js是一个用于在浏览器和Node.js中进行机器学习的JavaScript库。本文将详细讲解如何使用TensorFlow.js进行鸢尾花种类的预测,并提供两个示例说明。
示例1:使用预训练模型进行鸢尾花种类预测
以下是使用预训练模型进行鸢尾花种类预测的示例代码:
<!DOCTYPE html>
<html>
<head>
<title>TensorFlow.js鸢尾花种类预测</title>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.8.0"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/iris@3.0.0"></script>
</head>
<body>
<h1>TensorFlow.js鸢尾花种类预测</h1>
<div>
<label>Sepal Length:</label>
<input type="number" id="sepal-length" step="0.1" min="0" max="10">
</div>
<div>
<label>Sepal Width:</label>
<input type="number" id="sepal-width" step="0.1" min="0" max="10">
</div>
<div>
<label>Petal Length:</label>
<input type="number" id="petal-length" step="0.1" min="0" max="10">
</div>
<div>
<label>Petal Width:</label>
<input type="number" id="petal-width" step="0.1" min="0" max="10">
</div>
<button onclick="predict()">Predict</button>
<div id="result"></div>
<script>
async function predict() {
const sepalLength = document.getElementById('sepal-length').value;
const sepalWidth = document.getElementById('sepal-width').value;
const petalLength = document.getElementById('petal-length').value;
const petalWidth = document.getElementById('petal-width').value;
const model = await tf.loadLayersModel('https://storage.googleapis.com/tfjs-models/tfjs/iris_v1/model.json');
const input = tf.tensor2d([[sepalLength, sepalWidth, petalLength, petalWidth]]);
const prediction = model.predict(input);
const output = prediction.dataSync();
const species = ['Setosa', 'Versicolor', 'Virginica'][output.indexOf(Math.max(...output))];
document.getElementById('result').innerHTML = `Predicted Species: ${species}`;
}
</script>
</body>
</html>
在这个示例中,我们首先引入了TensorFlow.js和预训练模型iris_v1。然后,我们定义了一个简单的HTML页面,包含四个输入框和一个预测按钮。接着,我们使用tf.loadLayersModel()
方法加载预训练模型。最后,我们使用model.predict()
方法进行预测,并根据预测结果输出鸢尾花的种类。
示例2:使用自定义模型进行鸢尾花种类预测
以下是使用自定义模型进行鸢尾花种类预测的示例代码:
<!DOCTYPE html>
<html>
<head>
<title>TensorFlow.js鸢尾花种类预测</title>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.8.0"></script>
</head>
<body>
<h1>TensorFlow.js鸢尾花种类预测</h1>
<div>
<label>Sepal Length:</label>
<input type="number" id="sepal-length" step="0.1" min="0" max="10">
</div>
<div>
<label>Sepal Width:</label>
<input type="number" id="sepal-width" step="0.1" min="0" max="10">
</div>
<div>
<label>Petal Length:</label>
<input type="number" id="petal-length" step="0.1" min="0" max="10">
</div>
<div>
<label>Petal Width:</label>
<input type="number" id="petal-width" step="0.1" min="0" max="10">
</div>
<button onclick="predict()">Predict</button>
<div id="result"></div>
<script>
async function predict() {
const sepalLength = document.getElementById('sepal-length').value;
const sepalWidth = document.getElementById('sepal-width').value;
const petalLength = document.getElementById('petal-length').value;
const petalWidth = document.getElementById('petal-width').value;
const model = await tf.loadLayersModel('model.json');
const input = tf.tensor2d([[sepalLength, sepalWidth, petalLength, petalWidth]]);
const prediction = model.predict(input);
const output = prediction.dataSync();
const species = ['Setosa', 'Versicolor', 'Virginica'][output.indexOf(Math.max(...output))];
document.getElementById('result').innerHTML = `Predicted Species: ${species}`;
}
</script>
</body>
</html>
在这个示例中,我们首先定义了一个简单的HTML页面,包含四个输入框和一个预测按钮。接着,我们使用tf.loadLayersModel()
方法加载自定义模型。最后,我们使用model.predict()
方法进行预测,并根据预测结果输出鸢尾花的种类。
结语
以上是使用TensorFlow.js进行鸢尾花种类预测的完整攻略,包含了使用预训练模型和使用自定义模型的示例说明。在实际应用中,我们可以根据具体情况选择适合的方法来进行机器学习预测。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:TensorFlow.js机器学习预测鸢尾花种类 - Python技术站