Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/diffcalc_API/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from . import config, database, server
from . import config, database, openapi, server
from ._version_git import __version__

# __all__ defines the public API for the package.
# Each module also defines its own __all__.

__all__ = ["__version__", "server", "config", "database"]
__all__ = ["__version__", "server", "config", "database", "openapi"]
7 changes: 7 additions & 0 deletions src/diffcalc_API/errors/hkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
class ErrorCodes(ErrorCodesBase):
INVALID_MILLER_INDICES = 400
INVALID_SCAN_BOUNDS = 400
INVALID_SOLUTION_BOUNDS = 400


responses = {code: ALL_RESPONSES[code] for code in np.unique(ErrorCodes.all_codes())}
Expand All @@ -32,3 +33,9 @@ def __init__(self, start: float, stop: float, inc: float) -> None:
f" to stop: {stop} in increments of: {inc}"
)
self.status_code = ErrorCodes.INVALID_SCAN_BOUNDS


class InvalidSolutionBoundsError(DiffcalcAPIException):
def __init__(self, detail: str) -> None:
self.detail = detail
self.status_code = ErrorCodes.INVALID_SOLUTION_BOUNDS
4 changes: 2 additions & 2 deletions src/diffcalc_API/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from diffcalc_API.models import ub
from diffcalc_API.models import hkl, ub

__all__ = ["ub"]
__all__ = ["ub", "hkl"]
46 changes: 46 additions & 0 deletions src/diffcalc_API/models/hkl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from dataclasses import dataclass
from typing import Iterator, List, Optional, Union

from diffcalc.hkl.geometry import Position


@dataclass
class SolutionConstraints:
axes: Optional[List[str]] = None
low_bound: Optional[List[float]] = None
high_bound: Optional[List[float]] = None
valid: bool = True
msg: str = ""

def __post_init__(self):
self.invalid_bounds()

def invalid_bounds(self) -> None:
axes, low_bound, high_bound = self.axes, self.low_bound, self.high_bound
msg = self.msg

if axes and low_bound and high_bound:
iterator: Iterator[Union[List[str], List[float]]] = iter(
[axes, low_bound, high_bound]
)
length = len(next(iterator))
same_length = all(len(each_list) == length for each_list in iterator)

if not same_length:
msg = "queries axes, low_bound and high_bound are not the same length."

if not all(angle in Position.fields for angle in axes):
msg = (
"query {axes} contains an angle which is not a subset of "
+ f"{Position.fields}"
)

elif axes or low_bound or high_bound:
msg = (
"If bounds are provided, a list of axes, low bounds and high bounds "
+ "must be provided as query parameters."
)

self.msg = msg
if self.msg:
self.valid = False
59 changes: 55 additions & 4 deletions src/diffcalc_API/routes/hkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from fastapi import APIRouter, Depends, Query

from diffcalc_API.errors.hkl import InvalidSolutionBoundsError
from diffcalc_API.models.hkl import SolutionConstraints
from diffcalc_API.models.ub import HklModel, PositionModel
from diffcalc_API.services import hkl as service
from diffcalc_API.stores.protocol import HklCalcStore, get_store
Expand Down Expand Up @@ -30,11 +32,23 @@ async def lab_position_from_miller_indices(
name: str,
miller_indices: HklModel = Depends(),
wavelength: float = Query(..., example=1.0),
axes: Optional[List[str]] = Query(default=None, example=["mu", "nu", "phi"]),
low_bound: Optional[List[float]] = Query(default=None, example=[0.0, 0.0, -90.0]),
high_bound: Optional[List[float]] = Query(default=None, example=[90.0, 90.0, 90.0]),
store: HklCalcStore = Depends(get_store),
collection: Optional[str] = Query(default=None, example="B07"),
):
solution_constraints = SolutionConstraints(axes, low_bound, high_bound)
if not solution_constraints.valid:
raise InvalidSolutionBoundsError(solution_constraints.msg)

positions = await service.lab_position_from_miller_indices(
name, miller_indices, wavelength, store, collection
name,
miller_indices,
wavelength,
solution_constraints,
store,
collection,
)

