546 lines
23 KiB
Python
546 lines
23 KiB
Python
"""
|
||
NewsImpactOnStocks - 热点新闻股票影响分析工具
|
||
|
||
该工具可以分析热点新闻对股票市场的影响,找出可能受影响的股票,并给出投资建议。
|
||
基于DeepSeek API实现,使用deepseek-reasoner模型进行分析。
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import pandas as pd
|
||
import openai
|
||
import json
|
||
import time
|
||
import argparse
|
||
from datetime import datetime
|
||
import re
|
||
import traceback
|
||
|
||
# DeepSeek API配置
|
||
API_KEY = "sk-8a121704a9bc4ec6a5ab0ae16e0bc0ba" # 请替换为您的API密钥
|
||
BASE_URL = "https://api.deepseek.com"
|
||
|
||
# 获取脚本所在目录
|
||
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||
|
||
# 上市公司信息表路径
|
||
COMPANY_INFO_PATH = os.path.join(SCRIPT_DIR, "上市公司信息表.xlsx")
|
||
|
||
# 输出结果保存目录
|
||
OUTPUT_DIR = os.path.join(SCRIPT_DIR, "analysis_results")
|
||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||
|
||
class Logger:
|
||
"""日志记录器"""
|
||
|
||
def __init__(self, log_to_file=False):
|
||
self.log_to_file = log_to_file
|
||
self.log_file = None
|
||
|
||
if log_to_file:
|
||
log_dir = os.path.join(SCRIPT_DIR, "logs")
|
||
os.makedirs(log_dir, exist_ok=True)
|
||
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
self.log_file = open(os.path.join(log_dir, f"news_analysis_{timestamp}.log"), "w", encoding="utf-8")
|
||
|
||
def log(self, message, level="INFO"):
|
||
"""记录日志"""
|
||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
log_message = f"[{timestamp}] [{level}] {message}"
|
||
|
||
print(log_message)
|
||
|
||
if self.log_to_file and self.log_file:
|
||
self.log_file.write(log_message + "\n")
|
||
self.log_file.flush()
|
||
|
||
def close(self):
|
||
"""关闭日志文件"""
|
||
if self.log_to_file and self.log_file:
|
||
self.log_file.close()
|
||
|
||
class NewsImpactAnalyzer:
|
||
"""热点新闻股票影响分析器"""
|
||
|
||
def __init__(self, logger=None):
|
||
"""初始化分析器,加载上市公司信息"""
|
||
self.logger = logger or Logger()
|
||
# 使用openai 0.28.0版本的初始化方式
|
||
openai.api_key = API_KEY
|
||
openai.api_base = BASE_URL
|
||
self.client = openai
|
||
self.load_company_info()
|
||
|
||
def load_company_info(self):
|
||
"""加载上市公司信息表"""
|
||
try:
|
||
self.company_df = pd.read_excel(COMPANY_INFO_PATH)
|
||
self.logger.log(f"成功加载上市公司信息,共 {len(self.company_df)} 条记录")
|
||
except Exception as e:
|
||
self.logger.log(f"加载上市公司信息失败: {e}", "ERROR")
|
||
self.logger.log(traceback.format_exc(), "ERROR")
|
||
self.company_df = pd.DataFrame()
|
||
|
||
def analyze_news_impact_on_industries(self, news):
|
||
"""分析新闻对行业的影响"""
|
||
self.logger.log("开始分析新闻对行业的影响...")
|
||
|
||
system_prompt = """
|
||
# 角色
|
||
你是一位顶尖的证券分析师,拥有深厚的行业知识和敏锐的市场洞察力。能够对热点新闻进行深入分析,判断其对各行业的影响,并给出详细理由。
|
||
|
||
## 技能
|
||
### 技能 1:分析新闻对行业的影响
|
||
1. 当用户提供热点新闻标题或内容时,确定其中的关键要素;
|
||
2. 结合最新的市场趋势、经济环境以及各行业发展情况,深入分析该新闻可能对哪些行业产生影响;
|
||
3. 输出全部可能有被影响的行业,并整合到一句话中。
|
||
|
||
## 输出格式
|
||
请按以下格式输出:
|
||
|
||
影响行业:行业1、行业2、行业3...
|
||
|
||
影响分析:
|
||
1. 行业1:[分析理由]
|
||
2. 行业2:[分析理由]
|
||
...
|
||
|
||
## 限制:
|
||
- 只分析与新闻相关的行业影响,拒绝回答与新闻无关的问题。
|
||
- 分析理由要充分、有条理。
|
||
"""
|
||
|
||
user_prompt = f"""
|
||
下面是用户提供的热点新闻信息,请分析该新闻可能影响的全部行业,并给出详细理由。
|
||
|
||
===热点新闻开始===
|
||
{news}
|
||
===热点新闻结束===
|
||
"""
|
||
|
||
try:
|
||
self.logger.log("调用DeepSeek API分析行业影响...")
|
||
response = self.client.ChatCompletion.create(
|
||
model="deepseek-reasoner",
|
||
messages=[
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt}
|
||
],
|
||
temperature=0.1,
|
||
stream=False
|
||
)
|
||
|
||
analysis = response.choices[0].message.content
|
||
self.logger.log("行业影响分析完成")
|
||
|
||
# 提取行业和理由
|
||
industries = []
|
||
reasons = analysis
|
||
|
||
# 尝试从格式化输出中提取行业
|
||
industry_pattern = r"影响行业[::](.*?)(?:\n|$)"
|
||
industry_match = re.search(industry_pattern, analysis)
|
||
|
||
if industry_match:
|
||
industry_text = industry_match.group(1).strip()
|
||
industries = [ind.strip() for ind in re.split(r'[,,、]', industry_text) if ind.strip()]
|
||
|
||
if not industries:
|
||
# 如果没有明确提取出行业,尝试从整个分析中提取
|
||
industries = self._extract_industries_from_text(analysis)
|
||
|
||
return {
|
||
"industries": industries,
|
||
"reasons": reasons
|
||
}
|
||
except Exception as e:
|
||
self.logger.log(f"分析新闻对行业的影响时出错: {e}", "ERROR")
|
||
self.logger.log(traceback.format_exc(), "ERROR")
|
||
return {
|
||
"industries": [],
|
||
"reasons": "分析过程中出现错误,请检查API配置或网络连接。"
|
||
}
|
||
|
||
def _extract_industries_from_text(self, text):
|
||
"""从文本中提取可能的行业名称"""
|
||
# 常见行业列表
|
||
common_industries = [
|
||
"互联网", "金融", "银行", "保险", "证券", "房地产", "医药", "医疗", "健康",
|
||
"教育", "零售", "消费", "制造", "能源", "电力", "新能源", "汽车", "电子",
|
||
"半导体", "通信", "传媒", "娱乐", "旅游", "餐饮", "物流", "交通", "航空",
|
||
"铁路", "船舶", "钢铁", "煤炭", "石油", "化工", "农业", "食品", "饮料",
|
||
"纺织", "服装", "建筑", "建材", "家电", "软件", "硬件", "人工智能", "云计算",
|
||
"大数据", "区块链", "物联网", "5G", "军工", "航天", "环保", "新材料"
|
||
]
|
||
|
||
found_industries = []
|
||
for industry in common_industries:
|
||
if industry in text:
|
||
found_industries.append(industry)
|
||
|
||
return found_industries
|
||
|
||
def search_related_companies(self, industries):
|
||
"""查找与行业相关的上市公司"""
|
||
self.logger.log(f"开始查找与行业相关的上市公司: {', '.join(industries)}")
|
||
|
||
# 定义行业关键词及其权重
|
||
industry_keywords = {}
|
||
|
||
# 检查新闻是否与小米相关
|
||
is_xiaomi_related = False
|
||
for industry in industries:
|
||
if "小米" in industry or "消费电子" in industry or "互联网" in industry:
|
||
is_xiaomi_related = True
|
||
break
|
||
|
||
# 根据行业添加关键词
|
||
for industry in industries:
|
||
if "汽车" in industry:
|
||
industry_keywords["汽车零部件"] = 8
|
||
industry_keywords["汽车"] = 5
|
||
if "新能源" in industry:
|
||
industry_keywords["新能源汽车"] = 10
|
||
industry_keywords["新能源"] = 8
|
||
elif "电池" in industry:
|
||
industry_keywords["锂电池"] = 10
|
||
industry_keywords["电池"] = 7
|
||
elif "电子" in industry:
|
||
industry_keywords["消费电子"] = 7
|
||
industry_keywords["电子"] = 5
|
||
else:
|
||
# 其他行业直接添加,权重默认为5
|
||
industry_keywords[industry] = 5
|
||
|
||
# 如果新闻与小米相关,强制添加小米公司
|
||
if is_xiaomi_related:
|
||
self.logger.log("新闻与小米相关,强制添加小米公司")
|
||
industry_keywords["小米"] = 10
|
||
|
||
# 根据行业添加相关关键词
|
||
search_keywords = set()
|
||
for industry in industries:
|
||
# 将行业名称中的关键词添加到搜索关键词中
|
||
for keyword in industry_keywords.keys():
|
||
if keyword in industry:
|
||
search_keywords.add(keyword)
|
||
|
||
# 如果没有匹配到关键词,则直接添加行业名称
|
||
if not any(keyword in industry for keyword in industry_keywords.keys()):
|
||
search_keywords.add(industry)
|
||
|
||
self.logger.log(f"搜索关键词: {', '.join(search_keywords)}")
|
||
|
||
# 初始化公司得分字典
|
||
company_scores = {}
|
||
|
||
# 对每个关键词进行搜索并计算得分
|
||
for keyword in search_keywords:
|
||
weight = industry_keywords.get(keyword, 5) # 默认权重为5
|
||
self.logger.log(f"搜索关键词 '{keyword}' (权重: {weight})...")
|
||
|
||
try:
|
||
# 在不同字段中搜索关键词,并设置不同的权重
|
||
industry_match = self.company_df['IndustryName'].str.contains(keyword, na=False)
|
||
business_match = self.company_df['MAINBUSSINESS'].str.contains(keyword, na=False)
|
||
scope_match = self.company_df['BusinessScope'].str.contains(keyword, na=False)
|
||
|
||
# 匹配到的公司
|
||
matched_companies = self.company_df[industry_match | business_match | scope_match]
|
||
|
||
self.logger.log(f"找到 {len(matched_companies)} 家与 '{keyword}' 相关的公司")
|
||
|
||
# 计算每家公司的得分
|
||
for _, company in matched_companies.iterrows():
|
||
# 使用原始类型的Symbol作为键
|
||
symbol = company.get('Symbol')
|
||
score = company_scores.get(symbol, 0)
|
||
|
||
# 行业名称匹配权重最高
|
||
if industry_match.iloc[company.name]:
|
||
score += weight * 2
|
||
# 主营业务匹配次之
|
||
if business_match.iloc[company.name]:
|
||
score += weight * 1.5
|
||
# 经营范围匹配权重最低
|
||
if scope_match.iloc[company.name]:
|
||
score += weight * 1
|
||
|
||
company_scores[symbol] = score
|
||
except Exception as e:
|
||
self.logger.log(f"搜索关键词 '{keyword}' 时出错: {e}", "ERROR")
|
||
self.logger.log(traceback.format_exc(), "ERROR")
|
||
|
||
# 如果没有找到相关公司,返回空列表
|
||
if not company_scores:
|
||
self.logger.log("未找到与行业相关的公司", "WARNING")
|
||
return []
|
||
|
||
# 对公司按得分排序并选择前15家
|
||
# 确保公司得分字典中的键是正确的类型
|
||
sorted_companies = sorted(company_scores.items(), key=lambda x: x[1], reverse=True)
|
||
top_companies = sorted_companies[:15]
|
||
|
||
self.logger.log(f"共找到 {len(company_scores)} 家相关公司,并按相关度排序")
|
||
self.logger.log(f"选择得分最高的 {len(top_companies)} 家公司进行分析")
|
||
|
||
# 调试信息:打印前5家公司的代码和得分
|
||
if sorted_companies:
|
||
self.logger.log("前5家公司代码和得分:")
|
||
for i, (symbol, score) in enumerate(sorted_companies[:5]):
|
||
self.logger.log(f" {i+1}. 代码: {symbol}, 类型: {type(symbol)}, 得分: {score:.2f}")
|
||
|
||
# 将公司信息转换为字典列表
|
||
related_companies = []
|
||
|
||
# 特殊处理:如果新闻与小米相关,手动添加小米公司
|
||
if any(keyword in ''.join(industries) for keyword in ['小米', '消费电子', '智能手机', '汽车']):
|
||
# 查找小米公司
|
||
xiaomi_matches = self.company_df[self.company_df['ShortName'].str.contains('小米', na=False)]
|
||
if len(xiaomi_matches) > 0:
|
||
xiaomi_row = xiaomi_matches.iloc[0]
|
||
xiaomi_info = {
|
||
"Symbol": str(xiaomi_row.get('Symbol', '')),
|
||
"ShortName": str(xiaomi_row.get('ShortName', '')),
|
||
"IndustryName": str(xiaomi_row.get('IndustryName', '')),
|
||
"FullName": str(xiaomi_row.get('FullName', '')),
|
||
"BusinessScope": str(xiaomi_row.get('BusinessScope', '')),
|
||
"MAINBUSSINESS": str(xiaomi_row.get('MAINBUSSINESS', '')),
|
||
"Score": 100.0 # 给小米最高分
|
||
}
|
||
related_companies.append(xiaomi_info)
|
||
self.logger.log(f"特殊处理:添加小米公司 {xiaomi_info['ShortName']},得分: 100.00")
|
||
|
||
# 处理其他公司
|
||
for symbol, score in top_companies:
|
||
# 安全地获取公司信息,避免索引错误
|
||
# 保持Symbol的原始类型进行比较
|
||
company_matches = self.company_df[self.company_df['Symbol'] == symbol]
|
||
if len(company_matches) == 0:
|
||
self.logger.log(f"警告: 找不到股票代码为 {symbol} 的公司信息", "WARNING")
|
||
continue
|
||
|
||
company_row = company_matches.iloc[0]
|
||
company_info = {
|
||
"Symbol": str(company_row.get('Symbol', '')),
|
||
"ShortName": str(company_row.get('ShortName', '')),
|
||
"IndustryName": str(company_row.get('IndustryName', '')),
|
||
"FullName": str(company_row.get('FullName', '')),
|
||
"BusinessScope": str(company_row.get('BusinessScope', '')),
|
||
"MAINBUSSINESS": str(company_row.get('MAINBUSSINESS', '')),
|
||
"Score": score # 添加得分信息,便于调试
|
||
}
|
||
related_companies.append(company_info)
|
||
self.logger.log(f"公司: {company_info['ShortName']},得分: {score:.2f}")
|
||
|
||
return related_companies
|
||
|
||
def analyze_company_impact(self, news, impact_reasons, company_list):
|
||
"""分析新闻对公司的影响"""
|
||
self.logger.log(f"开始分析新闻对 {len(company_list)} 家公司的影响...")
|
||
|
||
system_prompt = """
|
||
# 角色
|
||
你是一位顶尖的证券分析师,擅长依据热点新闻的主题及关键信息,准确分析其对不同行业相关股票的潜在影响。
|
||
|
||
## 技能
|
||
### 技能 1:解读新闻对股票的影响
|
||
1. 当用户提供一条热点新闻时,认真剖析新闻的主题、核心内容及潜在影响因素。
|
||
2. 依据影响原因和上市公司的基本信息,分析新闻事件对公司造成的影响作出专业的分析。
|
||
3. 以Markdown表格的形式输出分析结果(包含公司简称,股票代码,公司简介,主营业务,影响分析)。
|
||
|
||
## 输出格式
|
||
请按以下格式输出每家公司的分析结果:
|
||
|
||
🎯 公司简称(代码:股票代码)
|
||
|
||
📝 公司简介: <公司介绍>
|
||
|
||
🌟 主营业务: <主营业务>
|
||
|
||
📈 影响分析: <详细说明新闻中的哪些因素影响了该股票>
|
||
|
||
## 限制:
|
||
- 只围绕热点新闻与股票影响进行分析,拒绝回答与此无关的问题。
|
||
- 输出内容严格按照给定格式组织,不得偏离要求。
|
||
- "影响分析"部分简洁明了,不超过200 字。
|
||
"""
|
||
|
||
# 格式化公司列表为文本
|
||
company_list_text = ""
|
||
for company in company_list:
|
||
company_list_text += f"公司简称:{company['ShortName']},股票代码:{company['Symbol']},行业:{company['IndustryName']},"
|
||
company_list_text += f"主营业务:{company['MAINBUSSINESS']},公司简介:{company['FullName']}\n\n"
|
||
|
||
user_prompt = f"""
|
||
请根据用户输入的热点新闻进行分析:
|
||
|
||
===热点新闻开始===
|
||
{news}
|
||
===热点新闻结束===
|
||
|
||
影响原因:
|
||
{impact_reasons}
|
||
|
||
上市公司基本信息:
|
||
{company_list_text}
|
||
"""
|
||
|
||
try:
|
||
self.logger.log("调用DeepSeek API分析公司影响...")
|
||
response = self.client.ChatCompletion.create(
|
||
model="deepseek-reasoner",
|
||
messages=[
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt}
|
||
],
|
||
temperature=0.2,
|
||
stream=False
|
||
)
|
||
|
||
analysis_result = response.choices[0].message.content
|
||
self.logger.log("公司影响分析完成")
|
||
return analysis_result
|
||
except Exception as e:
|
||
self.logger.log(f"分析新闻对公司的影响时出错: {e}", "ERROR")
|
||
self.logger.log(traceback.format_exc(), "ERROR")
|
||
return "分析过程中出现错误,请检查API配置或网络连接。"
|
||
|
||
def analyze_news(self, news):
|
||
"""完整的新闻分析流程"""
|
||
self.logger.log("=" * 50)
|
||
self.logger.log("开始新闻分析流程")
|
||
self.logger.log("=" * 50)
|
||
|
||
# 1. 分析新闻对行业的影响
|
||
self.logger.log("步骤1: 分析新闻对行业的影响")
|
||
industry_analysis = self.analyze_news_impact_on_industries(news)
|
||
|
||
if not industry_analysis["industries"]:
|
||
self.logger.log("未找到受影响的行业,分析结束", "WARNING")
|
||
return "未能识别出该新闻可能影响的行业,请尝试提供更详细的新闻内容。"
|
||
|
||
# 2. 查找相关上市公司
|
||
self.logger.log(f"步骤2: 查找与行业相关的上市公司")
|
||
related_companies = self.search_related_companies(industry_analysis["industries"])
|
||
|
||
if not related_companies:
|
||
self.logger.log("未找到相关上市公司,分析结束", "WARNING")
|
||
return f"已识别出可能受影响的行业: {', '.join(industry_analysis['industries'])}\n\n但未找到与这些行业相关的上市公司。"
|
||
|
||
# 限制公司数量,避免API请求过大
|
||
max_companies = 10
|
||
if len(related_companies) > max_companies:
|
||
self.logger.log(f"相关公司过多,限制为前{max_companies}家进行详细分析")
|
||
related_companies = related_companies[:max_companies]
|
||
|
||
# 3. 分析新闻对公司的影响
|
||
self.logger.log(f"步骤3: 分析新闻对{len(related_companies)}家公司的影响")
|
||
company_impact_analysis = self.analyze_company_impact(
|
||
news,
|
||
industry_analysis["reasons"],
|
||
related_companies
|
||
)
|
||
|
||
# 4. 保存分析结果
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
result_file = os.path.join(OUTPUT_DIR, f"news_analysis_{timestamp}.md")
|
||
|
||
with open(result_file, "w", encoding="utf-8") as f:
|
||
f.write("# 热点新闻股票影响分析\n\n")
|
||
f.write(f"分析时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
||
f.write("## 新闻内容\n\n")
|
||
f.write(f"{news}\n\n")
|
||
f.write("## 行业影响分析\n\n")
|
||
f.write(f"{industry_analysis['reasons']}\n\n")
|
||
f.write("## 公司影响分析\n\n")
|
||
f.write(f"{company_impact_analysis}\n")
|
||
|
||
self.logger.log(f"分析结果已保存至: {result_file}")
|
||
self.logger.log("=" * 50)
|
||
self.logger.log("新闻分析流程完成")
|
||
self.logger.log("=" * 50)
|
||
|
||
return company_impact_analysis
|
||
|
||
def display_banner():
|
||
"""显示程序横幅"""
|
||
banner = """
|
||
╔═══════════════════════════════════════════════════════════════╗
|
||
║ ║
|
||
║ 📈 热点新闻股票影响分析工具 📊 ║
|
||
║ ║
|
||
║ 基于DeepSeek AI的智能分析系统 ║
|
||
║ 版本: 1.0.0 ║
|
||
║ ║
|
||
╚═══════════════════════════════════════════════════════════════╝
|
||
"""
|
||
print(banner)
|
||
|
||
def parse_arguments():
|
||
"""解析命令行参数"""
|
||
parser = argparse.ArgumentParser(description="热点新闻股票影响分析工具")
|
||
parser.add_argument("--news", type=str, help="要分析的新闻内容")
|
||
parser.add_argument("--file", type=str, help="包含新闻内容的文件路径")
|
||
parser.add_argument("--log", action="store_true", help="是否记录日志到文件")
|
||
return parser.parse_args()
|
||
|
||
def main():
|
||
"""主函数"""
|
||
display_banner()
|
||
|
||
args = parse_arguments()
|
||
logger = Logger(log_to_file=args.log)
|
||
analyzer = NewsImpactAnalyzer(logger)
|
||
|
||
news_content = None
|
||
|
||
# 从命令行参数获取新闻内容
|
||
if args.news:
|
||
news_content = args.news
|
||
elif args.file:
|
||
try:
|
||
with open(args.file, "r", encoding="utf-8") as f:
|
||
news_content = f.read()
|
||
except Exception as e:
|
||
logger.log(f"读取新闻文件失败: {e}", "ERROR")
|
||
|
||
# 如果没有从命令行获取新闻内容,则进入交互模式
|
||
if not news_content:
|
||
while True:
|
||
print("\n请输入热点新闻(输入'退出'结束程序):")
|
||
news = input()
|
||
|
||
if news.lower() in ['退出', 'exit', 'quit']:
|
||
break
|
||
|
||
if not news.strip():
|
||
print("新闻内容不能为空,请重新输入。")
|
||
continue
|
||
|
||
print("\n开始分析新闻对股票的影响...")
|
||
result = analyzer.analyze_news(news)
|
||
|
||
print("\n分析结果:")
|
||
print(result)
|
||
print("\n" + "=" * 50)
|
||
else:
|
||
# 使用命令行提供的新闻内容进行分析
|
||
print("\n开始分析新闻对股票的影响...")
|
||
result = analyzer.analyze_news(news_content)
|
||
|
||
print("\n分析结果:")
|
||
print(result)
|
||
|
||
logger.close()
|
||
|
||
if __name__ == "__main__":
|
||
try:
|
||
main()
|
||
except KeyboardInterrupt:
|
||
print("\n程序已被用户中断")
|
||
except Exception as e:
|
||
print(f"程序运行出错: {e}")
|
||
print(traceback.format_exc())
|