Keras
Keras 是一个用 Python 编写的高级神经网络 API,最初由 François Chollet 创建,并于 2017 年合并到 TensorFlow 中,但依然可以作为一个独立的框架使用。它是一个开源的深度学习框架,运行在 TensorFlow、Theano 或 Microsoft Cognitive Toolkit (CNTK) 等深度学习后端之上。
你可以使用Keras快速进行模型训练,同时使用SwanLab进行实验跟踪与可视化。
1. 引入SwanLabLogger
python
from swanlab.integration.keras import SwanLabLogger
2. 与model.fit配合
首先初始化SwanLab:
python
swanlab.init( project="keras_mnist", experiment_name="mnist_example", description="Keras MNIST Example" )
然后,在model.fit
的callbacks
参数中添加SwanLabLogger
,即可完成集成:
python
model.fit(..., callbacks=[SwanLabLogger()])
3. 案例-MNIST
python
from swanlab.integration.keras import SwanLabLogger import tensorflow as tf import swanlab # Initialize SwanLab swanlab.init( project="keras_mnist", experiment_name="mnist_example", description="Keras MNIST Example" ) # Load and preprocess MNIST data (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.0 x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.0 # Build a simple CNN model model = tf.keras.Sequential([ tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Conv2D(64, 3, activation='relu'), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Flatten(), tf.keras.layers.Dense(64, activation='relu'), tf.keras.layers.Dense(10, activation='softmax') ]) # Compile the model model.compile( optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'] ) # Train the model with SwanLabLogger model.fit( x_train, y_train, epochs=5, validation_data=(x_test, y_test), callbacks=[SwanLabLogger()] )
效果演示: