This commit is contained in:
parent
8193801023
commit
c2449664bf
|
|
@ -123,6 +123,7 @@ celerybeat.pid
|
||||||
# SageMath parsed files
|
# SageMath parsed files
|
||||||
*.sage.py
|
*.sage.py
|
||||||
*.npy
|
*.npy
|
||||||
|
*.txt
|
||||||
|
|
||||||
# Environments
|
# Environments
|
||||||
.env
|
.env
|
||||||
|
|
|
||||||
|
|
@ -186,19 +186,11 @@ def train_model(model, optimizer, device, prompts, labels, args):
|
||||||
logging.info(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
|
logging.info(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
|
||||||
losses.append(epoch_loss)
|
losses.append(epoch_loss)
|
||||||
|
|
||||||
if test_auroc > best_test_auroc:
|
centroids_np = centroids.cpu().numpy()
|
||||||
best_test_auroc = test_auroc
|
centroids_file = os.path.join(dir_name, f"best_centroids_epoch_{epoch}.npy")
|
||||||
best_test_epoch = epoch
|
np.save(centroids_file, centroids_np)
|
||||||
print(f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}")
|
# print(f"Saved best centroids to {centroids_file}")
|
||||||
logging.info(
|
logging.info(f"Saved last centroids to {centroids_file}")
|
||||||
f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}"
|
|
||||||
)
|
|
||||||
# 保存最佳AUROC时的centroids到npy文件
|
|
||||||
centroids_np = centroids.cpu().numpy()
|
|
||||||
centroids_file = os.path.join(dir_name, f"best_centroids_epoch_{epoch}.npy")
|
|
||||||
np.save(centroids_file, centroids_np)
|
|
||||||
print(f"Saved best centroids to {centroids_file}")
|
|
||||||
logging.info(f"Saved best centroids to {centroids_file}")
|
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss:.4f}, "
|
f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss:.4f}, "
|
||||||
|
|
@ -305,7 +297,7 @@ def main():
|
||||||
parser.add_argument("--str_layer", type=int, default=9) # 插 SV 的层号(第几层 hidden)
|
parser.add_argument("--str_layer", type=int, default=9) # 插 SV 的层号(第几层 hidden)
|
||||||
parser.add_argument("--component", type=str, default='res') # SV 以何种方式注入(残差等)
|
parser.add_argument("--component", type=str, default='res') # SV 以何种方式注入(残差等)
|
||||||
parser.add_argument("--lam", type=float, default=5) # SV 强度 λ
|
parser.add_argument("--lam", type=float, default=5) # SV 强度 λ
|
||||||
parser.add_argument("--init_num_epochs", type=int, default=20) # SV训练轮数
|
parser.add_argument("--init_num_epochs", type=int, default=5) # SV训练轮数
|
||||||
parser.add_argument("--optimizer", type=str, default='AdamW')
|
parser.add_argument("--optimizer", type=str, default='AdamW')
|
||||||
parser.add_argument('--train_ratio', type=float, default=0.8)
|
parser.add_argument('--train_ratio', type=float, default=0.8)
|
||||||
parser.add_argument('--val_ratio', type=float, default=0.1)
|
parser.add_argument('--val_ratio', type=float, default=0.1)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue