diff --git a/adata/__init__.py b/adata/__init__.py index dee08e2..a1d13eb 100644 --- a/adata/__init__.py +++ b/adata/__init__.py @@ -7,10 +7,18 @@ # -*- coding: utf-8 -*- import logging +from typing import Optional from adata.__version__ import __version__ from adata.bond import bond -from adata.common.utils.sunrequests import SunProxy +from adata.common.utils.rate_limiter import ( + get_rate_limiter, + set_rate_limit, + set_default_rate_limit, + get_stats as get_rate_limit_stats, + reset_rate_limiter, +) +from adata.common.utils.sunrequests import SunProxy, sun_requests from adata.fund import fund from adata.sentiment import sentiment from adata.stock import stock @@ -33,6 +41,93 @@ def proxy(is_proxy=False, ip: str = None, proxy_url: str = None): return +def rate_limit(domain: str, max_requests: int = 30, time_window: int = 60) -> None: + """ + 设置特定域名的请求限流参数 + + 用于控制对特定数据源的请求频率,防止因高频请求导致的IP封禁或API配额耗尽。 + 采用滑动窗口算法,确保在任意 time_window 秒内不超过 max_requests 次请求。 + + :param domain: 域名,如 "eastmoney.com", "push2his.eastmoney.com" + :param max_requests: 时间窗口内最大请求数,默认30次 + :param time_window: 时间窗口(秒),默认60秒 + + 示例: + >>> import adata + >>> # 设置东方财富接口每分钟最多请求30次 + >>> adata.rate_limit("eastmoney.com", max_requests=30, time_window=60) + >>> # 设置更严格的限制:每分钟20次 + >>> adata.rate_limit("push2his.eastmoney.com", max_requests=20, time_window=60) + >>> # 查询股票行情(会自动应用限流) + >>> df = adata.stock.market.get_market(stock_code='000001', k_type=1) + """ + set_rate_limit(domain, max_requests=max_requests, time_window=time_window) + + +def rate_limit_default(max_requests: int = 30, time_window: int = 60) -> None: + """ + 设置默认的请求限流参数 + + 对所有未单独配置限流参数的域名生效。 + + :param max_requests: 默认时间窗口内最大请求数,默认30次 + :param time_window: 默认时间窗口(秒),默认60秒 + + 示例: + >>> import adata + >>> # 设置默认每分钟最多请求20次 + >>> adata.rate_limit_default(max_requests=20, time_window=60) + """ + set_default_rate_limit(max_requests=max_requests, time_window=time_window) + + +def rate_limit_enable() -> None: + """启用请求限流器(默认已启用)""" + sun_requests.enable_rate_limit() + logger.info("[RateLimiter] 请求限流已启用") + + +def rate_limit_disable() -> None: + """禁用请求限流器""" + sun_requests.disable_rate_limit() + logger.info("[RateLimiter] 请求限流已禁用") + + +def rate_limit_stats(domain: Optional[str] = None) -> dict: + """ + 获取限流统计信息 + + :param domain: 指定域名,None表示返回所有域名统计 + :return: 统计信息字典 + + 示例: + >>> import adata + >>> # 查看所有域名统计 + >>> stats = adata.rate_limit_stats() + >>> print(stats) + >>> # 查看特定域名统计 + >>> stats = adata.rate_limit_stats("eastmoney.com") + >>> print(stats) + """ + return get_rate_limit_stats(domain) + + +def rate_limit_reset(domain: Optional[str] = None) -> None: + """ + 重置限流器状态 + + :param domain: 指定域名,None表示重置所有域名 + + 示例: + >>> import adata + >>> # 重置特定域名 + >>> adata.rate_limit_reset("eastmoney.com") + >>> # 重置所有域名 + >>> adata.rate_limit_reset() + """ + reset_rate_limiter(domain) + + # set up logging logger = logging.getLogger("adata") diff --git a/adata/common/utils/rate_limiter.py b/adata/common/utils/rate_limiter.py new file mode 100644 index 0000000..bd5c717 --- /dev/null +++ b/adata/common/utils/rate_limiter.py @@ -0,0 +1,282 @@ +# -*- coding: utf-8 -*- +""" +@desc: 基于域名的请求限流器 +@author: 1nchaos +@time: 2026/3/18 +@log: 实现基于滑动窗口的域名级请求频率控制 +""" + +import threading +import time +from collections import deque +from dataclasses import dataclass, field +from typing import Dict, Optional +from urllib.parse import urlparse +import logging + +logger = logging.getLogger("adata") + + +@dataclass +class RateLimitConfig: + """限流配置数据类""" + max_requests: int = 30 # 默认每分钟最大请求数 + time_window: int = 60 # 默认时间窗口(秒) + wait_message: bool = True # 是否显示等待提示 + + +class DomainRateLimiter: + """ + 基于滑动窗口的域名限流器 + 为每个域名维护独立的请求时间窗口 + """ + + def __init__(self): + # 域名 -> 请求时间队列 的映射 + self._domain_windows: Dict[str, deque] = {} + # 域名 -> 限流配置 的映射 + self._domain_configs: Dict[str, RateLimitConfig] = {} + # 全局默认配置 + self._default_config = RateLimitConfig() + # 线程锁,保证线程安全 + self._lock = threading.RLock() + + def set_domain_config(self, domain: str, max_requests: Optional[int] = None, + time_window: Optional[int] = None) -> None: + """ + 设置特定域名的限流配置 + + :param domain: 域名,如 "eastmoney.com" + :param max_requests: 时间窗口内最大请求数,None表示使用默认值 + :param time_window: 时间窗口(秒),None表示使用默认值 + """ + with self._lock: + if domain not in self._domain_configs: + self._domain_configs[domain] = RateLimitConfig() + + if max_requests is not None: + self._domain_configs[domain].max_requests = max_requests + if time_window is not None: + self._domain_configs[domain].time_window = time_window + + logger.info(f"[RateLimiter] 域名 {domain} 限流配置已更新: " + f"{self._domain_configs[domain].max_requests}次/" + f"{self._domain_configs[domain].time_window}秒") + + def set_default_config(self, max_requests: Optional[int] = None, + time_window: Optional[int] = None) -> None: + """ + 设置全局默认限流配置 + + :param max_requests: 默认时间窗口内最大请求数 + :param time_window: 默认时间窗口(秒) + """ + with self._lock: + if max_requests is not None: + self._default_config.max_requests = max_requests + if time_window is not None: + self._default_config.time_window = time_window + + logger.info(f"[RateLimiter] 默认限流配置已更新: " + f"{self._default_config.max_requests}次/" + f"{self._default_config.time_window}秒") + + def get_domain_config(self, domain: str) -> RateLimitConfig: + """获取域名的限流配置,如未配置则返回默认配置""" + with self._lock: + return self._domain_configs.get(domain, self._default_config) + + def _extract_domain(self, url: str) -> str: + """从URL中提取域名""" + try: + parsed = urlparse(url) + domain = parsed.netloc.lower() + # 移除端口(如果有) + if ':' in domain: + domain = domain.split(':')[0] + return domain + except Exception: + return "unknown" + + def acquire(self, url: str) -> float: + """ + 尝试获取请求许可,如超限则等待 + + :param url: 请求的URL + :return: 实际等待的时间(秒) + """ + domain = self._extract_domain(url) + config = self.get_domain_config(domain) + + with self._lock: + # 初始化该域名的请求窗口 + if domain not in self._domain_windows: + self._domain_windows[domain] = deque() + + window = self._domain_windows[domain] + now = time.time() + + # 清理窗口中已过期的请求记录 + cutoff_time = now - config.time_window + while window and window[0] < cutoff_time: + window.popleft() + + # 检查是否需要等待 + if len(window) >= config.max_requests: + # 计算需要等待的时间 + oldest_request = window[0] + wait_time = (oldest_request + config.time_window) - now + + if wait_time > 0: + if config.wait_message: + logger.warning(f"[RateLimiter] 域名 {domain} 请求频率超限 " + f"({config.max_requests}次/{config.time_window}秒), " + f"等待 {wait_time:.2f} 秒...") + + # 释放锁,等待时间 + self._lock.release() + try: + time.sleep(wait_time) + finally: + self._lock.acquire() + + # 重新获取当前时间并清理过期记录 + now = time.time() + cutoff_time = now - config.time_window + while window and window[0] < cutoff_time: + window.popleft() + + # 记录本次请求时间 + window.append(now) + + return 0.0 # 成功获取许可,无需额外等待 + + def get_stats(self, domain: Optional[str] = None) -> dict: + """ + 获取限流统计信息 + + :param domain: 指定域名,None表示返回所有域名统计 + :return: 统计信息字典 + """ + with self._lock: + if domain: + if domain not in self._domain_windows: + return {"domain": domain, "request_count": 0} + window = self._domain_windows[domain] + config = self.get_domain_config(domain) + now = time.time() + # 统计窗口内有效请求数 + valid_count = sum(1 for t in window if t > now - config.time_window) + return { + "domain": domain, + "request_count": valid_count, + "max_requests": config.max_requests, + "time_window": config.time_window, + "window_size": len(window) + } + else: + return { + "domains": list(self._domain_windows.keys()), + "configured_domains": list(self._domain_configs.keys()), + "default_config": { + "max_requests": self._default_config.max_requests, + "time_window": self._default_config.time_window + } + } + + def reset(self, domain: Optional[str] = None) -> None: + """ + 重置限流器状态 + + :param domain: 指定域名,None表示重置所有域名 + """ + with self._lock: + if domain: + if domain in self._domain_windows: + self._domain_windows[domain].clear() + logger.info(f"[RateLimiter] 域名 {domain} 限流状态已重置") + else: + self._domain_windows.clear() + logger.info("[RateLimiter] 所有域名限流状态已重置") + + +# 全局限流器单例 +_rate_limiter_instance: Optional[DomainRateLimiter] = None +_instance_lock = threading.Lock() + + +def get_rate_limiter() -> DomainRateLimiter: + """获取全局限流器实例(单例模式)""" + global _rate_limiter_instance + if _rate_limiter_instance is None: + with _instance_lock: + if _rate_limiter_instance is None: + _rate_limiter_instance = DomainRateLimiter() + return _rate_limiter_instance + + +def set_rate_limit(domain: str, max_requests: int = 30, time_window: int = 60) -> None: + """ + 设置特定域名的限流参数(便捷函数) + + :param domain: 域名,如 "eastmoney.com" + :param max_requests: 时间窗口内最大请求数 + :param time_window: 时间窗口(秒) + + 示例: + >>> from adata.common.utils.rate_limiter import set_rate_limit + >>> set_rate_limit("eastmoney.com", max_requests=30, time_window=60) + """ + limiter = get_rate_limiter() + limiter.set_domain_config(domain, max_requests=max_requests, time_window=time_window) + + +def set_default_rate_limit(max_requests: int = 30, time_window: int = 60) -> None: + """ + 设置默认限流参数(便捷函数) + + :param max_requests: 默认时间窗口内最大请求数 + :param time_window: 默认时间窗口(秒) + + 示例: + >>> from adata.common.utils.rate_limiter import set_default_rate_limit + >>> set_default_rate_limit(max_requests=20, time_window=60) + """ + limiter = get_rate_limiter() + limiter.set_default_config(max_requests=max_requests, time_window=time_window) + # + +def enable_rate_limiter() -> None: + """ + 启用限流器(需要在导入sunrequests之前调用) + + 示例: + >>> from adata.common.utils.rate_limiter import enable_rate_limiter + >>> enable_rate_limiter() + >>> import adata + """ + limiter = get_rate_limiter() + # 设置一个标记,表示限流器已启用 + limiter._enabled = True + logger.info("[RateLimiter] 请求限流器已启用") + + +def get_stats(domain: Optional[str] = None) -> dict: + """ + 获取限流统计信息(便捷函数) + + :param domain: 指定域名,None表示返回所有域名统计 + :return: 统计信息字典 + """ + limiter = get_rate_limiter() + return limiter.get_stats(domain) + + +def reset_rate_limiter(domain: Optional[str] = None) -> None: + """ + 重置限流器状态(便捷函数) + + :param domain: 指定域名,None表示重置所有域名 + """ + limiter = get_rate_limiter() + limiter.reset(domain) diff --git a/adata/common/utils/sunrequests.py b/adata/common/utils/sunrequests.py index eaf7c5f..f491618 100644 --- a/adata/common/utils/sunrequests.py +++ b/adata/common/utils/sunrequests.py @@ -5,7 +5,7 @@ @desc: adata 请求工具类 @author: 1nchaos @time:2023/3/30 -@log: 封装请求次数 +@log: 封装请求次数,集成限流功能 """ import threading @@ -13,6 +13,8 @@ import requests +from adata.common.utils.rate_limiter import get_rate_limiter + class SunProxy(object): _data = {} @@ -45,6 +47,16 @@ class SunRequests(object): def __init__(self, sun_proxy: SunProxy = None) -> None: super().__init__() self.sun_proxy = sun_proxy + self._rate_limiter = get_rate_limiter() + self._rate_limit_enabled = True # 默认启用限流 + + def enable_rate_limit(self): + """启用请求限流""" + self._rate_limit_enabled = True + + def disable_rate_limit(self): + """禁用请求限流""" + self._rate_limit_enabled = False def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies=None, wait_time=None, **kwargs): """ @@ -58,9 +70,14 @@ def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies :param kwargs: 其它 requests 参数,用法相同 :return: res """ - # 1. 获取设置代理 + # 1. 限流控制(在请求前进行频率限制检查) + if self._rate_limit_enabled and url: + self._rate_limiter.acquire(url) + + # 2. 获取设置代理 proxies = self.__get_proxies(proxies) - # 2. 请求数据结果 + + # 3. 请求数据结果 res = None for i in range(times): if wait_time: diff --git a/tests/other/rate_limiter_demo.py b/tests/other/rate_limiter_demo.py new file mode 100644 index 0000000..0b2ddcb --- /dev/null +++ b/tests/other/rate_limiter_demo.py @@ -0,0 +1,256 @@ +# -*- coding: utf-8 -*- +""" +@desc: 限流器使用示例 +@author: 1nchaos +@time: 2026/3/18 +@log: 演示如何使用adata的请求限流功能 +""" + +""" +============================================ +AData 请求限流功能使用指南 +============================================ + +功能概述: +--------- +AData 现在内置了基于域名的请求限流功能,可以有效防止因高频请求导致的: +- IP被封禁 +- API配额耗尽 +- 服务商限制 + +核心特性: +--------- +1. 默认30次/分钟的保守策略 +2. 支持按域名独立配置限流参数 +3. 运行时动态调整,无需重启 +4. 对现有代码零侵入 +5. 超限请求给出明确的等待提示 + +使用方法: +--------- +""" + +# ============================================================ +# 示例1: 基本使用(使用默认限流配置) +# ============================================================ +def demo_basic_usage(): + """ + 基本使用示例 + 只需正常导入adata,限流器会自动生效(默认30次/分钟) + """ + import adata + + # 限流器默认已启用,无需额外配置 + # 查询股票行情,会自动应用限流 + df = adata.stock.market.get_market( + stock_code='000001', + k_type=1, + start_date='2026-03-17' + ) + print(df) + + +# ============================================================ +# 示例2: 自定义特定域名的限流参数 +# ============================================================ +def demo_custom_domain_limit(): + """ + 为特定数据源设置自定义限流参数 + """ + import adata + + # 为东方财富接口设置更严格的限制:20次/分钟 + adata.rate_limit("eastmoney.com", max_requests=20, time_window=60) + + # 为新浪接口设置更宽松的限制:60次/分钟 + adata.rate_limit("sina.com.cn", max_requests=60, time_window=60) + + # 后续对这些域名的请求会自动应用相应的限流策略 + df = adata.stock.market.get_market( + stock_code='000001', + k_type=1, + start_date='2026-03-17' + ) + print(df) + + +# ============================================================ +# 示例3: 修改默认限流参数 +# ============================================================ +def demo_default_limit(): + """ + 修改全局默认限流参数 + """ + import adata + + # 设置默认限制为20次/分钟(对所有未单独配置的域名生效) + adata.rate_limit_default(max_requests=20, time_window=60) + + # 现在所有请求都会使用这个更保守的默认策略 + + +# ============================================================ +# 示例4: 禁用/启用限流器 +# ============================================================ +def demo_toggle_limiter(): + """ + 临时禁用或启用限流器 + """ + import adata + + # 禁用限流器(例如在内部网络或测试环境) + adata.rate_limit_disable() + + # 执行大量请求...(不会触发限流等待) + + # 重新启用限流器 + adata.rate_limit_enable() + + +# ============================================================ +# 示例5: 查看限流统计信息 +# ============================================================ +def demo_stats(): + """ + 查看限流统计信息 + """ + import adata + + # 查看所有域名的统计 + all_stats = adata.rate_limit_stats() + print("所有域名统计:", all_stats) + + # 查看特定域名的统计 + domain_stats = adata.rate_limit_stats("eastmoney.com") + print("东方财富域名统计:", domain_stats) + + +# ============================================================ +# 示例6: 重置限流器状态 +# ============================================================ +def demo_reset(): + """ + 重置限流器状态 + """ + import adata + + # 重置特定域名的限流状态 + adata.rate_limit_reset("eastmoney.com") + + # 重置所有域名的限流状态 + adata.rate_limit_reset() + + +# ============================================================ +# 示例7: 实际测试限流效果 +# ============================================================ +def demo_test_rate_limit(): + """ + 测试限流效果 + 预期:请求40次应该至少需要等待1分钟(因为每30次需要等待60秒窗口) + """ + import time + import adata + + # 重置限流器 + adata.rate_limit_reset() + + # 设置限流:30次/分钟 + adata.rate_limit_default(max_requests=30, time_window=60) + + print("开始测试限流效果...") + print("配置: 30次/分钟") + print("=" * 60) + + # 测试20次请求(在限制内,应该很快) + print("\n测试1: 20次请求(在30次限制内)") + start = time.time() + for i in range(20): + try: + df = adata.stock.market.get_market( + stock_code='000001', + k_type=1, + start_date='2026-03-17' + ) + if (i + 1) % 5 == 0: + print(f" 已完成 {i + 1}/20 次请求") + except Exception as e: + print(f" 第 {i + 1} 次请求异常: {e}") + + elapsed_20 = time.time() - start + print(f" 20次请求完成,用时: {elapsed_20:.2f} 秒") + + # 测试40次请求(超过限制,应该触发等待) + print("\n测试2: 40次请求(超过30次限制)") + start = time.time() + for i in range(40): + try: + df = adata.stock.market.get_market( + stock_code='000001', + k_type=1, + start_date='2026-03-17' + ) + if (i + 1) % 10 == 0: + elapsed_so_far = time.time() - start + print(f" 已完成 {i + 1}/40 次请求,累计用时: {elapsed_so_far:.2f} 秒") + except Exception as e: + print(f" 第 {i + 1} 次请求异常: {e}") + + elapsed_40 = time.time() - start + print(f" 40次请求完成,总用时: {elapsed_40:.2f} 秒") + + print("\n" + "=" * 60) + print("测试结论:") + print(f" - 20次请求用时 {elapsed_20:.2f} 秒(预期 < 30秒)") + print(f" - 40次请求用时 {elapsed_40:.2f} 秒(预期 > 30秒,因为触发了限流等待)") + print("=" * 60) + + +# ============================================================ +# 示例8: 推荐配置(生产环境) +# ============================================================ +def demo_production_config(): + """ + 生产环境推荐配置 + """ + import adata + + # 重置限流器 + adata.rate_limit_reset() + + # 设置保守的默认策略 + adata.rate_limit_default(max_requests=30, time_window=60) + + # 为常用数据源设置特定策略 + # 东方财富 - 较严格 + adata.rate_limit("eastmoney.com", max_requests=30, time_window=60) + adata.rate_limit("push2his.eastmoney.com", max_requests=30, time_window=60) + adata.rate_limit("push2.eastmoney.com", max_requests=30, time_window=60) + + # 新浪财经 - 较宽松 + adata.rate_limit("hq.sinajs.cn", max_requests=60, time_window=60) + + # 腾讯财经 - 中等 + adata.rate_limit("qt.gtimg.cn", max_requests=40, time_window=60) + + # 百度股市通 - 中等 + adata.rate_limit("finance.pae.baidu.com", max_requests=40, time_window=60) + + print("生产环境限流配置已应用") + print("统计信息:", adata.rate_limit_stats()) + + +if __name__ == '__main__': + print(__doc__) + + # 运行示例7来测试限流效果 + print("\n" + "=" * 60) + print("运行限流效果测试") + print("=" * 60) + + # 注意:这个测试会实际发送HTTP请求,耗时较长 + # 如需运行,请取消下面的注释 + # demo_test_rate_limit() + + # 运行生产环境配置示例 + demo_production_config() diff --git a/tests/other/rate_limiter_simple_test.py b/tests/other/rate_limiter_simple_test.py new file mode 100644 index 0000000..0f22aeb --- /dev/null +++ b/tests/other/rate_limiter_simple_test.py @@ -0,0 +1,215 @@ +# -*- coding: utf-8 -*- +""" +Simple test for rate limiter - standalone version +""" +import sys +import time +import threading +from collections import deque +from dataclasses import dataclass +from typing import Dict, Optional +from urllib.parse import urlparse + +# Copy the rate limiter code here for standalone testing + +@dataclass +class RateLimitConfig: + max_requests: int = 30 + time_window: int = 60 + wait_message: bool = True + + +class DomainRateLimiter: + def __init__(self): + self._domain_windows: Dict[str, deque] = {} + self._domain_configs: Dict[str, RateLimitConfig] = {} + self._default_config = RateLimitConfig() + self._lock = threading.RLock() + + def set_domain_config(self, domain: str, max_requests: Optional[int] = None, + time_window: Optional[int] = None) -> None: + with self._lock: + if domain not in self._domain_configs: + self._domain_configs[domain] = RateLimitConfig() + if max_requests is not None: + self._domain_configs[domain].max_requests = max_requests + if time_window is not None: + self._domain_configs[domain].time_window = time_window + + def get_domain_config(self, domain: str) -> RateLimitConfig: + with self._lock: + return self._domain_configs.get(domain, self._default_config) + + def _extract_domain(self, url: str) -> str: + try: + parsed = urlparse(url) + domain = parsed.netloc.lower() + if ':' in domain: + domain = domain.split(':')[0] + return domain + except Exception: + return "unknown" + + def acquire(self, url: str) -> float: + domain = self._extract_domain(url) + config = self.get_domain_config(domain) + + with self._lock: + if domain not in self._domain_windows: + self._domain_windows[domain] = deque() + + window = self._domain_windows[domain] + now = time.time() + + cutoff_time = now - config.time_window + while window and window[0] < cutoff_time: + window.popleft() + + if len(window) >= config.max_requests: + oldest_request = window[0] + wait_time = (oldest_request + config.time_window) - now + + if wait_time > 0: + self._lock.release() + try: + time.sleep(wait_time) + finally: + self._lock.acquire() + + now = time.time() + cutoff_time = now - config.time_window + while window and window[0] < cutoff_time: + window.popleft() + + window.append(now) + return 0.0 + + def get_stats(self, domain: str) -> dict: + with self._lock: + if domain not in self._domain_windows: + return {"domain": domain, "request_count": 0} + window = self._domain_windows[domain] + config = self.get_domain_config(domain) + now = time.time() + valid_count = sum(1 for t in window if t > now - config.time_window) + return { + "domain": domain, + "request_count": valid_count, + "max_requests": config.max_requests, + "time_window": config.time_window, + } + + +print("=" * 60) +print("Testing Rate Limiter Basic Functionality") +print("=" * 60) + +# Test 1: Basic rate limiting +print("\nTest 1: Basic rate limiting (2 requests per 2 seconds)") +limiter = DomainRateLimiter() +limiter.set_domain_config('test.com', max_requests=2, time_window=2) + +# First 2 requests should pass immediately +start = time.time() +limiter.acquire('http://test.com/api') +limiter.acquire('http://test.com/api') +elapsed1 = time.time() - start +print(f" First 2 requests took: {elapsed1:.2f}s (expected < 0.5s)") + +# 3rd request should wait +start = time.time() +limiter.acquire('http://test.com/api') +elapsed2 = time.time() - start +print(f" 3rd request took: {elapsed2:.2f}s (expected > 1s)") + +if elapsed1 < 0.5 and elapsed2 > 1.0: + print(" PASSED!") +else: + print(" FAILED!") + +# Test 2: Different domains are independent +print("\nTest 2: Different domains are independent") +limiter2 = DomainRateLimiter() +limiter2.set_domain_config('a.com', max_requests=1, time_window=10) +limiter2.set_domain_config('b.com', max_requests=10, time_window=10) + +start = time.time() +limiter2.acquire('http://a.com/api') +elapsed_a1 = time.time() - start + +start = time.time() +limiter2.acquire('http://b.com/api') +elapsed_b1 = time.time() - start + +print(f" a.com 1st request: {elapsed_a1:.2f}s") +print(f" b.com 1st request: {elapsed_b1:.2f}s") + +if elapsed_a1 < 0.1 and elapsed_b1 < 0.1: + print(" PASSED!") +else: + print(" FAILED!") + +# Test 3: Default config +print("\nTest 3: Default config") +limiter3 = DomainRateLimiter() +config = limiter3.get_domain_config('unknown.com') +print(f" Default max_requests: {config.max_requests} (expected 30)") +print(f" Default time_window: {config.time_window} (expected 60)") + +if config.max_requests == 30 and config.time_window == 60: + print(" PASSED!") +else: + print(" FAILED!") + +# Test 4: Stats +print("\nTest 4: Stats tracking") +limiter4 = DomainRateLimiter() +limiter4.set_domain_config('stats.com', max_requests=5, time_window=60) + +for _ in range(3): + limiter4.acquire('http://stats.com/api') + +stats = limiter4.get_stats('stats.com') +print(f" Request count: {stats['request_count']} (expected 3)") +print(f" Max requests: {stats['max_requests']} (expected 5)") + +if stats['request_count'] == 3 and stats['max_requests'] == 5: + print(" PASSED!") +else: + print(" FAILED!") + +# Test 5: Simulate the 30 requests/minute scenario +print("\nTest 5: Simulate 30 requests/minute limit") +limiter5 = DomainRateLimiter() +limiter5.set_domain_config('api.com', max_requests=30, time_window=60) + +# Send 30 requests quickly +start = time.time() +for i in range(30): + limiter5.acquire('http://api.com/data') +elapsed_30 = time.time() - start +print(f" 30 requests took: {elapsed_30:.2f}s (should be quick)") + +# 31st request should trigger wait (but we'll wait less for testing) +# Actually, let's use a smaller window for faster testing +limiter5.set_domain_config('api2.com', max_requests=5, time_window=3) + +start = time.time() +for i in range(5): + limiter5.acquire('http://api2.com/data') +elapsed_5 = time.time() - start +print(f" First 5 requests to api2.com took: {elapsed_5:.2f}s") + +start = time.time() +limiter5.acquire('http://api2.com/data') # 6th request +elapsed_6th = time.time() - start +print(f" 6th request to api2.com took: {elapsed_6th:.2f}s (should wait ~3s)") + +if elapsed_5 < 1.0 and elapsed_6th > 2.0: + print(" PASSED!") +else: + print(" FAILED!") + +print("\n" + "=" * 60) +print("All basic tests completed!") +print("=" * 60) diff --git a/tests/other/rate_limiter_test.py b/tests/other/rate_limiter_test.py new file mode 100644 index 0000000..775524a --- /dev/null +++ b/tests/other/rate_limiter_test.py @@ -0,0 +1,377 @@ +# -*- coding: utf-8 -*- +""" +@desc: 限流器测试用例 +@author: 1nchaos +@time: 2026/3/18 +@log: 测试基于域名的请求限流功能 +""" + +import time +import unittest + +import adata +from adata.common.utils.rate_limiter import get_rate_limiter, DomainRateLimiter + + +class TestRateLimiter(unittest.TestCase): + """限流器单元测试""" + + def setUp(self): + """每个测试前重置限流器""" + adata.rate_limit_reset() + + def test_domain_rate_limiter_basic(self): + """测试限流器基本功能""" + limiter = DomainRateLimiter() + + # 设置测试域名限制:2次/3秒 + limiter.set_domain_config("test.com", max_requests=2, time_window=3) + + # 前2次请求应该立即通过 + start = time.time() + limiter.acquire("http://test.com/api") + limiter.acquire("http://test.com/api") + elapsed = time.time() - start + + # 前2次应该几乎不等待 + self.assertLess(elapsed, 0.5, "前2次请求应该立即通过") + + def test_domain_rate_limiter_wait(self): + """测试限流器等待功能""" + limiter = DomainRateLimiter() + + # 设置测试域名限制:2次/2秒 + limiter.set_domain_config("test.com", max_requests=2, time_window=2) + + # 先发送2次请求 + limiter.acquire("http://test.com/api") + limiter.acquire("http://test.com/api") + + # 第3次请求应该等待 + start = time.time() + limiter.acquire("http://test.com/api") + elapsed = time.time() - start + + # 应该等待至少1秒(因为窗口是2秒,但前面2次请求几乎同时发送) + self.assertGreaterEqual(elapsed, 1.0, "第3次请求应该等待至少1秒") + + def test_different_domains_independent(self): + """测试不同域名限流独立""" + limiter = DomainRateLimiter() + + # 设置不同域名的限制 + limiter.set_domain_config("a.com", max_requests=1, time_window=10) + limiter.set_domain_config("b.com", max_requests=10, time_window=10) + + # a.com 第1次 + start = time.time() + limiter.acquire("http://a.com/api") + elapsed_a1 = time.time() - start + + # b.com 第1次(应该不等待) + start = time.time() + limiter.acquire("http://b.com/api") + elapsed_b1 = time.time() - start + + # a.com 第2次(应该等待) + start = time.time() + limiter.acquire("http://a.com/api") + elapsed_a2 = time.time() - start + + self.assertLess(elapsed_a1, 0.1) + self.assertLess(elapsed_b1, 0.1) + self.assertGreaterEqual(elapsed_a2, 9.0, "a.com第2次请求应该等待") + + def test_default_config(self): + """测试默认配置""" + limiter = DomainRateLimiter() + + # 未配置的域名应该使用默认配置 + config = limiter.get_domain_config("unknown.com") + self.assertEqual(config.max_requests, 30) + self.assertEqual(config.time_window, 60) + + def test_stats(self): + """测试统计功能""" + limiter = DomainRateLimiter() + limiter.set_domain_config("stats.com", max_requests=5, time_window=60) + + # 发送3次请求 + for _ in range(3): + limiter.acquire("http://stats.com/api") + + stats = limiter.get_stats("stats.com") + self.assertEqual(stats["request_count"], 3) + self.assertEqual(stats["max_requests"], 5) + + +class TestRateLimiterIntegration(unittest.TestCase): + """限流器集成测试(使用实际adata接口)""" + + def setUp(self): + """每个测试前重置限流器""" + adata.rate_limit_reset() + + def test_20_requests_within_limit(self): + """ + 测试20次请求(在30次/分钟限制内,应该不等待或极少等待) + + 预期:20次请求应该在较短时间内完成(< 10秒) + """ + print("\n=== 测试20次请求(在限制内)===") + + # 设置限流:30次/分钟 + adata.rate_limit_default(max_requests=30, time_window=60) + + start_time = time.time() + success_count = 0 + + for i in range(20): + try: + # 使用东方财富接口查询股票行情 + df = adata.stock.market.get_market( + stock_code='000001', + k_type=1, + start_date='2026-03-17' + ) + if not df.empty: + success_count += 1 + if (i + 1) % 5 == 0: + print(f" 已完成 {i + 1}/20 次请求") + except Exception as e: + print(f" 第 {i + 1} 次请求异常: {e}") + + elapsed = time.time() - start_time + print(f" 20次请求完成,用时: {elapsed:.2f} 秒") + print(f" 成功次数: {success_count}/20") + + # 20次请求应该在较短时间内完成(网络延迟 + 处理时间) + # 正常情况下应该 < 30秒 + self.assertLess(elapsed, 60, "20次请求应该在一分钟内完成") + self.assertGreaterEqual(success_count, 15, "至少15次请求应该成功") + + def test_40_requests_exceeds_limit(self): + """ + 测试40次请求(超过30次/分钟限制,应该触发限流等待) + + 预期:40次请求应该至少需要等待1分钟(因为每30次需要等待60秒窗口) + """ + print("\n=== 测试40次请求(超过限制)===") + + # 设置限流:30次/分钟 + adata.rate_limit_default(max_requests=30, time_window=60) + + start_time = time.time() + success_count = 0 + + for i in range(40): + try: + df = adata.stock.market.get_market( + stock_code='000001', + k_type=1, + start_date='2026-03-17' + ) + if not df.empty: + success_count += 1 + if (i + 1) % 10 == 0: + elapsed_so_far = time.time() - start_time + print(f" 已完成 {i + 1}/40 次请求,用时: {elapsed_so_far:.2f} 秒") + except Exception as e: + print(f" 第 {i + 1} 次请求异常: {e}") + + elapsed = time.time() - start_time + print(f" 40次请求完成,总用时: {elapsed:.2f} 秒") + print(f" 成功次数: {success_count}/40") + + # 40次请求,每30次需要等待60秒,所以至少应该等待约30-60秒 + # 加上网络和处理时间,总时间应该 > 30秒 + self.assertGreater(elapsed, 30, "40次请求应该触发限流等待,总时间应该超过30秒") + self.assertGreaterEqual(success_count, 30, "至少30次请求应该成功") + + def test_domain_specific_limit(self): + """ + 测试特定域名限流配置 + + 为eastmoney.com设置更严格的限制 + """ + print("\n=== 测试特定域名限流配置 ===") + + # 设置默认限制较宽松 + adata.rate_limit_default(max_requests=100, time_window=60) + + # 为eastmoney.com设置更严格的限制:5次/10秒 + adata.rate_limit("eastmoney.com", max_requests=5, time_window=10) + + start_time = time.time() + + # 发送8次请求 + for i in range(8): + try: + df = adata.stock.market.get_market( + stock_code='000001', + k_type=1, + start_date='2026-03-17' + ) + elapsed_so_far = time.time() - start_time + print(f" 第 {i + 1} 次请求完成,累计用时: {elapsed_so_far:.2f} 秒") + except Exception as e: + print(f" 第 {i + 1} 次请求异常: {e}") + + elapsed = time.time() - start_time + print(f" 8次请求完成,总用时: {elapsed:.2f} 秒") + + # 8次请求,每5次需要等待10秒,所以至少应该等待约5秒 + self.assertGreater(elapsed, 5, "超过5次请求应该触发限流等待") + + def test_rate_limit_disable_enable(self): + """测试禁用和启用限流""" + print("\n=== 测试禁用/启用限流 ===") + + # 先禁用限流 + adata.rate_limit_disable() + + # 设置很严格的限制(如果启用会触发) + adata.rate_limit_default(max_requests=1, time_window=60) + + start_time = time.time() + + # 发送3次请求(如果限流启用,第2次就会等待) + for i in range(3): + try: + df = adata.stock.market.get_market( + stock_code='000001', + k_type=1, + start_date='2026-03-17' + ) + print(f" 第 {i + 1} 次请求完成") + except Exception as e: + print(f" 第 {i + 1} 次请求异常: {e}") + + elapsed_disabled = time.time() - start_time + print(f" 禁用限流时3次请求用时: {elapsed_disabled:.2f} 秒") + + # 重新启用限流 + adata.rate_limit_enable() + adata.rate_limit_reset() # 重置状态 + + start_time = time.time() + + # 再次发送3次请求(这次应该触发限流) + for i in range(3): + try: + df = adata.stock.market.get_market( + stock_code='000001', + k_type=1, + start_date='2026-03-17' + ) + print(f" 第 {i + 1} 次请求完成") + except Exception as e: + print(f" 第 {i + 1} 次请求异常: {e}") + + elapsed_enabled = time.time() - start_time + print(f" 启用限流时3次请求用时: {elapsed_enabled:.2f} 秒") + + # 禁用限流时应该更快 + self.assertLess(elapsed_disabled, elapsed_enabled, + "禁用限流时请求应该更快完成") + + +class TestRateLimiterAPI(unittest.TestCase): + """测试限流API接口""" + + def setUp(self): + """每个测试前重置""" + adata.rate_limit_reset() + + def test_rate_limit_api(self): + """测试rate_limit API""" + # 应该能正常设置而不抛出异常 + adata.rate_limit("test.com", max_requests=10, time_window=30) + + # 获取统计信息 + stats = adata.rate_limit_stats("test.com") + self.assertEqual(stats["max_requests"], 10) + self.assertEqual(stats["time_window"], 30) + + def test_rate_limit_default_api(self): + """测试rate_limit_default API""" + adata.rate_limit_default(max_requests=50, time_window=120) + + stats = adata.rate_limit_stats() + self.assertEqual(stats["default_config"]["max_requests"], 50) + self.assertEqual(stats["default_config"]["time_window"], 120) + + def test_rate_limit_stats_api(self): + """测试rate_limit_stats API""" + # 获取所有域名统计 + stats = adata.rate_limit_stats() + self.assertIn("domains", stats) + self.assertIn("default_config", stats) + + def test_rate_limit_reset_api(self): + """测试rate_limit_reset API""" + # 发送一些请求 + limiter = get_rate_limiter() + limiter.acquire("http://test.com/api") + limiter.acquire("http://test.com/api") + + # 重置前应该有记录 + stats_before = limiter.get_stats("test.com") + self.assertEqual(stats_before["request_count"], 2) + + # 重置 + adata.rate_limit_reset("test.com") + + # 重置后应该清零 + stats_after = limiter.get_stats("test.com") + self.assertEqual(stats_after["request_count"], 0) + + +def run_quick_tests(): + """运行快速测试(不依赖网络)""" + print("\n" + "=" * 60) + print("运行限流器单元测试(快速,不依赖网络)") + print("=" * 60) + + suite = unittest.TestLoader().loadTestsFromTestCase(TestRateLimiter) + runner = unittest.TextTestRunner(verbosity=2) + runner.run(suite) + + +def run_integration_tests(): + """运行集成测试(依赖网络,耗时较长)""" + print("\n" + "=" * 60) + print("运行限流器集成测试(依赖网络,耗时较长)") + print("=" * 60) + + suite = unittest.TestLoader().loadTestsFromTestCase(TestRateLimiterIntegration) + runner = unittest.TextTestRunner(verbosity=2) + runner.run(suite) + + +def run_api_tests(): + """运行API测试""" + print("\n" + "=" * 60) + print("运行限流器API测试") + print("=" * 60) + + suite = unittest.TestLoader().loadTestsFromTestCase(TestRateLimiterAPI) + runner = unittest.TextTestRunner(verbosity=2) + runner.run(suite) + + +if __name__ == '__main__': + # 运行所有测试 + run_quick_tests() + run_api_tests() + + # 集成测试可选(因为耗时较长) + print("\n" + "=" * 60) + print("是否运行集成测试?(y/n): ", end="") + try: + response = input().strip().lower() + if response == 'y': + run_integration_tests() + except EOFError: + # 非交互式环境,跳过集成测试 + print("非交互式环境,跳过集成测试")