import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.datasets import mnist
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 加载 MNIST 数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 选择一个样本,比如第 0 张
idx = 0
img = x_train[idx]
label = y_train[idx]
fig, ax = plt.subplots(figsize=(8, 8), dpi=200)
# 显示图像(关闭插值)
ax.imshow(img, cmap="gray_r", interpolation="nearest")
# --- 像素网格(28x28) ---
ax.set_xticks(np.arange(-0.5, 28, 1), minor=True)
ax.set_yticks(np.arange(-0.5, 28, 1), minor=True)
ax.grid(which="minor", color="gray", linestyle="-", linewidth=0.3)
# --- 大网格(10x10) ---
ax.set_xticks(np.arange(-0.5, 28, 10), minor=False)
ax.set_yticks(np.arange(-0.5, 28, 10), minor=False)
ax.grid(which="major", color="red", linestyle="--", linewidth=1.0)
# 关闭坐标轴刻度
ax.tick_params(which="both", bottom=False, left=False,
labelbottom=False, labelleft=False)
ax.set_title(f"数字: {label}", fontsize=16)
# 在每个格子中间标注像素值(0-255)
for (j, i), val in np.ndenumerate(img):
ax.text(i, j, str(int(val)),
ha="center", va="center",
color="black" if val < 128 else "white", # 深色背景用白字
fontsize=6)
plt.show()
文章字数:1040
阅读时间: 3 分钟
等 人表示很赞
评论