This commit is contained in:
weixin_43297441 2025-12-03 16:52:04 +08:00
parent 8193801023
commit c2449664bf
2 changed files with 7 additions and 14 deletions

1
.gitignore vendored
View File

@ -123,6 +123,7 @@ celerybeat.pid
# SageMath parsed files
*.sage.py
*.npy
*.txt
# Environments
.env

View File

@ -186,19 +186,11 @@ def train_model(model, optimizer, device, prompts, labels, args):
logging.info(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
losses.append(epoch_loss)
if test_auroc > best_test_auroc:
best_test_auroc = test_auroc
best_test_epoch = epoch
print(f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}")
logging.info(
f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}"
)
# 保存最佳AUROC时的centroids到npy文件
centroids_np = centroids.cpu().numpy()
centroids_file = os.path.join(dir_name, f"best_centroids_epoch_{epoch}.npy")
np.save(centroids_file, centroids_np)
print(f"Saved best centroids to {centroids_file}")
logging.info(f"Saved best centroids to {centroids_file}")
centroids_np = centroids.cpu().numpy()
centroids_file = os.path.join(dir_name, f"best_centroids_epoch_{epoch}.npy")
np.save(centroids_file, centroids_np)
# print(f"Saved best centroids to {centroids_file}")
logging.info(f"Saved last centroids to {centroids_file}")
logging.info(
f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss:.4f}, "
@ -305,7 +297,7 @@ def main():
parser.add_argument("--str_layer", type=int, default=9) # 插 SV 的层号(第几层 hidden
parser.add_argument("--component", type=str, default='res') # SV 以何种方式注入(残差等)
parser.add_argument("--lam", type=float, default=5) # SV 强度 λ
parser.add_argument("--init_num_epochs", type=int, default=20) # SV训练轮数
parser.add_argument("--init_num_epochs", type=int, default=5) # SV训练轮数
parser.add_argument("--optimizer", type=str, default='AdamW')
parser.add_argument('--train_ratio', type=float, default=0.8)
parser.add_argument('--val_ratio', type=float, default=0.1)