diff --git a/server/mergin/sync/interfaces.py b/server/mergin/sync/interfaces.py index bb2e9843..4f30c2bc 100644 --- a/server/mergin/sync/interfaces.py +++ b/server/mergin/sync/interfaces.py @@ -110,6 +110,12 @@ def get_by_name(self, name): """ pass + def get_by_names(self, names): + """ + Return list of workspaces whose names are in the given collection. + """ + pass + @abstractmethod def get_by_project(self, project): """ diff --git a/server/mergin/sync/models.py b/server/mergin/sync/models.py index 5f4aa967..3b32cda9 100644 --- a/server/mergin/sync/models.py +++ b/server/mergin/sync/models.py @@ -6,6 +6,7 @@ import json import logging import os +import re import threading import time import uuid @@ -18,7 +19,9 @@ from blinker import signal from flask_login import current_user from pygeodiff import GeoDiff +from functools import cached_property from sqlalchemy import text, null, desc, nullslast, tuple_ +from sqlalchemy.orm import contains_eager, joinedload, load_only from sqlalchemy.dialects.postgresql import ARRAY, BIGINT, UUID, JSONB, ENUM, insert from sqlalchemy.types import String from sqlalchemy.ext.hybrid import hybrid_property @@ -136,6 +139,12 @@ def workspace(self): project_workspace = current_app.ws_handler.get(self.workspace_id) return project_workspace + @cached_property + def _has_conflict(self) -> bool: + """True if any current project file matches a known conflict-copy pattern.""" + pattern = r"(\.gpkg|\.qgs|.qgz)(.*conflict.*)|( \(.*conflict.*)" + return any(re.search(pattern, f.path) for f in self.files) + def get_latest_files_cache(self) -> List[int]: """Get latest file history ids either from cached table or calculate them on the fly""" if self.latest_project_files.file_history_ids is not None: @@ -658,7 +667,7 @@ def __init__( def path(self) -> str: return self.file.path - @property + @cached_property def diff(self) -> Optional[FileDiff]: """Diff file pushed with UPDATE_DIFF change type. @@ -713,9 +722,37 @@ def changes( if not (is_versioned_file(file) and since is not None and to is not None): return [] - history = [] + # when since=1 the range spans the entire project history; narrow it to + # the most recent CREATE/DELETE so we don't load records from previous + # file lifecycles that the Python break would discard anyway + if since == 1: + boundary = ( + FileHistory.query.join(ProjectFilePath) + .filter( + ProjectFilePath.project_id == project_id, + ProjectFilePath.path == file, + FileHistory.project_version_name <= to, + FileHistory.change.in_( + [PushChangeType.CREATE.value, PushChangeType.DELETE.value] + ), + ) + .order_by(desc(FileHistory.project_version_name)) + .with_entities(FileHistory.project_version_name) + .first() + ) + since = boundary[0] if boundary else since + full_history = ( FileHistory.query.join(ProjectFilePath) + .join(FileHistory.version) + .join(ProjectVersion.project) + .options( + contains_eager(FileHistory.file).load_only(ProjectFilePath.path), + contains_eager(FileHistory.version) + .load_only(ProjectVersion.name, ProjectVersion.project_id) + .contains_eager(ProjectVersion.project) + .load_only(Project.storage_params), + ) .filter( ProjectFilePath.project_id == project_id, FileHistory.project_version_name <= to, @@ -726,6 +763,7 @@ def changes( .all() ) + history = [] for item in full_history: history.append(item) @@ -1781,11 +1819,15 @@ def diff_summary(self): def changes_count(self) -> Dict: """Return number of changes by type""" - query = f"SELECT change, COUNT(change) FROM file_history WHERE version_id = :version_id GROUP BY change;" + query = "SELECT change, COUNT(change) FROM file_history WHERE version_id = :version_id GROUP BY change;" params = {"version_id": self.id} result = db.session.execute(text(query), params).fetchall() return {row[0]: row[1] for row in result} + @cached_property + def _changes_count(self) -> Dict: + return self.changes_count() + @property def zip_path(self): return os.path.join( diff --git a/server/mergin/sync/public_api.yaml b/server/mergin/sync/public_api.yaml index 157e8262..7f5749d1 100644 --- a/server/mergin/sync/public_api.yaml +++ b/server/mergin/sync/public_api.yaml @@ -1124,11 +1124,6 @@ components: - added - updated - removed - expiration: - nullable: true - type: string - format: date-time - example: 2019-02-26T08:47:58.636074Z UploadFileInfo: allOf: - $ref: "#/components/schemas/FileInfo" diff --git a/server/mergin/sync/public_api_controller.py b/server/mergin/sync/public_api_controller.py index 8f142e71..90c71945 100644 --- a/server/mergin/sync/public_api_controller.py +++ b/server/mergin/sync/public_api_controller.py @@ -24,8 +24,10 @@ ) from pygeodiff import GeoDiffLibError from flask_login import current_user -from sqlalchemy import and_, desc, asc +import re +from sqlalchemy import and_, desc, asc, text, tuple_ from sqlalchemy.exc import IntegrityError, SQLAlchemyError +from sqlalchemy.orm import contains_eager, joinedload, load_only, selectinload from gevent import sleep import base64 from werkzeug.exceptions import HTTPException, Conflict @@ -37,6 +39,7 @@ from ..auth.models import User from .models import ( FileSyncErrorType, + FileDiff, Project, ProjectVersion, Upload, @@ -397,19 +400,73 @@ def get_project(project_name, namespace, since="", version=None): # noqa: E501 abort(400, "Parameters 'since' and 'version' are mutually exclusive") elif since: data = ProjectSchema(exclude=["storage_params"]).dump(project) - # append history for versioned files + since_version = ProjectVersion.from_v_name(since) + versioned_paths = [f.path for f in project.files if is_versioned_file(f.path)] + + # load history for all versioned files in one query; only the columns + # actually used downstream are fetched from the joined tables + all_history = ( + FileHistory.query.join(ProjectFilePath) + .join(FileHistory.version) + .options( + contains_eager(FileHistory.file).load_only( + ProjectFilePath.path, ProjectFilePath.project_id + ), + contains_eager(FileHistory.version).load_only(ProjectVersion.name), + ) + .filter( + ProjectFilePath.project_id == project.id, + FileHistory.project_version_name.between( + since_version, project.latest_version + ), + ProjectFilePath.path.in_(versioned_paths), + ) + .order_by(FileHistory.file_path_id, desc(FileHistory.project_version_name)) + .all() + ) + + # partition by file and apply stop-at-CREATE logic, matching FileHistory.changes behaviour + history_by_file: dict = {} + for item in all_history: + fid = item.file_path_id + file_history = history_by_file.setdefault(fid, []) + if file_history and file_history[-1].change in ( + PushChangeType.CREATE.value, + PushChangeType.DELETE.value, + ): + continue + file_history.append(item) + + # batch-load all FileDiff records needed across all files in one query + update_diff_items = [ + i + for items in history_by_file.values() + for i in items + if i.change == PushChangeType.UPDATE_DIFF.value + ] + if update_diff_items: + diffs = FileDiff.query.filter( + FileDiff.file_path_id.in_({i.file_path_id for i in update_diff_items}), + FileDiff.rank == 0, + FileDiff.version.in_( + [i.project_version_name for i in update_diff_items] + ), + ).all() + diff_map = {(d.file_path_id, d.version): d for d in diffs} + for item in update_diff_items: + item.__dict__["diff"] = diff_map.get( + (item.file_path_id, item.project_version_name) + ) + + path_to_file_id = {i.file.path: i.file_path_id for i in all_history} files = [] for f in project.files: history_field = {} - for item in FileHistory.changes( - project.id, - f.path, - ProjectVersion.from_v_name(since), - project.latest_version, - ): - history_field[ProjectVersion.to_v_name(item.version.name)] = ( - FileHistorySchema(exclude=("mtime",)).dump(item) - ) + if is_versioned_file(f.path): + for item in history_by_file.get(path_to_file_id.get(f.path), []): + history_field[ProjectVersion.to_v_name(item.version.name)] = ( + FileHistorySchema(exclude=("mtime", "expiration")).dump(item) + ) files.append({**asdict(f), "history": history_field}) data["files"] = files elif version: @@ -446,6 +503,8 @@ def get_paginated_project_versions( project = require_project(namespace, project_name, ProjectPermissions.Read) query = ProjectVersion.query.filter( and_(ProjectVersion.project_id == project.id, ProjectVersion.name != 0) + ).options( + joinedload(ProjectVersion.project).load_only(Project.name, Project.workspace_id) ) query = ( query.order_by(desc(ProjectVersion.name)) @@ -455,11 +514,59 @@ def get_paginated_project_versions( paginate = query.paginate(page=page, per_page=per_page) result = paginate.items total = paginate.total - versions = ProjectVersionListSchema(many=True).dump(result) + + # batch-resolve workspace names for the page + ws_ids = {v.project.workspace_id for v in result} + workspaces_map = {w.id: w.name for w in current_app.ws_handler.get_by_ids(ws_ids)} + + # batch-compute change counts for all versions in the page in one query + if result: + version_ids = [v.id for v in result] + rows = db.session.execute( + text( + "SELECT version_id, change, COUNT(change) AS cnt" + " FROM file_history" + " WHERE version_id = ANY(:ids)" + " GROUP BY version_id, change" + ), + {"ids": version_ids}, + ).fetchall() + counts_map = {} + for row in rows: + counts_map.setdefault(row.version_id, {})[row.change] = row.cnt + for v in result: + v.__dict__["_changes_count"] = counts_map.get(v.id, {}) + + ctx = {"workspaces_map": workspaces_map} + versions = ProjectVersionListSchema(many=True, context=ctx).dump(result) data = {"versions": versions, "count": total} return data, 200 +def _precompute_has_conflict(projects): + """Pre-populate _has_conflict on each project using a single SQL query.""" + if not projects: + return + conflict_regex = r"(\.gpkg|\.qgs|.qgz)(.*conflict.*)|( \(.*conflict.*)" + rows = db.session.execute( + text( + """ + SELECT DISTINCT lpf.project_id + FROM latest_project_files lpf + CROSS JOIN unnest(lpf.file_history_ids) AS fh_id + JOIN file_history fh ON fh.id = fh_id + JOIN project_file_path fp ON fp.id = fh.file_path_id + WHERE lpf.project_id = ANY(:project_ids) + AND fp.path ~ :pattern + """ + ), + {"project_ids": [p.id for p in projects], "pattern": conflict_regex}, + ).fetchall() + conflict_ids = {row.project_id for row in rows} + for p in projects: + p.__dict__["_has_conflict"] = p.id in conflict_ids + + def get_projects_by_names(): # noqa: E501 """List mergin projects specified by list of projects with namespaces and names @@ -470,38 +577,71 @@ def get_projects_by_names(): # noqa: E501 list_of_projects = request.json.get("projects", []) if len(list_of_projects) > 50: abort(400, "Too many projects") + + # batch-resolve workspaces by name (one DB query for DB-backed handlers) + unique_ws_names = { + key.split("/")[0].lower() + for key in list_of_projects + if len(key.split("/")) == 2 + } + workspaces_by_name = { + ws.name.lower(): ws + for ws in current_app.ws_handler.get_by_names(unique_ws_names) + } + results = {} - for project in list_of_projects: - projects = projects_query(ProjectPermissions.Read, as_admin=False) - splitted = project.split("/") - if len(splitted) != 2: - results[project] = {"error": 404} + valid_projects = [] # list of (key, workspace, project_name) + for key in list_of_projects: + parts = key.split("/") + if len(parts) != 2: + results[key] = {"error": 404} continue - ws = splitted[0] - name = splitted[1] - workspace = current_app.ws_handler.get_by_name(ws) + workspace = workspaces_by_name.get(parts[0].lower()) if not workspace: - results[project] = {"error": 404} + results[key] = {"error": 404} continue - result = projects.filter( - Project.workspace_id == workspace.id, Project.name == name - ).first() - if result: - users_map = { - u.id: u.username - for u in User.query.select_from(ProjectUser) - .join(User) - .filter(ProjectUser.project_id == result.id) - .all() - } - workspaces_map = {workspace.id: workspace.name} - ctx = {"users_map": users_map, "workspaces_map": workspaces_map} - results[project] = ProjectListSchema(context=ctx).dump(result) - else: - if not current_user or not current_user.is_authenticated: - results[project] = {"error": 401} + valid_projects.append((key, workspace, parts[1])) + + if valid_projects: + # batch-fetch all requested projects, eagerly loading project_users so + # members_by_role / get_role don't trigger per-project lazy loads + ws_name_pairs = [(ws.id, name) for _, ws, name in valid_projects] + found_projects = ( + projects_query(ProjectPermissions.Read, as_admin=False) + .options(selectinload(Project.project_users)) + .filter(tuple_(Project.workspace_id, Project.name).in_(ws_name_pairs)) + .all() + ) + found_map = {(p.workspace_id, p.name): p for p in found_projects} + + # batch-fetch all project members in one query + users_map = { + u.id: u.username + for u in User.query.select_from(ProjectUser) + .join(User) + .filter(ProjectUser.project_id.in_([p.id for p in found_projects])) + .all() + } + ws_ids = {p.workspace_id for p in found_projects} + workspaces_map = { + w.id: w.name for w in current_app.ws_handler.get_by_ids(ws_ids) + } + + _precompute_has_conflict(found_projects) + + ctx = {"users_map": users_map, "workspaces_map": workspaces_map} + + for key, workspace, name in valid_projects: + result = found_map.get((workspace.id, name)) + if result: + results[key] = ProjectListSchema(context=ctx).dump(result) else: - results[project] = {"error": 404} + results[key] = ( + {"error": 401} + if not current_user or not current_user.is_authenticated + else {"error": 404} + ) + return results, 200 @@ -521,9 +661,11 @@ def get_projects_by_uuids(uuids): # noqa: E501 projects = ( projects_query(ProjectPermissions.Read, as_admin=False) + .options(selectinload(Project.project_users)) .filter(Project.id.in_(proj_ids)) .all() ) + _precompute_has_conflict(projects) ws_ids = set([p.workspace_id for p in projects]) projects_ids = [p.id for p in projects] users_map = { @@ -605,9 +747,12 @@ def get_paginated_projects( public, only_public, ) - pagination = projects.paginate(page=page, per_page=per_page) + pagination = projects.options(selectinload(Project.project_users)).paginate( + page=page, per_page=per_page + ) result = pagination.items total = pagination.total + _precompute_has_conflict(result) # create user map id:username passed to project schema to minimize queries to db projects_ids = [p.id for p in result] @@ -618,7 +763,7 @@ def get_paginated_projects( .filter(ProjectUser.project_id.in_(projects_ids)) .all() } - ws_ids = [p.workspace_id for p in projects] + ws_ids = [p.workspace_id for p in result] workspaces_map = {w.id: w.name for w in current_app.ws_handler.get_by_ids(ws_ids)} ctx = {"users_map": users_map, "workspaces_map": workspaces_map} sleep( @@ -1191,10 +1336,30 @@ def get_resource_history(project_name, namespace, path): # noqa: E501 ) data = ProjectFileSchema().dump(fh) + history = FileHistory.changes(project.id, path, 1, project.latest_version) + + # batch-load all rank-0 FileDiff records needed for the history in one query + diff_map = {} + if history: + update_diff_versions = [ + i.project_version_name + for i in history + if i.change == PushChangeType.UPDATE_DIFF.value + ] + if update_diff_versions: + diffs = FileDiff.query.filter( + FileDiff.file_path_id == history[0].file_path_id, + FileDiff.rank == 0, + FileDiff.version.in_(update_diff_versions), + ).all() + diff_map = {d.version: d for d in diffs} + history_field = {} - for item in FileHistory.changes(project.id, path, 1, project.latest_version): + for item in history: + if item.change == PushChangeType.UPDATE_DIFF.value: + item.__dict__["diff"] = diff_map.get(item.project_version_name) history_field[ProjectVersion.to_v_name(item.version.name)] = FileHistorySchema( - exclude=("mtime",) + exclude=("mtime", "expiration") ).dump(item) data["history"] = history_field diff --git a/server/mergin/sync/schemas.py b/server/mergin/sync/schemas.py index da18f7db..c2184c62 100644 --- a/server/mergin/sync/schemas.py +++ b/server/mergin/sync/schemas.py @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-MerginMaps-Commercial -import re from marshmallow import fields, ValidationError, Schema, post_dump from flask_login import current_user from flask import current_app @@ -192,7 +191,7 @@ class ProjectListSchema(ma.SQLAlchemyAutoSchema): id = fields.UUID() name = fields.Str() namespace = fields.Method("get_workspace_name") - access = fields.Function(lambda obj: ProjectAccessSchema().dump(obj)) + access = fields.Method("get_access") permissions = fields.Function(project_user_permissions) version = fields.Function(lambda obj: ProjectVersion.to_v_name(obj.latest_version)) updated = fields.Method("get_updated") @@ -200,22 +199,14 @@ class ProjectListSchema(ma.SQLAlchemyAutoSchema): creator = fields.Integer(attribute="creator_id") disk_usage = fields.Integer() tags = fields.List(fields.Str()) - has_conflict = fields.Method("get_has_conflict") + has_conflict = fields.Function(lambda obj: obj._has_conflict) + + def get_access(self, obj): + return ProjectAccessSchema(context=self.context).dump(obj) def get_updated(self, obj): return obj.updated if obj.updated else obj.created - def get_has_conflict(self, obj): - """Check if there is any conflict file in project generated by client - Patterns to check: - - file.[gpkg|qgs|qgz]_conflict_copy (older convention) - - file.gpkg_rebase_conflicts (older convention) - - file (conflicted copy, user vx).* - - file (edit conflict, user vx).json - """ - regex = r"(\.gpkg|\.qgs|.qgz)(.*conflict.*)|( \(.*conflict.*)" - return any(re.search(regex, file.path) for file in obj.files) - def get_workspace_name(self, obj): """Discover ProjectListSchema workspace name""" try: @@ -368,22 +359,25 @@ class ProjectAccessDetailSchema(Schema): class ProjectVersionListSchema(ma.SQLAlchemyAutoSchema): project_name = fields.Function(lambda obj: obj.project.name) - namespace = fields.Function(lambda obj: obj.project.workspace.name) + namespace = fields.Method("get_namespace") name = fields.Function(lambda obj: ProjectVersion.to_v_name(obj.name)) author = fields.String(attribute="author.username") created = DateTimeWithZ() changes = fields.Method("_changes") project_size = fields.Integer() + def get_namespace(self, obj): + workspaces_map = self.context.get("workspaces_map", {}) + return workspaces_map.get(obj.project.workspace_id, "") + def _changes(self, obj): - result = obj.changes_count() - data = { - "added": result.get(PushChangeType.CREATE.value, 0), - "updated": result.get(PushChangeType.UPDATE.value, 0), - "updated_diff": result.get(PushChangeType.UPDATE_DIFF.value, 0), - "removed": result.get(PushChangeType.DELETE.value, 0), + counts = obj._changes_count + return { + "added": counts.get(PushChangeType.CREATE.value, 0), + "updated": counts.get(PushChangeType.UPDATE.value, 0), + "updated_diff": counts.get(PushChangeType.UPDATE_DIFF.value, 0), + "removed": counts.get(PushChangeType.DELETE.value, 0), } - return data class Meta: model = ProjectVersion diff --git a/server/mergin/sync/workspace.py b/server/mergin/sync/workspace.py index e7575e46..4c9866c5 100644 --- a/server/mergin/sync/workspace.py +++ b/server/mergin/sync/workspace.py @@ -144,6 +144,14 @@ def get_by_name(self, name): return return self.factory_method() + def get_by_names(self, names): + result = [] + for name in set(names): + ws = self.get_by_name(name) + if ws: + result.append(ws) + return result + def get_by_project(self, project): return self.factory_method() diff --git a/server/mergin/tests/test_project_controller.py b/server/mergin/tests/test_project_controller.py index 60c36ee2..d1f1afd6 100644 --- a/server/mergin/tests/test_project_controller.py +++ b/server/mergin/tests/test_project_controller.py @@ -166,7 +166,6 @@ def test_file_history(client, diff_project): assert "v1" not in history assert "v3" in history assert "location" not in history["v7"] - assert "expiration" in history["v7"] def test_get_paginated_projects(client): @@ -2167,7 +2166,12 @@ def test_project_conflict_files(diff_project, file): } ] } + project_id = diff_project.id _ = add_project_version(diff_project, changes) + # expunge so the identity map releases the instance; re-query gives a fresh + # object without the stale cached_property value + db.session.expunge(diff_project) + diff_project = db.session.get(Project, project_id) project_info = ProjectListSchema(only=("has_conflict",), context=ctx).dump( diff_project )