3030from sqlalchemy_utils import UUIDType
3131
3232from airflow ._shared .secrets_masker import mask_secret
33- from airflow .configuration import ensure_secrets_loaded
33+ from airflow .configuration import conf , ensure_secrets_loaded
3434from airflow .models .base import ID_LEN , Base
3535from airflow .models .crypto import get_fernet
3636from airflow .models .team import Team
@@ -149,7 +149,7 @@ def get(
149149 # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big
150150 # back-compat layer
151151
152- # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
152+ # If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
153153 # and should use the Task SDK API server path
154154 if hasattr (sys .modules .get ("airflow.sdk.execution_time.task_runner" ), "SUPERVISOR_COMMS" ):
155155 warnings .warn (
@@ -185,6 +185,7 @@ def set(
185185 value : Any ,
186186 description : str | None = None ,
187187 serialize_json : bool = False ,
188+ team_id : str | None = None ,
188189 session : Session | None = None ,
189190 ) -> None :
190191 """
@@ -196,13 +197,14 @@ def set(
196197 :param value: Value to set for the Variable
197198 :param description: Description of the Variable
198199 :param serialize_json: Serialize the value to a JSON string
200+ :param team_id: ID of the team associated to the variable (if any)
199201 :param session: optional session, use if provided or create a new one
200202 """
201203 # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
202204 # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big
203205 # back-compat layer
204206
205- # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
207+ # If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
206208 # and should use the Task SDK API server path
207209 if hasattr (sys .modules .get ("airflow.sdk.execution_time.task_runner" ), "SUPERVISOR_COMMS" ):
208210 warnings .warn (
@@ -221,6 +223,11 @@ def set(
221223 )
222224 return
223225
226+ if team_id and not conf .getboolean ("core" , "multi_team" ):
227+ raise ValueError (
228+ "Multi-team mode is not configured in the Airflow environment. To assign a team to a variable, multi-mode must be enabled."
229+ )
230+
224231 # check if the secret exists in the custom secrets' backend.
225232 Variable .check_for_write_conflict (key = key )
226233 if serialize_json :
@@ -235,7 +242,7 @@ def set(
235242 ctx = create_session ()
236243
237244 with ctx as session :
238- new_variable = Variable (key = key , val = stored_value , description = description )
245+ new_variable = Variable (key = key , val = stored_value , description = description , team_id = team_id )
239246
240247 val = new_variable ._val
241248 is_encrypted = new_variable .is_encrypted
@@ -252,13 +259,15 @@ def set(
252259 val = val ,
253260 description = description ,
254261 is_encrypted = is_encrypted ,
262+ team_id = team_id ,
255263 )
256264 stmt = pg_stmt .on_conflict_do_update (
257265 index_elements = ["key" ],
258266 set_ = dict (
259267 val = val ,
260268 description = description ,
261269 is_encrypted = is_encrypted ,
270+ team_id = team_id ,
262271 ),
263272 )
264273 elif dialect_name == "mysql" :
@@ -269,11 +278,13 @@ def set(
269278 val = val ,
270279 description = description ,
271280 is_encrypted = is_encrypted ,
281+ team_id = team_id ,
272282 )
273283 stmt = mysql_stmt .on_duplicate_key_update (
274284 val = val ,
275285 description = description ,
276286 is_encrypted = is_encrypted ,
287+ team_id = team_id ,
277288 )
278289 else :
279290 from sqlalchemy .dialects .sqlite import insert as sqlite_insert
@@ -283,13 +294,15 @@ def set(
283294 val = val ,
284295 description = description ,
285296 is_encrypted = is_encrypted ,
297+ team_id = team_id ,
286298 )
287299 stmt = sqlite_stmt .on_conflict_do_update (
288300 index_elements = ["key" ],
289301 set_ = dict (
290302 val = val ,
291303 description = description ,
292304 is_encrypted = is_encrypted ,
305+ team_id = team_id ,
293306 ),
294307 )
295308
0 commit comments