diff --git a/.gitignore b/.gitignore index 09c8938..231f10e 100644 --- a/.gitignore +++ b/.gitignore @@ -123,6 +123,7 @@ celerybeat.pid # SageMath parsed files *.sage.py *.npy +*.txt # Environments .env diff --git a/steer_vector.py b/steer_vector.py index 136b4fb..d3b6e1c 100644 --- a/steer_vector.py +++ b/steer_vector.py @@ -186,19 +186,11 @@ def train_model(model, optimizer, device, prompts, labels, args): logging.info(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}") losses.append(epoch_loss) - if test_auroc > best_test_auroc: - best_test_auroc = test_auroc - best_test_epoch = epoch - print(f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}") - logging.info( - f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}" - ) - # 保存最佳AUROC时的centroids到npy文件 - centroids_np = centroids.cpu().numpy() - centroids_file = os.path.join(dir_name, f"best_centroids_epoch_{epoch}.npy") - np.save(centroids_file, centroids_np) - print(f"Saved best centroids to {centroids_file}") - logging.info(f"Saved best centroids to {centroids_file}") + centroids_np = centroids.cpu().numpy() + centroids_file = os.path.join(dir_name, f"best_centroids_epoch_{epoch}.npy") + np.save(centroids_file, centroids_np) + # print(f"Saved best centroids to {centroids_file}") + logging.info(f"Saved last centroids to {centroids_file}") logging.info( f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss:.4f}, " @@ -305,7 +297,7 @@ def main(): parser.add_argument("--str_layer", type=int, default=9) # 插 SV 的层号(第几层 hidden) parser.add_argument("--component", type=str, default='res') # SV 以何种方式注入(残差等) parser.add_argument("--lam", type=float, default=5) # SV 强度 λ - parser.add_argument("--init_num_epochs", type=int, default=20) # SV训练轮数 + parser.add_argument("--init_num_epochs", type=int, default=5) # SV训练轮数 parser.add_argument("--optimizer", type=str, default='AdamW') parser.add_argument('--train_ratio', type=float, default=0.8) parser.add_argument('--val_ratio', type=float, default=0.1)