From 9818e02ed1a9f0786c5c5ed083e54e90d8121c07 Mon Sep 17 00:00:00 2001 From: kurihada Date: Thu, 11 Dec 2025 22:26:43 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20AST=20RepoMap=20?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E4=BB=93=E5=BA=93=E5=9C=B0=E5=9B=BE=E5=8A=9F?= =?UTF-8?q?=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 实现基于 Tree-sitter 的代码符号提取 (支持 TS/JS/Python) - 实现 PageRank 算法进行符号相关性排序 - 支持个性化权重调整 (提及的标识符、聊天文件等) - 添加磁盘缓存避免重复解析 - 集成 repo_map 工具到工具系统 - 添加 15 个单元测试 --- .gitignore | 6 + src/repomap/cache/disk-cache.ts | 218 +++++++++ src/repomap/cache/index.ts | 6 + src/repomap/index.ts | 26 ++ src/repomap/ranking/graph.ts | 99 ++++ src/repomap/ranking/index.ts | 7 + src/repomap/ranking/pagerank.ts | 146 ++++++ src/repomap/repomap.ts | 419 +++++++++++++++++ src/repomap/tags/extractor.ts | 465 +++++++++++++++++++ src/repomap/tags/index.ts | 5 + src/repomap/tags/queries/javascript-tags.scm | 65 +++ src/repomap/tags/queries/python-tags.scm | 37 ++ src/repomap/tags/queries/typescript-tags.scm | 90 ++++ src/repomap/types.ts | 142 ++++++ src/tools/descriptions/repo_map.txt | 20 + src/tools/index.ts | 6 + src/tools/repomap/index.ts | 5 + src/tools/repomap/repo_map.ts | 261 +++++++++++ tests/repomap/repomap.test.ts | 329 +++++++++++++ 19 files changed, 2352 insertions(+) create mode 100644 src/repomap/cache/disk-cache.ts create mode 100644 src/repomap/cache/index.ts create mode 100644 src/repomap/index.ts create mode 100644 src/repomap/ranking/graph.ts create mode 100644 src/repomap/ranking/index.ts create mode 100644 src/repomap/ranking/pagerank.ts create mode 100644 src/repomap/repomap.ts create mode 100644 src/repomap/tags/extractor.ts create mode 100644 src/repomap/tags/index.ts create mode 100644 src/repomap/tags/queries/javascript-tags.scm create mode 100644 src/repomap/tags/queries/python-tags.scm create mode 100644 src/repomap/tags/queries/typescript-tags.scm create mode 100644 src/repomap/types.ts create mode 100644 src/tools/descriptions/repo_map.txt create mode 100644 src/tools/repomap/index.ts create mode 100644 src/tools/repomap/repo_map.ts create mode 100644 tests/repomap/repomap.test.ts diff --git a/.gitignore b/.gitignore index 417aa11..caa5d75 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,9 @@ npm-debug.log* # Test coverage coverage/ + +# AI Open reference code +ai-open/ + +# Design docs (internal) +docs/ diff --git a/src/repomap/cache/disk-cache.ts b/src/repomap/cache/disk-cache.ts new file mode 100644 index 0000000..b39873a --- /dev/null +++ b/src/repomap/cache/disk-cache.ts @@ -0,0 +1,218 @@ +/** + * 磁盘缓存实现 + * 使用 JSON 文件存储,支持按文件路径索引 + */ + +import * as fs from 'fs/promises'; +import * as path from 'path'; +import { createHash } from 'crypto'; + +export interface CacheEntry { + key: string; + value: T; + timestamp: number; +} + +/** + * 磁盘缓存类 + */ +export class DiskCache { + private cacheDir: string; + private memoryCache: Map = new Map(); + private dirty: Set = new Set(); + private initialized = false; + + constructor(cacheDir: string) { + this.cacheDir = cacheDir; + } + + /** + * 初始化缓存目录 + */ + private async ensureDir(): Promise { + if (this.initialized) return; + + try { + await fs.mkdir(this.cacheDir, { recursive: true }); + this.initialized = true; + } catch (error) { + console.warn(`Failed to create cache directory: ${this.cacheDir}`, error); + } + } + + /** + * 生成缓存文件路径 + */ + private getCacheFilePath(key: string): string { + // 使用哈希避免文件名过长或包含特殊字符 + const hash = createHash('md5').update(key).digest('hex'); + return path.join(this.cacheDir, `${hash}.json`); + } + + /** + * 获取缓存值 + */ + async get(key: string): Promise { + // 先检查内存缓存 + if (this.memoryCache.has(key)) { + return this.memoryCache.get(key)!; + } + + await this.ensureDir(); + + // 从磁盘读取 + const filePath = this.getCacheFilePath(key); + try { + const content = await fs.readFile(filePath, 'utf-8'); + const entry: CacheEntry = JSON.parse(content); + + // 验证 key 匹配 + if (entry.key === key) { + this.memoryCache.set(key, entry.value); + return entry.value; + } + } catch { + // 文件不存在或解析失败 + } + + return null; + } + + /** + * 设置缓存值 + */ + async set(key: string, value: T): Promise { + this.memoryCache.set(key, value); + this.dirty.add(key); + + // 立即写入磁盘(可以优化为批量写入) + await this.flush(key); + } + + /** + * 删除缓存 + */ + async delete(key: string): Promise { + this.memoryCache.delete(key); + this.dirty.delete(key); + + await this.ensureDir(); + + const filePath = this.getCacheFilePath(key); + try { + await fs.unlink(filePath); + } catch { + // 文件不存在 + } + } + + /** + * 检查缓存是否存在 + */ + async has(key: string): Promise { + if (this.memoryCache.has(key)) { + return true; + } + + const value = await this.get(key); + return value !== null; + } + + /** + * 刷新指定 key 到磁盘 + */ + private async flush(key: string): Promise { + if (!this.dirty.has(key)) return; + + await this.ensureDir(); + + const value = this.memoryCache.get(key); + if (value === undefined) return; + + const entry: CacheEntry = { + key, + value, + timestamp: Date.now(), + }; + + const filePath = this.getCacheFilePath(key); + try { + await fs.writeFile(filePath, JSON.stringify(entry), 'utf-8'); + this.dirty.delete(key); + } catch (error) { + console.warn(`Failed to write cache: ${key}`, error); + } + } + + /** + * 刷新所有脏数据到磁盘 + */ + async flushAll(): Promise { + const keys = Array.from(this.dirty); + await Promise.all(keys.map((key) => this.flush(key))); + } + + /** + * 清空所有缓存 + */ + async clear(): Promise { + this.memoryCache.clear(); + this.dirty.clear(); + + try { + const files = await fs.readdir(this.cacheDir); + await Promise.all( + files + .filter((f) => f.endsWith('.json')) + .map((f) => fs.unlink(path.join(this.cacheDir, f))) + ); + } catch { + // 目录不存在或其他错误 + } + } + + /** + * 获取缓存大小(条目数) + */ + size(): number { + return this.memoryCache.size; + } + + /** + * 获取所有缓存的 keys + */ + async keys(): Promise { + await this.ensureDir(); + + const result: string[] = Array.from(this.memoryCache.keys()); + + try { + const files = await fs.readdir(this.cacheDir); + for (const file of files) { + if (!file.endsWith('.json')) continue; + + const filePath = path.join(this.cacheDir, file); + try { + const content = await fs.readFile(filePath, 'utf-8'); + const entry: CacheEntry = JSON.parse(content); + if (!result.includes(entry.key)) { + result.push(entry.key); + } + } catch { + // 跳过无效文件 + } + } + } catch { + // 目录不存在 + } + + return result; + } +} + +/** + * 创建磁盘缓存实例 + */ +export function createDiskCache(cacheDir: string): DiskCache { + return new DiskCache(cacheDir); +} diff --git a/src/repomap/cache/index.ts b/src/repomap/cache/index.ts new file mode 100644 index 0000000..95db3d2 --- /dev/null +++ b/src/repomap/cache/index.ts @@ -0,0 +1,6 @@ +/** + * 缓存模块导出 + */ + +export { DiskCache, createDiskCache } from './disk-cache.js'; +export type { CacheEntry } from './disk-cache.js'; diff --git a/src/repomap/index.ts b/src/repomap/index.ts new file mode 100644 index 0000000..92ac939 --- /dev/null +++ b/src/repomap/index.ts @@ -0,0 +1,26 @@ +/** + * RepoMap 模块 + * + * 使用 AST 分析和 PageRank 算法生成代码仓库地图 + * 参考 Aider 的实现 + */ + +// 主类 +export { RepoMap, createRepoMap } from './repomap.js'; + +// Tag 提取 +export { TagExtractor } from './tags/index.js'; + +// PageRank 排序 +export { + Graph, + pagerank, + distributeRanksToDefinitions, + type PageRankOptions, +} from './ranking/index.js'; + +// 缓存 +export { DiskCache, createDiskCache } from './cache/index.js'; + +// 类型 +export type { Tag, TagCacheEntry, RepoMapConfig, GraphEdge } from './types.js'; diff --git a/src/repomap/ranking/graph.ts b/src/repomap/ranking/graph.ts new file mode 100644 index 0000000..af96cfa --- /dev/null +++ b/src/repomap/ranking/graph.ts @@ -0,0 +1,99 @@ +/** + * 图数据结构 + * 用于 PageRank 算法 + */ + +import type { GraphEdge } from '../types.js'; + +export class Graph { + /** 邻接表:from -> edges[] */ + private outEdges: Map = new Map(); + /** 反向邻接表:to -> edges[] */ + private inEdges: Map = new Map(); + /** 所有节点 */ + private nodes: Set = new Set(); + + /** + * 添加边 + */ + addEdge(edge: GraphEdge): void { + this.nodes.add(edge.from); + this.nodes.add(edge.to); + + // 出边 + if (!this.outEdges.has(edge.from)) { + this.outEdges.set(edge.from, []); + } + this.outEdges.get(edge.from)!.push(edge); + + // 入边 + if (!this.inEdges.has(edge.to)) { + this.inEdges.set(edge.to, []); + } + this.inEdges.get(edge.to)!.push(edge); + } + + /** + * 获取所有节点 + */ + getNodes(): string[] { + return Array.from(this.nodes); + } + + /** + * 获取节点的出边 + */ + getOutEdges(node: string): GraphEdge[] { + return this.outEdges.get(node) || []; + } + + /** + * 获取节点的入边 + */ + getInEdges(node: string): GraphEdge[] { + return this.inEdges.get(node) || []; + } + + /** + * 获取节点的出度(考虑权重) + */ + getOutDegree(node: string): number { + const edges = this.outEdges.get(node) || []; + return edges.reduce((sum, e) => sum + e.weight, 0); + } + + /** + * 获取节点的入度(考虑权重) + */ + getInDegree(node: string): number { + const edges = this.inEdges.get(node) || []; + return edges.reduce((sum, e) => sum + e.weight, 0); + } + + /** + * 获取边数量 + */ + getEdgeCount(): number { + let count = 0; + for (const edges of this.outEdges.values()) { + count += edges.length; + } + return count; + } + + /** + * 清空图 + */ + clear(): void { + this.outEdges.clear(); + this.inEdges.clear(); + this.nodes.clear(); + } + + /** + * 是否为空 + */ + isEmpty(): boolean { + return this.nodes.size === 0; + } +} diff --git a/src/repomap/ranking/index.ts b/src/repomap/ranking/index.ts new file mode 100644 index 0000000..836265b --- /dev/null +++ b/src/repomap/ranking/index.ts @@ -0,0 +1,7 @@ +/** + * 排序模块导出 + */ + +export { Graph } from './graph.js'; +export { pagerank, distributeRanksToDefinitions } from './pagerank.js'; +export type { PageRankOptions } from './pagerank.js'; diff --git a/src/repomap/ranking/pagerank.ts b/src/repomap/ranking/pagerank.ts new file mode 100644 index 0000000..a031673 --- /dev/null +++ b/src/repomap/ranking/pagerank.ts @@ -0,0 +1,146 @@ +/** + * PageRank 算法实现 + * 基于 Aider 的实现,用于代码符号相关性排序 + */ + +import { Graph } from './graph.js'; + +export interface PageRankOptions { + /** 阻尼系数 (默认 0.85) */ + damping?: number; + /** 最大迭代次数 (默认 100) */ + iterations?: number; + /** 收敛阈值 (默认 1e-6) */ + tolerance?: number; + /** 个性化向量:节点 -> 初始权重 */ + personalization?: Map; +} + +/** + * PageRank 算法 + * + * @param graph - 图结构 + * @param options - 算法选项 + * @returns 节点排名 Map<节点, 排名值> + */ +export function pagerank( + graph: Graph, + options: PageRankOptions = {} +): Map { + const { + damping = 0.85, + iterations = 100, + tolerance = 1e-6, + personalization, + } = options; + + const nodes = graph.getNodes(); + const n = nodes.length; + + if (n === 0) { + return new Map(); + } + + // 初始化排名 + let ranks = new Map(); + const baseRank = 1 / n; + + // 处理个性化向量 + let persVector = new Map(); + if (personalization && personalization.size > 0) { + // 归一化个性化向量 + const total = Array.from(personalization.values()).reduce((a, b) => a + b, 0); + if (total > 0) { + for (const [node, value] of personalization) { + persVector.set(node, value / total); + } + } + } else { + // 均匀分布 + for (const node of nodes) { + persVector.set(node, baseRank); + } + } + + // 初始排名 = 个性化向量 + for (const node of nodes) { + ranks.set(node, persVector.get(node) || baseRank); + } + + // 迭代计算 + for (let iter = 0; iter < iterations; iter++) { + const newRanks = new Map(); + let diff = 0; + + // 计算悬挂节点的贡献(没有出边的节点) + let danglingSum = 0; + for (const node of nodes) { + const outEdges = graph.getOutEdges(node); + if (outEdges.length === 0) { + danglingSum += ranks.get(node) || 0; + } + } + + for (const node of nodes) { + // 基础分数:(1 - damping) * 个性化 + damping * 悬挂贡献 + let rank = + (1 - damping) * (persVector.get(node) || baseRank) + + (damping * danglingSum) / n; + + // 收集入边贡献 + const inEdges = graph.getInEdges(node); + for (const edge of inEdges) { + const sourceRank = ranks.get(edge.from) || 0; + const outDegree = graph.getOutDegree(edge.from); + + if (outDegree > 0) { + // 边权重占源节点总出度的比例 + rank += damping * sourceRank * (edge.weight / outDegree); + } + } + + newRanks.set(node, rank); + diff += Math.abs(rank - (ranks.get(node) || 0)); + } + + ranks = newRanks; + + // 检查收敛 + if (diff < tolerance) { + break; + } + } + + return ranks; +} + +/** + * 将 PageRank 排名分配到定义上 + * 按照 Aider 的方式:将源节点的排名按边权重比例分配给目标定义 + * + * @param graph - 图结构 + * @param nodeRanks - 节点 PageRank 排名 + * @returns 定义排名 Map<"file:ident", rank> + */ +export function distributeRanksToDefinitions( + graph: Graph, + nodeRanks: Map +): Map { + const definitionRanks = new Map(); + + for (const src of graph.getNodes()) { + const srcRank = nodeRanks.get(src) || 0; + const outEdges = graph.getOutEdges(src); + const totalWeight = outEdges.reduce((sum, e) => sum + e.weight, 0); + + if (totalWeight === 0) continue; + + for (const edge of outEdges) { + const edgeRank = (srcRank * edge.weight) / totalWeight; + const key = `${edge.to}:${edge.ident}`; + definitionRanks.set(key, (definitionRanks.get(key) || 0) + edgeRank); + } + } + + return definitionRanks; +} diff --git a/src/repomap/repomap.ts b/src/repomap/repomap.ts new file mode 100644 index 0000000..b9f9339 --- /dev/null +++ b/src/repomap/repomap.ts @@ -0,0 +1,419 @@ +/** + * RepoMap 主类 + * 使用 AST 分析和 PageRank 算法生成代码仓库地图 + */ + +import * as fs from 'fs/promises'; +import * as path from 'path'; +import { TagExtractor } from './tags/extractor.js'; +import { pagerank, distributeRanksToDefinitions } from './ranking/pagerank.js'; +import { Graph } from './ranking/graph.js'; +import { DiskCache } from './cache/disk-cache.js'; +import type { Tag, RepoMapConfig, TagCacheEntry } from './types.js'; + +/** + * RepoMap 配置默认值 + */ +const defaultConfig: RepoMapConfig = { + mapTokens: 1024, + mapMulNoFiles: 8, + maxContextWindow: 128000, + refresh: 'auto', + cacheDir: '.ai-assist/tags-cache', + verbose: false, + exclude: [ + 'node_modules/**', + 'dist/**', + 'build/**', + '.git/**', + '*.test.*', + '*.spec.*', + '**/*.d.ts', + ], + include: ['**/*.ts', '**/*.tsx', '**/*.js', '**/*.jsx', '**/*.py'], +}; + +/** + * RepoMap 类 + * 生成代码仓库的上下文地图,帮助 AI 理解代码结构 + */ +export class RepoMap { + private tagExtractor: TagExtractor; + private tagsCache: DiskCache; + private config: RepoMapConfig; + private root: string; + + constructor(root: string, config: Partial = {}) { + this.root = root; + this.config = { ...defaultConfig, ...config }; + + this.tagExtractor = new TagExtractor(); + this.tagsCache = new DiskCache( + path.join(this.root, this.config.cacheDir) + ); + } + + /** + * 获取 repo map + * @param chatFiles 当前对话中涉及的文件 + * @param otherFiles 仓库中其他文件 + * @param mentionedFnames 对话中提到的文件名 + * @param mentionedIdents 对话中提到的标识符 + */ + async getRepoMap( + chatFiles: string[], + otherFiles: string[], + mentionedFnames: Set = new Set(), + mentionedIdents: Set = new Set() + ): Promise { + if (this.config.mapTokens <= 0 || otherFiles.length === 0) { + return ''; + } + + let maxMapTokens = this.config.mapTokens; + + // 无聊天文件时,给更大的视图 + if (chatFiles.length === 0) { + maxMapTokens = Math.min( + maxMapTokens * this.config.mapMulNoFiles, + this.config.maxContextWindow - 4096 + ); + } + + const rankedTags = await this.getRankedTags( + chatFiles, + otherFiles, + mentionedFnames, + mentionedIdents + ); + + // 二分搜索找到最优 token 数量 + return this.fitToTokenLimit(rankedTags, maxMapTokens, new Set(chatFiles)); + } + + /** + * 获取排序后的 tags + */ + private async getRankedTags( + chatFnames: string[], + otherFnames: string[], + mentionedFnames: Set, + mentionedIdents: Set + ): Promise { + // ident -> files that define it + const defines = new Map>(); + // ident -> files that reference it + const references = new Map(); + // (file:ident) -> tag objects + const definitions = new Map(); + // personalization vector for PageRank + const personalization = new Map(); + + const allFnames = [...new Set([...chatFnames, ...otherFnames])]; + const chatRelFnames = new Set(chatFnames.map((f) => this.getRelFname(f))); + const basePersonalize = 100 / Math.max(allFnames.length, 1); + + // 收集所有文件的 tags + for (const fname of allFnames) { + const relFname = this.getRelFname(fname); + let currentPers = 0; + + // 个性化权重 + if (chatFnames.includes(fname)) { + currentPers += basePersonalize; + } + if (mentionedFnames.has(relFname)) { + currentPers = Math.max(currentPers, basePersonalize); + } + + // 路径组件匹配 + const pathParts = relFname.split('/'); + const basename = pathParts[pathParts.length - 1]; + const basenameNoExt = basename.replace(/\.[^.]+$/, ''); + const components = new Set([...pathParts, basename, basenameNoExt]); + + if ([...components].some((c) => mentionedIdents.has(c))) { + currentPers += basePersonalize; + } + + if (currentPers > 0) { + personalization.set(relFname, currentPers); + } + + // 提取 tags + const tags = await this.getTags(fname, relFname); + for (const tag of tags) { + if (tag.kind === 'def') { + if (!defines.has(tag.name)) defines.set(tag.name, new Set()); + defines.get(tag.name)!.add(relFname); + + const key = `${relFname}:${tag.name}`; + if (!definitions.has(key)) definitions.set(key, []); + definitions.get(key)!.push(tag); + } else { + if (!references.has(tag.name)) references.set(tag.name, []); + references.get(tag.name)!.push(relFname); + } + } + } + + // 构建图 + const graph = new Graph(); + + // 找到同时有定义和引用的标识符 + const idents = [...defines.keys()].filter((id) => references.has(id)); + + for (const ident of idents) { + const definers = defines.get(ident)!; + const refs = references.get(ident)!; + + // 计算权重乘数 + let mul = 1.0; + if (mentionedIdents.has(ident)) mul *= 10; + if (this.isSignificantName(ident)) mul *= 10; + if (ident.startsWith('_')) mul *= 0.1; + if (definers.size > 5) mul *= 0.1; + + // 统计引用次数 + const refCounts = new Map(); + for (const ref of refs) { + refCounts.set(ref, (refCounts.get(ref) || 0) + 1); + } + + // 添加边 + for (const [referencer, numRefs] of refCounts) { + for (const definer of definers) { + let useMul = mul; + if (chatRelFnames.has(referencer)) useMul *= 50; + + graph.addEdge({ + from: referencer, + to: definer, + weight: useMul * Math.sqrt(numRefs), + ident, + }); + } + } + } + + // 运行 PageRank + const ranked = pagerank(graph, { personalization }); + + // 分配排名到定义 + const rankedDefinitions = distributeRanksToDefinitions(graph, ranked); + + // 排序 + const sorted = [...rankedDefinitions.entries()].sort((a, b) => b[1] - a[1]); + + // 收集排序后的 tags + const rankedTags: Tag[] = []; + for (const [key] of sorted) { + const colonIdx = key.lastIndexOf(':'); + const fname = key.substring(0, colonIdx); + if (chatRelFnames.has(fname)) continue; + const tags = definitions.get(key) || []; + rankedTags.push(...tags); + } + + return rankedTags; + } + + /** + * 获取文件的 tags(带缓存) + */ + private async getTags(fname: string, relFname: string): Promise { + const mtime = await this.getMtime(fname); + if (!mtime) return []; + + // 检查缓存 + const cached = await this.tagsCache.get(fname); + if (cached && cached.mtime === mtime) { + return cached.data; + } + + // 提取 tags + const tags = await this.tagExtractor.getTags(fname, relFname); + + // 更新缓存 + await this.tagsCache.set(fname, { mtime, data: tags }); + + return tags; + } + + /** + * 二分搜索拟合 token 限制 + */ + private fitToTokenLimit( + tags: Tag[], + maxTokens: number, + chatRelFnames: Set + ): string { + const numTags = tags.length; + if (numTags === 0) return ''; + + let lower = 0; + let upper = numTags; + let bestTree = ''; + let bestTokens = 0; + + let middle = Math.min(Math.floor(maxTokens / 25), numTags); + + while (lower <= upper) { + const tree = this.toTree(tags.slice(0, middle), chatRelFnames); + const numTokens = this.tokenCount(tree); + + const pctErr = Math.abs(numTokens - maxTokens) / maxTokens; + + if ((numTokens <= maxTokens && numTokens > bestTokens) || pctErr < 0.15) { + bestTree = tree; + bestTokens = numTokens; + + if (pctErr < 0.15) break; + } + + if (numTokens < maxTokens) { + lower = middle + 1; + } else { + upper = middle - 1; + } + + middle = Math.floor((lower + upper) / 2); + } + + return bestTree; + } + + /** + * 转换为树形展示 + */ + private toTree(tags: Tag[], chatRelFnames: Set): string { + if (tags.length === 0) return ''; + + const output: string[] = []; + let curFname: string | null = null; + let lois: number[] = []; + let curNames: string[] = []; + + // 按文件分组 + const sortedTags = [...tags].sort( + (a, b) => a.relFname.localeCompare(b.relFname) || a.line - b.line + ); + + for (const tag of sortedTags) { + if (chatRelFnames.has(tag.relFname)) continue; + + if (tag.relFname !== curFname) { + // 输出前一个文件 + if (curFname && (lois.length > 0 || curNames.length > 0)) { + output.push(this.renderFileTree(curFname, lois, curNames)); + } + curFname = tag.relFname; + lois = []; + curNames = []; + } + + if (tag.line >= 0) { + lois.push(tag.line); + } + curNames.push(tag.name); + } + + // 输出最后一个文件 + if (curFname && (lois.length > 0 || curNames.length > 0)) { + output.push(this.renderFileTree(curFname, lois, curNames)); + } + + return output.join('\n'); + } + + /** + * 渲染单个文件的树形展示 + */ + private renderFileTree( + relFname: string, + lois: number[], + names: string[] + ): string { + const uniqueNames = [...new Set(names)]; + const uniqueLines = [...new Set(lois)].sort((a, b) => a - b); + + const lines = [`${relFname}:`]; + + // 简化版:显示符号列表 + for (const name of uniqueNames.slice(0, 10)) { + lines.push(` - ${name}`); + } + + if (uniqueNames.length > 10) { + lines.push(` ... and ${uniqueNames.length - 10} more`); + } + + return lines.join('\n'); + } + + /** + * 判断是否是有意义的名称 + */ + private isSignificantName(name: string): boolean { + if (name.length < 8) return false; + const hasAlpha = /[a-zA-Z]/.test(name); + const isSnake = name.includes('_') && hasAlpha; + const isKebab = name.includes('-') && hasAlpha; + const isCamel = /[a-z]/.test(name) && /[A-Z]/.test(name); + return isSnake || isKebab || isCamel; + } + + /** + * 获取相对路径 + */ + private getRelFname(fname: string): string { + if (fname.startsWith(this.root)) { + return fname.slice(this.root.length + 1); + } + return fname; + } + + /** + * 获取文件修改时间 + */ + private async getMtime(fname: string): Promise { + try { + const stat = await fs.stat(fname); + return stat.mtimeMs; + } catch { + return null; + } + } + + /** + * 估算 token 数量 + * 简化估算:约 4 字符一个 token + */ + private tokenCount(text: string): number { + return Math.ceil(text.length / 4); + } + + /** + * 刷新缓存到磁盘 + */ + async flushCache(): Promise { + await this.tagsCache.flushAll(); + } + + /** + * 清空缓存 + */ + async clearCache(): Promise { + await this.tagsCache.clear(); + } +} + +/** + * 创建 RepoMap 实例 + */ +export function createRepoMap( + root: string, + config?: Partial +): RepoMap { + return new RepoMap(root, config); +} diff --git a/src/repomap/tags/extractor.ts b/src/repomap/tags/extractor.ts new file mode 100644 index 0000000..c4c83b0 --- /dev/null +++ b/src/repomap/tags/extractor.ts @@ -0,0 +1,465 @@ +/** + * Tag 提取器 + * 使用 Tree-sitter 解析代码并提取符号定义和引用 + */ + +import * as fs from 'fs/promises'; +import * as path from 'path'; +import { fileURLToPath } from 'url'; +import type { Tag } from '../types.js'; +import { getLanguageFromFilename } from '../types.js'; + +// Tree-sitter 类型(web-tree-sitter) +interface TreeSitterParser { + parse(input: string): TreeSitterTree; + setLanguage(language: TreeSitterLanguage): void; + getLanguage(): TreeSitterLanguage; +} + +interface TreeSitterTree { + rootNode: TreeSitterNode; +} + +interface TreeSitterNode { + text: string; + startPosition: { row: number; column: number }; + endPosition: { row: number; column: number }; + type: string; + childCount: number; + namedChildCount: number; + children: TreeSitterNode[]; + namedChildren: TreeSitterNode[]; +} + +interface TreeSitterLanguage { + query(source: string): TreeSitterQuery; +} + +interface TreeSitterQuery { + captures(node: TreeSitterNode): TreeSitterCapture[]; +} + +interface TreeSitterCapture { + node: TreeSitterNode; + name: string; +} + +// 动态导入 web-tree-sitter +let ParserClass: any = null; +let treeSitterInitialized = false; + +/** + * 初始化 Tree-sitter + */ +async function initTreeSitter(): Promise { + if (treeSitterInitialized) return; + + try { + const TreeSitter = await import('web-tree-sitter'); + // Parser 是命名导出的类,init 是其静态方法 + await TreeSitter.Parser.init(); + ParserClass = TreeSitter.Parser; + treeSitterInitialized = true; + } catch (error) { + console.warn('Failed to initialize tree-sitter:', error); + throw new Error('Tree-sitter initialization failed'); + } +} + +/** + * Tag 提取器类 + */ +export class TagExtractor { + private parsers: Map = new Map(); + private languages: Map = new Map(); + private queries: Map = new Map(); + private initialized = false; + private queriesDir: string; + + constructor() { + // 获取查询文件目录 + const __filename = fileURLToPath(import.meta.url); + const __dirname = path.dirname(__filename); + this.queriesDir = path.join(__dirname, 'queries'); + } + + /** + * 初始化提取器 + */ + async initialize(): Promise { + if (this.initialized) return; + + try { + await initTreeSitter(); + this.initialized = true; + } catch (error) { + // Tree-sitter 初始化失败,使用回退方案 + console.warn('Tree-sitter not available, using regex fallback'); + this.initialized = true; + } + } + + /** + * 获取文件的所有 tags + */ + async getTags(fname: string, relFname: string): Promise { + await this.initialize(); + + const lang = getLanguageFromFilename(fname); + if (!lang) return []; + + let code: string; + try { + code = await fs.readFile(fname, 'utf-8'); + } catch { + return []; + } + + // 尝试使用 Tree-sitter + if (ParserClass) { + try { + const tags = await this.extractWithTreeSitter(code, fname, relFname, lang); + if (tags.length > 0) { + return tags; + } + } catch (error) { + // Tree-sitter 解析失败,回退到正则 + } + } + + // 回退到正则表达式 + return this.extractWithRegex(code, fname, relFname, lang); + } + + /** + * 使用 Tree-sitter 提取 tags + */ + private async extractWithTreeSitter( + code: string, + fname: string, + relFname: string, + lang: string + ): Promise { + const parser = await this.getParser(lang); + const query = await this.getQuery(lang); + + if (!parser || !query) { + return []; + } + + const tree = parser.parse(code); + const captures = query.captures(tree.rootNode); + + const tags: Tag[] = []; + const seenKinds = new Set(); + + for (const capture of captures) { + const { node, name } = capture; + + let kind: 'def' | 'ref' | null = null; + if (name.startsWith('name.definition.')) { + kind = 'def'; + } else if (name.startsWith('name.reference.')) { + kind = 'ref'; + } else { + continue; + } + + seenKinds.add(kind); + + tags.push({ + relFname, + fname, + name: node.text, + kind, + line: node.startPosition.row, + }); + } + + // 如果只有 def 没有 ref,使用正则回退提取引用 + if (seenKinds.has('def') && !seenKinds.has('ref')) { + const refTags = this.extractRefsWithRegex(code, fname, relFname); + tags.push(...refTags); + } + + return tags; + } + + /** + * 获取或创建解析器 + */ + private async getParser(lang: string): Promise { + if (this.parsers.has(lang)) { + return this.parsers.get(lang)!; + } + + try { + const language = await this.getLanguage(lang); + if (!language) return null; + + const parser = new ParserClass(); + parser.setLanguage(language); + this.parsers.set(lang, parser); + return parser; + } catch { + return null; + } + } + + /** + * 获取语言定义 + */ + private async getLanguage(lang: string): Promise { + if (this.languages.has(lang)) { + return this.languages.get(lang)!; + } + + try { + // 尝试加载 WASM 语言文件 + const wasmPath = this.getWasmPath(lang); + const language = await ParserClass.Language.load(wasmPath); + this.languages.set(lang, language); + return language; + } catch { + return null; + } + } + + /** + * 获取 WASM 文件路径 + */ + private getWasmPath(lang: string): string { + // tree-sitter WASM 文件的标准位置 + const wasmName = `tree-sitter-${lang}.wasm`; + + // 尝试多个可能的位置 + const possiblePaths = [ + path.join(process.cwd(), 'node_modules', 'tree-sitter-wasms', 'out', wasmName), + path.join(process.cwd(), 'node_modules', `tree-sitter-${lang}`, wasmName), + path.join(__dirname, '..', '..', '..', 'wasm', wasmName), + ]; + + // 返回第一个路径(实际使用时需要检查存在性) + return possiblePaths[0]; + } + + /** + * 获取语言查询 + */ + private async getQuery(lang: string): Promise { + if (this.queries.has(lang)) { + return this.queries.get(lang)!; + } + + try { + const language = await this.getLanguage(lang); + if (!language) return null; + + const queryPath = path.join(this.queriesDir, `${lang}-tags.scm`); + const queryText = await fs.readFile(queryPath, 'utf-8'); + const query = language.query(queryText); + this.queries.set(lang, query); + return query; + } catch { + return null; + } + } + + /** + * 使用正则表达式提取 tags (回退方案) + */ + private extractWithRegex( + code: string, + fname: string, + relFname: string, + lang: string + ): Tag[] { + const tags: Tag[] = []; + const lines = code.split('\n'); + + // 根据语言选择正则模式 + const patterns = this.getRegexPatterns(lang); + + lines.forEach((line, lineNum) => { + for (const pattern of patterns.definitions) { + const match = line.match(pattern.regex); + if (match && match[pattern.nameGroup]) { + tags.push({ + relFname, + fname, + name: match[pattern.nameGroup], + kind: 'def', + line: lineNum, + }); + } + } + }); + + // 提取引用 + tags.push(...this.extractRefsWithRegex(code, fname, relFname)); + + return tags; + } + + /** + * 使用正则提取引用 + */ + private extractRefsWithRegex(code: string, fname: string, relFname: string): Tag[] { + const tags: Tag[] = []; + + // 简单的标识符匹配 - 排除关键字 + const keywords = new Set([ + 'if', + 'else', + 'for', + 'while', + 'do', + 'switch', + 'case', + 'break', + 'continue', + 'return', + 'function', + 'class', + 'const', + 'let', + 'var', + 'import', + 'export', + 'from', + 'default', + 'async', + 'await', + 'try', + 'catch', + 'finally', + 'throw', + 'new', + 'this', + 'super', + 'extends', + 'implements', + 'interface', + 'type', + 'enum', + 'public', + 'private', + 'protected', + 'static', + 'readonly', + 'abstract', + 'true', + 'false', + 'null', + 'undefined', + 'void', + 'never', + 'any', + 'unknown', + 'string', + 'number', + 'boolean', + 'object', + 'symbol', + 'bigint', + 'def', + 'class', + 'self', + 'None', + 'True', + 'False', + 'and', + 'or', + 'not', + 'in', + 'is', + 'lambda', + 'with', + 'as', + 'pass', + 'raise', + 'yield', + 'global', + 'nonlocal', + 'assert', + 'del', + ]); + + // 匹配 PascalCase 或 snake_case 标识符(更可能是用户定义的) + const identRegex = /\b([A-Z][a-zA-Z0-9]*|[a-z][a-zA-Z0-9]*_[a-zA-Z0-9_]*)\b/g; + let match; + + while ((match = identRegex.exec(code)) !== null) { + const name = match[1]; + if (!keywords.has(name) && name.length >= 2) { + tags.push({ + relFname, + fname, + name, + kind: 'ref', + line: -1, // 行号未知 + }); + } + } + + return tags; + } + + /** + * 获取语言的正则模式 + */ + private getRegexPatterns(lang: string): { + definitions: Array<{ regex: RegExp; nameGroup: number }>; + } { + switch (lang) { + case 'typescript': + case 'javascript': + return { + definitions: [ + // function name( + { regex: /(?:export\s+)?(?:async\s+)?function\s+(\w+)\s*[<(]/, nameGroup: 1 }, + // class Name + { regex: /(?:export\s+)?(?:abstract\s+)?class\s+(\w+)/, nameGroup: 1 }, + // interface Name + { regex: /(?:export\s+)?interface\s+(\w+)/, nameGroup: 1 }, + // type Name = + { regex: /(?:export\s+)?type\s+(\w+)\s*[<=]/, nameGroup: 1 }, + // const/let/var name = (arrow function or function) + { + regex: /(?:export\s+)?(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?(?:\([^)]*\)|[^=])\s*=>/, + nameGroup: 1, + }, + // enum Name + { regex: /(?:export\s+)?enum\s+(\w+)/, nameGroup: 1 }, + // method name( + { regex: /^\s*(?:async\s+)?(\w+)\s*\([^)]*\)\s*[:{]/, nameGroup: 1 }, + ], + }; + + case 'python': + return { + definitions: [ + // def name( + { regex: /^\s*(?:async\s+)?def\s+(\w+)\s*\(/, nameGroup: 1 }, + // class Name + { regex: /^\s*class\s+(\w+)/, nameGroup: 1 }, + ], + }; + + default: + return { + definitions: [ + // 通用:function/def/class 后面的标识符 + { regex: /(?:function|def|class)\s+(\w+)/, nameGroup: 1 }, + ], + }; + } + } +} + +/** + * 创建 Tag 提取器实例 + */ +export function createTagExtractor(): TagExtractor { + return new TagExtractor(); +} diff --git a/src/repomap/tags/index.ts b/src/repomap/tags/index.ts new file mode 100644 index 0000000..1d9b0a8 --- /dev/null +++ b/src/repomap/tags/index.ts @@ -0,0 +1,5 @@ +/** + * Tags 模块导出 + */ + +export { TagExtractor, createTagExtractor } from './extractor.js'; diff --git a/src/repomap/tags/queries/javascript-tags.scm b/src/repomap/tags/queries/javascript-tags.scm new file mode 100644 index 0000000..14827e8 --- /dev/null +++ b/src/repomap/tags/queries/javascript-tags.scm @@ -0,0 +1,65 @@ +;; JavaScript/JSX 标签查询 +;; 定义使用 @name.definition.* 前缀 +;; 引用使用 @name.reference.* 前缀 + +;; ==================== 定义 ==================== + +;; 函数声明 +(function_declaration + name: (identifier) @name.definition.function) @definition.function + +;; 生成器函数 +(generator_function_declaration + name: (identifier) @name.definition.function) @definition.function + +;; 箭头函数赋值 (const/let) +(lexical_declaration + (variable_declarator + name: (identifier) @name.definition.function + value: [(arrow_function) (function)])) @definition.function + +;; 箭头函数赋值 (var) +(variable_declaration + (variable_declarator + name: (identifier) @name.definition.function + value: [(arrow_function) (function)])) @definition.function + +;; 赋值表达式中的函数 +(assignment_expression + left: (identifier) @name.definition.function + right: [(arrow_function) (function)]) @definition.function + +(assignment_expression + left: (member_expression + property: (property_identifier) @name.definition.function) + right: [(arrow_function) (function)]) @definition.function + +;; 对象方法 +(pair + key: (property_identifier) @name.definition.function + value: [(arrow_function) (function)]) @definition.function + +;; 方法定义 +(method_definition + name: (property_identifier) @name.definition.method) @definition.method + +;; 类定义 +(class_declaration + name: (identifier) @name.definition.class) @definition.class + +(class + name: (identifier) @name.definition.class) @definition.class + +;; ==================== 引用 ==================== + +;; 函数调用 +(call_expression + function: (identifier) @name.reference.call) @reference.call + +(call_expression + function: (member_expression + property: (property_identifier) @name.reference.call)) @reference.call + +;; new 表达式 +(new_expression + constructor: (identifier) @name.reference.class) @reference.class diff --git a/src/repomap/tags/queries/python-tags.scm b/src/repomap/tags/queries/python-tags.scm new file mode 100644 index 0000000..87efa81 --- /dev/null +++ b/src/repomap/tags/queries/python-tags.scm @@ -0,0 +1,37 @@ +;; Python 标签查询 +;; 定义使用 @name.definition.* 前缀 +;; 引用使用 @name.reference.* 前缀 + +;; ==================== 定义 ==================== + +;; 函数定义 +(function_definition + name: (identifier) @name.definition.function) @definition.function + +;; 类定义 +(class_definition + name: (identifier) @name.definition.class) @definition.class + +;; 装饰器函数 (通常也是定义) +(decorated_definition + (function_definition + name: (identifier) @name.definition.function)) @definition.function + +(decorated_definition + (class_definition + name: (identifier) @name.definition.class)) @definition.class + +;; ==================== 引用 ==================== + +;; 函数调用 +(call + function: (identifier) @name.reference.call) @reference.call + +(call + function: (attribute + attribute: (identifier) @name.reference.call)) @reference.call + +;; 类继承 +(class_definition + superclasses: (argument_list + (identifier) @name.reference.class)) @reference.class diff --git a/src/repomap/tags/queries/typescript-tags.scm b/src/repomap/tags/queries/typescript-tags.scm new file mode 100644 index 0000000..139fa41 --- /dev/null +++ b/src/repomap/tags/queries/typescript-tags.scm @@ -0,0 +1,90 @@ +;; TypeScript/TSX 标签查询 +;; 定义使用 @name.definition.* 前缀 +;; 引用使用 @name.reference.* 前缀 + +;; ==================== 定义 ==================== + +;; 函数定义 +(function_declaration + name: (identifier) @name.definition.function) @definition.function + +(function_signature + name: (identifier) @name.definition.function) @definition.function + +;; 箭头函数赋值 +(lexical_declaration + (variable_declarator + name: (identifier) @name.definition.function + value: (arrow_function))) @definition.function + +(variable_declaration + (variable_declarator + name: (identifier) @name.definition.function + value: (arrow_function))) @definition.function + +;; 方法定义 +(method_definition + name: (property_identifier) @name.definition.method) @definition.method + +(method_signature + name: (property_identifier) @name.definition.method) @definition.method + +(abstract_method_signature + name: (property_identifier) @name.definition.method) @definition.method + +;; 类定义 +(class_declaration + name: (type_identifier) @name.definition.class) @definition.class + +(abstract_class_declaration + name: (type_identifier) @name.definition.class) @definition.class + +;; 接口定义 +(interface_declaration + name: (type_identifier) @name.definition.interface) @definition.interface + +;; 类型别名 +(type_alias_declaration + name: (type_identifier) @name.definition.type) @definition.type + +;; 枚举 +(enum_declaration + name: (identifier) @name.definition.enum) @definition.enum + +;; 模块/命名空间 +(module + name: (identifier) @name.definition.module) @definition.module + +(internal_module + name: (identifier) @name.definition.module) @definition.module + +;; ==================== 引用 ==================== + +;; 类型注解引用 +(type_annotation + (type_identifier) @name.reference.type) @reference.type + +;; 类型参数中的引用 +(type_arguments + (type_identifier) @name.reference.type) @reference.type + +;; extends/implements +(class_heritage + (extends_clause + value: (identifier) @name.reference.class)) @reference.class + +(class_heritage + (implements_clause + (type_identifier) @name.reference.interface)) @reference.interface + +;; new 表达式 +(new_expression + constructor: (identifier) @name.reference.class) @reference.class + +;; 函数调用 +(call_expression + function: (identifier) @name.reference.call) @reference.call + +(call_expression + function: (member_expression + property: (property_identifier) @name.reference.call)) @reference.call diff --git a/src/repomap/types.ts b/src/repomap/types.ts new file mode 100644 index 0000000..5962991 --- /dev/null +++ b/src/repomap/types.ts @@ -0,0 +1,142 @@ +/** + * RepoMap 类型定义 + * 基于 Aider 的 RepoMap 实现 + */ + +/** + * 代码标签 - 对应 Aider 的 Tag + */ +export interface Tag { + /** 相对路径 */ + relFname: string; + /** 绝对路径 */ + fname: string; + /** 行号 (0-indexed, -1 表示未知) */ + line: number; + /** 符号名称 */ + name: string; + /** 定义或引用 */ + kind: 'def' | 'ref'; +} + +/** + * 文件缓存条目 + */ +export interface TagCacheEntry { + /** 文件修改时间 (ms) */ + mtime: number; + /** 标签数据 */ + data: Tag[]; +} + +/** + * RepoMap 配置 + */ +export interface RepoMapConfig { + /** 目标 token 数量 */ + mapTokens: number; + /** 无聊天文件时的乘数 */ + mapMulNoFiles: number; + /** 最大上下文窗口 */ + maxContextWindow: number; + /** 刷新策略 */ + refresh: 'auto' | 'always' | 'files' | 'manual'; + /** 缓存目录 */ + cacheDir: string; + /** 是否详细输出 */ + verbose: boolean; + /** 排除的文件模式 */ + exclude: string[]; + /** 包含的文件模式 */ + include: string[]; +} + +/** + * 默认配置 + */ +export const DEFAULT_REPOMAP_CONFIG: RepoMapConfig = { + mapTokens: 2048, + mapMulNoFiles: 8, + maxContextWindow: 128000, + refresh: 'auto', + cacheDir: '.ai-assist/tags-cache', + verbose: false, + exclude: [ + 'node_modules/**', + 'dist/**', + 'build/**', + '.git/**', + '*.test.*', + '*.spec.*', + '**/*.d.ts', + ], + include: ['**/*.ts', '**/*.tsx', '**/*.js', '**/*.jsx', '**/*.py'], +}; + +/** + * 图边 + */ +export interface GraphEdge { + /** 引用文件 */ + from: string; + /** 定义文件 */ + to: string; + /** 边权重 */ + weight: number; + /** 符号名 */ + ident: string; +} + +/** + * 排序后的定义 + */ +export interface RankedDefinition { + fname: string; + ident: string; + rank: number; + tags: Tag[]; +} + +/** + * PageRank 算法配置 + */ +export interface PageRankOptions { + /** 阻尼系数 (默认 0.85) */ + damping?: number; + /** 最大迭代次数 (默认 100) */ + maxIterations?: number; + /** 收敛容差 (默认 1e-6) */ + tolerance?: number; + /** 个性化向量 */ + personalization?: Map; +} + +/** + * 支持的语言映射 + */ +export const LANGUAGE_MAP: Record = { + '.ts': 'typescript', + '.tsx': 'typescript', + '.js': 'javascript', + '.jsx': 'javascript', + '.mjs': 'javascript', + '.cjs': 'javascript', + '.py': 'python', + '.go': 'go', + '.rs': 'rust', + '.java': 'java', + '.rb': 'ruby', + '.cpp': 'cpp', + '.cc': 'cpp', + '.c': 'c', + '.h': 'c', + '.hpp': 'cpp', +}; + +/** + * 获取文件语言 + */ +export function getLanguageFromFilename(filename: string): string | null { + const ext = filename.substring(filename.lastIndexOf('.')).toLowerCase(); + return LANGUAGE_MAP[ext] || null; +} diff --git a/src/tools/descriptions/repo_map.txt b/src/tools/descriptions/repo_map.txt new file mode 100644 index 0000000..eef1a81 --- /dev/null +++ b/src/tools/descriptions/repo_map.txt @@ -0,0 +1,20 @@ +Generate a repository map showing the most relevant code symbols (functions, classes, methods) based on AST analysis and PageRank ranking. + +This tool analyzes the codebase structure using Tree-sitter AST parsing and ranks symbols by their relevance using the PageRank algorithm. It helps AI understand the code structure by showing: +- Important function and class definitions +- Their relationships through references +- Prioritized by relevance to the current context + +Parameters: +- directory: Directory to analyze (default: current working directory) +- chat_files: Files currently being discussed (excluded from map output) +- mentioned_files: File names mentioned in conversation (boost relevance) +- mentioned_identifiers: Symbol names mentioned (boost relevance) +- max_tokens: Maximum output token count (default: 1024) + +The output shows relevant code symbols organized by file, helping to understand: +- What functions/classes exist in the codebase +- How different parts of the code relate to each other +- Which symbols are most important based on usage patterns + +Use this tool when you need to understand the overall structure of a codebase or find relevant code locations for a task. diff --git a/src/tools/index.ts b/src/tools/index.ts index 639927a..46b3000 100644 --- a/src/tools/index.ts +++ b/src/tools/index.ts @@ -46,6 +46,9 @@ import { gitStashTool, } from './git/index.js'; +// RepoMap 工具 +import { repoMapTool } from './repomap/index.js'; + // 所有工具列表(用于注册) const allToolsWithMetadata: ToolWithMetadata[] = [ // 核心工具 (deferLoading: false) @@ -88,6 +91,9 @@ const allToolsWithMetadata: ToolWithMetadata[] = [ gitPullTool, gitCheckoutTool, gitStashTool, + + // RepoMap 工具 (deferLoading: true) + repoMapTool, ]; // 注册所有工具到 registry diff --git a/src/tools/repomap/index.ts b/src/tools/repomap/index.ts new file mode 100644 index 0000000..3db0e61 --- /dev/null +++ b/src/tools/repomap/index.ts @@ -0,0 +1,5 @@ +/** + * RepoMap 工具模块 + */ + +export { repoMapTool } from './repo_map.js'; diff --git a/src/tools/repomap/repo_map.ts b/src/tools/repomap/repo_map.ts new file mode 100644 index 0000000..34c981c --- /dev/null +++ b/src/tools/repomap/repo_map.ts @@ -0,0 +1,261 @@ +/** + * RepoMap 工具 + * 生成代码仓库的上下文地图,帮助 AI 理解代码结构 + */ + +import * as path from 'path'; +import * as fs from 'fs/promises'; +import type { ToolResult } from '../../types/index.js'; +import type { ToolWithMetadata } from '../types.js'; +import { loadDescription } from '../load_description.js'; +import { getPermissionManager } from '../../permission/index.js'; +import { createRepoMap, type RepoMapConfig } from '../../repomap/index.js'; + +// 缓存 RepoMap 实例 +const repoMapCache = new Map>(); + +/** + * 获取或创建 RepoMap 实例 + */ +function getRepoMap(root: string, config?: Partial) { + const key = root; + if (!repoMapCache.has(key)) { + repoMapCache.set(key, createRepoMap(root, config)); + } + return repoMapCache.get(key)!; +} + +/** + * 支持的文件扩展名 + */ +const SUPPORTED_EXTENSIONS = new Set([ + '.ts', + '.tsx', + '.js', + '.jsx', + '.mjs', + '.cjs', + '.py', +]); + +/** + * 要排除的目录 + */ +const EXCLUDED_DIRS = new Set([ + 'node_modules', + 'dist', + 'build', + '.git', + '.next', + '__pycache__', + '.pytest_cache', + 'coverage', + '.nyc_output', +]); + +/** + * 要排除的文件模式 + */ +function shouldExcludeFile(filename: string): boolean { + return ( + filename.includes('.test.') || + filename.includes('.spec.') || + filename.endsWith('.d.ts') + ); +} + +/** + * 递归获取目录中的文件 + */ +async function getFilesRecursive( + directory: string, + maxDepth = 10 +): Promise { + const files: string[] = []; + + async function walk(dir: string, depth: number): Promise { + if (depth > maxDepth) return; + + try { + const entries = await fs.readdir(dir, { withFileTypes: true }); + + for (const entry of entries) { + const fullPath = path.join(dir, entry.name); + + if (entry.isDirectory()) { + if (!EXCLUDED_DIRS.has(entry.name) && !entry.name.startsWith('.')) { + await walk(fullPath, depth + 1); + } + } else if (entry.isFile()) { + const ext = path.extname(entry.name).toLowerCase(); + if (SUPPORTED_EXTENSIONS.has(ext) && !shouldExcludeFile(entry.name)) { + files.push(fullPath); + } + } + } + } catch { + // 忽略无法访问的目录 + } + } + + await walk(directory, 0); + return files; +} + +export const repoMapTool: ToolWithMetadata = { + name: 'repo_map', + description: loadDescription('repo_map'), + metadata: { + name: 'repo_map', + category: 'core', + description: '生成代码仓库上下文地图,帮助理解代码结构', + keywords: [ + 'repomap', + 'repo', + 'map', + 'context', + 'code', + 'structure', + 'ast', + 'symbol', + 'definition', + '代码地图', + '仓库结构', + '代码结构', + '符号', + '定义', + ], + deferLoading: true, + }, + parameters: { + directory: { + type: 'string', + description: '要分析的目录路径(默认为当前工作目录)', + required: false, + }, + chat_files: { + type: 'string', + description: + '当前对话中涉及的文件列表,逗号分隔(这些文件会被排除在地图外)', + required: false, + }, + mentioned_files: { + type: 'string', + description: '对话中提到的文件名,逗号分隔(用于提高相关性)', + required: false, + }, + mentioned_identifiers: { + type: 'string', + description: '对话中提到的标识符/符号名,逗号分隔(用于提高相关性)', + required: false, + }, + max_tokens: { + type: 'number', + description: '最大输出 token 数(默认 1024)', + required: false, + }, + }, + execute: async (params: Record): Promise => { + const cwd = process.cwd(); + const directory = (params.directory as string) || cwd; + const absolutePath = path.isAbsolute(directory) + ? directory + : path.join(cwd, directory); + + // 解析逗号分隔的字符串为数组 + const chatFilesStr = (params.chat_files as string) || ''; + const chatFiles = chatFilesStr + ? chatFilesStr.split(',').map((s) => s.trim()) + : []; + + const mentionedFilesStr = (params.mentioned_files as string) || ''; + const mentionedFiles = new Set( + mentionedFilesStr ? mentionedFilesStr.split(',').map((s) => s.trim()) : [] + ); + + const mentionedIdentsStr = (params.mentioned_identifiers as string) || ''; + const mentionedIdents = new Set( + mentionedIdentsStr ? mentionedIdentsStr.split(',').map((s) => s.trim()) : [] + ); + + const maxTokens = (params.max_tokens as number) || 1024; + + // 权限检查 + const permissionManager = getPermissionManager(); + const permResult = await permissionManager.checkFilePermission({ + operation: 'search', + path: absolutePath, + workdir: cwd, + }); + + if (!permResult.allowed) { + if (permResult.needsConfirmation) { + return { + success: false, + output: '', + error: `需要用户确认: 分析目录 ${absolutePath}\n原因: ${permResult.reason || '需要权限确认'}`, + }; + } + return { + success: false, + output: '', + error: `权限被拒绝: ${permResult.reason || '不允许分析此目录'}`, + }; + } + + try { + // 检查目录是否存在 + const stat = await fs.stat(absolutePath); + if (!stat.isDirectory()) { + return { + success: false, + output: '', + error: `${absolutePath} 不是一个目录`, + }; + } + + // 获取 RepoMap 实例 + const repoMap = getRepoMap(absolutePath, { + mapTokens: maxTokens, + }); + + // 获取文件列表 + const allFiles = await getFilesRecursive(absolutePath); + + // 分离聊天文件和其他文件 + const chatFilesAbs = chatFiles.map((f) => + path.isAbsolute(f) ? f : path.join(absolutePath, f) + ); + const otherFiles = allFiles.filter((f) => !chatFilesAbs.includes(f)); + + // 生成 repo map + const mapContent = await repoMap.getRepoMap( + chatFilesAbs, + otherFiles, + mentionedFiles, + mentionedIdents + ); + + if (!mapContent || mapContent.trim() === '') { + return { + success: true, + output: + '未找到相关的代码符号。可能是因为目录中没有支持的代码文件,或者文件过少。', + }; + } + + const header = `# Repository Map\n# Directory: ${absolutePath}\n# Files analyzed: ${allFiles.length}\n\n`; + + return { + success: true, + output: header + mapContent, + }; + } catch (error) { + return { + success: false, + output: '', + error: error instanceof Error ? error.message : String(error), + }; + } + }, +}; diff --git a/tests/repomap/repomap.test.ts b/tests/repomap/repomap.test.ts new file mode 100644 index 0000000..47a6a78 --- /dev/null +++ b/tests/repomap/repomap.test.ts @@ -0,0 +1,329 @@ +/** + * RepoMap 测试 + */ + +import { describe, it, expect, beforeAll } from 'vitest'; +import * as path from 'path'; +import * as fs from 'fs/promises'; +import * as os from 'os'; +import { createRepoMap } from '../../src/repomap/repomap.js'; +import { TagExtractor } from '../../src/repomap/tags/extractor.js'; +import { Graph, pagerank, distributeRanksToDefinitions } from '../../src/repomap/ranking/index.js'; +import { DiskCache } from '../../src/repomap/cache/disk-cache.js'; + +describe('Graph', () => { + it('should add edges and track nodes', () => { + const graph = new Graph(); + + graph.addEdge({ from: 'a.ts', to: 'b.ts', weight: 1, ident: 'foo' }); + graph.addEdge({ from: 'a.ts', to: 'c.ts', weight: 2, ident: 'bar' }); + graph.addEdge({ from: 'b.ts', to: 'c.ts', weight: 1, ident: 'baz' }); + + const nodes = graph.getNodes(); + expect(nodes).toContain('a.ts'); + expect(nodes).toContain('b.ts'); + expect(nodes).toContain('c.ts'); + expect(nodes.length).toBe(3); + }); + + it('should track in and out edges correctly', () => { + const graph = new Graph(); + + graph.addEdge({ from: 'a.ts', to: 'b.ts', weight: 1, ident: 'foo' }); + graph.addEdge({ from: 'a.ts', to: 'c.ts', weight: 2, ident: 'bar' }); + + const outEdges = graph.getOutEdges('a.ts'); + expect(outEdges.length).toBe(2); + + const inEdges = graph.getInEdges('b.ts'); + expect(inEdges.length).toBe(1); + expect(inEdges[0].from).toBe('a.ts'); + }); + + it('should calculate out degree correctly', () => { + const graph = new Graph(); + + graph.addEdge({ from: 'a.ts', to: 'b.ts', weight: 1, ident: 'foo' }); + graph.addEdge({ from: 'a.ts', to: 'c.ts', weight: 2, ident: 'bar' }); + + expect(graph.getOutDegree('a.ts')).toBe(3); // sum of weights + expect(graph.getOutDegree('b.ts')).toBe(0); + }); +}); + +describe('PageRank', () => { + it('should compute ranks for a simple graph', () => { + const graph = new Graph(); + + // 创建一个简单的图: a -> b -> c + graph.addEdge({ from: 'a.ts', to: 'b.ts', weight: 1, ident: 'foo' }); + graph.addEdge({ from: 'b.ts', to: 'c.ts', weight: 1, ident: 'bar' }); + + const ranks = pagerank(graph); + + expect(ranks.size).toBe(3); + expect(ranks.get('a.ts')).toBeGreaterThan(0); + expect(ranks.get('b.ts')).toBeGreaterThan(0); + expect(ranks.get('c.ts')).toBeGreaterThan(0); + }); + + it('should respect personalization vector', () => { + const graph = new Graph(); + + graph.addEdge({ from: 'a.ts', to: 'b.ts', weight: 1, ident: 'foo' }); + graph.addEdge({ from: 'b.ts', to: 'c.ts', weight: 1, ident: 'bar' }); + graph.addEdge({ from: 'c.ts', to: 'a.ts', weight: 1, ident: 'baz' }); + + // 高度偏好 a.ts + const personalization = new Map(); + personalization.set('a.ts', 100); + personalization.set('b.ts', 1); + personalization.set('c.ts', 1); + + const ranks = pagerank(graph, { personalization }); + + // a.ts 应该有较高的 rank + expect(ranks.get('a.ts')!).toBeGreaterThan(ranks.get('b.ts')!); + }); + + it('should distribute ranks to definitions', () => { + const graph = new Graph(); + + graph.addEdge({ from: 'a.ts', to: 'b.ts', weight: 1, ident: 'foo' }); + graph.addEdge({ from: 'a.ts', to: 'b.ts', weight: 2, ident: 'bar' }); + + const nodeRanks = pagerank(graph); + const defRanks = distributeRanksToDefinitions(graph, nodeRanks); + + expect(defRanks.has('b.ts:foo')).toBe(true); + expect(defRanks.has('b.ts:bar')).toBe(true); + }); +}); + +describe('DiskCache', () => { + let cacheDir: string; + let cache: DiskCache<{ value: number }>; + + beforeAll(async () => { + cacheDir = path.join(os.tmpdir(), `repomap-test-${Date.now()}`); + cache = new DiskCache(cacheDir); + }); + + it('should set and get values', async () => { + await cache.set('key1', { value: 42 }); + const result = await cache.get('key1'); + + expect(result).not.toBeNull(); + expect(result?.value).toBe(42); + }); + + it('should return null for missing keys', async () => { + const result = await cache.get('nonexistent'); + expect(result).toBeNull(); + }); + + it('should delete values', async () => { + await cache.set('key2', { value: 100 }); + await cache.delete('key2'); + + const result = await cache.get('key2'); + expect(result).toBeNull(); + }); + + it('should check existence with has', async () => { + await cache.set('key3', { value: 200 }); + + expect(await cache.has('key3')).toBe(true); + expect(await cache.has('nonexistent2')).toBe(false); + }); +}); + +describe('TagExtractor', () => { + let tempDir: string; + + beforeAll(async () => { + tempDir = path.join(os.tmpdir(), `repomap-extractor-test-${Date.now()}`); + await fs.mkdir(tempDir, { recursive: true }); + }); + + it('should extract tags from TypeScript code using regex fallback', async () => { + const testFile = path.join(tempDir, 'test.ts'); + const code = ` +export function greet(name: string): string { + return \`Hello, \${name}!\`; +} + +export class Greeter { + private name: string; + + constructor(name: string) { + this.name = name; + } + + greet(): string { + return \`Hello, \${this.name}!\`; + } +} + +export interface IGreeter { + greet(): string; +} +`; + + await fs.writeFile(testFile, code); + + const extractor = new TagExtractor(); + const tags = await extractor.getTags(testFile, 'test.ts'); + + // 检查是否提取到了一些 tags + expect(tags.length).toBeGreaterThan(0); + + // 检查是否有定义 + const defs = tags.filter((t) => t.kind === 'def'); + expect(defs.length).toBeGreaterThan(0); + + // 检查是否包含函数和类 + const defNames = defs.map((t) => t.name); + expect(defNames).toContain('greet'); + expect(defNames).toContain('Greeter'); + }); + + it('should extract tags from Python code using regex fallback', async () => { + const testFile = path.join(tempDir, 'test.py'); + const code = ` +def greet(name: str) -> str: + return f"Hello, {name}!" + +class Greeter: + def __init__(self, name: str): + self.name = name + + def greet(self) -> str: + return f"Hello, {self.name}!" + +async def async_greet(name: str) -> str: + return f"Hello async, {name}!" +`; + + await fs.writeFile(testFile, code); + + const extractor = new TagExtractor(); + const tags = await extractor.getTags(testFile, 'test.py'); + + // 检查是否提取到了一些 tags + expect(tags.length).toBeGreaterThan(0); + + // 检查是否有定义 + const defs = tags.filter((t) => t.kind === 'def'); + expect(defs.length).toBeGreaterThan(0); + + // 检查是否包含函数和类 + const defNames = defs.map((t) => t.name); + expect(defNames).toContain('greet'); + expect(defNames).toContain('Greeter'); + expect(defNames).toContain('async_greet'); + }); +}); + +describe('RepoMap', () => { + let tempDir: string; + + beforeAll(async () => { + tempDir = path.join(os.tmpdir(), `repomap-test-${Date.now()}`); + await fs.mkdir(tempDir, { recursive: true }); + + // 创建测试文件 + await fs.writeFile( + path.join(tempDir, 'utils.ts'), + ` +export function add(a: number, b: number): number { + return a + b; +} + +export function multiply(a: number, b: number): number { + return a * b; +} +` + ); + + await fs.writeFile( + path.join(tempDir, 'calculator.ts'), + ` +import { add, multiply } from './utils'; + +export class Calculator { + add(a: number, b: number): number { + return add(a, b); + } + + multiply(a: number, b: number): number { + return multiply(a, b); + } +} +` + ); + + await fs.writeFile( + path.join(tempDir, 'main.ts'), + ` +import { Calculator } from './calculator'; + +const calc = new Calculator(); +console.log(calc.add(1, 2)); +console.log(calc.multiply(3, 4)); +` + ); + }); + + it('should create a repo map', async () => { + const repoMap = createRepoMap(tempDir, { + mapTokens: 2048, + }); + + const files = [ + path.join(tempDir, 'utils.ts'), + path.join(tempDir, 'calculator.ts'), + path.join(tempDir, 'main.ts'), + ]; + + const map = await repoMap.getRepoMap([], files, new Set(), new Set()); + + // 应该生成了一些内容 + expect(map.length).toBeGreaterThan(0); + }); + + it('should boost mentioned identifiers', async () => { + const repoMap = createRepoMap(tempDir, { + mapTokens: 2048, + }); + + const files = [ + path.join(tempDir, 'utils.ts'), + path.join(tempDir, 'calculator.ts'), + path.join(tempDir, 'main.ts'), + ]; + + const mentionedIdents = new Set(['Calculator']); + const map = await repoMap.getRepoMap([], files, new Set(), mentionedIdents); + + // 应该包含 Calculator 相关内容 + expect(map.length).toBeGreaterThan(0); + }); + + it('should exclude chat files from output', async () => { + const repoMap = createRepoMap(tempDir, { + mapTokens: 2048, + }); + + const chatFiles = [path.join(tempDir, 'utils.ts')]; + const otherFiles = [ + path.join(tempDir, 'calculator.ts'), + path.join(tempDir, 'main.ts'), + ]; + + const map = await repoMap.getRepoMap(chatFiles, otherFiles, new Set(), new Set()); + + // utils.ts 不应该出现在输出中 + expect(map).not.toContain('utils.ts:'); + }); +});