2025-07-11 13:52:19 +00:00

197 lines
5.5 KiB
Python

from abc import ABC, abstractmethod
from datetime import date, datetime, time, timedelta
from typing import Any, ClassVar, Dict, List, Optional, Set
from neo4j.time import Date as Neo4jDate
from neo4j.time import DateTime as Neo4jDateTime
from neo4j.time import Time as Neo4jTime
from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
field_validator,
model_validator,
)
class CommonModel(BaseModel, ABC):
model_config = ConfigDict(
validate_assignment=True,
extra="forbid",
arbitrary_types_allowed=True,
)
created: datetime = Field(
default_factory=datetime.now, json_schema_extra={"set_on_create": True}
)
merged: Optional[datetime] = Field(default=None, validate_default=True)
_set_on_match: List[str] = PrivateAttr()
_set_on_create: List[str] = PrivateAttr()
_always_set: List[str] = PrivateAttr()
_neo4j_supported_types: ClassVar[Any] = (
list,
bool,
int,
bytearray,
float,
str,
bytes,
date,
time,
datetime,
timedelta,
)
def __init__(self, **data: dict):
super().__init__(**data)
self._set_on_match = self._get_prop_usage("set_on_match")
self._set_on_create = self._get_prop_usage("set_on_create")
self._always_set = [
x
for x in self.model_dump().keys()
if x not in self._set_on_match + self._set_on_create + ["source", "target"]
]
@classmethod
def _get_prop_usage(cls, usage_type: str) -> List[str]:
all_props = cls.model_json_schema()["properties"]
selected_props = []
for prop, entry in all_props.items():
if entry.get(usage_type) is True:
selected_props.append(prop)
return selected_props
def _get_prop_values(
self, props: List[str], exclude: Set[str] = set()
) -> Dict[str, Any]:
"""
Returns:
Dict[str, Any]: a dictionary of key/value pairs.
"""
prop_values = {
k: v for k, v in self.neo4j_dict(exclude=exclude).items() if k in props
}
return prop_values
@abstractmethod
def _get_merge_parameters(self) -> Dict[str, Any]:
raise NotImplementedError
@classmethod
def export_type_converter(cls, value: Any) -> Any:
if isinstance(value, dict):
raise TypeError("Neo4j doesn't support dict types for properties.")
elif isinstance(value, (tuple, set)):
new_value = list(value)
return cls.export_type_converter(new_value)
elif isinstance(value, list):
# items in a list must all be the same type
item_type = type(value[0])
for item in value:
if isinstance(item, item_type) is False:
raise TypeError(
"For neo4j, all items in a list must be of the same type."
)
return [cls.export_type_converter(x) for x in value]
elif isinstance(value, cls._neo4j_supported_types) is False:
return str(value)
else:
return value
@classmethod
def _export_dict_converter(cls, original_dict: Dict[str, Any]) -> Dict[str, Any]:
"""_summary_
Args:
export_dict (Dict[str, Any]): _description_
Returns:
Dict[str, Any]: _description_
"""
export_dict = original_dict.copy()
for k, v in export_dict.items():
export_dict[k] = cls.export_type_converter(v)
return export_dict
def neo4j_dict(self, **kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Return a dict made up of only types compatible with neo4j
Returns:
dict: a dictionary export of this model instance
"""
export_dict = self.model_dump(exclude_none=True, **kwargs)
export_dict = self._export_dict_converter(export_dict)
return export_dict
#
# validators
#
@field_validator("merged")
def set_merged_to_created(
cls, value: Optional[datetime], values: Dict[str, Any]
) -> datetime:
"""By default, set the 'merged' time equal to the 'created' time.
If the 'merged' value has been explicitly set, this is preserved.
Args:
value (Optional[datetime]): the value of the field.
values (Dict[str, Any]): a dictionary of field/value pairs set so far.
Returns:
datetime: The merged datetime value.
"""
if value is None:
return values.data["created"]
else:
return value
@model_validator(mode="before")
@classmethod
def neo4j_datetime_to_native(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Datetimes come back from Neo4j as a non standard DateTime type.
We check for any values where that is the case and convert them to
native Python datetimes.
See https://neo4j.com/docs/api/python-driver/4.4/temporal_types.html for further info.
Args:
values (Dict[str, Any]): Dictionary of field/value pairs from pydantic.
Returns:
Dict[str, Any]: Returns the dictionary, with any Neo4jDateTimes updated.
"""
if not isinstance(values, dict):
raise ValueError
for key in values:
if isinstance(values[key], (Neo4jDateTime, Neo4jDate, Neo4jTime)):
values[key] = values[key].to_native()
return values