diff --git a/packages/common/src/types/command/PartialTargetDescriptor.types.ts b/packages/common/src/types/command/PartialTargetDescriptor.types.ts index 61bbd268da..13a9239a9d 100644 --- a/packages/common/src/types/command/PartialTargetDescriptor.types.ts +++ b/packages/common/src/types/command/PartialTargetDescriptor.types.ts @@ -230,7 +230,7 @@ export interface SimpleScopeType { type: SimpleScopeTypeType; } -export type ScopeTypeType = SimpleScopeTypeType | ScopeType["type"]; +export type ScopeTypeType = ScopeType["type"]; export interface CustomRegexScopeType { type: "customRegex"; diff --git a/packages/cursorless-engine/src/languages/LanguageDefinition.ts b/packages/cursorless-engine/src/languages/LanguageDefinition.ts index 194c0778fb..0504719592 100644 --- a/packages/cursorless-engine/src/languages/LanguageDefinition.ts +++ b/packages/cursorless-engine/src/languages/LanguageDefinition.ts @@ -87,16 +87,22 @@ export class LanguageDefinition { * document. We use this in our surrounding pair code. * * @param document The document to search - * @param captureName The name of a capture to search for + * @param captureNames Optional capture names to include * @returns A map of captures in the document */ - getCapturesMap(document: TextDocument) { - const matches = this.query.matches(document); - const result: Partial> = {}; + getCapturesMap( + document: TextDocument, + captureNames: readonly T[], + ) { + const matches = this.query.matchesForCaptures( + document, + new Set(captureNames), + ); + const result: Partial> = {}; for (const match of matches) { for (const capture of match.captures) { - const name = capture.name as SimpleScopeTypeType; + const name = capture.name as T; if (result[name] == null) { result[name] = []; } diff --git a/packages/cursorless-engine/src/languages/TreeSitterQuery/TreeSitterQuery.ts b/packages/cursorless-engine/src/languages/TreeSitterQuery/TreeSitterQuery.ts index 18f4974bb9..8d4a0412c6 100644 --- a/packages/cursorless-engine/src/languages/TreeSitterQuery/TreeSitterQuery.ts +++ b/packages/cursorless-engine/src/languages/TreeSitterQuery/TreeSitterQuery.ts @@ -1,10 +1,10 @@ import type { Position, TextDocument, TreeSitter } from "@cursorless/common"; import type * as treeSitter from "web-tree-sitter"; import { ide } from "../../singletons/ide.singleton"; +import { getNormalizedCaptureName } from "./captureNames"; import { checkCaptureStartEnd } from "./checkCaptureStartEnd"; import { getNodeRange } from "./getNodeRange"; import { isContainedInErrorNode } from "./isContainedInErrorNode"; -import { normalizeCaptureName } from "./normalizeCaptureName"; import { parsePredicatesWithErrorHandling } from "./parsePredicatesWithErrorHandling"; import { positionToPoint } from "./positionToPoint"; import type { @@ -55,7 +55,7 @@ export class TreeSitterQuery { hasCapture(name: string): boolean { return this.query.captureNames.some( - (n) => normalizeCaptureName(n) === name, + (n) => getNormalizedCaptureName(n) === name, ); } @@ -64,29 +64,86 @@ export class TreeSitterQuery { start?: Position, end?: Position, ): QueryMatch[] { - if (!treeSitterQueryCache.isValid(document, start, end)) { - const matches = this.getAllMatches(document, start, end); - treeSitterQueryCache.update(document, start, end, matches); + return this.getMatches(document, start, end, undefined); + } + + matchesForCaptures( + document: TextDocument, + captureNames: Set, + ): QueryMatch[] { + return this.getMatches(document, undefined, undefined, captureNames); + } + + private getMatches( + document: TextDocument, + start: Position | undefined, + end: Position | undefined, + captureNameFilter: Set | undefined, + ): QueryMatch[] { + if ( + !treeSitterQueryCache.isValid(document, start, end, captureNameFilter) + ) { + const matches = this.calculateMatches( + document, + start, + end, + captureNameFilter, + ); + treeSitterQueryCache.update( + document, + start, + end, + captureNameFilter, + matches, + ); } return treeSitterQueryCache.get(); } - private getAllMatches( + private calculateMatches( document: TextDocument, - start?: Position, - end?: Position, + start: Position | undefined, + end: Position | undefined, + captureNameFilter: Set | undefined, ): QueryMatch[] { const matches = this.getTreeMatches(document, start, end); const results: QueryMatch[] = []; for (const match of matches) { - const mutableMatch = this.createMutableQueryMatch(document, match); + if ( + captureNameFilter != null && + !match.captures.some((capture) => + captureNameFilter.has(getNormalizedCaptureName(capture.name)), + ) + ) { + continue; + } + + const hasPatternPredicates = + this.patternPredicates[match.patternIndex].length > 0; + + const mutableMatch = this.createMutableQueryMatch( + document, + match, + // If there are pattern predicates, we need to include all captures when + // creating the mutable match, since the predicates may depend on any of + // the captures. + !hasPatternPredicates ? captureNameFilter : undefined, + ); if (!this.runPredicates(mutableMatch)) { continue; } - results.push(this.createQueryMatch(mutableMatch)); + const queryMatch = this.createQueryMatch( + mutableMatch, + // We only need to filter here if we didn't filter in createMutableQueryMatch() + hasPatternPredicates ? captureNameFilter : undefined, + ); + + if (queryMatch != null) { + results.push(queryMatch); + } } return results; @@ -107,10 +164,19 @@ export class TreeSitterQuery { private createMutableQueryMatch( document: TextDocument, match: treeSitter.QueryMatch, + captureNameFilter?: Set, ): MutableQueryMatch { - return { - patternIdx: match.patternIndex, - captures: match.captures.map(({ name, node }) => ({ + const captures: MutableQueryCapture[] = []; + + for (const { name, node } of match.captures) { + if ( + captureNameFilter != null && + !captureNameFilter.has(getNormalizedCaptureName(name)) + ) { + continue; + } + + captures.push({ name, node, document, @@ -118,7 +184,12 @@ export class TreeSitterQuery { insertionDelimiter: undefined, allowMultiple: false, hasError: () => isContainedInErrorNode(node), - })), + }); + } + + return { + patternIdx: match.patternIndex, + captures, }; } @@ -131,7 +202,10 @@ export class TreeSitterQuery { return true; } - private createQueryMatch(match: MutableQueryMatch): QueryMatch { + private createQueryMatch( + match: MutableQueryMatch, + captureNameFilter?: Set, + ): QueryMatch | undefined { const result: MutableQueryCapture[] = []; const map = new Map< string, @@ -144,7 +218,10 @@ export class TreeSitterQuery { // name, for which we'd return a capture with name `foo`. for (const capture of match.captures) { - const name = normalizeCaptureName(capture.name); + const name = getNormalizedCaptureName(capture.name); + if (captureNameFilter != null && !captureNameFilter.has(name)) { + continue; + } const range = getStartOfEndOfRange(capture); const existing = map.get(name); @@ -168,6 +245,10 @@ export class TreeSitterQuery { } } + if (result.length === 0) { + return undefined; + } + if (this.shouldCheckCaptures) { this.checkCaptures(Array.from(map.values())); } diff --git a/packages/cursorless-engine/src/languages/TreeSitterQuery/TreeSitterQueryCache.ts b/packages/cursorless-engine/src/languages/TreeSitterQuery/TreeSitterQueryCache.ts index a2b251912e..ee229a501d 100644 --- a/packages/cursorless-engine/src/languages/TreeSitterQuery/TreeSitterQueryCache.ts +++ b/packages/cursorless-engine/src/languages/TreeSitterQuery/TreeSitterQueryCache.ts @@ -8,6 +8,7 @@ export class TreeSitterQueryCache { private startPosition: Position | undefined; private endPosition: Position | undefined; private matches: QueryMatch[] = []; + private captureNames: Set | undefined; clear() { this.documentUri = ""; @@ -15,6 +16,7 @@ export class TreeSitterQueryCache { this.documentLanguageId = ""; this.startPosition = undefined; this.endPosition = undefined; + this.captureNames = undefined; this.matches = []; } @@ -22,13 +24,15 @@ export class TreeSitterQueryCache { document: TextDocument, startPosition: Position | undefined, endPosition: Position | undefined, + captureNames: Set | undefined, ) { return ( this.documentVersion === document.version && this.documentUri === document.uri.toString() && this.documentLanguageId === document.languageId && positionsEqual(this.startPosition, startPosition) && - positionsEqual(this.endPosition, endPosition) + positionsEqual(this.endPosition, endPosition) && + setEqual(this.captureNames, captureNames) ); } @@ -36,6 +40,7 @@ export class TreeSitterQueryCache { document: TextDocument, startPosition: Position | undefined, endPosition: Position | undefined, + captureNames: Set | undefined, matches: QueryMatch[], ) { this.documentVersion = document.version; @@ -43,6 +48,7 @@ export class TreeSitterQueryCache { this.documentLanguageId = document.languageId; this.startPosition = startPosition; this.endPosition = endPosition; + this.captureNames = captureNames; this.matches = matches; } @@ -58,4 +64,19 @@ function positionsEqual(a: Position | undefined, b: Position | undefined) { return a.isEqual(b); } +function setEqual(a: Set | undefined, b: Set | undefined) { + if (a == null || b == null) { + return a === b; + } + if (a.size !== b.size) { + return false; + } + for (const item of a) { + if (!b.has(item)) { + return false; + } + } + return true; +} + export const treeSitterQueryCache = new TreeSitterQueryCache(); diff --git a/packages/cursorless-engine/src/languages/TreeSitterQuery/captureNames.ts b/packages/cursorless-engine/src/languages/TreeSitterQuery/captureNames.ts new file mode 100644 index 0000000000..eba556276e --- /dev/null +++ b/packages/cursorless-engine/src/languages/TreeSitterQuery/captureNames.ts @@ -0,0 +1,86 @@ +import { pseudoScopes, simpleScopeTypeTypes } from "@cursorless/common"; + +const wildcard = "_"; +const captureNames = [ + ...simpleScopeTypeTypes.filter((s) => !pseudoScopes.has(s)), + wildcard, + "interior", +]; + +const positionRelationships = ["prefix", "leading", "trailing"]; +const positionSuffixes = [ + "startOf", + "endOf", + "start.startOf", + "start.endOf", + "end.startOf", + "end.endOf", +]; + +const rangeRelationships = [ + "domain", + "removal", + "iteration", + "iteration.domain", +]; +const rangeSuffixes = [ + "start", + "end", + "start.startOf", + "start.endOf", + "end.startOf", + "end.endOf", +]; + +const allowedCaptures = new Set(); + +for (const captureName of captureNames) { + // Wildcard is not allowed by itself without a relationship + if (captureName !== wildcard) { + // eg: statement + allowedCaptures.add(captureName); + + // eg: statement.start | statement.start.endOf + for (const suffix of rangeSuffixes) { + allowedCaptures.add(`${captureName}.${suffix}`); + } + } + + for (const relationship of positionRelationships) { + // eg: statement.leading + allowedCaptures.add(`${captureName}.${relationship}`); + + for (const suffix of positionSuffixes) { + // eg: statement.leading.endOf + allowedCaptures.add(`${captureName}.${relationship}.${suffix}`); + } + } + + for (const relationship of rangeRelationships) { + // eg: statement.domain + allowedCaptures.add(`${captureName}.${relationship}`); + + for (const suffix of rangeSuffixes) { + // eg: statement.domain.start | statement.domain.start.endOf + allowedCaptures.add(`${captureName}.${relationship}.${suffix}`); + } + } +} + +const normalizedCaptureNamesMap = new Map(); + +for (const captureName of allowedCaptures) { + normalizedCaptureNamesMap.set(captureName, normalizeCaptureName(captureName)); +} + +function normalizeCaptureName(name: string): string { + return name.replace(/(\.(start|end))?(\.(startOf|endOf))?$/, ""); +} + +export function isCaptureAllowed(captureName: string): boolean { + return allowedCaptures.has(captureName); +} + +export function getNormalizedCaptureName(captureName: string): string { + return normalizedCaptureNamesMap.get(captureName) ?? captureName; +} diff --git a/packages/cursorless-engine/src/languages/TreeSitterQuery/normalizeCaptureName.ts b/packages/cursorless-engine/src/languages/TreeSitterQuery/normalizeCaptureName.ts deleted file mode 100644 index 5322ff1556..0000000000 --- a/packages/cursorless-engine/src/languages/TreeSitterQuery/normalizeCaptureName.ts +++ /dev/null @@ -1,3 +0,0 @@ -export function normalizeCaptureName(name: string): string { - return name.replace(/(\.(start|end))?(\.(startOf|endOf))?$/, ""); -} diff --git a/packages/cursorless-engine/src/languages/TreeSitterQuery/validateQueryCaptures.ts b/packages/cursorless-engine/src/languages/TreeSitterQuery/validateQueryCaptures.ts index 9ff3f94f19..46b86337b1 100644 --- a/packages/cursorless-engine/src/languages/TreeSitterQuery/validateQueryCaptures.ts +++ b/packages/cursorless-engine/src/languages/TreeSitterQuery/validateQueryCaptures.ts @@ -1,76 +1,6 @@ -import { - pseudoScopes, - showError, - simpleScopeTypeTypes, -} from "@cursorless/common"; +import { showError } from "@cursorless/common"; import { ide } from "../../singletons/ide.singleton"; - -const wildcard = "_"; -const captureNames = [ - ...simpleScopeTypeTypes.filter((s) => !pseudoScopes.has(s)), - wildcard, - "interior", -]; - -const positionRelationships = ["prefix", "leading", "trailing"]; -const positionSuffixes = [ - "startOf", - "endOf", - "start.startOf", - "start.endOf", - "end.startOf", - "end.endOf", -]; - -const rangeRelationships = [ - "domain", - "removal", - "iteration", - "iteration.domain", -]; -const rangeSuffixes = [ - "start", - "end", - "start.startOf", - "start.endOf", - "end.startOf", - "end.endOf", -]; - -const allowedCaptures = new Set(); - -for (const captureName of captureNames) { - // Wildcard is not allowed by itself without a relationship - if (captureName !== wildcard) { - // eg: statement - allowedCaptures.add(captureName); - - // eg: statement.start | statement.start.endOf - for (const suffix of rangeSuffixes) { - allowedCaptures.add(`${captureName}.${suffix}`); - } - } - - for (const relationship of positionRelationships) { - // eg: statement.leading - allowedCaptures.add(`${captureName}.${relationship}`); - - for (const suffix of positionSuffixes) { - // eg: statement.leading.endOf - allowedCaptures.add(`${captureName}.${relationship}.${suffix}`); - } - } - - for (const relationship of rangeRelationships) { - // eg: statement.domain - allowedCaptures.add(`${captureName}.${relationship}`); - - for (const suffix of rangeSuffixes) { - // eg: statement.domain.start | statement.domain.start.endOf - allowedCaptures.add(`${captureName}.${relationship}.${suffix}`); - } - } -} +import { isCaptureAllowed } from "./captureNames"; // Not a comment. ie line is not starting with `;;` // Not a string. @@ -94,7 +24,7 @@ export function validateQueryCaptures(file: string, rawQuery: string): void { continue; } - if (!allowedCaptures.has(captureName)) { + if (!isCaptureAllowed(captureName)) { const lineNumber = match.input.slice(0, match.index).split("\n").length; errors.push(`${file}(${lineNumber}) invalid capture '@${captureName}'.`); } diff --git a/packages/cursorless-engine/src/processTargets/modifiers/scopeHandlers/SurroundingPairScopeHandler/getDelimiterOccurrences.ts b/packages/cursorless-engine/src/processTargets/modifiers/scopeHandlers/SurroundingPairScopeHandler/getDelimiterOccurrences.ts index 1ec74624e6..2af9777db9 100644 --- a/packages/cursorless-engine/src/processTargets/modifiers/scopeHandlers/SurroundingPairScopeHandler/getDelimiterOccurrences.ts +++ b/packages/cursorless-engine/src/processTargets/modifiers/scopeHandlers/SurroundingPairScopeHandler/getDelimiterOccurrences.ts @@ -23,7 +23,12 @@ export function getDelimiterOccurrences( return []; } - const capturesMap = languageDefinition?.getCapturesMap(document) ?? {}; + const capturesMap = + languageDefinition?.getCapturesMap(document, [ + "disqualifyDelimiter", + "pairDelimiter", + "textFragment", + ]) ?? {}; const disqualifyDelimiters = new OneWayRangeFinder( getSortedCaptures(capturesMap.disqualifyDelimiter), ); diff --git a/packages/cursorless-engine/src/processTargets/modifiers/scopeHandlers/SurroundingPairScopeHandler/getIndividualDelimiters.ts b/packages/cursorless-engine/src/processTargets/modifiers/scopeHandlers/SurroundingPairScopeHandler/getIndividualDelimiters.ts index 74850d1047..a2cd5f818a 100644 --- a/packages/cursorless-engine/src/processTargets/modifiers/scopeHandlers/SurroundingPairScopeHandler/getIndividualDelimiters.ts +++ b/packages/cursorless-engine/src/processTargets/modifiers/scopeHandlers/SurroundingPairScopeHandler/getIndividualDelimiters.ts @@ -6,7 +6,7 @@ import type { import { isString } from "@cursorless/common"; import { concat, uniq } from "lodash-es"; import { complexDelimiterMap, getSimpleDelimiterMap } from "./delimiterMaps"; -import type { IndividualDelimiter } from "./types"; +import { DelimiterSide, type IndividualDelimiter } from "./types"; /** * Given a list of delimiters, returns a list where each element corresponds to @@ -55,14 +55,14 @@ function getSimpleIndividualDelimiters( const side = (() => { if (isLeft && !isRight) { - return "left"; + return DelimiterSide.left; } if (!isLeft && isRight) { - return "right"; + return DelimiterSide.right; } // If delimiter text is the same for left and right, we say its side // is "unknown", so must be determined from context. - return "unknown"; + return DelimiterSide.unknown; })(); return { diff --git a/packages/cursorless-engine/src/processTargets/modifiers/scopeHandlers/SurroundingPairScopeHandler/getSurroundingPairOccurrences.ts b/packages/cursorless-engine/src/processTargets/modifiers/scopeHandlers/SurroundingPairScopeHandler/getSurroundingPairOccurrences.ts index 0a031d0421..da5f70944a 100644 --- a/packages/cursorless-engine/src/processTargets/modifiers/scopeHandlers/SurroundingPairScopeHandler/getSurroundingPairOccurrences.ts +++ b/packages/cursorless-engine/src/processTargets/modifiers/scopeHandlers/SurroundingPairScopeHandler/getSurroundingPairOccurrences.ts @@ -1,5 +1,6 @@ import type { Range } from "@cursorless/common"; import findLastIndex from "lodash-es/findLastIndex"; +import { DelimiterSide } from "./types"; import type { DelimiterOccurrence, IndividualDelimiter, @@ -41,7 +42,8 @@ export function getSurroundingPairOccurrences( if (closestOpeningDelimiterMatch == null) { const openingDelimiterInfo = occurrence.delimiterInfos.find( - ({ side }) => side === "left" || side === "unknown", + ({ side }) => + side === DelimiterSide.left || side === DelimiterSide.unknown, ); // Pure closing delimiters with no matching opener are ignored. @@ -84,7 +86,7 @@ function getClosestOpeningDelimiterMatch( let closestMatch: OpeningDelimiterMatch | undefined; for (const delimiterInfo of occurrence.delimiterInfos) { - if (delimiterInfo.side === "left") { + if (delimiterInfo.side === DelimiterSide.left) { continue; } diff --git a/packages/cursorless-engine/src/processTargets/modifiers/scopeHandlers/SurroundingPairScopeHandler/types.ts b/packages/cursorless-engine/src/processTargets/modifiers/scopeHandlers/SurroundingPairScopeHandler/types.ts index 520e9630fc..963d959ecd 100644 --- a/packages/cursorless-engine/src/processTargets/modifiers/scopeHandlers/SurroundingPairScopeHandler/types.ts +++ b/packages/cursorless-engine/src/processTargets/modifiers/scopeHandlers/SurroundingPairScopeHandler/types.ts @@ -5,7 +5,11 @@ import type { Range, SimpleSurroundingPairName } from "@cursorless/common"; * or if we do not know. Note that the terms "opening" and "closing" could be * used instead of "left" and "right", respectively. */ -export type DelimiterSide = "unknown" | "left" | "right"; +export enum DelimiterSide { + unknown, + left, + right, +} /** * A description of one possible side of a delimiter diff --git a/packages/cursorless-vscode-e2e/src/suite/performance.vscode.test.ts b/packages/cursorless-vscode-e2e/src/suite/performance.vscode.test.ts index 1d8f43dcb3..5048e59571 100644 --- a/packages/cursorless-vscode-e2e/src/suite/performance.vscode.test.ts +++ b/packages/cursorless-vscode-e2e/src/suite/performance.vscode.test.ts @@ -15,9 +15,15 @@ import { isMac } from "@cursorless/node-common"; const testData = generateTestData(100); const multiplier = calculateMultiplier(); const smallThresholdMs = 50 * multiplier; +const midThresholdMs = 200 * multiplier; const largeThresholdMs = 300 * multiplier; const xlThresholdMs = 400 * multiplier; -const thresholds = [smallThresholdMs, largeThresholdMs, xlThresholdMs]; +const thresholds = [ + smallThresholdMs, + midThresholdMs, + largeThresholdMs, + xlThresholdMs, +]; type ModifierType = "containing" | "previous" | "every"; @@ -78,14 +84,10 @@ suite(`Performance ${thresholds.join("/")} ms`, async function () { ["collectionItem", largeThresholdMs, "every"], ["collectionItem", largeThresholdMs, "previous"], // Surrounding pair - [{ type: "surroundingPair", delimiter: "curlyBrackets" }, largeThresholdMs], - [{ type: "surroundingPair", delimiter: "any" }, largeThresholdMs], - [{ type: "surroundingPair", delimiter: "any" }, largeThresholdMs, "every"], - [ - { type: "surroundingPair", delimiter: "any" }, - largeThresholdMs, - "previous", - ], + [{ type: "surroundingPair", delimiter: "curlyBrackets" }, midThresholdMs], + [{ type: "surroundingPair", delimiter: "any" }, midThresholdMs], + [{ type: "surroundingPair", delimiter: "any" }, midThresholdMs, "every"], + [{ type: "surroundingPair", delimiter: "any" }, midThresholdMs, "previous"], ]; for (const [scope, threshold, modifierType] of fixtures) { @@ -111,7 +113,7 @@ suite(`Performance ${thresholds.join("/")} ms`, async function () { test( "Select collectionItem with multiple cursors", asyncSafety(() => - selectWithMultipleCursors(largeThresholdMs, { + selectWithMultipleCursors(midThresholdMs, { type: "collectionItem", }), ), @@ -126,6 +128,23 @@ suite(`Performance ${thresholds.join("/")} ms`, async function () { }), ), ); + + test( + "Swap key / value with multiple cursors", + asyncSafety(() => + testWithMultipleCursors(midThresholdMs, { + name: "swapTargets", + target1: { + type: "primitive", + modifiers: [getModifier({ type: "collectionKey" })], + }, + target2: { + type: "primitive", + modifiers: [getModifier({ type: "value" })], + }, + }), + ), + ); }); function removeToken(thresholdMs: number) { @@ -139,6 +158,19 @@ function removeToken(thresholdMs: number) { } function selectWithMultipleCursors(thresholdMs: number, scopeType: ScopeType) { + return testWithMultipleCursors(thresholdMs, { + name: "setSelection", + target: { + type: "primitive", + modifiers: [getModifier(scopeType)], + }, + }); +} + +function testWithMultipleCursors( + thresholdMs: number, + action: ActionDescriptor, +) { const beforeCallback = async (editor: vscode.TextEditor) => { await runCursorlessAction({ name: "setSelectionBefore", @@ -151,16 +183,7 @@ function selectWithMultipleCursors(thresholdMs: number, scopeType: ScopeType) { assert.equal(editor.selections.length, 100, "Expected 100 cursors"); }; - const callback = () => { - return runCursorlessAction({ - name: "setSelection", - target: { - type: "primitive", - modifiers: [getModifier(scopeType)], - }, - }); - }; - + const callback = () => runCursorlessAction(action); return testPerformanceCallback(thresholdMs, callback, beforeCallback); }