打卡day35

news/2025/6/14 11:43:09

一、模型结构可视化

理解一个深度学习网络最重要的2点:

  1. 了解损失如何定义的,知道损失从何而来----把抽象的任务通过损失函数量化出来
  2. 了解参数总量,即知道每一层的设计才能退出—层设计决定参数总量

为了了解参数总量,我们需要知道层设计,以及每一层参数的数量。下面介绍1几个层可视化工具:

1.1 nn.model自带的方法

#  nn.Module 的内置功能,直接输出模型结构
print(model)

这是最基础、最简单的方法,直接打印模型对象,它会输出模型的结构,显示模型中各个层的名称和参数信息

# nn.Module 的内置功能,返回模型的可训练参数迭代器
for name, param in model.named_parameters():print(f"Parameter name: {name}, Shape: {param.shape}")

可以将模型中带有weight的参数(即权重)提取出来,并转为 numpy 数组形式,对其计算统计分布,并且绘制可视化图表

# 提取权重数据
import numpy as np
weight_data = {}
for name, param in model.named_parameters():if 'weight' in name:weight_data[name] = param.detach().cpu().numpy()# 可视化权重分布
fig, axes = plt.subplots(1, len(weight_data), figsize=(15, 5))
fig.suptitle('Weight Distribution of Layers')for i, (name, weights) in enumerate(weight_data.items()):# 展平权重张量为一维数组weights_flat = weights.flatten()# 绘制直方图axes[i].hist(weights_flat, bins=50, alpha=0.7)axes[i].set_title(name)axes[i].set_xlabel('Weight Value')axes[i].set_ylabel('Frequency')axes[i].grid(True, linestyle='--', alpha=0.7)plt.tight_layout()
plt.subplots_adjust(top=0.85)
plt.show()# 计算并打印每层权重的统计信息
print("\n=== 权重统计信息 ===")
for name, weights in weight_data.items():mean = np.mean(weights)std = np.std(weights)min_val = np.min(weights)max_val = np.max(weights)print(f"{name}:")print(f"  均值: {mean:.6f}")print(f"  标准差: {std:.6f}")print(f"  最小值: {min_val:.6f}")print(f"  最大值: {max_val:.6f}")print("-" * 30)

1.2 torchsummary库的summary方法

# pip install torchsummary -i https://pypi.tuna.tsinghua.edu.cn/simple
from torchsummary import summary
# 打印模型摘要,可以放置在模型定义后面
summary(model, input_size=(4,))

1.3 torchinfo库的summary方法

torchinfo 是提供比 torchsummary 更详细的模型摘要信息,包括每层的输入输出形状、参数数量、计算量等。

# pip install torchinfo -i https://pypi.tuna.tsinghua.edu.cn/simple
from torchinfo import summary
summary(model, input_size=(4, ))

二、 进度条功能

我们介绍下tqdm这个库,他非常适合用在循环中观察进度。尤其在深度学习这种训练是循环的场景中。他最核心的逻辑如下

  1. 创建一个进度条对象,并传入总迭代次数。一般用with语句创建对象,这样对象会在with语句结束后自动销毁,保证资源释放。with是常见的上下文管理器,这样的使用方式还有用with打开文件,结束后会自动关闭文件。

  2. 更新进度条,通过pbar.update(n)指定每次前进的步数n(适用于非固定步长的循环)。

2.1 手动更新

from tqdm import tqdm  # 先导入tqdm库
import time  # 用于模拟耗时操作# 创建一个总步数为10的进度条
with tqdm(total=10) as pbar:  # pbar是进度条对象的变量名# pbar 是 progress bar(进度条)的缩写,约定俗成的命名习惯。for i in range(10):  # 循环10次(对应进度条的10步)time.sleep(0.5)  # 模拟每次循环耗时0.5秒pbar.update(1)  # 每次循环后,进度条前进1步
from tqdm import tqdm
import time# 创建进度条时添加描述(desc)和单位(unit)
with tqdm(total=5, desc="下载文件", unit="个") as pbar:# 进度条这个对象,可以设置描述和单位# desc是描述,在左侧显示# unit是单位,在进度条右侧显示for i in range(5):time.sleep(1)pbar.update(1)  # 每次循环进度+1

unit 参数的核心作用是明确进度条中每个进度单位的含义,使可视化信息更具可读性。在深度学习训练中,常用的单位包括:

  • epoch:训练轮次(遍历整个数据集一次)。
  • batch:批次(每次梯度更新处理的样本组)。
  • sample:样本(单个数据点)

2.2 自动更新

from tqdm import tqdm
import time# 直接将range(3)传给tqdm,自动生成进度条
# 这个写法我觉得是有点神奇的,直接可以给这个对象内部传入一个可迭代对象,然后自动生成进度条
for i in tqdm(range(3), desc="处理任务", unit="epoch"):time.sleep(1)

三、 模型的推理

