|
6 | 6 | @software: PyCharm |
7 | 7 | @time: 18-12-25 下午4:58 |
8 | 8 | """ |
| 9 | +import threading |
9 | 10 | from collections import MutableMapping, MutableSequence |
10 | 11 | from contextlib import contextmanager |
11 | 12 | from typing import Dict, List, NoReturn, Union |
@@ -93,6 +94,7 @@ def __init__(self, app=None, *, username: str = "root", passwd: str = None, host |
93 | 94 | self.dialect = dialect |
94 | 95 | self.msg_zh = None |
95 | 96 | self.scoped_sessions: Dict[str, scoped_session] = {} # 主要保存其他scope session |
| 97 | + self.registry = threading.local() # 当前线程注册bind key |
96 | 98 |
|
97 | 99 | # 这里要用重写的BaseQuery, 根据BaseQuery的规则,Model中的query_class也需要重新指定为子类model, |
98 | 100 | # 但是从Model的初始化看,如果Model的query_class为None的话还是会设置为和Query一致,符合要求 |
@@ -154,8 +156,8 @@ def init_app(self, app, username: str = None, passwd: str = None, host: str = No |
154 | 156 |
|
155 | 157 | @app.teardown_appcontext |
156 | 158 | def shutdown_other_session(response_or_exc): |
157 | | - for _, session_ in self.scoped_sessions.items(): |
158 | | - session_.remove() |
| 159 | + for bind_key in getattr(self.registry, "bind_keys", set()): |
| 160 | + self.scoped_sessions[bind_key].remove() |
159 | 161 | return response_or_exc |
160 | 162 |
|
161 | 163 | def get_engine(self, app=None, bind=None): |
@@ -231,6 +233,10 @@ def gen_session(self, bind_key: str, session_options: Dict = None) -> Session: |
231 | 233 | session = self.scoped_sessions[bind_key]() |
232 | 234 | session.bind_key = bind_key # 设置bind key |
233 | 235 | session = self.ping_session(session) # 校验重连,保证可用 |
| 236 | + # 加入当前线程bindkey,用于自动关闭处理 |
| 237 | + if hasattr(self.registry, "bind_keys") is False: |
| 238 | + self.registry.bind_keys = set() |
| 239 | + self.registry.bind_keys.add(bind_key) |
234 | 240 |
|
235 | 241 | return session |
236 | 242 |
|
|
0 commit comments