This commit is contained in:
parent
8193801023
commit
c2449664bf
|
|
@ -123,6 +123,7 @@ celerybeat.pid
|
|||
# SageMath parsed files
|
||||
*.sage.py
|
||||
*.npy
|
||||
*.txt
|
||||
|
||||
# Environments
|
||||
.env
|
||||
|
|
|
|||
|
|
@ -186,19 +186,11 @@ def train_model(model, optimizer, device, prompts, labels, args):
|
|||
logging.info(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
|
||||
losses.append(epoch_loss)
|
||||
|
||||
if test_auroc > best_test_auroc:
|
||||
best_test_auroc = test_auroc
|
||||
best_test_epoch = epoch
|
||||
print(f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}")
|
||||
logging.info(
|
||||
f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}"
|
||||
)
|
||||
# 保存最佳AUROC时的centroids到npy文件
|
||||
centroids_np = centroids.cpu().numpy()
|
||||
centroids_file = os.path.join(dir_name, f"best_centroids_epoch_{epoch}.npy")
|
||||
np.save(centroids_file, centroids_np)
|
||||
print(f"Saved best centroids to {centroids_file}")
|
||||
logging.info(f"Saved best centroids to {centroids_file}")
|
||||
# print(f"Saved best centroids to {centroids_file}")
|
||||
logging.info(f"Saved last centroids to {centroids_file}")
|
||||
|
||||
logging.info(
|
||||
f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss:.4f}, "
|
||||
|
|
@ -305,7 +297,7 @@ def main():
|
|||
parser.add_argument("--str_layer", type=int, default=9) # 插 SV 的层号(第几层 hidden)
|
||||
parser.add_argument("--component", type=str, default='res') # SV 以何种方式注入(残差等)
|
||||
parser.add_argument("--lam", type=float, default=5) # SV 强度 λ
|
||||
parser.add_argument("--init_num_epochs", type=int, default=20) # SV训练轮数
|
||||
parser.add_argument("--init_num_epochs", type=int, default=5) # SV训练轮数
|
||||
parser.add_argument("--optimizer", type=str, default='AdamW')
|
||||
parser.add_argument('--train_ratio', type=float, default=0.8)
|
||||
parser.add_argument('--val_ratio', type=float, default=0.1)
|
||||
|
|
|
|||
Loading…
Reference in New Issue