import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.datasets import mnist

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

# 加载 MNIST 数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 选择一个样本,比如第 0 张
idx = 0
img = x_train[idx]
label = y_train[idx]

fig, ax = plt.subplots(figsize=(8, 8), dpi=200)

# 显示图像(关闭插值)
ax.imshow(img, cmap="gray_r", interpolation="nearest")

# --- 像素网格(28x28) ---
ax.set_xticks(np.arange(-0.5, 28, 1), minor=True)
ax.set_yticks(np.arange(-0.5, 28, 1), minor=True)
ax.grid(which="minor", color="gray", linestyle="-", linewidth=0.3)

# --- 大网格(10x10) ---
ax.set_xticks(np.arange(-0.5, 28, 10), minor=False)
ax.set_yticks(np.arange(-0.5, 28, 10), minor=False)
ax.grid(which="major", color="red", linestyle="--", linewidth=1.0)

# 关闭坐标轴刻度
ax.tick_params(which="both", bottom=False, left=False,
               labelbottom=False, labelleft=False)

ax.set_title(f"数字: {label}", fontsize=16)

# 在每个格子中间标注像素值(0-255)
for (j, i), val in np.ndenumerate(img):
    ax.text(i, j, str(int(val)),
            ha="center", va="center",
            color="black" if val < 128 else "white",  # 深色背景用白字
            fontsize=6)

plt.show()