return {"payload": positions}
Expand All @@ -61,11 +75,25 @@ async def scan_hkl(
stop: List[float] = Query(..., example=[2, 0, 2]),
inc: List[float] = Query(..., example=[0.1, 0, 0.1]),
wavelength: float = Query(..., example=1),
axes: Optional[List[str]] = Query(default=None, example=["mu", "nu", "phi"]),
low_bound: Optional[List[float]] = Query(default=None, example=[0.0, 0.0, -90.0]),
high_bound: Optional[List[float]] = Query(default=None, example=[90.0, 90.0, 90.0]),
store: HklCalcStore = Depends(get_store),
collection: Optional[str] = Query(default=None, example="B07"),
):
solution_constraints = SolutionConstraints(axes, low_bound, high_bound)
if not solution_constraints.valid:
raise InvalidSolutionBoundsError(solution_constraints.msg)

scan_results = await service.scan_hkl(
name, start, stop, inc, wavelength, store, collection
name,
start,
stop,
inc,
wavelength,
solution_constraints,
store,
collection,
)
return {"payload": scan_results}

Expand All @@ -77,11 +105,18 @@ async def scan_wavelength(
stop: float = Query(..., example=2.0),
inc: float = Query(..., example=0.2),
hkl: HklModel = Depends(),
axes: Optional[List[str]] = Query(default=None, example=["mu", "nu", "phi"]),
low_bound: Optional[List[float]] = Query(default=None, example=[0.0, 0.0, -90.0]),
high_bound: Optional[List[float]] = Query(default=None, example=[90.0, 90.0, 90.0]),
store: HklCalcStore = Depends(get_store),
collection: Optional[str] = Query(default=None, example="B07"),
):
solution_constraints = SolutionConstraints(axes, low_bound, high_bound)
if not solution_constraints.valid:
raise InvalidSolutionBoundsError(solution_constraints.msg)

scan_results = await service.scan_wavelength(
name, start, stop, inc, hkl, store, collection
name, start, stop, inc, hkl, solution_constraints, store, collection
)
return {"payload": scan_results}

Expand All @@ -95,11 +130,27 @@ async def scan_constraint(
inc: float = Query(..., example=1),
hkl: HklModel = Depends(),
wavelength: float = Query(..., example=1.0),
axes: Optional[List[str]] = Query(default=None, example=["mu", "nu", "phi"]),
low_bound: Optional[List[float]] = Query(default=None, example=[0.0, 0.0, -90.0]),
high_bound: Optional[List[float]] = Query(default=None, example=[90.0, 90.0, 90.0]),
store: HklCalcStore = Depends(get_store),
collection: Optional[str] = Query(default=None, example="B07"),
):
solution_constraints = SolutionConstraints(axes, low_bound, high_bound)
if not solution_constraints.valid:
raise InvalidSolutionBoundsError(solution_constraints.msg)

scan_results = await service.scan_constraint(
name, constraint, start, stop, inc, hkl, wavelength, store, collection
name,
constraint,
start,
stop,
inc,
hkl,
wavelength,
solution_constraints,
store,
collection,
)

return {"payload": scan_results}
2 changes: 1 addition & 1 deletion src/diffcalc_API/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
@app.exception_handler(DiffcalcException)
async def diffcalc_exception_handler(request: Request, exc: DiffcalcException):
tb = traceback.format_exc()
logger.warn(f"Diffcalc Exception caught by middleware: {tb}")
logger.warning(f"Diffcalc Exception caught by middleware: {tb}")

