diff --git a/src/cli/mm.test.ts b/src/cli/mm.test.ts index 083472f..edd0ee5 100644 --- a/src/cli/mm.test.ts +++ b/src/cli/mm.test.ts @@ -533,6 +533,54 @@ describe('parseLaunchArgs', () => { ); }); + it('parses --platform value', () => { + expect(parseLaunchArgs(['--platform', 'ios'])).toStrictEqual({ + platform: 'ios', + }); + }); + + it('parses --device-id value', () => { + expect( + parseLaunchArgs(['--device-id', '4A3B2C1D-E5F6-7890-ABCD-EF1234567890']), + ).toStrictEqual({ + deviceId: '4A3B2C1D-E5F6-7890-ABCD-EF1234567890', + }); + }); + + it('parses --platform and --device-id together', () => { + expect( + parseLaunchArgs([ + '--platform', + 'android', + '--device-id', + 'emulator-5554', + ]), + ).toStrictEqual({ + platform: 'android', + deviceId: 'emulator-5554', + }); + }); + + it('exits for --platform without value', () => { + expect(() => parseLaunchArgs(['--platform'])).toThrowError('process.exit'); + expect(stderrSpy).toHaveBeenCalledWith( + 'Error: --platform requires a value (browser|ios|android)\n', + ); + }); + + it('exits for --platform with flag as value', () => { + expect(() => parseLaunchArgs(['--platform', '--force'])).toThrowError( + 'process.exit', + ); + }); + + it('exits for --device-id without value', () => { + expect(() => parseLaunchArgs(['--device-id'])).toThrowError('process.exit'); + expect(stderrSpy).toHaveBeenCalledWith( + 'Error: --device-id requires a value\n', + ); + }); + it('writes warning for unknown flags', () => { parseLaunchArgs(['--unknown']); expect(stderrSpy).toHaveBeenCalledWith( diff --git a/src/cli/mm.ts b/src/cli/mm.ts index 64cdb6b..99333b6 100644 --- a/src/cli/mm.ts +++ b/src/cli/mm.ts @@ -1296,6 +1296,8 @@ export function parseLaunchArgs(args: string[]): Record { '--goal', '--force', '--flow-tags', + '--platform', + '--device-id', ]); for (let i = 0; i < args.length; i++) { @@ -1341,6 +1343,22 @@ export function parseLaunchArgs(args: string[]): Record { process.exit(1); } result.flowTags = args[i].split(',').map((tag) => tag.trim()); + } else if (arg === '--platform') { + i += 1; + if (!args[i] || args[i].startsWith('--')) { + process.stderr.write( + 'Error: --platform requires a value (browser|ios|android)\n', + ); + process.exit(1); + } + result.platform = args[i]; + } else if (arg === '--device-id') { + i += 1; + if (!args[i] || args[i].startsWith('--')) { + process.stderr.write('Error: --device-id requires a value\n'); + process.exit(1); + } + result.deviceId = args[i]; } else if (arg.startsWith('--') && !knownFlags.has(arg)) { process.stderr.write(`Warning: unknown launch flag '${arg}'\n`); } @@ -1365,7 +1383,7 @@ Environment Variables: Falls back to the current git worktree root. Lifecycle: - mm launch [--context e2e|prod] [--state default|onboarding|custom] [--extension-path ] [--goal ] [--force] [--flow-tags ] + mm launch [--context e2e|prod] [--state default|onboarding|custom] [--extension-path ] [--goal ] [--force] [--flow-tags ] [--platform browser|ios|android] [--device-id ] mm cleanup [--shutdown] mm status mm stop [--force] diff --git a/src/index.ts b/src/index.ts index 2275fbe..ba19096 100644 --- a/src/index.ts +++ b/src/index.ts @@ -2,6 +2,17 @@ export type * from './capabilities/types.js'; export * from './capabilities/context.js'; +// Platform +export type { + PlatformType, + IPlatformDriver, + ClickActionResult, + TypeActionResult, + GetTextActionResult, + PlatformScreenshotOptions, +} from './platform'; +export { PlaywrightPlatformDriver } from './platform'; + // Session Manager Interface (transport-agnostic) export type { ISessionManager, diff --git a/src/platform/index.ts b/src/platform/index.ts new file mode 100644 index 0000000..0a54f09 --- /dev/null +++ b/src/platform/index.ts @@ -0,0 +1,12 @@ +export type { + PlatformType, + TargetType, + ClickActionResult, + TypeActionResult, + GetTextActionResult, + PlatformScreenshotOptions, + WithinScope, + IPlatformDriver, +} from './types.js'; + +export { PlaywrightPlatformDriver } from './playwright-driver.js'; diff --git a/src/platform/playwright-driver.test.ts b/src/platform/playwright-driver.test.ts new file mode 100644 index 0000000..c35c7f6 --- /dev/null +++ b/src/platform/playwright-driver.test.ts @@ -0,0 +1,387 @@ +import { describe, it, expect, vi, afterEach } from 'vitest'; + +import { PlaywrightPlatformDriver } from './playwright-driver.js'; +import type { IPlatformDriver } from './types.js'; +import { createMockSessionManager } from '../tools/test-utils/mock-factories.js'; +import * as discoveryModule from '../tools/utils/discovery.js'; + +function createMockLocator() { + return { + click: vi.fn().mockResolvedValue(undefined), + fill: vi.fn().mockResolvedValue(undefined), + waitFor: vi.fn().mockResolvedValue(undefined), + textContent: vi.fn().mockResolvedValue('Hello World'), + }; +} + +function createMockPage() { + return { + url: vi.fn().mockReturnValue('chrome-extension://abc/home.html'), + locator: vi.fn(() => createMockLocator()), + isClosed: vi.fn(() => false), + }; +} + +function createDriver( + pageOverride?: object, + sessionOverride?: object, +): IPlatformDriver { + const page = (pageOverride ?? createMockPage()) as any; + const sessionManager = (sessionOverride ?? + createMockSessionManager({ hasActive: true })) as any; + return new PlaywrightPlatformDriver(() => page, sessionManager); +} + +describe('PlaywrightPlatformDriver', () => { + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe('getPlatform', () => { + it('returns browser', () => { + const driver = createDriver(); + expect(driver.getPlatform()).toBe('browser'); + }); + }); + + describe('getCurrentUrl', () => { + it('returns the page URL', () => { + const page = createMockPage(); + const driver = createDriver(page); + expect(driver.getCurrentUrl()).toBe('chrome-extension://abc/home.html'); + expect(page.url).toHaveBeenCalledOnce(); + }); + }); + + describe('click', () => { + it('delegates to waitForTarget and locator.click', async () => { + const locator = createMockLocator(); + vi.spyOn(discoveryModule, 'waitForTarget').mockResolvedValue( + locator as any, + ); + + const driver = createDriver(); + const result = await driver.click( + 'testId', + 'submit-btn', + new Map(), + 5000, + ); + + expect(result.clicked).toBe(true); + expect(result.target).toBe('testId:submit-btn'); + expect(locator.click).toHaveBeenCalledOnce(); + }); + + it('passes within scope to waitForTarget', async () => { + const locator = createMockLocator(); + vi.spyOn(discoveryModule, 'waitForTarget').mockResolvedValue( + locator as any, + ); + + const driver = createDriver(); + await driver.click('testId', 'btn', new Map(), 5000, { + type: 'testId', + value: 'parent-container', + }); + + expect(discoveryModule.waitForTarget).toHaveBeenCalledWith( + expect.anything(), + 'testId', + 'btn', + expect.any(Map), + 5000, + { type: 'testId', value: 'parent-container' }, + ); + }); + + it('throws timeout error when waitForTarget times out', async () => { + vi.spyOn(discoveryModule, 'waitForTarget').mockRejectedValue( + new Error('Timeout 5000ms exceeded waiting for element'), + ); + + const driver = createDriver(); + await expect( + driver.click('testId', 'missing', new Map(), 5000), + ).rejects.toThrowError('Timeout 5000ms exceeded'); + }); + + it('returns pageClosedAfterClick when page closes during click', async () => { + const locator = createMockLocator(); + const pageClosedError = new Error( + 'Target page, context or browser has been closed', + ); + locator.click.mockRejectedValue(pageClosedError); + vi.spyOn(discoveryModule, 'waitForTarget').mockResolvedValue( + locator as any, + ); + + const driver = createDriver(); + const result = await driver.click('testId', 'close-btn', new Map(), 5000); + + expect(result.clicked).toBe(true); + expect(result.pageClosedAfterClick).toBe(true); + }); + + it('throws on non-page-closed errors', async () => { + const locator = createMockLocator(); + locator.click.mockRejectedValue(new Error('Element detached')); + vi.spyOn(discoveryModule, 'waitForTarget').mockResolvedValue( + locator as any, + ); + + const driver = createDriver(); + await expect( + driver.click('testId', 'btn', new Map(), 5000), + ).rejects.toThrowError('Element detached'); + }); + + it('throws when budget exhausted by waitForTarget', async () => { + const locator = createMockLocator(); + vi.spyOn(discoveryModule, 'waitForTarget').mockResolvedValue( + locator as any, + ); + + const nowSpy = vi.spyOn(Date, 'now'); + nowSpy.mockReturnValueOnce(1000); + nowSpy.mockReturnValueOnce(6000); + + const driver = createDriver(); + await expect( + driver.click('testId', 'btn', new Map(), 5), + ).rejects.toThrowError('visibility wait consumed entire budget'); + expect(locator.click).not.toHaveBeenCalled(); + }); + }); + + describe('type', () => { + it('delegates to waitForTarget and locator.fill', async () => { + const locator = createMockLocator(); + vi.spyOn(discoveryModule, 'waitForTarget').mockResolvedValue( + locator as any, + ); + + const driver = createDriver(); + const result = await driver.type( + 'testId', + 'email-input', + 'user@test.com', + new Map(), + 5000, + ); + + expect(result.typed).toBe(true); + expect(result.textLength).toBe(13); + expect(locator.fill).toHaveBeenCalledWith( + 'user@test.com', + expect.any(Object), + ); + }); + + it('passes within scope to waitForTarget', async () => { + const locator = createMockLocator(); + vi.spyOn(discoveryModule, 'waitForTarget').mockResolvedValue( + locator as any, + ); + + const driver = createDriver(); + await driver.type('testId', 'input', 'text', new Map(), 5000, { + type: 'a11yRef', + value: 'e5', + }); + + expect(discoveryModule.waitForTarget).toHaveBeenCalledWith( + expect.anything(), + 'testId', + 'input', + expect.any(Map), + 5000, + { type: 'a11yRef', value: 'e5' }, + ); + }); + + it('throws when budget exhausted by waitForTarget', async () => { + const locator = createMockLocator(); + vi.spyOn(discoveryModule, 'waitForTarget').mockResolvedValue( + locator as any, + ); + + const nowSpy = vi.spyOn(Date, 'now'); + nowSpy.mockReturnValueOnce(1000); + nowSpy.mockReturnValueOnce(6000); + + const driver = createDriver(); + await expect( + driver.type('testId', 'input', 'text', new Map(), 5), + ).rejects.toThrowError('visibility wait consumed entire budget'); + expect(locator.fill).not.toHaveBeenCalled(); + }); + }); + + describe('waitForElement', () => { + it('delegates to waitForTarget', async () => { + const locator = createMockLocator(); + vi.spyOn(discoveryModule, 'waitForTarget').mockResolvedValue( + locator as any, + ); + + const driver = createDriver(); + await driver.waitForElement('testId', 'loading', new Map(), 5000); + + expect(discoveryModule.waitForTarget).toHaveBeenCalledOnce(); + }); + }); + + describe('getText', () => { + it('returns text content from locator', async () => { + const locator = createMockLocator(); + locator.textContent.mockResolvedValue('Balance: $100'); + vi.spyOn(discoveryModule, 'waitForTarget').mockResolvedValue( + locator as any, + ); + + const driver = createDriver(); + const result = await driver.getText('testId', 'balance', new Map(), 5000); + + expect(result.text).toBe('Balance: $100'); + expect(result).toHaveLength(13); + }); + + it('returns empty string when textContent is null', async () => { + const locator = createMockLocator(); + locator.textContent.mockResolvedValue(null); + vi.spyOn(discoveryModule, 'waitForTarget').mockResolvedValue( + locator as any, + ); + + const driver = createDriver(); + const result = await driver.getText('testId', 'empty', new Map(), 5000); + + expect(result.text).toBe(''); + expect(result).toHaveLength(0); + }); + + it('throws when budget exhausted by waitForTarget', async () => { + const locator = createMockLocator(); + vi.spyOn(discoveryModule, 'waitForTarget').mockResolvedValue( + locator as any, + ); + + const nowSpy = vi.spyOn(Date, 'now'); + nowSpy.mockReturnValueOnce(1000); + nowSpy.mockReturnValueOnce(6000); + + const driver = createDriver(); + await expect( + driver.getText('testId', 'el', new Map(), 5), + ).rejects.toThrowError('visibility wait consumed entire budget'); + expect(locator.textContent).not.toHaveBeenCalled(); + }); + }); + + describe('getAccessibilityTree', () => { + it('delegates to collectTrimmedA11ySnapshot', async () => { + const mockNodes = [ + { ref: 'e1', role: 'button', name: 'Submit', path: [] }, + ]; + const mockRefMap = new Map([['e1', 'role=button[name="Submit"]']]); + vi.spyOn(discoveryModule, 'collectTrimmedA11ySnapshot').mockResolvedValue( + { nodes: mockNodes, refMap: mockRefMap } as any, + ); + + const driver = createDriver(); + const { nodes, refMap } = await driver.getAccessibilityTree(); + + expect(nodes).toStrictEqual(mockNodes); + expect(refMap).toBe(mockRefMap); + }); + + it('passes rootSelector to collectTrimmedA11ySnapshot', async () => { + vi.spyOn(discoveryModule, 'collectTrimmedA11ySnapshot').mockResolvedValue( + { nodes: [], refMap: new Map() } as any, + ); + + const driver = createDriver(); + await driver.getAccessibilityTree('#main-content'); + + expect(discoveryModule.collectTrimmedA11ySnapshot).toHaveBeenCalledWith( + expect.anything(), + '#main-content', + ); + }); + }); + + describe('getTestIds', () => { + it('delegates to collectTestIds', async () => { + const mockItems = [{ testId: 'submit', tag: 'button', visible: true }]; + vi.spyOn(discoveryModule, 'collectTestIds').mockResolvedValue( + mockItems as any, + ); + + const driver = createDriver(); + const items = await driver.getTestIds(50); + + expect(items).toStrictEqual(mockItems); + expect(discoveryModule.collectTestIds).toHaveBeenCalledWith( + expect.anything(), + 50, + ); + }); + }); + + describe('screenshot', () => { + it('delegates to sessionManager.screenshot', async () => { + const mockResult = { + path: '/tmp/ss.png', + base64: 'abc', + width: 800, + height: 600, + }; + const sessionManager = createMockSessionManager({ hasActive: true }); + vi.spyOn(sessionManager, 'screenshot').mockResolvedValue(mockResult); + + const driver = new PlaywrightPlatformDriver( + () => createMockPage() as any, + sessionManager as any, + ); + + const result = await driver.screenshot({ name: 'test', fullPage: true }); + + expect(result).toStrictEqual(mockResult); + expect(sessionManager.screenshot).toHaveBeenCalledWith({ + name: 'test', + fullPage: true, + selector: undefined, + }); + }); + }); + + describe('getAppState', () => { + it('delegates to sessionManager.getExtensionState', async () => { + const mockState = { + isLoaded: true, + currentUrl: 'chrome-extension://abc/home.html', + extensionId: 'abc', + isUnlocked: true, + currentScreen: 'home', + accountAddress: null, + networkName: null, + chainId: null, + balance: null, + }; + const sessionManager = createMockSessionManager({ hasActive: true }); + vi.spyOn(sessionManager, 'getExtensionState').mockResolvedValue( + mockState, + ); + + const driver = new PlaywrightPlatformDriver( + () => createMockPage() as any, + sessionManager as any, + ); + + const state = await driver.getAppState(); + + expect(state).toStrictEqual(mockState); + }); + }); +}); diff --git a/src/platform/playwright-driver.ts b/src/platform/playwright-driver.ts new file mode 100644 index 0000000..d2d8db5 --- /dev/null +++ b/src/platform/playwright-driver.ts @@ -0,0 +1,285 @@ +import type { Page } from '@playwright/test'; + +import type { + IPlatformDriver, + TargetType, + ClickActionResult, + TypeActionResult, + GetTextActionResult, + PlatformScreenshotOptions, + PlatformType, + WithinScope, +} from './types.js'; +import type { + ScreenshotResult, + ExtensionState, +} from '../capabilities/types.js'; +import type { ISessionManager } from '../server/session-manager.js'; +import { isPageClosedError } from '../tools/error-classification.js'; +import type { TestIdItem, A11yNodeTrimmed } from '../tools/types/discovery.js'; +import { + collectTestIds, + collectTrimmedA11ySnapshot, + waitForTarget, +} from '../tools/utils/discovery.js'; + +/** + * Platform driver implementation for Playwright-based browser automation. + * Wraps existing Playwright interaction and discovery functions behind + * the IPlatformDriver interface for cross-platform tool delegation. + */ +export class PlaywrightPlatformDriver implements IPlatformDriver { + readonly #getPage: () => Page; + + readonly #sessionManager: ISessionManager; + + /** + * @param getPage - Getter for the current active Playwright page. + * @param sessionManager - The session manager for screenshot and state delegation. + */ + constructor(getPage: () => Page, sessionManager: ISessionManager) { + this.#getPage = getPage; + this.#sessionManager = sessionManager; + } + + /** + * Click an element, handling page-closed errors as successful navigation clicks. + * + * @param targetType - The type of target identifier (a11yRef, testId, selector). + * @param targetValue - The target value used for element lookup. + * @param refMap - Map of a11y refs to selectors. + * @param timeoutMs - Maximum time in milliseconds for the interaction. + * @param within - Optional parent scope for chained locator resolution. + * @returns The click result with success status and target info. + */ + async click( + targetType: TargetType, + targetValue: string, + refMap: Map, + timeoutMs: number, + within?: WithinScope, + ): Promise { + const page = this.#getPage(); + const deadline = Date.now() + timeoutMs; + + const locator = await waitForTarget( + page, + targetType, + targetValue, + refMap, + timeoutMs, + within, + ); + + const remaining = deadline - Date.now(); + if (remaining <= 0) { + throw new Error( + `Timeout ${timeoutMs}ms exceeded: visibility wait consumed entire budget for ${targetType}:${targetValue}`, + ); + } + + try { + await locator.click({ timeout: remaining }); + } catch (error) { + if (isPageClosedError(error)) { + return { + clicked: true, + target: `${targetType}:${targetValue}`, + pageClosedAfterClick: true, + }; + } + throw error; + } + + return { + clicked: true, + target: `${targetType}:${targetValue}`, + }; + } + + /** + * Type text into an input element after waiting for visibility. + * + * @param targetType - The type of target identifier. + * @param targetValue - The target value used for element lookup. + * @param text - The text to type into the input. + * @param refMap - Map of a11y refs to selectors. + * @param timeoutMs - Maximum time in milliseconds. + * @param within - Optional parent scope. + * @returns The type result with success status and text length. + */ + async type( + targetType: TargetType, + targetValue: string, + text: string, + refMap: Map, + timeoutMs: number, + within?: WithinScope, + ): Promise { + const page = this.#getPage(); + const deadline = Date.now() + timeoutMs; + + const locator = await waitForTarget( + page, + targetType, + targetValue, + refMap, + timeoutMs, + within, + ); + + const remaining = deadline - Date.now(); + if (remaining <= 0) { + throw new Error( + `Timeout ${timeoutMs}ms exceeded: visibility wait consumed entire budget for ${targetType}:${targetValue}`, + ); + } + + await locator.fill(text, { timeout: remaining }); + + return { + typed: true, + target: `${targetType}:${targetValue}`, + textLength: text.length, + }; + } + + /** + * Wait for an element to become visible on the page. + * + * @param targetType - The type of target identifier. + * @param targetValue - The target value used for element lookup. + * @param refMap - Map of a11y refs to selectors. + * @param timeoutMs - Maximum time in milliseconds. + * @param within - Optional parent scope. + */ + async waitForElement( + targetType: TargetType, + targetValue: string, + refMap: Map, + timeoutMs: number, + within?: WithinScope, + ): Promise { + const page = this.#getPage(); + await waitForTarget( + page, + targetType, + targetValue, + refMap, + timeoutMs, + within, + ); + } + + /** + * Read the text content of an element. + * + * @param targetType - The type of target identifier. + * @param targetValue - The target value used for element lookup. + * @param refMap - Map of a11y refs to selectors. + * @param timeoutMs - Maximum time in milliseconds. + * @param within - Optional parent scope. + * @returns The text content, target descriptor, and character length. + */ + async getText( + targetType: TargetType, + targetValue: string, + refMap: Map, + timeoutMs: number, + within?: WithinScope, + ): Promise { + const page = this.#getPage(); + const deadline = Date.now() + timeoutMs; + + const locator = await waitForTarget( + page, + targetType, + targetValue, + refMap, + timeoutMs, + within, + ); + + const remaining = deadline - Date.now(); + if (remaining <= 0) { + throw new Error( + `Timeout ${timeoutMs}ms exceeded: visibility wait consumed entire budget for ${targetType}:${targetValue}`, + ); + } + + const text = (await locator.textContent({ timeout: remaining })) ?? ''; + + return { + text, + target: `${targetType}:${targetValue}`, + length: text.length, + }; + } + + /** + * Capture the trimmed accessibility tree with deterministic refs. + * + * @param rootSelector - Optional CSS selector to scope the snapshot. + * @returns The accessibility nodes and ref-to-selector map. + */ + async getAccessibilityTree( + rootSelector?: string, + ): Promise<{ nodes: A11yNodeTrimmed[]; refMap: Map }> { + const page = this.#getPage(); + return collectTrimmedA11ySnapshot(page, rootSelector); + } + + /** + * Collect visible test IDs from the current page. + * + * @param limit - Maximum number of test IDs to return. + * @returns Array of test ID items. + */ + async getTestIds(limit?: number): Promise { + const page = this.#getPage(); + return collectTestIds(page, limit ?? 150); + } + + /** + * Capture a screenshot of the current page. + * + * @param options - Screenshot options (name, fullPage, selector). + * @returns Screenshot result with path and dimensions. + */ + async screenshot( + options: PlatformScreenshotOptions, + ): Promise { + return this.#sessionManager.screenshot({ + name: options.name, + fullPage: options.fullPage, + selector: options.selector, + }); + } + + /** + * Get the current extension state. + * + * @returns The extension state including URL, screen, network, balance. + */ + async getAppState(): Promise { + return this.#sessionManager.getExtensionState(); + } + + /** + * Get the current page URL. + * + * @returns The URL of the active page. + */ + getCurrentUrl(): string { + return this.#getPage().url(); + } + + /** + * Get the platform type. + * + * @returns 'browser' for this driver. + */ + getPlatform(): PlatformType { + return 'browser'; + } +} diff --git a/src/platform/types.ts b/src/platform/types.ts new file mode 100644 index 0000000..8085c1e --- /dev/null +++ b/src/platform/types.ts @@ -0,0 +1,88 @@ +import type { + ScreenshotResult, + ExtensionState, +} from '../capabilities/types.js'; +import type { TestIdItem, A11yNodeTrimmed } from '../tools/types/discovery.js'; + +export type PlatformType = 'browser' | 'ios' | 'android'; + +export type TargetType = 'a11yRef' | 'testId' | 'selector'; + +export type ClickActionResult = { + clicked: boolean; + target: string; + pageClosedAfterClick?: boolean; +}; + +export type TypeActionResult = { + typed: boolean; + target: string; + textLength: number; +}; + +export type GetTextActionResult = { + text: string; + target: string; + length: number; +}; + +export type PlatformScreenshotOptions = { + name: string; + fullPage?: boolean; + selector?: string; + includeBase64?: boolean; +}; + +export type WithinScope = { + type: TargetType; + value: string; +}; + +export type IPlatformDriver = { + click( + targetType: TargetType, + targetValue: string, + refMap: Map, + timeoutMs: number, + within?: WithinScope, + ): Promise; + + type( + targetType: TargetType, + targetValue: string, + text: string, + refMap: Map, + timeoutMs: number, + within?: WithinScope, + ): Promise; + + waitForElement( + targetType: TargetType, + targetValue: string, + refMap: Map, + timeoutMs: number, + within?: WithinScope, + ): Promise; + + getText( + targetType: TargetType, + targetValue: string, + refMap: Map, + timeoutMs: number, + within?: WithinScope, + ): Promise; + + getAccessibilityTree( + rootSelector?: string, + ): Promise<{ nodes: A11yNodeTrimmed[]; refMap: Map }>; + + getTestIds(limit?: number): Promise; + + screenshot(options: PlatformScreenshotOptions): Promise; + + getAppState(): Promise; + + getCurrentUrl(): string; + + getPlatform(): PlatformType; +}; diff --git a/src/server/create-server.test.ts b/src/server/create-server.test.ts index d8edd45..2242503 100644 --- a/src/server/create-server.test.ts +++ b/src/server/create-server.test.ts @@ -1789,3 +1789,88 @@ describe('observation compaction in HTTP responses', () => { expect(body.observations).toBeUndefined(); }); }); + +describe('createServer observation driver freshness after launch', () => { + let server: ServerInstance; + let state: DaemonState; + let mockSM: ReturnType; + let mobileDriver: { + getPlatform: ReturnType; + getAppState: ReturnType; + getTestIds: ReturnType; + getAccessibilityTree: ReturnType; + }; + + beforeEach(async () => { + await fs.mkdir(tmpDir, { recursive: true }); + + mobileDriver = { + getPlatform: vi.fn().mockReturnValue('ios'), + getAppState: vi.fn().mockResolvedValue({ + isLoaded: true, + currentUrl: '', + extensionId: 'io.metamask', + isUnlocked: true, + currentScreen: 'unknown', + accountAddress: null, + networkName: null, + chainId: null, + balance: null, + }), + getTestIds: vi.fn().mockResolvedValue([]), + getAccessibilityTree: vi + .fn() + .mockResolvedValue({ nodes: [], refMap: new Map() }), + }; + + mockSM = createMockSessionManager(); + const getPlatformDriverMock = vi.fn().mockReturnValue(undefined); + (mockSM as Record).getPlatformDriver = + getPlatformDriverMock; + // Session starts inactive; launch activates it and sets the mobile driver + mockSM.hasActiveSession.mockReturnValue(false); + mockSM.launch.mockImplementation(async () => { + mockSM.hasActiveSession.mockReturnValue(true); + getPlatformDriverMock.mockReturnValue(mobileDriver); + return { + sessionId: 'mobile-session', + extensionId: 'io.metamask', + state: {}, + }; + }); + + server = createServer( + buildConfig({ + sessionManager: mockSM as unknown as ServerConfig['sessionManager'], + }), + ); + state = await server.start(); + }); + + afterEach(async () => { + await server.stop(); + await fs.rm(tmpDir, { recursive: true, force: true }).catch(() => {}); + }); + + it('uses the mobile driver for observations after launch sets it', async () => { + const res = await httpRequest(`http://127.0.0.1:${state.port}/launch`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ platform: 'ios' }), + }); + const body = (await res.json()) as { + ok: boolean; + observations?: { state: unknown }; + }; + + expect(res.status).toBe(200); + expect(body.ok).toBe(true); + // The mobile driver's getAppState should have been called for + // observation collection, NOT the Playwright path (which would + // fail since there's no browser page). + expect(mobileDriver.getAppState).toHaveBeenCalled(); + expect(mobileDriver.getTestIds).toHaveBeenCalled(); + expect(mobileDriver.getAccessibilityTree).toHaveBeenCalled(); + expect(body.observations).toBeDefined(); + }); +}); diff --git a/src/server/create-server.ts b/src/server/create-server.ts index 14d51a1..6196dcd 100644 --- a/src/server/create-server.ts +++ b/src/server/create-server.ts @@ -14,7 +14,12 @@ import { KnowledgeStore, createDefaultObservation, } from '../knowledge-store/knowledge-store.js'; -import { toolRegistry, getToolCategory } from '../tools/registry.js'; +import { PlaywrightPlatformDriver } from '../platform/playwright-driver.js'; +import { + toolRegistry, + getToolCategory, + isBrowserOnlyTool, +} from '../tools/registry.js'; import type { ToolCategory } from '../tools/registry.js'; import type { StepRecordObservation, @@ -300,6 +305,14 @@ export function createServer(config: ServerConfig): ServerInstance { * @returns A ToolContext with lazy page and refMap accessors. */ function buildToolContext(wfCtx: WorkflowContext): ToolContext { + const driver = config.sessionManager.hasActiveSession() + ? (config.sessionManager.getPlatformDriver?.() ?? + new PlaywrightPlatformDriver( + () => config.sessionManager.getPage(), + config.sessionManager, + )) + : undefined; + return { sessionManager: config.sessionManager, get page(): ReturnType { @@ -313,6 +326,7 @@ export function createServer(config: ServerConfig): ServerInstance { workflowContext: wfCtx, knowledgeStore, toolRegistry, + driver, }; } @@ -433,6 +447,21 @@ export function createServer(config: ServerConfig): ServerInstance { const currentWorkflowContext = workflowContext; const category = getToolCategory(toolName); + + if (isBrowserOnlyTool(toolName)) { + const platformDriver = config.sessionManager.getPlatformDriver?.(); + if (platformDriver && platformDriver.getPlatform() !== 'browser') { + res.json({ + ok: false, + error: { + code: 'MM_TOOL_NOT_SUPPORTED_ON_PLATFORM', + message: `Tool "${toolName}" is not supported on ${platformDriver.getPlatform()} platform`, + }, + }); + return; + } + } + const inputRecord = validatedInput as Record; const toolTimeoutMs = inputRecord?.timeoutMs; const queueTimeoutMs = @@ -457,6 +486,29 @@ export function createServer(config: ServerConfig): ServerInstance { try { obs = await Promise.race([ (async (): Promise => { + // Re-read the driver from the session manager rather than + // using context.driver, which was captured before the tool + // ran. The launch tool creates the session (and sets the + // platform driver) during execution, so the context snapshot + // would still be undefined at that point. + const currentDriver = + config.sessionManager.getPlatformDriver?.() ?? + context.driver; + + if ( + currentDriver && + currentDriver.getPlatform() !== 'browser' + ) { + const state = await currentDriver.getAppState(); + const testIds = await currentDriver.getTestIds( + OBSERVATION_TESTID_LIMIT, + ); + const { nodes, refMap: newRefMap } = + await currentDriver.getAccessibilityTree(); + config.sessionManager.setRefMap(newRefMap); + return createDefaultObservation(state, testIds, nodes); + } + const page = config.sessionManager.getPage(); if (category === 'mutating') { @@ -482,8 +534,6 @@ export function createServer(config: ServerConfig): ServerInstance { } let state = await config.sessionManager.getExtensionState(); - // Post-mutation recheck: if currentScreen is 'unknown' after a mutation, - // the extension's internal router may not have updated yet. Poll briefly. if ( category === 'mutating' && state.currentScreen === 'unknown' diff --git a/src/server/session-manager.ts b/src/server/session-manager.ts index 2ac10fe..dae4f62 100644 --- a/src/server/session-manager.ts +++ b/src/server/session-manager.ts @@ -23,6 +23,7 @@ import type { StateSnapshotCapability, ScreenshotResult, } from '../capabilities/types.js'; +import type { IPlatformDriver, PlatformType } from '../platform/types.js'; import type { TabRole, SessionState, SessionMetadata } from '../tools/types'; /** @@ -63,6 +64,10 @@ export type SessionLaunchInput = { }; /** Smart contracts to deploy on launch */ seedContracts?: string[]; + /** Target platform (defaults to 'browser') */ + platform?: PlatformType; + /** Device ID for explicit mobile device targeting */ + deviceId?: string; }; /** @@ -309,4 +314,12 @@ export type ISessionManager = { }; canSwitchContext: boolean; }; + + // ----------------------------------------------------------------------------- + // Platform Driver (Optional — consumers that support mobile implement these) + // ----------------------------------------------------------------------------- + + getPlatformDriver?(): IPlatformDriver | undefined; + + setPlatformDriver?(driver: IPlatformDriver): void; }; diff --git a/src/tools/batch.test.ts b/src/tools/batch.test.ts index 2cfffab..61d59fd 100644 --- a/src/tools/batch.test.ts +++ b/src/tools/batch.test.ts @@ -9,9 +9,14 @@ function createMockContext( options: { hasActive?: boolean; toolRegistry?: Map>; + driverPlatform?: 'browser' | 'ios' | 'android'; } = {}, ): ToolContext { - const { hasActive = true, toolRegistry } = options; + const { hasActive = true, toolRegistry, driverPlatform } = options; + + const driver = driverPlatform + ? ({ getPlatform: () => driverPlatform } as ToolContext['driver']) + : undefined; return { sessionManager: createMockSessionManager({ hasActive }), @@ -20,6 +25,7 @@ function createMockContext( workflowContext: {}, knowledgeStore: {}, toolRegistry, + driver, } as unknown as ToolContext; } @@ -624,4 +630,57 @@ describe('runStepsTool', () => { ); } }); + + it('rejects browser-only tools on non-browser platforms', async () => { + const navigateHandler = vi.fn(); + const clickHandler = vi.fn().mockResolvedValue({ ok: true, result: {} }); + const context = createMockContext({ + toolRegistry: new Map([ + ['navigate', navigateHandler], + ['click', clickHandler], + ]), + driverPlatform: 'ios', + }); + + const result = await runStepsTool( + { + steps: [ + { tool: 'navigate', args: { screen: 'home' } }, + { tool: 'click', args: { testId: 'btn' } }, + ], + }, + context, + ); + + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.result.steps).toHaveLength(2); + expect(result.result.steps[0].ok).toBe(false); + expect(result.result.steps[0].error?.code).toBe( + 'MM_TOOL_NOT_SUPPORTED_ON_PLATFORM', + ); + expect(navigateHandler).not.toHaveBeenCalled(); + expect(result.result.steps[1].ok).toBe(true); + expect(clickHandler).toHaveBeenCalled(); + } + }); + + it('allows browser-only tools on browser platform', async () => { + const navigateHandler = vi.fn().mockResolvedValue({ ok: true, result: {} }); + const context = createMockContext({ + toolRegistry: new Map([['navigate', navigateHandler]]), + driverPlatform: 'browser', + }); + + const result = await runStepsTool( + { steps: [{ tool: 'navigate', args: { screen: 'home' } }] }, + context, + ); + + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.result.steps[0].ok).toBe(true); + expect(navigateHandler).toHaveBeenCalled(); + } + }); }); diff --git a/src/tools/batch.ts b/src/tools/batch.ts index dc88651..aff5f2a 100644 --- a/src/tools/batch.ts +++ b/src/tools/batch.ts @@ -1,3 +1,4 @@ +import { isBrowserOnlyTool } from './registry.js'; import type { RunStepsInput, RunStepsResult, StepResult } from './types'; import { ErrorCodes } from './types'; import { createToolError, createToolSuccess } from './utils.js'; @@ -159,6 +160,31 @@ export async function runStepsTool( continue; } + if (isBrowserOnlyTool(tool)) { + const driverPlatform = context.driver?.getPlatform(); + if (driverPlatform && driverPlatform !== 'browser') { + stepResults.push({ + tool, + ok: false, + error: { + code: 'MM_TOOL_NOT_SUPPORTED_ON_PLATFORM', + message: `Tool "${tool}" is not supported on ${driverPlatform} platform`, + }, + meta: { + durationMs: Date.now() - stepStartTime, + timestamp: new Date().toISOString(), + }, + }); + failed += 1; + + if (stopOnError) { + break; + } + + continue; + } + } + const schema = tool in toolSchemas ? toolSchemas[tool as ToolName] : undefined; let validatedArgs: Record = args; diff --git a/src/tools/discovery-tools.test.ts b/src/tools/discovery-tools.test.ts index 683a7af..d691aa5 100644 --- a/src/tools/discovery-tools.test.ts +++ b/src/tools/discovery-tools.test.ts @@ -19,6 +19,7 @@ import { createMockSessionManager } from './test-utils/mock-factories.js'; import type { A11yNodeTrimmed, TestIdItem } from './types'; import { ErrorCodes } from './types/errors.js'; import * as discoveryModule from './utils/discovery.js'; +import { PlaywrightPlatformDriver } from '../platform/playwright-driver.js'; import type { ToolContext } from '../types/http.js'; function createMockPage(): Page { @@ -33,29 +34,34 @@ function createMockContext( } = {}, ): ToolContext { const { hasActive = true } = options; - - return { - sessionManager: createMockSessionManager({ - hasActive, + const page = createMockPage(); + const sessionManager = createMockSessionManager({ + hasActive, + sessionId: 'test-session-123', + sessionMetadata: { + schemaVersion: 1, sessionId: 'test-session-123', - sessionMetadata: { - schemaVersion: 1, - sessionId: 'test-session-123', - createdAt: '2026-02-04T00:00:00.000Z', - goal: 'Test discovery', - flowTags: ['discovery'], - tags: [], - launch: { - stateMode: 'default', - }, + createdAt: '2026-02-04T00:00:00.000Z', + goal: 'Test discovery', + flowTags: ['discovery'], + tags: [], + launch: { + stateMode: 'default', }, - }), - page: createMockPage(), + }, + }); + + return { + sessionManager, + page, refMap: new Map(), workflowContext: {}, knowledgeStore: { generatePriorKnowledge: vi.fn().mockResolvedValue(undefined), }, + driver: hasActive + ? new PlaywrightPlatformDriver(() => page, sessionManager as any) + : undefined, } as unknown as ToolContext; } @@ -222,23 +228,20 @@ describe('discovery-tools', () => { expect(context.sessionManager.setRefMap).toHaveBeenCalledWith(mockRefMap); }); - it('collects test ids with observation limit', async () => { + it('updates refMap after snapshot', async () => { const context = createMockContext(); + const refMap = new Map([['e1', 'role=button[name="OK"]']]); vi.spyOn(discoveryModule, 'collectTrimmedA11ySnapshot').mockResolvedValue( { - nodes: [], - refMap: new Map(), + nodes: [{ ref: 'e1', role: 'button', name: 'OK', path: [] }], + refMap, }, ); - vi.spyOn(discoveryModule, 'collectTestIds').mockResolvedValue([]); await accessibilitySnapshotTool({}, context); - expect(discoveryModule.collectTestIds).toHaveBeenCalledWith( - context.page, - 50, - ); + expect(context.sessionManager.setRefMap).toHaveBeenCalledWith(refMap); }); it('returns error when no active session', async () => { diff --git a/src/tools/discovery-tools.ts b/src/tools/discovery-tools.ts index 60962fc..d40e00a 100644 --- a/src/tools/discovery-tools.ts +++ b/src/tools/discovery-tools.ts @@ -8,14 +8,8 @@ import type { ListTestIdsResult, PriorKnowledgeContext, } from './types'; -import { - DEFAULT_TESTID_LIMIT, - OBSERVATION_TESTID_LIMIT, -} from './utils/constants.js'; -import { - collectTestIds, - collectTrimmedA11ySnapshot, -} from './utils/discovery.js'; +import { ErrorCodes } from './types'; +import { DEFAULT_TESTID_LIMIT } from './utils/constants.js'; import { createToolError, createToolSuccess, @@ -23,6 +17,28 @@ import { } from './utils.js'; import type { ToolContext, ToolResponse } from '../types/http.js'; +/** + * Validates that the platform driver is available, returning it or an error. + * + * @param context - The tool execution context. + * @returns The driver if available, or an error response. + */ +function requireDriver( + context: ToolContext, +): + | { driver: NonNullable } + | { error: ToolResponse } { + if (!context.driver) { + return { + error: createToolError( + ErrorCodes.MM_NO_ACTIVE_SESSION, + 'No platform driver available', + ), + }; + } + return { driver: context.driver }; +} + /** * Collects visible test IDs from the current page. * @@ -38,15 +54,18 @@ export async function listTestIdsTool( if (missingSession) { return missingSession; } + const driverResult = requireDriver(context); + if ('error' in driverResult) { + return driverResult.error; + } + const { driver } = driverResult; const limit = input.limit ?? DEFAULT_TESTID_LIMIT; try { - const items = await collectTestIds(context.page, limit); - const { refMap } = await collectTrimmedA11ySnapshot(context.page); - + const items = await driver.getTestIds(limit); + const { refMap } = await driver.getAccessibilityTree(); context.sessionManager.setRefMap(refMap); - return createToolSuccess({ items }); } catch (error) { const errorInfo = classifyDiscoveryError(error); @@ -70,16 +89,17 @@ export async function accessibilitySnapshotTool( if (missingSession) { return missingSession; } + const driverResult = requireDriver(context); + if ('error' in driverResult) { + return driverResult.error; + } + const { driver } = driverResult; try { - const { nodes, refMap } = await collectTrimmedA11ySnapshot( - context.page, + const { nodes, refMap } = await driver.getAccessibilityTree( input.rootSelector, ); - context.sessionManager.setRefMap(refMap); - await collectTestIds(context.page, OBSERVATION_TESTID_LIMIT); - return createToolSuccess({ nodes }); } catch (error) { const errorInfo = classifyDiscoveryError(error); @@ -102,28 +122,39 @@ export async function describeScreenTool( if (missingSession) { return missingSession; } + const driverResult = requireDriver(context); + if ('error' in driverResult) { + return driverResult.error; + } + const { driver } = driverResult; try { - const state = await context.sessionManager.getExtensionState(); - const testIds = await collectTestIds(context.page, DEFAULT_TESTID_LIMIT); - const { nodes, refMap } = await collectTrimmedA11ySnapshot(context.page); + const state = await driver.getAppState(); + const testIds = await driver.getTestIds(DEFAULT_TESTID_LIMIT); + const { nodes, refMap } = await driver.getAccessibilityTree(); context.sessionManager.setRefMap(refMap); const trackedPages = context.sessionManager.getTrackedPages(); - const activePage = context.sessionManager.getPage(); - const activeTracked = trackedPages.find((tp) => tp.page === activePage); - const activeTab = activeTracked - ? { role: activeTracked.role, url: activePage.url() } - : undefined; + let activeTab: DescribeScreenResult['activeTab']; + try { + const activePage = context.sessionManager.getPage(); + const activeTracked = trackedPages.find((tp) => tp.page === activePage); + activeTab = activeTracked + ? { role: activeTracked.role, url: activePage.url() } + : undefined; + } catch { + activeTab = undefined; + } let screenshot: DescribeScreenResult['screenshot'] = null; if (input.includeScreenshot) { const screenshotName = input.screenshotName ?? 'describe-screen'; - const result = await context.sessionManager.screenshot({ + const result = await driver.screenshot({ name: screenshotName, fullPage: true, + includeBase64: input.includeScreenshotBase64, }); screenshot = { diff --git a/src/tools/driver-delegation.test.ts b/src/tools/driver-delegation.test.ts new file mode 100644 index 0000000..19ad5e2 --- /dev/null +++ b/src/tools/driver-delegation.test.ts @@ -0,0 +1,580 @@ +import { describe, it, expect, vi, afterEach } from 'vitest'; + +import { + listTestIdsTool, + accessibilitySnapshotTool, + describeScreenTool, +} from './discovery-tools.js'; +import { + clickTool, + getTextTool, + typeTool, + waitForTool, +} from './interaction.js'; +import { screenshotTool } from './screenshot.js'; +import { getStateTool } from './state.js'; +import { createMockSessionManager } from './test-utils/mock-factories.js'; +import { ErrorCodes } from './types/errors.js'; +import type { IPlatformDriver } from '../platform/types.js'; +import type { ToolContext } from '../types/http.js'; + +function createMockDriver( + overrides: Partial = {}, +): IPlatformDriver { + return { + click: vi.fn().mockResolvedValue({ clicked: true, target: 'testId:btn' }), + type: vi.fn().mockResolvedValue({ + typed: true, + target: 'testId:input', + textLength: 5, + }), + waitForElement: vi.fn().mockResolvedValue(undefined), + getText: vi + .fn() + .mockResolvedValue({ text: 'hello', target: 'testId:el', length: 5 }), + getAccessibilityTree: vi.fn().mockResolvedValue({ + nodes: [{ ref: 'e1', role: 'button', name: 'OK', path: [] }], + refMap: new Map([['e1', 'role=button[name="OK"]']]), + }), + getTestIds: vi + .fn() + .mockResolvedValue([{ testId: 'submit', tag: 'button', visible: true }]), + screenshot: vi.fn().mockResolvedValue({ + path: '/tmp/shot.png', + base64: 'abc', + width: 400, + height: 800, + }), + getAppState: vi.fn().mockResolvedValue({ + isLoaded: true, + currentUrl: '', + extensionId: 'test', + isUnlocked: true, + currentScreen: 'home', + accountAddress: null, + networkName: null, + chainId: null, + balance: null, + }), + getCurrentUrl: vi.fn().mockReturnValue(''), + getPlatform: vi.fn().mockReturnValue('ios'), + ...overrides, + }; +} + +function createMockPage() { + return { + url: vi.fn().mockReturnValue('chrome-extension://test/home.html'), + isClosed: vi.fn().mockReturnValue(false), + }; +} + +function createContextWithDriver(driver: IPlatformDriver): ToolContext { + const page = createMockPage(); + const sessionManager = createMockSessionManager({ hasActive: true }); + vi.spyOn(sessionManager as any, 'getPage').mockReturnValue(page); + return { + sessionManager, + page: page as any, + refMap: new Map(), + workflowContext: {} as any, + knowledgeStore: { + generatePriorKnowledge: vi.fn().mockResolvedValue(undefined), + } as any, + toolRegistry: new Map(), + driver, + } as unknown as ToolContext; +} + +function createContextWithoutDriver(): ToolContext { + const page = createMockPage(); + const sessionManager = createMockSessionManager({ hasActive: true }); + vi.spyOn(sessionManager as any, 'getPage').mockReturnValue(page); + return { + sessionManager, + page: page as any, + refMap: new Map(), + workflowContext: {} as any, + knowledgeStore: { + generatePriorKnowledge: vi.fn().mockResolvedValue(undefined), + } as any, + toolRegistry: new Map(), + driver: undefined, + } as unknown as ToolContext; +} + +describe('tool delegation to context.driver', () => { + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe('no driver with active session', () => { + it('clickTool returns error when driver is undefined', async () => { + const context = createContextWithoutDriver(); + const result = await clickTool({ testId: 'btn' }, context); + + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error.code).toBe(ErrorCodes.MM_NO_ACTIVE_SESSION); + expect(result.error.message).toContain('platform driver'); + } + }); + + it('typeTool returns error when driver is undefined', async () => { + const context = createContextWithoutDriver(); + const result = await typeTool( + { testId: 'input', text: 'hello' }, + context, + ); + + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error.code).toBe(ErrorCodes.MM_NO_ACTIVE_SESSION); + } + }); + + it('waitForTool returns error when driver is undefined', async () => { + const context = createContextWithoutDriver(); + const result = await waitForTool({ testId: 'el' }, context); + + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error.code).toBe(ErrorCodes.MM_NO_ACTIVE_SESSION); + } + }); + + it('getTextTool returns error when driver is undefined', async () => { + const context = createContextWithoutDriver(); + const result = await getTextTool({ testId: 'label' }, context); + + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error.code).toBe(ErrorCodes.MM_NO_ACTIVE_SESSION); + } + }); + + it('listTestIdsTool returns error when driver is undefined', async () => { + const context = createContextWithoutDriver(); + const result = await listTestIdsTool({}, context); + + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error.code).toBe(ErrorCodes.MM_NO_ACTIVE_SESSION); + } + }); + + it('accessibilitySnapshotTool returns error when driver is undefined', async () => { + const context = createContextWithoutDriver(); + const result = await accessibilitySnapshotTool({}, context); + + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error.code).toBe(ErrorCodes.MM_NO_ACTIVE_SESSION); + } + }); + + it('describeScreenTool returns error when driver is undefined', async () => { + const context = createContextWithoutDriver(); + const result = await describeScreenTool({}, context); + + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error.code).toBe(ErrorCodes.MM_NO_ACTIVE_SESSION); + } + }); + + it('screenshotTool returns error when driver is undefined', async () => { + const context = createContextWithoutDriver(); + const result = await screenshotTool({ name: 'test' }, context); + + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error.code).toBe(ErrorCodes.MM_NO_ACTIVE_SESSION); + } + }); + + it('getStateTool returns error when driver is undefined', async () => { + const context = createContextWithoutDriver(); + const result = await getStateTool({} as any, context); + + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error.code).toBe(ErrorCodes.MM_NO_ACTIVE_SESSION); + } + }); + }); + + describe('interaction tools', () => { + it('clickTool delegates to driver.click', async () => { + const driver = createMockDriver(); + const context = createContextWithDriver(driver); + + const result = await clickTool({ testId: 'btn' }, context); + + expect(result.ok).toBe(true); + expect(driver.click).toHaveBeenCalledWith( + 'testId', + 'btn', + expect.any(Map), + expect.any(Number), + undefined, + ); + }); + + it('typeTool delegates to driver.type', async () => { + const driver = createMockDriver(); + const context = createContextWithDriver(driver); + + const result = await typeTool( + { testId: 'input', text: 'hello' }, + context, + ); + + expect(result.ok).toBe(true); + expect(driver.type).toHaveBeenCalledWith( + 'testId', + 'input', + 'hello', + expect.any(Map), + expect.any(Number), + undefined, + ); + }); + + it('waitForTool delegates to driver.waitForElement', async () => { + const driver = createMockDriver(); + const context = createContextWithDriver(driver); + + const result = await waitForTool({ testId: 'spinner' }, context); + + expect(result.ok).toBe(true); + expect(driver.waitForElement).toHaveBeenCalledOnce(); + }); + + it('getTextTool delegates to driver.getText', async () => { + const driver = createMockDriver(); + const context = createContextWithDriver(driver); + + const result = await getTextTool({ testId: 'label' }, context); + + expect(result.ok).toBe(true); + expect(driver.getText).toHaveBeenCalledOnce(); + }); + + it('clickTool maps timeout errors to MM_CLICK_TIMEOUT with descriptive message', async () => { + const driver = createMockDriver({ + click: vi + .fn() + .mockRejectedValue( + new Error('Timeout 15000ms exceeded waiting for element'), + ), + }); + const context = createContextWithDriver(driver); + + const result = await clickTool({ testId: 'slow-btn' }, context); + + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error.code).toBe(ErrorCodes.MM_CLICK_TIMEOUT); + expect(result.error.message).toContain('Click timed out'); + expect(result.error.message).toContain('describe-screen'); + } + }); + + it('clickTool maps target-not-found errors to MM_TARGET_NOT_FOUND', async () => { + const driver = createMockDriver({ + click: vi.fn().mockRejectedValue(new Error('Unknown a11yRef: e99')), + }); + const context = createContextWithDriver(driver); + + const result = await clickTool({ a11yRef: 'e99' }, context); + + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error.code).toBe(ErrorCodes.MM_TARGET_NOT_FOUND); + } + }); + + it('typeTool maps timeout errors to MM_TYPE_TIMEOUT with descriptive message', async () => { + const driver = createMockDriver({ + type: vi + .fn() + .mockRejectedValue( + new Error('Timeout 15000ms exceeded waiting for element'), + ), + }); + const context = createContextWithDriver(driver); + + const result = await typeTool( + { testId: 'slow-input', text: 'hello' }, + context, + ); + + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error.code).toBe(ErrorCodes.MM_TYPE_TIMEOUT); + expect(result.error.message).toContain('Type timed out'); + } + }); + + it('typeTool classifies fill errors', async () => { + const driver = createMockDriver({ + type: vi.fn().mockRejectedValue(new Error('Fill failed: detached')), + }); + const context = createContextWithDriver(driver); + + const result = await typeTool( + { testId: 'input', text: 'hello' }, + context, + ); + + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error.code).toBe(ErrorCodes.MM_TYPE_FAILED); + } + }); + + it('waitForTool maps timeout errors to MM_WAIT_TIMEOUT', async () => { + const driver = createMockDriver({ + waitForElement: vi + .fn() + .mockRejectedValue( + new Error('Timeout 10000ms exceeded waiting for element'), + ), + }); + const context = createContextWithDriver(driver); + + const result = await waitForTool({ testId: 'spinner' }, context); + + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error.code).toBe(ErrorCodes.MM_WAIT_TIMEOUT); + } + }); + + it('getTextTool maps timeout errors to MM_GETTEXT_TIMEOUT with descriptive message', async () => { + const driver = createMockDriver({ + getText: vi + .fn() + .mockRejectedValue( + new Error('Timeout 15000ms exceeded waiting for element'), + ), + }); + const context = createContextWithDriver(driver); + + const result = await getTextTool({ testId: 'label' }, context); + + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error.code).toBe(ErrorCodes.MM_GETTEXT_TIMEOUT); + expect(result.error.message).toContain('GetText timed out'); + } + }); + + it('getTextTool classifies getText errors', async () => { + const driver = createMockDriver({ + getText: vi + .fn() + .mockRejectedValue(new Error('textContent failed: detached')), + }); + const context = createContextWithDriver(driver); + + const result = await getTextTool({ testId: 'label' }, context); + + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error.code).toBe(ErrorCodes.MM_GETTEXT_FAILED); + } + }); + }); + + describe('discovery tools', () => { + it('listTestIdsTool delegates to driver.getTestIds', async () => { + const driver = createMockDriver(); + const context = createContextWithDriver(driver); + + const result = await listTestIdsTool({}, context); + + expect(result.ok).toBe(true); + expect(driver.getTestIds).toHaveBeenCalledOnce(); + expect(driver.getAccessibilityTree).toHaveBeenCalledOnce(); + }); + + it('accessibilitySnapshotTool delegates to driver.getAccessibilityTree', async () => { + const driver = createMockDriver(); + const context = createContextWithDriver(driver); + + const result = await accessibilitySnapshotTool({}, context); + + expect(result.ok).toBe(true); + expect(driver.getAccessibilityTree).toHaveBeenCalledOnce(); + }); + + it('describeScreenTool delegates to driver for state, testIds, and a11y', async () => { + const driver = createMockDriver(); + const context = createContextWithDriver(driver); + + const result = await describeScreenTool({}, context); + + expect(result.ok).toBe(true); + expect(driver.getAppState).toHaveBeenCalledOnce(); + expect(driver.getTestIds).toHaveBeenCalledOnce(); + expect(driver.getAccessibilityTree).toHaveBeenCalledOnce(); + }); + + it('describeScreenTool handles getPage failure gracefully', async () => { + const driver = createMockDriver(); + const sessionManager = createMockSessionManager({ hasActive: true }); + vi.spyOn(sessionManager as any, 'getPage').mockImplementation(() => { + throw new Error('No page on mobile'); + }); + const context = { + sessionManager, + page: {} as any, + refMap: new Map(), + workflowContext: {} as any, + knowledgeStore: { + generatePriorKnowledge: vi.fn().mockResolvedValue(undefined), + } as any, + toolRegistry: new Map(), + driver, + } as unknown as ToolContext; + + const result = await describeScreenTool({}, context); + + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.result.activeTab).toBeUndefined(); + } + }); + + it('listTestIdsTool returns error when driver throws', async () => { + const driver = createMockDriver({ + getTestIds: vi.fn().mockRejectedValue(new Error('Snapshot failed')), + }); + const context = createContextWithDriver(driver); + + const result = await listTestIdsTool({}, context); + + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error.code).toBe(ErrorCodes.MM_DISCOVERY_FAILED); + } + }); + + it('accessibilitySnapshotTool returns error when driver throws', async () => { + const driver = createMockDriver({ + getAccessibilityTree: vi + .fn() + .mockRejectedValue(new Error('A11y snapshot failed')), + }); + const context = createContextWithDriver(driver); + + const result = await accessibilitySnapshotTool({}, context); + + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error.code).toBe(ErrorCodes.MM_DISCOVERY_FAILED); + } + }); + }); + + describe('screenshot tool', () => { + it('screenshotTool delegates to driver.screenshot', async () => { + const driver = createMockDriver(); + const context = createContextWithDriver(driver); + + const result = await screenshotTool({ name: 'test' }, context); + + expect(result.ok).toBe(true); + expect(driver.screenshot).toHaveBeenCalledWith( + expect.objectContaining({ + name: 'test', + }), + ); + }); + + it('screenshotTool returns error when driver throws', async () => { + const driver = createMockDriver({ + screenshot: vi + .fn() + .mockRejectedValue(new Error('Screenshot capture failed')), + }); + const context = createContextWithDriver(driver); + + const result = await screenshotTool({ name: 'fail' }, context); + + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error.code).toBe(ErrorCodes.MM_SCREENSHOT_FAILED); + } + }); + }); + + describe('state tool', () => { + it('getStateTool delegates to driver.getAppState', async () => { + const driver = createMockDriver(); + const context = createContextWithDriver(driver); + + const result = await getStateTool({} as any, context); + + expect(result.ok).toBe(true); + expect(driver.getAppState).toHaveBeenCalledOnce(); + }); + + it('getStateTool handles getPage failure gracefully', async () => { + const driver = createMockDriver(); + const sessionManager = createMockSessionManager({ hasActive: true }); + vi.spyOn(sessionManager as any, 'getPage').mockImplementation(() => { + throw new Error('No page on mobile'); + }); + const context = { + sessionManager, + page: {} as any, + refMap: new Map(), + workflowContext: {} as any, + knowledgeStore: { + generatePriorKnowledge: vi.fn().mockResolvedValue(undefined), + } as any, + toolRegistry: new Map(), + driver, + } as unknown as ToolContext; + + const result = await getStateTool({} as any, context); + + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.result.tabs?.active.role).toBe('other'); + expect(result.result.tabs?.active.url).toBe(''); + } + }); + + it('getStateTool uses driver.getAppState when no stateSnapshotCapability', async () => { + const driver = createMockDriver({ + getPlatform: vi.fn().mockReturnValue('browser'), + }); + const context = createContextWithDriver(driver); + + const result = await getStateTool({} as any, context); + + expect(result.ok).toBe(true); + expect(driver.getAppState).toHaveBeenCalledOnce(); + }); + + it('getStateTool returns error when driver throws', async () => { + const driver = createMockDriver({ + getAppState: vi + .fn() + .mockRejectedValue(new Error('State retrieval failed')), + }); + const context = createContextWithDriver(driver); + + const result = await getStateTool({} as any, context); + + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error.code).toBe(ErrorCodes.MM_STATE_FAILED); + } + }); + }); +}); diff --git a/src/tools/interaction.test.ts b/src/tools/interaction.test.ts index 11a38d7..7dc2d98 100644 --- a/src/tools/interaction.test.ts +++ b/src/tools/interaction.test.ts @@ -17,6 +17,7 @@ import { createMockSessionManager } from './test-utils/mock-factories.js'; import { ErrorCodes } from './types/errors.js'; import * as discoveryModule from './utils/discovery.js'; import * as targetsModule from './utils/targets.js'; +import { PlaywrightPlatformDriver } from '../platform/playwright-driver.js'; import type { ToolContext } from '../types/http.js'; function createMockLocator() { @@ -43,14 +44,20 @@ function createMockContext( refMap?: Map; } = {}, ): ToolContext { + const page = (options.page ?? createMockPage()) as ToolContext['page']; + const sessionManager = createMockSessionManager({ + hasActive: options.hasActive ?? true, + }); + const hasSession = options.hasActive ?? true; return { - sessionManager: createMockSessionManager({ - hasActive: options.hasActive ?? true, - }), - page: (options.page ?? createMockPage()) as ToolContext['page'], + sessionManager, + page, refMap: options.refMap ?? new Map(), workflowContext: {}, knowledgeStore: {}, + driver: hasSession + ? new PlaywrightPlatformDriver(() => page, sessionManager as any) + : undefined, } as unknown as ToolContext; } @@ -324,6 +331,7 @@ describe('interaction', () => { expect(result.ok).toBe(false); if (!result.ok) { expect(result.error.code).toBe(ErrorCodes.MM_CLICK_TIMEOUT); + expect(result.error.message).toContain('describe-screen'); } }); @@ -350,7 +358,6 @@ describe('interaction', () => { const nowSpy = vi.spyOn(Date, 'now'); nowSpy.mockReturnValueOnce(1000); nowSpy.mockReturnValueOnce(1006); - nowSpy.mockReturnValueOnce(1006); const result = await clickTool( { testId: 'slow-button', timeoutMs: 5 }, @@ -675,7 +682,6 @@ describe('interaction', () => { const nowSpy = vi.spyOn(Date, 'now'); nowSpy.mockReturnValueOnce(2000); nowSpy.mockReturnValueOnce(2006); - nowSpy.mockReturnValueOnce(2006); const result = await typeTool( { testId: 'slow-input', text: 'value', timeoutMs: 5 }, @@ -1147,7 +1153,6 @@ describe('interaction', () => { const nowSpy = vi.spyOn(Date, 'now'); nowSpy.mockReturnValueOnce(3000); nowSpy.mockReturnValueOnce(3006); - nowSpy.mockReturnValueOnce(3006); const result = await getTextTool( { testId: 'slow-text', timeoutMs: 5 }, diff --git a/src/tools/interaction.ts b/src/tools/interaction.ts index e182a53..cb2fbc9 100644 --- a/src/tools/interaction.ts +++ b/src/tools/interaction.ts @@ -1,11 +1,8 @@ -import type { Locator } from '@playwright/test'; - import { classifyClickError, classifyGetTextError, classifyTypeError, classifyWaitError, - isPageClosedError, } from './error-classification.js'; import type { ClickInput, @@ -20,8 +17,7 @@ import type { } from './types'; import { ErrorCodes } from './types'; import { DEFAULT_INTERACTION_TIMEOUT_MS } from './utils/constants.js'; -import { waitForTarget } from './utils/discovery.js'; -import type { TargetType, WithinScope } from './utils/discovery.js'; +import type { TargetType } from './utils/discovery.js'; import { validateTargetSelection } from './utils/targets.js'; import { isInvalidTargetSelection, @@ -32,25 +28,21 @@ import { createToolSuccess, requireActiveSession, } from './utils.js'; +import type { WithinScope } from '../platform/types.js'; import type { ToolContext, ToolResponse } from '../types/http.js'; -/** - * Checks whether the given error is a Playwright timeout error. - * - * @param error - The error to inspect. - * @returns True if the error represents an action timeout. - */ -function isActionTimeoutError(error: unknown): boolean { - return error instanceof Error && error.name === 'TimeoutError'; -} - type ValidatedTarget = { targetType: TargetType; targetValue: string; }; +type ValidatedInteraction = { + target: ValidatedTarget; + driver: NonNullable; +}; + /** - * Validates session and target selection for interaction tools. + * Validates session, driver, and target selection for interaction tools. * Returns an error response if validation fails, or the resolved target. * * @param input - The tool input with target selection fields. @@ -60,12 +52,21 @@ type ValidatedTarget = { function validateInteraction( input: ClickInput | TypeInput | WaitForInput | GetTextInput, context: ToolContext, -): { error: ToolResponse } | { target: ValidatedTarget } { +): { error: ToolResponse } | ValidatedInteraction { const missingSession = requireActiveSession(context); if (missingSession) { return { error: missingSession }; } + if (!context.driver) { + return { + error: createToolError( + ErrorCodes.MM_NO_ACTIVE_SESSION, + 'No platform driver available', + ), + }; + } + const validation = validateTargetSelection(input); if (isInvalidTargetSelection(validation)) { @@ -88,128 +89,12 @@ function validateInteraction( targetType: validation.type, targetValue: validation.value, }, + driver: context.driver, }; } -type InteractionErrorInfo = { - code: string; - message: string; -}; - -type RunInteractionWithTimeoutOptions = { - context: ToolContext; - timeoutMs: number; - within?: WithinTarget; - targetType: TargetType; - targetValue: string; - timeoutErrorCode: string; - classifyError: (error: unknown) => InteractionErrorInfo; - action: (locator: Locator, timeout: number) => Promise; - createSuccessResult: (result: TResult) => ToolResponse; - formatTimeoutMessage: ( - phase: 'deadline' | 'action', - elapsedMs: number, - ) => string; - handleActionError?: ( - error: unknown, - locator: Locator, - ) => ToolResponse | undefined; -}; - -/** - * Runs an element interaction within a deadline-based timeout. - * - * @param options - The interaction configuration object. - * @param options.context - The tool execution context with session and page. - * @param options.timeoutMs - Maximum time in milliseconds for the interaction. - * @param options.within - Optional parent scope to restrict element search. - * @param options.targetType - The type of target identifier (a11yRef, testId, selector). - * @param options.targetValue - The target value used for element lookup. - * @param options.timeoutErrorCode - The error code to use when the interaction times out. - * @param options.classifyError - Classifies a caught error into a code and message. - * @param options.action - The interaction to perform on the resolved locator. - * @param options.createSuccessResult - Creates the tool response from a successful result. - * @param options.formatTimeoutMessage - Formats the timeout error message for a given phase. - * @param options.handleActionError - Optional handler for action errors before fallback. - * @returns The tool response for the interaction outcome. - */ -async function runInteractionWithTimeout({ - context, - timeoutMs, - within, - targetType, - targetValue, - timeoutErrorCode, - classifyError, - action, - createSuccessResult, - formatTimeoutMessage, - handleActionError, -}: RunInteractionWithTimeoutOptions): Promise> { - const startTime = Date.now(); - const deadline = startTime + timeoutMs; - const withinScope = resolveWithinScope(within); - let locator: Locator | undefined; - - try { - locator = await waitForTarget( - context.page, - targetType, - targetValue, - context.refMap, - timeoutMs, - withinScope, - ); - } catch (error) { - const errorInfo = classifyError(error); - if (errorInfo.code === ErrorCodes.MM_WAIT_TIMEOUT) { - return createToolError(timeoutErrorCode, errorInfo.message); - } - return createToolError(errorInfo.code, errorInfo.message); - } - - const remaining = deadline - Date.now(); - - if (remaining <= 0) { - const elapsedMs = Date.now() - startTime; - return createToolError( - timeoutErrorCode, - formatTimeoutMessage('deadline', elapsedMs), - ); - } - - try { - const result = await action(locator, remaining); - return createSuccessResult(result); - } catch (actionError) { - const handledError = handleActionError?.(actionError, locator); - if (handledError) { - return handledError; - } - - // Page-closed errors can surface with name='TimeoutError' due to - // Playwright race conditions. Classify them before the timeout - // check so they are never misreported as action timeouts. - if (isPageClosedError(actionError)) { - const errorInfo = classifyError(actionError); - return createToolError(errorInfo.code, errorInfo.message); - } - - if (isActionTimeoutError(actionError)) { - const elapsedMs = Date.now() - startTime; - return createToolError( - timeoutErrorCode, - formatTimeoutMessage('action', elapsedMs), - ); - } - - const errorInfo = classifyError(actionError); - return createToolError(errorInfo.code, errorInfo.message); - } -} - /** - * Converts a WithinTarget input to the WithinScope format expected by waitForTarget. + * Converts a WithinTarget input to the WithinScope format expected by the platform driver. * * @param within - The optional within target from tool input. * @returns The resolved scope, or undefined if no within target is provided. @@ -250,38 +135,26 @@ export async function clickTool( const timeoutMs = input.timeoutMs ?? DEFAULT_INTERACTION_TIMEOUT_MS; const { targetType, targetValue } = validated.target; - return runInteractionWithTimeout({ - context, - timeoutMs, - within: input.within, - targetType, - targetValue, - timeoutErrorCode: ErrorCodes.MM_CLICK_TIMEOUT, - classifyError: classifyClickError, - action: async (locator, timeout) => { - await locator.click({ timeout }); - return { - clicked: true, - target: `${targetType}:${targetValue}`, - }; - }, - createSuccessResult: createToolSuccess, - formatTimeoutMessage: (phase, elapsedMs) => - phase === 'deadline' - ? `Click timed out after ${elapsedMs}ms. Note: the click action may have completed in the background after this timeout. Run describe-screen to verify current page state before retrying.` - : `Click action timed out after ${elapsedMs}ms. Note: the click action may have completed in the background after this timeout. Run describe-screen to verify current page state before retrying.`, - handleActionError: (error) => { - if (!isPageClosedError(error)) { - return undefined; - } - return createToolSuccess({ - clicked: true, - target: `${targetType}:${targetValue}`, - pageClosedAfterClick: true, - }); - }, - }); + try { + const result = await validated.driver.click( + targetType, + targetValue, + context.refMap, + timeoutMs, + resolveWithinScope(input.within), + ); + return createToolSuccess(result); + } catch (error) { + const errorInfo = classifyClickError(error); + if (errorInfo.code === ErrorCodes.MM_WAIT_TIMEOUT) { + return createToolError( + ErrorCodes.MM_CLICK_TIMEOUT, + `Click timed out after ${timeoutMs}ms. Note: the click action may have completed in the background after this timeout. Run describe-screen to verify current page state before retrying.`, + ); + } + return createToolError(errorInfo.code, errorInfo.message); + } } /** @@ -302,28 +175,27 @@ export async function typeTool( const timeoutMs = input.timeoutMs ?? DEFAULT_INTERACTION_TIMEOUT_MS; const { targetType, targetValue } = validated.target; - return runInteractionWithTimeout({ - context, - timeoutMs, - within: input.within, - targetType, - targetValue, - timeoutErrorCode: ErrorCodes.MM_TYPE_TIMEOUT, - classifyError: classifyTypeError, - action: async (locator, timeout) => { - await locator.fill(input.text, { timeout }); - return { - typed: true, - target: `${targetType}:${targetValue}`, - textLength: input.text.length, - }; - }, - createSuccessResult: createToolSuccess, - formatTimeoutMessage: (phase, elapsedMs) => - phase === 'deadline' - ? `Type timed out after ${elapsedMs}ms.` - : `Type action timed out after ${elapsedMs}ms.`, - }); + + try { + const result = await validated.driver.type( + targetType, + targetValue, + input.text, + context.refMap, + timeoutMs, + resolveWithinScope(input.within), + ); + return createToolSuccess(result); + } catch (error) { + const errorInfo = classifyTypeError(error); + if (errorInfo.code === ErrorCodes.MM_WAIT_TIMEOUT) { + return createToolError( + ErrorCodes.MM_TYPE_TIMEOUT, + `Type timed out after ${timeoutMs}ms.`, + ); + } + return createToolError(errorInfo.code, errorInfo.message); + } } /** @@ -346,15 +218,13 @@ export async function waitForTool( const { targetType, targetValue } = validated.target; try { - await waitForTarget( - context.page, + await validated.driver.waitForElement( targetType, targetValue, context.refMap, timeoutMs, resolveWithinScope(input.within), ); - return createToolSuccess({ found: true, target: `${targetType}:${targetValue}`, @@ -383,26 +253,24 @@ export async function getTextTool( const timeoutMs = input.timeoutMs ?? DEFAULT_INTERACTION_TIMEOUT_MS; const { targetType, targetValue } = validated.target; - return runInteractionWithTimeout({ - context, - timeoutMs, - within: input.within, - targetType, - targetValue, - timeoutErrorCode: ErrorCodes.MM_GETTEXT_TIMEOUT, - classifyError: classifyGetTextError, - action: async (locator, timeout) => { - const text = (await locator.textContent({ timeout })) ?? ''; - return { - text, - target: `${targetType}:${targetValue}`, - length: text.length, - }; - }, - createSuccessResult: createToolSuccess, - formatTimeoutMessage: (phase, elapsedMs) => - phase === 'deadline' - ? `GetText timed out after ${elapsedMs}ms.` - : `GetText action timed out after ${elapsedMs}ms.`, - }); + + try { + const result = await validated.driver.getText( + targetType, + targetValue, + context.refMap, + timeoutMs, + resolveWithinScope(input.within), + ); + return createToolSuccess(result); + } catch (error) { + const errorInfo = classifyGetTextError(error); + if (errorInfo.code === ErrorCodes.MM_WAIT_TIMEOUT) { + return createToolError( + ErrorCodes.MM_GETTEXT_TIMEOUT, + `GetText timed out after ${timeoutMs}ms.`, + ); + } + return createToolError(errorInfo.code, errorInfo.message); + } } diff --git a/src/tools/platform-gating.test.ts b/src/tools/platform-gating.test.ts new file mode 100644 index 0000000..a2ff076 --- /dev/null +++ b/src/tools/platform-gating.test.ts @@ -0,0 +1,32 @@ +import { describe, it, expect } from 'vitest'; + +import { isBrowserOnlyTool } from './registry.js'; + +describe('isBrowserOnlyTool', () => { + it('returns true for browser-only tools', () => { + expect(isBrowserOnlyTool('navigate')).toBe(true); + expect(isBrowserOnlyTool('switch_to_tab')).toBe(true); + expect(isBrowserOnlyTool('close_tab')).toBe(true); + expect(isBrowserOnlyTool('wait_for_notification')).toBe(true); + expect(isBrowserOnlyTool('cdp')).toBe(true); + }); + + it('returns false for cross-platform tools', () => { + expect(isBrowserOnlyTool('click')).toBe(false); + expect(isBrowserOnlyTool('type')).toBe(false); + expect(isBrowserOnlyTool('wait_for')).toBe(false); + expect(isBrowserOnlyTool('describe_screen')).toBe(false); + expect(isBrowserOnlyTool('screenshot')).toBe(false); + expect(isBrowserOnlyTool('get_state')).toBe(false); + expect(isBrowserOnlyTool('launch')).toBe(false); + expect(isBrowserOnlyTool('cleanup')).toBe(false); + expect(isBrowserOnlyTool('list_testids')).toBe(false); + expect(isBrowserOnlyTool('accessibility_snapshot')).toBe(false); + expect(isBrowserOnlyTool('get_text')).toBe(false); + expect(isBrowserOnlyTool('clipboard')).toBe(false); + }); + + it('returns false for unknown tools', () => { + expect(isBrowserOnlyTool('nonexistent_tool')).toBe(false); + }); +}); diff --git a/src/tools/registry.ts b/src/tools/registry.ts index f4c0fd8..4ef04a4 100644 --- a/src/tools/registry.ts +++ b/src/tools/registry.ts @@ -129,3 +129,21 @@ export const TOOL_CATEGORIES: Record = { export function getToolCategory(toolName: string): ToolCategory { return TOOL_CATEGORIES[toolName] ?? 'mutating'; } + +const BROWSER_ONLY_TOOLS = new Set([ + 'navigate', + 'switch_to_tab', + 'close_tab', + 'wait_for_notification', + 'cdp', +]); + +/** + * Checks if a tool is only available on the browser platform. + * + * @param toolName - The registered tool name to check. + * @returns True if the tool is browser-only, false if cross-platform. + */ +export function isBrowserOnlyTool(toolName: string): boolean { + return BROWSER_ONLY_TOOLS.has(toolName); +} diff --git a/src/tools/screenshot.test.ts b/src/tools/screenshot.test.ts index 1b2ee2e..583a582 100644 --- a/src/tools/screenshot.test.ts +++ b/src/tools/screenshot.test.ts @@ -10,6 +10,7 @@ import { describe, it, expect, vi } from 'vitest'; import { screenshotTool } from './screenshot.js'; import { createMockSessionManager } from './test-utils/mock-factories.js'; import { ErrorCodes } from './types/errors.js'; +import { PlaywrightPlatformDriver } from '../platform/playwright-driver.js'; import type { ToolContext } from '../types/http.js'; function createMockContext( @@ -18,13 +19,18 @@ function createMockContext( } = {}, ): ToolContext { const { hasActive = true } = options; + const page = {} as ToolContext['page']; + const sessionManager = createMockSessionManager({ hasActive }); return { - sessionManager: createMockSessionManager({ hasActive }), - page: {} as ToolContext['page'], + sessionManager, + page, refMap: new Map(), workflowContext: {}, knowledgeStore: {}, + driver: hasActive + ? new PlaywrightPlatformDriver(() => page, sessionManager as any) + : undefined, } as unknown as ToolContext; } diff --git a/src/tools/screenshot.ts b/src/tools/screenshot.ts index 5a842c4..a1af873 100644 --- a/src/tools/screenshot.ts +++ b/src/tools/screenshot.ts @@ -1,5 +1,6 @@ import { classifyScreenshotError } from './error-classification.js'; import type { ScreenshotInput, ScreenshotToolResult } from './types'; +import { ErrorCodes } from './types'; import { createToolError, createToolSuccess, @@ -23,12 +24,20 @@ export async function screenshotTool( return missingSession; } + if (!context.driver) { + return createToolError( + ErrorCodes.MM_NO_ACTIVE_SESSION, + 'No platform driver available', + ); + } + try { const screenshotName = input.name ?? `screenshot-${Date.now()}`; - const result = await context.sessionManager.screenshot({ + const result = await context.driver.screenshot({ name: screenshotName, fullPage: input.fullPage ?? true, selector: input.selector, + includeBase64: input.includeBase64, }); const response: ScreenshotToolResult = { diff --git a/src/tools/state.test.ts b/src/tools/state.test.ts index 3969f5a..2b1da1a 100644 --- a/src/tools/state.test.ts +++ b/src/tools/state.test.ts @@ -12,6 +12,7 @@ import type { StateSnapshotCapability } from '../capabilities/types.js'; import { createMockSessionManager } from './test-utils/mock-factories.js'; import type { MockSessionManagerOptions } from './test-utils/mock-factories.js'; import { ErrorCodes } from './types/errors.js'; +import { PlaywrightPlatformDriver } from '../platform/playwright-driver.js'; import type { ToolContext } from '../types/http.js'; function createMockPage(url = 'chrome-extension://ext-123/home.html') { @@ -36,12 +37,21 @@ function createMockContext( options.stateSnapshotCapability, ); + const activePage = options.page ?? page; + const hasSession = options.hasActive ?? false; + return { sessionManager, - page: options.page ?? page, + page: activePage, refMap: new Map(), workflowContext: {}, knowledgeStore: {}, + driver: hasSession + ? new PlaywrightPlatformDriver( + () => activePage as any, + sessionManager as any, + ) + : undefined, } as unknown as ToolContext & { sessionManager: ReturnType; }; diff --git a/src/tools/state.ts b/src/tools/state.ts index c974cce..54798fb 100644 --- a/src/tools/state.ts +++ b/src/tools/state.ts @@ -1,7 +1,8 @@ import type { Page } from '@playwright/test'; import { classifyStateError } from './error-classification.js'; -import type { GetStateResult } from './types'; +import type { GetStateResult, TabInfo } from './types'; +import { ErrorCodes } from './types'; import { createToolError, createToolSuccess, @@ -15,14 +16,14 @@ import type { ISessionManager } from '../server/session-manager.js'; import type { ToolContext, ToolResponse } from '../types/http.js'; /** - * Retrieves the extension state using the snapshot capability or session manager. + * Retrieves the extension state using the snapshot capability or driver. * * @param page - The active Playwright page. * @param sessionManager - The session manager instance. * @param stateSnapshotCapability - Optional capability for direct state snapshots. * @returns The current extension state. */ -async function getState( +async function getStateWithCapability( page: Page, sessionManager: ISessionManager, stateSnapshotCapability?: StateSnapshotCapability, @@ -54,27 +55,46 @@ export async function getStateTool( return missingSession; } + if (!context.driver) { + return createToolError( + ErrorCodes.MM_NO_ACTIVE_SESSION, + 'No platform driver available', + ); + } + try { - const state = await getState( - context.page, - context.sessionManager, + const stateSnapshotCapability = context.workflowContext.stateSnapshot ?? - context.sessionManager.getStateSnapshotCapability(), - ); + context.sessionManager.getStateSnapshotCapability?.(); + + const state = + stateSnapshotCapability && context.driver.getPlatform() === 'browser' + ? await getStateWithCapability( + context.page, + context.sessionManager, + stateSnapshotCapability, + ) + : await context.driver.getAppState(); const trackedPages = context.sessionManager.getTrackedPages(); - const activePage = context.sessionManager.getPage(); - const activeTabInfo = trackedPages.find( - (trackedPage) => trackedPage.page === activePage, - ); + let activeTab: TabInfo; + try { + const activePage = context.sessionManager.getPage(); + const tracked = trackedPages.find( + (trackedPage) => trackedPage.page === activePage, + ); + activeTab = { + role: tracked?.role ?? 'other', + url: activePage.url(), + }; + } catch { + activeTab = { role: 'other', url: '' }; + } return createToolSuccess({ state, tabs: { - active: { - role: activeTabInfo?.role ?? 'other', - url: activePage.url(), - }, + active: activeTab, tracked: trackedPages.map((trackedPage) => ({ role: trackedPage.role, url: trackedPage.url, diff --git a/src/tools/types/tool-inputs.ts b/src/tools/types/tool-inputs.ts index 52c8293..cd04775 100644 --- a/src/tools/types/tool-inputs.ts +++ b/src/tools/types/tool-inputs.ts @@ -24,6 +24,10 @@ export type LaunchInput = { tags?: string[]; seedContracts?: SmartContractName[]; force?: boolean; + /** Target platform: browser (default), ios, or android. */ + platform?: 'browser' | 'ios' | 'android'; + /** Explicit device ID for mobile platforms (iOS UDID or Android serial). */ + deviceId?: string; }; export type CleanupInput = { diff --git a/src/types/http.ts b/src/types/http.ts index d1cac1c..6daf1ed 100644 --- a/src/types/http.ts +++ b/src/types/http.ts @@ -8,6 +8,7 @@ import type { Page } from '@playwright/test'; import type { PortMap, WorkflowContext } from '../capabilities/context.js'; import type { KnowledgeStore } from '../knowledge-store/knowledge-store.js'; +import type { IPlatformDriver } from '../platform/types.js'; import type { ISessionManager } from '../server/session-manager.js'; /** @@ -29,6 +30,8 @@ export type ToolContext = { knowledgeStore: KnowledgeStore; /** Tool registry for batch execution (run_steps) */ toolRegistry: Map>; + /** Platform driver for cross-platform element interaction (browser, iOS, Android) */ + driver?: IPlatformDriver; }; /** diff --git a/src/validation/schemas.test.ts b/src/validation/schemas.test.ts index 07f2887..fa07377 100644 --- a/src/validation/schemas.test.ts +++ b/src/validation/schemas.test.ts @@ -16,6 +16,7 @@ import { navigateInputSchema, networkMockRouteRuleSchema, mockNetworkInputSchema, + launchInputSchema, } from './schemas.js'; describe('switchToTabInputSchema', () => { @@ -399,3 +400,66 @@ describe('network mock schemas', () => { expect(result.success).toBe(false); }); }); + +describe('launchInputSchema', () => { + it('preserves platform field', () => { + const input = { platform: 'ios' }; + const result = launchInputSchema.safeParse(input); + + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.platform).toBe('ios'); + } + }); + + it('preserves deviceId field', () => { + const input = { deviceId: 'emulator-5554' }; + const result = launchInputSchema.safeParse(input); + + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.deviceId).toBe('emulator-5554'); + } + }); + + it('preserves platform and deviceId together', () => { + const input = { + platform: 'android' as const, + deviceId: 'emulator-5554', + stateMode: 'default' as const, + }; + const result = launchInputSchema.safeParse(input); + + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.platform).toBe('android'); + expect(result.data.deviceId).toBe('emulator-5554'); + expect(result.data.stateMode).toBe('default'); + } + }); + + it('rejects invalid platform value', () => { + const input = { platform: 'windows' }; + const result = launchInputSchema.safeParse(input); + + expect(result.success).toBe(false); + }); + + it('rejects empty deviceId', () => { + const input = { deviceId: '' }; + const result = launchInputSchema.safeParse(input); + + expect(result.success).toBe(false); + }); + + it('accepts launch input without platform or deviceId', () => { + const input = { stateMode: 'default' as const }; + const result = launchInputSchema.safeParse(input); + + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.platform).toBeUndefined(); + expect(result.data.deviceId).toBeUndefined(); + } + }); +}); diff --git a/src/validation/schemas.ts b/src/validation/schemas.ts index 3909848..f28446d 100644 --- a/src/validation/schemas.ts +++ b/src/validation/schemas.ts @@ -226,6 +226,21 @@ export const launchInputSchema = z.object({ .boolean() .default(false) .describe('Force replace an existing active session (runs cleanup first)'), + platform: z + .enum(['browser', 'ios', 'android']) + .describe( + 'Target platform: browser (default), ios, or android. ' + + 'Mobile platforms require @metamask/device-mcp.', + ) + .optional(), + deviceId: z + .string() + .min(1) + .describe( + 'Explicit device ID for mobile platforms (iOS UDID or Android serial). ' + + 'When omitted, auto-detects if exactly one device is connected.', + ) + .optional(), }); export const cleanupInputSchema = z.object({ diff --git a/vitest.config.mts b/vitest.config.mts index 218bebf..0cf43b2 100644 --- a/vitest.config.mts +++ b/vitest.config.mts @@ -35,10 +35,10 @@ export default defineConfig({ // Auto-update the coverage thresholds when running locally. // Disabled in CI to prevent non-deterministic config changes. autoUpdate: !process.env.CI, - branches: 89.5, - functions: 92.47, - lines: 95.55, - statements: 95.25, + branches: 89.52, + functions: 92.3, + lines: 95.5, + statements: 95.21, }, },