# 在测试集上评估模型,此时model内部已经是训练好的参数了
# 评估模型
model.eval() # 设置模型为评估模式
with torch.no_grad(): # torch.no_grad()的作用是禁用梯度计算,可以提高模型推理速度outputs = model(X_test)  # 对测试数据进行前向传播,获得预测结果_, predicted = torch.max(outputs, 1) # torch.max(outputs, 1)返回每行的最大值和对应的索引#这个函数返回2个值,分别是最大值和对应索引,参数1是在第1维度(行)上找最大值,_ 是Python的约定,表示忽略这个返回值,所以这个写法是找到每一行最大值的下标# 此时outputs是一个tensor,p每一行是一个样本,每一行有3个值,分别是属于3个类别的概率,取最大值的下标就是预测的类别# predicted == y_test判断预测值和真实值是否相等,返回一个tensor,1表示相等,0表示不等,然后求和,再除以y_test.size(0)得到准确率# 因为这个时候数据是tensor,所以需要用item()方法将tensor转化为Python的标量# 之所以不用sklearn的accuracy_score函数,是因为这个函数是在CPU上运行的,需要将数据转移到CPU上,这样会慢一些# size(0)获取第0维的长度,即样本数量correct = (predicted == y_test).sum().item() # 计算预测正确的样本数accuracy = correct / y_test.size(0)print(f'测试集准确率: {accuracy * 100:.2f}%')
文章来源:https://blog.csdn.net/2301_81663101/article/details/148198194
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:https://dhexx.cn/news/show-5526273.html

相关文章

2.2.1 05年T2

引言 本文将从一预习、二自习、三学习、四复习等四个阶段来分析2005年考研英语阅读第二篇文章。为了便于后续阅读,我将第四部分复习放在了首位。 四、复习 方法:错误思路分析总结考点文章梳理 4.1 错题分析 题目:26(细节题&…

day11制作窗口(鼠标显示、图层和图层控制器、显示窗口、高速计数器、消除闪烁)

day11制作窗口 鼠标显示问题 if (mx > binfo->scrnx - 16) { mx binfo->scrnx - 16; } if (my > binfo->scrny - 16) { my binfo->scrny - 16; }因为这段代码的问题,会完整显示整个鼠标,导致鼠标所指的像素无法覆盖整个屏幕像素 i…

静态分配动态绑定

静态分配看编译时类型 比如你用了多态A anew B() A就是编译时类型,会先在A里分配方法,分配原则遵循找最匹配,其次是兼容 然后看运行时类型,如果重写了就执行重写的,没重写就执行静态分配的方…

零基础入门Selenium自动化测试:自动登录edu邮箱

🌟 Selenium简单概述一下 Selenium 是一个开源的自动化测试工具,主要用于 Web 应用程序的功能测试。它能够模拟用户操作浏览器的行为(如点击按钮、填写表单、导航页面等),应用于前端开发、测试和运维领域。 特点 跨…

解决 Supabase “permission denied for table XXX“ 错误

解决 Supabase “permission denied for table” 错误 问题描述 在使用 Supabase 开发应用时,你可能会遇到以下错误: [Nest] ERROR [ExceptionsHandler] Object(4) {code: 42501,details: null,hint: null,message: permission denied for table user…

【软考向】Chapter 2 程序设计语言基础知识

程序设计语言概述低级语言 —— 机器指令、汇编语言高级语言 ——翻译:汇编、解释和编译语言处理程序基础 —— 翻译给计算机,汇编、编译、解释三类编译程序基本原理 —— 词法分析、语法分析、语义分析、中间代码生成、代码优化、目标代码生成文法和语言的形式描述确定的有限…

[创业之路-377]:企业战略管理案例分析-战略制定/设计-市场洞察“五看”:看宏观之社会发展趋势:数字化、智能化、个性化的趋势对初创公司的战略机会

数字化、智能化、个性化趋势为初创公司带来了捕捉长尾需求、提升运营效率、创新商业模式等战略机会,具体分析如下: 一、数字化趋势带来的战略机会 捕捉长尾需求:数字化技术能够帮助初创公司更好地捕捉市场中的长尾需求,满足那些…

tvalid寄存器的理解

if(!out_axis_tvalid_reg || m_axis_tready ) beginend m_axis_tready 是上拍下一级给的ready信号 out_axis_tvalid_reg是上一拍,本级给下级的valid信号 一共有四种组合,然后可以通过这个if语句,在接下来的begin ... end中,用来…

【一. Java基础:注释、变量与数据类型详解】

1. Java 基础概念 1.1 注释 注释:对代码的解释和说明文字 java的三种注释: 单行注释:两个斜杠 // 后面跟着你的注释内容 //哈哈多行注释:以 /* 开头,以 */ 结尾,中间可以写很多行 /*哈哈哈哈哈哈…

python打卡day35@浙大疏锦行

知识点回顾: 三种不同的模型可视化方法:推荐torchinfo打印summary权重分布可视化进度条功能:手动和自动写法,让打印结果更加美观推理的写法:评估模式 作业:调整模型定义时的超参数,对比下效果。…