模型蒸馏数据提取脚本:
import os import torch from transformers import AutoTokenizer, AutoModelForCausalLM import json from tqdm import tqdm # 路径设置 teacher_model_path = "D:/Qwen2.5-7B-Instruct" dataset_path = "D:/Qwen2.5-7B-Instruct/sdata/alpaca_zh_demo.json" output_dir = "D:/Qwen2.5-7B-Instruct/data" os.makedirs(output_dir, exist_ok=True) # 加载教师模型和tokenizer print("Loading teacher model...") teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_path) teacher_model = AutoModelForCausalLM.from_pretrained( teacher_model_path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" ) # 加载数据集 print("Loading dataset...") def load_custom_dataset(file_path): with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) return data raw_dataset = load_custom_dataset(dataset_path) # 预处理函数 def format_prompt(example): prompt = f"Instruction: {example['instruction']}\n" if example.get('input', ''): prompt += f"Input: {example['input']}\n" prompt += "Output: " return prompt # 蒸馏过程 - 只提取教师模型的输出 print("Running distillation (teacher model inference)...") distilled_data = [] for example in tqdm(raw_dataset): try: # 准备输入 prompt = format_prompt(example) inputs = teacher_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) inputs = {k: v.to(teacher_model.device) for k, v in inputs.items()} # 计算输出长度 output_length = len(teacher_tokenizer(example['output'])) + 10 # 获取教师模型输出 with torch.no_grad(): outputs = teacher_model.generate( **inputs, max_new_tokens=output_length, # 修正括号问题 do_sample=True, temperature=0.7, top_p=0.9 ) # 解码输出 teacher_output = teacher_tokenizer.decode(outputs[0], skip_special_tokens=True) teacher_response = teacher_output[len(prompt):] # 只提取生成的部分 # 保存蒸馏数据 distilled_example = { "instruction": example["instruction"], "input": example.get("input", ""), "original_output": example["output"], "teacher_output": teacher_response } distilled_data.append(distilled_example) except Exception as e: print(f"Error processing example: {e}") continue # 保存蒸馏后的数据 output_file = os.path.join(output_dir, "distilled_data.json") with open(output_file, 'w', encoding='utf-8') as f: json.dump(distilled_data, f, ensure_ascii=False, indent=2) print(f"Distillation complete! Data saved to {output_file}")
说明:
运行前请安装相应模型运行环境,计算输出长度可调,自己也可根据实际情况进行拓展!
声明:本站内容来自公开平台,如若侵犯到您的权益,请联系我们,我们会第一时间删除!联系QQ:502428990。
评论(0)