模型蒸馏数据提取脚本:

import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
from tqdm import tqdm

# 路径设置
teacher_model_path = "D:/Qwen2.5-7B-Instruct"
dataset_path = "D:/Qwen2.5-7B-Instruct/sdata/alpaca_zh_demo.json"
output_dir = "D:/Qwen2.5-7B-Instruct/data"
os.makedirs(output_dir, exist_ok=True)

# 加载教师模型和tokenizer
print("Loading teacher model...")
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_path)
teacher_model = AutoModelForCausalLM.from_pretrained(
teacher_model_path,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto"
)

# 加载数据集
print("Loading dataset...")
def load_custom_dataset(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
return data

raw_dataset = load_custom_dataset(dataset_path)

# 预处理函数
def format_prompt(example):
prompt = f"Instruction: {example['instruction']}\n"
if example.get('input', ''):
prompt += f"Input: {example['input']}\n"
prompt += "Output: "
return prompt

# 蒸馏过程 - 只提取教师模型的输出
print("Running distillation (teacher model inference)...")
distilled_data = []

for example in tqdm(raw_dataset):
try:
# 准备输入
prompt = format_prompt(example)
inputs = teacher_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(teacher_model.device) for k, v in inputs.items()}

# 计算输出长度
output_length = len(teacher_tokenizer(example['output'])) + 10

# 获取教师模型输出
with torch.no_grad():
outputs = teacher_model.generate(
**inputs,
max_new_tokens=output_length, # 修正括号问题
do_sample=True,
temperature=0.7,
top_p=0.9
)

# 解码输出
teacher_output = teacher_tokenizer.decode(outputs[0], skip_special_tokens=True)
teacher_response = teacher_output[len(prompt):] # 只提取生成的部分

# 保存蒸馏数据
distilled_example = {
"instruction": example["instruction"],
"input": example.get("input", ""),
"original_output": example["output"],
"teacher_output": teacher_response
}
distilled_data.append(distilled_example)

except Exception as e:
print(f"Error processing example: {e}")
continue

# 保存蒸馏后的数据
output_file = os.path.join(output_dir, "distilled_data.json")
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(distilled_data, f, ensure_ascii=False, indent=2)

print(f"Distillation complete! Data saved to {output_file}")

说明:

运行前请安装相应模型运行环境,计算输出长度可调,自己也可根据实际情况进行拓展!

声明:本站内容来自公开平台,如若侵犯到您的权益,请联系我们,我们会第一时间删除!联系QQ:502428990。