| | import matplotlib.pyplot as plt |
| | import seaborn as sns |
| | import pandas as pd |
| | import json |
| | import os |
| | from datetime import datetime |
| |
|
| | |
| | def set_chinese_font(): |
| | plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei', 'PingFang SC', 'Heiti TC'] |
| | plt.rcParams['axes.unicode_minus'] = False |
| |
|
| | def plot_data_distribution(dataset_dict, save_path=None): |
| | """ |
| | 绘制数据集中 Positive/Neutral/Negative 的分布饼图 |
| | """ |
| | set_chinese_font() |
| | |
| | |
| | |
| | if hasattr(dataset_dict, 'keys') and 'train' in dataset_dict.keys(): |
| | ds = dataset_dict['train'] |
| | else: |
| | ds = dataset_dict |
| | |
| | |
| | if 'label' in ds.features: |
| | train_labels = ds['label'] |
| | elif 'labels' in ds.features: |
| | train_labels = ds['labels'] |
| | else: |
| | |
| | train_labels = [x.get('label', x.get('labels')) for x in ds] |
| | |
| | |
| | id2label = {0: 'Negative (消极)', 1: 'Neutral (中性)', 2: 'Positive (积极)'} |
| | labels_str = [id2label.get(x, str(x)) for x in train_labels] |
| | |
| | df = pd.DataFrame({'Label': labels_str}) |
| | counts = df['Label'].value_counts() |
| | |
| | plt.figure(figsize=(10, 6)) |
| | plt.pie(counts, labels=counts.index, autopct='%1.1f%%', startangle=140, colors=sns.color_palette("pastel")) |
| | plt.title('训练集情感分布') |
| | plt.tight_layout() |
| | |
| | if save_path: |
| | print(f"Saving distribution plot to {save_path}...") |
| | plt.savefig(save_path) |
| | |
| |
|
| | def plot_training_history(log_history, save_path=None): |
| | """ |
| | 根据 Trainer 的 log_history 绘制 Loss 和 Accuracy 曲线 |
| | """ |
| | set_chinese_font() |
| | |
| | if not log_history: |
| | print("没有可用的训练日志。") |
| | return |
| | |
| | df = pd.DataFrame(log_history) |
| | |
| | |
| | train_loss = df[df['loss'].notna()] |
| | eval_acc = df[df['eval_accuracy'].notna()] |
| | |
| | plt.figure(figsize=(14, 5)) |
| | |
| | |
| | plt.subplot(1, 2, 1) |
| | plt.plot(train_loss['epoch'], train_loss['loss'], label='Training Loss', color='salmon') |
| | if 'eval_loss' in df.columns: |
| | eval_loss = df[df['eval_loss'].notna()] |
| | plt.plot(eval_loss['epoch'], eval_loss['eval_loss'], label='Validation Loss', color='skyblue') |
| | plt.title('训练损失 (Loss) 曲线') |
| | plt.xlabel('Epoch') |
| | plt.ylabel('Loss') |
| | plt.legend() |
| | plt.grid(True, alpha=0.3) |
| | |
| | |
| | if not eval_acc.empty: |
| | plt.subplot(1, 2, 2) |
| | plt.plot(eval_acc['epoch'], eval_acc['eval_accuracy'], label='Validation Accuracy', color='lightgreen', marker='o') |
| | plt.title('验证集准确率 (Accuracy)') |
| | plt.xlabel('Epoch') |
| | plt.ylabel('Accuracy') |
| | plt.legend() |
| | plt.grid(True, alpha=0.3) |
| | |
| | |
| | save_dir = os.path.join(Config.RESULTS_DIR, "images") |
| | if not os.path.exists(save_dir): |
| | os.makedirs(save_dir) |
| |
|
| | plt.tight_layout() |
| | |
| | |
| | timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
| |
|
| | |
| | if save_path is None: |
| | save_path = os.path.join(save_dir, f"training_metrics_{timestamp}.png") |
| | |
| | print(f"Saving plot to {save_path}...") |
| | plt.savefig(save_path) |
| | |
| | |
| | if not eval_acc.empty: |
| | final_acc = eval_acc.iloc[-1]['eval_accuracy'] |
| | final_loss = eval_acc.iloc[-1]['eval_loss'] if 'eval_loss' in eval_acc.columns else "N/A" |
| | metrics_file = os.path.join(save_dir, f"metrics_{timestamp}.txt") |
| | with open(metrics_file, "w") as f: |
| | f.write(f"Timestamp: {timestamp}\n") |
| | f.write(f"Final Validation Accuracy: {final_acc:.4f}\n") |
| | f.write(f"Final Validation Loss: {final_loss}\n") |
| | f.write(f"Plot saved to: {os.path.basename(save_path)}\n") |
| | print(f"Saved metrics text to {metrics_file}") |
| |
|
| | def load_and_plot_logs(log_dir): |
| | """ |
| | 从 checkpoint 目录加载 trainer_state.json 并绘图 |
| | """ |
| | json_path = os.path.join(log_dir, 'trainer_state.json') |
| | if not os.path.exists(json_path): |
| | print(f"未找到日志文件: {json_path}") |
| | return |
| | |
| | with open(json_path, 'r') as f: |
| | data = json.load(f) |
| | |
| | plot_training_history(data['log_history']) |
| |
|
| | if __name__ == "__main__": |
| | import sys |
| | import os |
| | |
| | |
| | project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| | sys.path.append(project_root) |
| | |
| | from src.config import Config |
| | |
| | |
| | |
| | try: |
| | print("\n正在加载数据集以生成样本分布分析...") |
| | from transformers import AutoTokenizer |
| | from src.dataset import DataProcessor |
| | |
| | tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL) |
| | processor = DataProcessor(tokenizer) |
| | |
| | dataset = processor.get_processed_dataset(cache_dir=Config.DATA_DIR) |
| | |
| | |
| | timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
| | dist_save_path = os.path.join(Config.RESULTS_DIR, "images", f"data_distribution_{timestamp}.png") |
| | |
| | |
| | plot_data_distribution(dataset, save_path=dist_save_path) |
| | print(f"数据样本分布分析已保存至: {dist_save_path}") |
| | |
| | except Exception as e: |
| | print(f"无法生成数据分布图 (可能是数据尚未下载或处理): {e}") |
| |
|
| | |
| | |
| | |
| | import glob |
| | |
| | |
| | search_paths = [ |
| | Config.OUTPUT_DIR, |
| | os.path.join(Config.RESULTS_DIR, "checkpoint-*") |
| | ] |
| | |
| | candidates = [] |
| | for p in search_paths: |
| | candidates.extend(glob.glob(p)) |
| | |
| | if candidates: |
| | |
| | candidates.sort(key=os.path.getmtime) |
| | latest_ckpt = candidates[-1] |
| | print(f"Loading logs from: {latest_ckpt}") |
| | load_and_plot_logs(latest_ckpt) |
| | else: |
| | print("未找到任何 checkpoint 或 trainer_state.json 日志文件。") |
| |
|