return responses.JSONResponse(
status_code=400,
Expand Down
45 changes: 37 additions & 8 deletions src/diffcalc_API/services/hkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from diffcalc.hkl.geometry import Position

from diffcalc_API.errors.hkl import InvalidMillerIndicesError, InvalidScanBoundsError
from diffcalc_API.models.hkl import SolutionConstraints
from diffcalc_API.models.ub import HklModel, PositionModel
from diffcalc_API.stores.protocol import HklCalcStore

Expand All @@ -13,6 +14,7 @@ async def lab_position_from_miller_indices(
name: str,
miller_indices: HklModel,
wavelength: float,
solution_constraints: SolutionConstraints,
store: HklCalcStore,
collection: Optional[str],
) -> List[Dict[str, float]]:
Expand All @@ -22,8 +24,9 @@ async def lab_position_from_miller_indices(
raise InvalidMillerIndicesError()

all_positions = hklcalc.get_position(*miller_indices.dict().values(), wavelength)
result = combine_lab_position_results(all_positions, solution_constraints)

return combine_lab_position_results(all_positions)
return result


async def miller_indices_from_lab_position(
Expand All @@ -44,14 +47,15 @@ async def scan_hkl(
stop: List[float],
inc: List[float],
wavelength: float,
solution_constraints: SolutionConstraints,
store: HklCalcStore,
collection: Optional[str],
) -> Dict[str, List[Dict[str, float]]]:
hklcalc = await store.load(name, collection)

if (len(start) != 3) or (len(stop) != 3) or (len(inc) != 3):
raise InvalidMillerIndicesError(
detail="start, stop and inc must have three floats for each miller index."
"start, stop and inc must have three floats for each miller index."
)

axes_values = [
Expand All @@ -63,10 +67,14 @@ async def scan_hkl(

for h, k, l in product(*axes_values):
if all([idx == 0 for idx in (h, k, l)]):
raise InvalidMillerIndicesError() # what if this goes through 0?
raise InvalidMillerIndicesError(
"choose a hkl range that does not cross through [0, 0, 0]"
) # is this good enough? do people need scans through 0,0,0?

all_positions = hklcalc.get_position(h, k, l, wavelength)
results[f"({h}, {k}, {l})"] = combine_lab_position_results(all_positions)
results[f"({h}, {k}, {l})"] = combine_lab_position_results(
all_positions, solution_constraints
)

return results

Expand All @@ -77,6 +85,7 @@ async def scan_wavelength(
stop: float,
inc: float,
hkl: HklModel,
solution_constraints: SolutionConstraints,
store: HklCalcStore,
collection: Optional[str],
) -> Dict[str, List[Dict[str, float]]]:
Expand All @@ -90,7 +99,9 @@ async def scan_wavelength(

for wavelength in wavelengths:
all_positions = hklcalc.get_position(*hkl.dict().values(), wavelength)
result[f"{wavelength}"] = combine_lab_position_results(all_positions)
result[f"{wavelength}"] = combine_lab_position_results(
all_positions, solution_constraints
)

return result

Expand All @@ -103,6 +114,7 @@ async def scan_constraint(
inc: float,
hkl: HklModel,
wavelength: float,
solution_constraints: SolutionConstraints,
store: HklCalcStore,
collection: Optional[str],
) -> Dict[str, List[Dict[str, float]]]:
Expand All @@ -115,7 +127,9 @@ async def scan_constraint(
for value in np.arange(start, stop + inc, inc):
setattr(hklcalc, constraint, value)
all_positions = hklcalc.get_position(*hkl.dict().values(), wavelength)
result[f"{value}"] = combine_lab_position_results(all_positions)
result[f"{value}"] = combine_lab_position_results(
all_positions, solution_constraints
)

return result

Expand All @@ -128,12 +142,27 @@ def generate_axis(start: float, stop: float, inc: float):


def combine_lab_position_results(
positions: List[Tuple[Position, Dict[str, float]]]
positions: List[Tuple[Position, Dict[str, float]]],
solution_constraints: SolutionConstraints,
) -> List[Dict[str, float]]:
axes = solution_constraints.axes
low_bound = solution_constraints.low_bound
high_bound = solution_constraints.high_bound

result = []

for position in positions:
result.append({**position[0].asdict, **position[1]})
physical_angles, virtual_angles = position
if axes and low_bound and high_bound:
if all(
[
low_bound[i] < getattr(physical_angles, angle) < high_bound[i]
for i, angle in enumerate(axes)
]
):
result.append({**physical_angles.asdict, **virtual_angles})
else:
result.append({**physical_angles.asdict, **virtual_angles})

return result

Expand Down
Loading