feat: 添加 AST RepoMap 代码仓库地图功能

- 实现基于 Tree-sitter 的代码符号提取 (支持 TS/JS/Python)
- 实现 PageRank 算法进行符号相关性排序
- 支持个性化权重调整 (提及的标识符、聊天文件等)
- 添加磁盘缓存避免重复解析
- 集成 repo_map 工具到工具系统
- 添加 15 个单元测试
This commit is contained in:
2025-12-11 22:26:43 +08:00
parent 4beaf088d0
commit 9818e02ed1
19 changed files with 2352 additions and 0 deletions
+6
View File
@@ -24,3 +24,9 @@ npm-debug.log*
# Test coverage # Test coverage
coverage/ coverage/
# AI Open reference code
ai-open/
# Design docs (internal)
docs/
+218
View File
@@ -0,0 +1,218 @@
/**
* 磁盘缓存实现
* 使用 JSON 文件存储,支持按文件路径索引
*/
import * as fs from 'fs/promises';
import * as path from 'path';
import { createHash } from 'crypto';
export interface CacheEntry<T> {
key: string;
value: T;
timestamp: number;
}
/**
* 磁盘缓存类
*/
export class DiskCache<T> {
private cacheDir: string;
private memoryCache: Map<string, T> = new Map();
private dirty: Set<string> = new Set();
private initialized = false;
constructor(cacheDir: string) {
this.cacheDir = cacheDir;
}
/**
* 初始化缓存目录
*/
private async ensureDir(): Promise<void> {
if (this.initialized) return;
try {
await fs.mkdir(this.cacheDir, { recursive: true });
this.initialized = true;
} catch (error) {
console.warn(`Failed to create cache directory: ${this.cacheDir}`, error);
}
}
/**
* 生成缓存文件路径
*/
private getCacheFilePath(key: string): string {
// 使用哈希避免文件名过长或包含特殊字符
const hash = createHash('md5').update(key).digest('hex');
return path.join(this.cacheDir, `${hash}.json`);
}
/**
* 获取缓存值
*/
async get(key: string): Promise<T | null> {
// 先检查内存缓存
if (this.memoryCache.has(key)) {
return this.memoryCache.get(key)!;
}
await this.ensureDir();
// 从磁盘读取
const filePath = this.getCacheFilePath(key);
try {
const content = await fs.readFile(filePath, 'utf-8');
const entry: CacheEntry<T> = JSON.parse(content);
// 验证 key 匹配
if (entry.key === key) {
this.memoryCache.set(key, entry.value);
return entry.value;
}
} catch {
// 文件不存在或解析失败
}
return null;
}
/**
* 设置缓存值
*/
async set(key: string, value: T): Promise<void> {
this.memoryCache.set(key, value);
this.dirty.add(key);
// 立即写入磁盘(可以优化为批量写入)
await this.flush(key);
}
/**
* 删除缓存
*/
async delete(key: string): Promise<void> {
this.memoryCache.delete(key);
this.dirty.delete(key);
await this.ensureDir();
const filePath = this.getCacheFilePath(key);
try {
await fs.unlink(filePath);
} catch {
// 文件不存在
}
}
/**
* 检查缓存是否存在
*/
async has(key: string): Promise<boolean> {
if (this.memoryCache.has(key)) {
return true;
}
const value = await this.get(key);
return value !== null;
}
/**
* 刷新指定 key 到磁盘
*/
private async flush(key: string): Promise<void> {
if (!this.dirty.has(key)) return;
await this.ensureDir();
const value = this.memoryCache.get(key);
if (value === undefined) return;
const entry: CacheEntry<T> = {
key,
value,
timestamp: Date.now(),
};
const filePath = this.getCacheFilePath(key);
try {
await fs.writeFile(filePath, JSON.stringify(entry), 'utf-8');
this.dirty.delete(key);
} catch (error) {
console.warn(`Failed to write cache: ${key}`, error);
}
}
/**
* 刷新所有脏数据到磁盘
*/
async flushAll(): Promise<void> {
const keys = Array.from(this.dirty);
await Promise.all(keys.map((key) => this.flush(key)));
}
/**
* 清空所有缓存
*/
async clear(): Promise<void> {
this.memoryCache.clear();
this.dirty.clear();
try {
const files = await fs.readdir(this.cacheDir);
await Promise.all(
files
.filter((f) => f.endsWith('.json'))
.map((f) => fs.unlink(path.join(this.cacheDir, f)))
);
} catch {
// 目录不存在或其他错误
}
}
/**
* 获取缓存大小(条目数)
*/
size(): number {
return this.memoryCache.size;
}
/**
* 获取所有缓存的 keys
*/
async keys(): Promise<string[]> {
await this.ensureDir();
const result: string[] = Array.from(this.memoryCache.keys());
try {
const files = await fs.readdir(this.cacheDir);
for (const file of files) {
if (!file.endsWith('.json')) continue;
const filePath = path.join(this.cacheDir, file);
try {
const content = await fs.readFile(filePath, 'utf-8');
const entry: CacheEntry<T> = JSON.parse(content);
if (!result.includes(entry.key)) {
result.push(entry.key);
}
} catch {
// 跳过无效文件
}
}
} catch {
// 目录不存在
}
return result;
}
}
/**
* 创建磁盘缓存实例
*/
export function createDiskCache<T>(cacheDir: string): DiskCache<T> {
return new DiskCache<T>(cacheDir);
}
+6
View File
@@ -0,0 +1,6 @@
/**
* 缓存模块导出
*/
export { DiskCache, createDiskCache } from './disk-cache.js';
export type { CacheEntry } from './disk-cache.js';
+26
View File
@@ -0,0 +1,26 @@
/**
* RepoMap 模块
*
* 使用 AST 分析和 PageRank 算法生成代码仓库地图
* 参考 Aider 的实现
*/
// 主类
export { RepoMap, createRepoMap } from './repomap.js';
// Tag 提取
export { TagExtractor } from './tags/index.js';
// PageRank 排序
export {
Graph,
pagerank,
distributeRanksToDefinitions,
type PageRankOptions,
} from './ranking/index.js';
// 缓存
export { DiskCache, createDiskCache } from './cache/index.js';
// 类型
export type { Tag, TagCacheEntry, RepoMapConfig, GraphEdge } from './types.js';
+99
View File
@@ -0,0 +1,99 @@
/**
* 图数据结构
* 用于 PageRank 算法
*/
import type { GraphEdge } from '../types.js';
export class Graph {
/** 邻接表:from -> edges[] */
private outEdges: Map<string, GraphEdge[]> = new Map();
/** 反向邻接表:to -> edges[] */
private inEdges: Map<string, GraphEdge[]> = new Map();
/** 所有节点 */
private nodes: Set<string> = new Set();
/**
* 添加边
*/
addEdge(edge: GraphEdge): void {
this.nodes.add(edge.from);
this.nodes.add(edge.to);
// 出边
if (!this.outEdges.has(edge.from)) {
this.outEdges.set(edge.from, []);
}
this.outEdges.get(edge.from)!.push(edge);
// 入边
if (!this.inEdges.has(edge.to)) {
this.inEdges.set(edge.to, []);
}
this.inEdges.get(edge.to)!.push(edge);
}
/**
* 获取所有节点
*/
getNodes(): string[] {
return Array.from(this.nodes);
}
/**
* 获取节点的出边
*/
getOutEdges(node: string): GraphEdge[] {
return this.outEdges.get(node) || [];
}
/**
* 获取节点的入边
*/
getInEdges(node: string): GraphEdge[] {
return this.inEdges.get(node) || [];
}
/**
* 获取节点的出度(考虑权重)
*/
getOutDegree(node: string): number {
const edges = this.outEdges.get(node) || [];
return edges.reduce((sum, e) => sum + e.weight, 0);
}
/**
* 获取节点的入度(考虑权重)
*/
getInDegree(node: string): number {
const edges = this.inEdges.get(node) || [];
return edges.reduce((sum, e) => sum + e.weight, 0);
}
/**
* 获取边数量
*/
getEdgeCount(): number {
let count = 0;
for (const edges of this.outEdges.values()) {
count += edges.length;
}
return count;
}
/**
* 清空图
*/
clear(): void {
this.outEdges.clear();
this.inEdges.clear();
this.nodes.clear();
}
/**
* 是否为空
*/
isEmpty(): boolean {
return this.nodes.size === 0;
}
}
+7
View File
@@ -0,0 +1,7 @@
/**
* 排序模块导出
*/
export { Graph } from './graph.js';
export { pagerank, distributeRanksToDefinitions } from './pagerank.js';
export type { PageRankOptions } from './pagerank.js';
+146
View File
@@ -0,0 +1,146 @@
/**
* PageRank 算法实现
* 基于 Aider 的实现,用于代码符号相关性排序
*/
import { Graph } from './graph.js';
export interface PageRankOptions {
/** 阻尼系数 (默认 0.85) */
damping?: number;
/** 最大迭代次数 (默认 100) */
iterations?: number;
/** 收敛阈值 (默认 1e-6) */
tolerance?: number;
/** 个性化向量:节点 -> 初始权重 */
personalization?: Map<string, number>;
}
/**
* PageRank 算法
*
* @param graph - 图结构
* @param options - 算法选项
* @returns 节点排名 Map<节点, 排名值>
*/
export function pagerank(
graph: Graph,
options: PageRankOptions = {}
): Map<string, number> {
const {
damping = 0.85,
iterations = 100,
tolerance = 1e-6,
personalization,
} = options;
const nodes = graph.getNodes();
const n = nodes.length;
if (n === 0) {
return new Map();
}
// 初始化排名
let ranks = new Map<string, number>();
const baseRank = 1 / n;
// 处理个性化向量
let persVector = new Map<string, number>();
if (personalization && personalization.size > 0) {
// 归一化个性化向量
const total = Array.from(personalization.values()).reduce((a, b) => a + b, 0);
if (total > 0) {
for (const [node, value] of personalization) {
persVector.set(node, value / total);
}
}
} else {
// 均匀分布
for (const node of nodes) {
persVector.set(node, baseRank);
}
}
// 初始排名 = 个性化向量
for (const node of nodes) {
ranks.set(node, persVector.get(node) || baseRank);
}
// 迭代计算
for (let iter = 0; iter < iterations; iter++) {
const newRanks = new Map<string, number>();
let diff = 0;
// 计算悬挂节点的贡献(没有出边的节点)
let danglingSum = 0;
for (const node of nodes) {
const outEdges = graph.getOutEdges(node);
if (outEdges.length === 0) {
danglingSum += ranks.get(node) || 0;
}
}
for (const node of nodes) {
// 基础分数:(1 - damping) * 个性化 + damping * 悬挂贡献
let rank =
(1 - damping) * (persVector.get(node) || baseRank) +
(damping * danglingSum) / n;
// 收集入边贡献
const inEdges = graph.getInEdges(node);
for (const edge of inEdges) {
const sourceRank = ranks.get(edge.from) || 0;
const outDegree = graph.getOutDegree(edge.from);
if (outDegree > 0) {
// 边权重占源节点总出度的比例
rank += damping * sourceRank * (edge.weight / outDegree);
}
}
newRanks.set(node, rank);
diff += Math.abs(rank - (ranks.get(node) || 0));
}
ranks = newRanks;
// 检查收敛
if (diff < tolerance) {
break;
}
}
return ranks;
}
/**
* 将 PageRank 排名分配到定义上
* 按照 Aider 的方式:将源节点的排名按边权重比例分配给目标定义
*
* @param graph - 图结构
* @param nodeRanks - 节点 PageRank 排名
* @returns 定义排名 Map<"file:ident", rank>
*/
export function distributeRanksToDefinitions(
graph: Graph,
nodeRanks: Map<string, number>
): Map<string, number> {
const definitionRanks = new Map<string, number>();
for (const src of graph.getNodes()) {
const srcRank = nodeRanks.get(src) || 0;
const outEdges = graph.getOutEdges(src);
const totalWeight = outEdges.reduce((sum, e) => sum + e.weight, 0);
if (totalWeight === 0) continue;
for (const edge of outEdges) {
const edgeRank = (srcRank * edge.weight) / totalWeight;
const key = `${edge.to}:${edge.ident}`;
definitionRanks.set(key, (definitionRanks.get(key) || 0) + edgeRank);
}
}
return definitionRanks;
}
+419
View File
@@ -0,0 +1,419 @@
/**
* RepoMap 主类
* 使用 AST 分析和 PageRank 算法生成代码仓库地图
*/
import * as fs from 'fs/promises';
import * as path from 'path';
import { TagExtractor } from './tags/extractor.js';
import { pagerank, distributeRanksToDefinitions } from './ranking/pagerank.js';
import { Graph } from './ranking/graph.js';
import { DiskCache } from './cache/disk-cache.js';
import type { Tag, RepoMapConfig, TagCacheEntry } from './types.js';
/**
* RepoMap 配置默认值
*/
const defaultConfig: RepoMapConfig = {
mapTokens: 1024,
mapMulNoFiles: 8,
maxContextWindow: 128000,
refresh: 'auto',
cacheDir: '.ai-assist/tags-cache',
verbose: false,
exclude: [
'node_modules/**',
'dist/**',
'build/**',
'.git/**',
'*.test.*',
'*.spec.*',
'**/*.d.ts',
],
include: ['**/*.ts', '**/*.tsx', '**/*.js', '**/*.jsx', '**/*.py'],
};
/**
* RepoMap 类
* 生成代码仓库的上下文地图,帮助 AI 理解代码结构
*/
export class RepoMap {
private tagExtractor: TagExtractor;
private tagsCache: DiskCache<TagCacheEntry>;
private config: RepoMapConfig;
private root: string;
constructor(root: string, config: Partial<RepoMapConfig> = {}) {
this.root = root;
this.config = { ...defaultConfig, ...config };
this.tagExtractor = new TagExtractor();
this.tagsCache = new DiskCache(
path.join(this.root, this.config.cacheDir)
);
}
/**
* 获取 repo map
* @param chatFiles 当前对话中涉及的文件
* @param otherFiles 仓库中其他文件
* @param mentionedFnames 对话中提到的文件名
* @param mentionedIdents 对话中提到的标识符
*/
async getRepoMap(
chatFiles: string[],
otherFiles: string[],
mentionedFnames: Set<string> = new Set(),
mentionedIdents: Set<string> = new Set()
): Promise<string> {
if (this.config.mapTokens <= 0 || otherFiles.length === 0) {
return '';
}
let maxMapTokens = this.config.mapTokens;
// 无聊天文件时,给更大的视图
if (chatFiles.length === 0) {
maxMapTokens = Math.min(
maxMapTokens * this.config.mapMulNoFiles,
this.config.maxContextWindow - 4096
);
}
const rankedTags = await this.getRankedTags(
chatFiles,
otherFiles,
mentionedFnames,
mentionedIdents
);
// 二分搜索找到最优 token 数量
return this.fitToTokenLimit(rankedTags, maxMapTokens, new Set(chatFiles));
}
/**
* 获取排序后的 tags
*/
private async getRankedTags(
chatFnames: string[],
otherFnames: string[],
mentionedFnames: Set<string>,
mentionedIdents: Set<string>
): Promise<Tag[]> {
// ident -> files that define it
const defines = new Map<string, Set<string>>();
// ident -> files that reference it
const references = new Map<string, string[]>();
// (file:ident) -> tag objects
const definitions = new Map<string, Tag[]>();
// personalization vector for PageRank
const personalization = new Map<string, number>();
const allFnames = [...new Set([...chatFnames, ...otherFnames])];
const chatRelFnames = new Set(chatFnames.map((f) => this.getRelFname(f)));
const basePersonalize = 100 / Math.max(allFnames.length, 1);
// 收集所有文件的 tags
for (const fname of allFnames) {
const relFname = this.getRelFname(fname);
let currentPers = 0;
// 个性化权重
if (chatFnames.includes(fname)) {
currentPers += basePersonalize;
}
if (mentionedFnames.has(relFname)) {
currentPers = Math.max(currentPers, basePersonalize);
}
// 路径组件匹配
const pathParts = relFname.split('/');
const basename = pathParts[pathParts.length - 1];
const basenameNoExt = basename.replace(/\.[^.]+$/, '');
const components = new Set([...pathParts, basename, basenameNoExt]);
if ([...components].some((c) => mentionedIdents.has(c))) {
currentPers += basePersonalize;
}
if (currentPers > 0) {
personalization.set(relFname, currentPers);
}
// 提取 tags
const tags = await this.getTags(fname, relFname);
for (const tag of tags) {
if (tag.kind === 'def') {
if (!defines.has(tag.name)) defines.set(tag.name, new Set());
defines.get(tag.name)!.add(relFname);
const key = `${relFname}:${tag.name}`;
if (!definitions.has(key)) definitions.set(key, []);
definitions.get(key)!.push(tag);
} else {
if (!references.has(tag.name)) references.set(tag.name, []);
references.get(tag.name)!.push(relFname);
}
}
}
// 构建图
const graph = new Graph();
// 找到同时有定义和引用的标识符
const idents = [...defines.keys()].filter((id) => references.has(id));
for (const ident of idents) {
const definers = defines.get(ident)!;
const refs = references.get(ident)!;
// 计算权重乘数
let mul = 1.0;
if (mentionedIdents.has(ident)) mul *= 10;
if (this.isSignificantName(ident)) mul *= 10;
if (ident.startsWith('_')) mul *= 0.1;
if (definers.size > 5) mul *= 0.1;
// 统计引用次数
const refCounts = new Map<string, number>();
for (const ref of refs) {
refCounts.set(ref, (refCounts.get(ref) || 0) + 1);
}
// 添加边
for (const [referencer, numRefs] of refCounts) {
for (const definer of definers) {
let useMul = mul;
if (chatRelFnames.has(referencer)) useMul *= 50;
graph.addEdge({
from: referencer,
to: definer,
weight: useMul * Math.sqrt(numRefs),
ident,
});
}
}
}
// 运行 PageRank
const ranked = pagerank(graph, { personalization });
// 分配排名到定义
const rankedDefinitions = distributeRanksToDefinitions(graph, ranked);
// 排序
const sorted = [...rankedDefinitions.entries()].sort((a, b) => b[1] - a[1]);
// 收集排序后的 tags
const rankedTags: Tag[] = [];
for (const [key] of sorted) {
const colonIdx = key.lastIndexOf(':');
const fname = key.substring(0, colonIdx);
if (chatRelFnames.has(fname)) continue;
const tags = definitions.get(key) || [];
rankedTags.push(...tags);
}
return rankedTags;
}
/**
* 获取文件的 tags(带缓存)
*/
private async getTags(fname: string, relFname: string): Promise<Tag[]> {
const mtime = await this.getMtime(fname);
if (!mtime) return [];
// 检查缓存
const cached = await this.tagsCache.get(fname);
if (cached && cached.mtime === mtime) {
return cached.data;
}
// 提取 tags
const tags = await this.tagExtractor.getTags(fname, relFname);
// 更新缓存
await this.tagsCache.set(fname, { mtime, data: tags });
return tags;
}
/**
* 二分搜索拟合 token 限制
*/
private fitToTokenLimit(
tags: Tag[],
maxTokens: number,
chatRelFnames: Set<string>
): string {
const numTags = tags.length;
if (numTags === 0) return '';
let lower = 0;
let upper = numTags;
let bestTree = '';
let bestTokens = 0;
let middle = Math.min(Math.floor(maxTokens / 25), numTags);
while (lower <= upper) {
const tree = this.toTree(tags.slice(0, middle), chatRelFnames);
const numTokens = this.tokenCount(tree);
const pctErr = Math.abs(numTokens - maxTokens) / maxTokens;
if ((numTokens <= maxTokens && numTokens > bestTokens) || pctErr < 0.15) {
bestTree = tree;
bestTokens = numTokens;
if (pctErr < 0.15) break;
}
if (numTokens < maxTokens) {
lower = middle + 1;
} else {
upper = middle - 1;
}
middle = Math.floor((lower + upper) / 2);
}
return bestTree;
}
/**
* 转换为树形展示
*/
private toTree(tags: Tag[], chatRelFnames: Set<string>): string {
if (tags.length === 0) return '';
const output: string[] = [];
let curFname: string | null = null;
let lois: number[] = [];
let curNames: string[] = [];
// 按文件分组
const sortedTags = [...tags].sort(
(a, b) => a.relFname.localeCompare(b.relFname) || a.line - b.line
);
for (const tag of sortedTags) {
if (chatRelFnames.has(tag.relFname)) continue;
if (tag.relFname !== curFname) {
// 输出前一个文件
if (curFname && (lois.length > 0 || curNames.length > 0)) {
output.push(this.renderFileTree(curFname, lois, curNames));
}
curFname = tag.relFname;
lois = [];
curNames = [];
}
if (tag.line >= 0) {
lois.push(tag.line);
}
curNames.push(tag.name);
}
// 输出最后一个文件
if (curFname && (lois.length > 0 || curNames.length > 0)) {
output.push(this.renderFileTree(curFname, lois, curNames));
}
return output.join('\n');
}
/**
* 渲染单个文件的树形展示
*/
private renderFileTree(
relFname: string,
lois: number[],
names: string[]
): string {
const uniqueNames = [...new Set(names)];
const uniqueLines = [...new Set(lois)].sort((a, b) => a - b);
const lines = [`${relFname}:`];
// 简化版:显示符号列表
for (const name of uniqueNames.slice(0, 10)) {
lines.push(` - ${name}`);
}
if (uniqueNames.length > 10) {
lines.push(` ... and ${uniqueNames.length - 10} more`);
}
return lines.join('\n');
}
/**
* 判断是否是有意义的名称
*/
private isSignificantName(name: string): boolean {
if (name.length < 8) return false;
const hasAlpha = /[a-zA-Z]/.test(name);
const isSnake = name.includes('_') && hasAlpha;
const isKebab = name.includes('-') && hasAlpha;
const isCamel = /[a-z]/.test(name) && /[A-Z]/.test(name);
return isSnake || isKebab || isCamel;
}
/**
* 获取相对路径
*/
private getRelFname(fname: string): string {
if (fname.startsWith(this.root)) {
return fname.slice(this.root.length + 1);
}
return fname;
}
/**
* 获取文件修改时间
*/
private async getMtime(fname: string): Promise<number | null> {
try {
const stat = await fs.stat(fname);
return stat.mtimeMs;
} catch {
return null;
}
}
/**
* 估算 token 数量
* 简化估算:约 4 字符一个 token
*/
private tokenCount(text: string): number {
return Math.ceil(text.length / 4);
}
/**
* 刷新缓存到磁盘
*/
async flushCache(): Promise<void> {
await this.tagsCache.flushAll();
}
/**
* 清空缓存
*/
async clearCache(): Promise<void> {
await this.tagsCache.clear();
}
}
/**
* 创建 RepoMap 实例
*/
export function createRepoMap(
root: string,
config?: Partial<RepoMapConfig>
): RepoMap {
return new RepoMap(root, config);
}
+465
View File
@@ -0,0 +1,465 @@
/**
* Tag 提取器
* 使用 Tree-sitter 解析代码并提取符号定义和引用
*/
import * as fs from 'fs/promises';
import * as path from 'path';
import { fileURLToPath } from 'url';
import type { Tag } from '../types.js';
import { getLanguageFromFilename } from '../types.js';
// Tree-sitter 类型(web-tree-sitter
interface TreeSitterParser {
parse(input: string): TreeSitterTree;
setLanguage(language: TreeSitterLanguage): void;
getLanguage(): TreeSitterLanguage;
}
interface TreeSitterTree {
rootNode: TreeSitterNode;
}
interface TreeSitterNode {
text: string;
startPosition: { row: number; column: number };
endPosition: { row: number; column: number };
type: string;
childCount: number;
namedChildCount: number;
children: TreeSitterNode[];
namedChildren: TreeSitterNode[];
}
interface TreeSitterLanguage {
query(source: string): TreeSitterQuery;
}
interface TreeSitterQuery {
captures(node: TreeSitterNode): TreeSitterCapture[];
}
interface TreeSitterCapture {
node: TreeSitterNode;
name: string;
}
// 动态导入 web-tree-sitter
let ParserClass: any = null;
let treeSitterInitialized = false;
/**
* 初始化 Tree-sitter
*/
async function initTreeSitter(): Promise<void> {
if (treeSitterInitialized) return;
try {
const TreeSitter = await import('web-tree-sitter');
// Parser 是命名导出的类,init 是其静态方法
await TreeSitter.Parser.init();
ParserClass = TreeSitter.Parser;
treeSitterInitialized = true;
} catch (error) {
console.warn('Failed to initialize tree-sitter:', error);
throw new Error('Tree-sitter initialization failed');
}
}
/**
* Tag 提取器类
*/
export class TagExtractor {
private parsers: Map<string, TreeSitterParser> = new Map();
private languages: Map<string, TreeSitterLanguage> = new Map();
private queries: Map<string, TreeSitterQuery> = new Map();
private initialized = false;
private queriesDir: string;
constructor() {
// 获取查询文件目录
const __filename = fileURLToPath(import.meta.url);
const __dirname = path.dirname(__filename);
this.queriesDir = path.join(__dirname, 'queries');
}
/**
* 初始化提取器
*/
async initialize(): Promise<void> {
if (this.initialized) return;
try {
await initTreeSitter();
this.initialized = true;
} catch (error) {
// Tree-sitter 初始化失败,使用回退方案
console.warn('Tree-sitter not available, using regex fallback');
this.initialized = true;
}
}
/**
* 获取文件的所有 tags
*/
async getTags(fname: string, relFname: string): Promise<Tag[]> {
await this.initialize();
const lang = getLanguageFromFilename(fname);
if (!lang) return [];
let code: string;
try {
code = await fs.readFile(fname, 'utf-8');
} catch {
return [];
}
// 尝试使用 Tree-sitter
if (ParserClass) {
try {
const tags = await this.extractWithTreeSitter(code, fname, relFname, lang);
if (tags.length > 0) {
return tags;
}
} catch (error) {
// Tree-sitter 解析失败,回退到正则
}
}
// 回退到正则表达式
return this.extractWithRegex(code, fname, relFname, lang);
}
/**
* 使用 Tree-sitter 提取 tags
*/
private async extractWithTreeSitter(
code: string,
fname: string,
relFname: string,
lang: string
): Promise<Tag[]> {
const parser = await this.getParser(lang);
const query = await this.getQuery(lang);
if (!parser || !query) {
return [];
}
const tree = parser.parse(code);
const captures = query.captures(tree.rootNode);
const tags: Tag[] = [];
const seenKinds = new Set<string>();
for (const capture of captures) {
const { node, name } = capture;
let kind: 'def' | 'ref' | null = null;
if (name.startsWith('name.definition.')) {
kind = 'def';
} else if (name.startsWith('name.reference.')) {
kind = 'ref';
} else {
continue;
}
seenKinds.add(kind);
tags.push({
relFname,
fname,
name: node.text,
kind,
line: node.startPosition.row,
});
}
// 如果只有 def 没有 ref,使用正则回退提取引用
if (seenKinds.has('def') && !seenKinds.has('ref')) {
const refTags = this.extractRefsWithRegex(code, fname, relFname);
tags.push(...refTags);
}
return tags;
}
/**
* 获取或创建解析器
*/
private async getParser(lang: string): Promise<TreeSitterParser | null> {
if (this.parsers.has(lang)) {
return this.parsers.get(lang)!;
}
try {
const language = await this.getLanguage(lang);
if (!language) return null;
const parser = new ParserClass();
parser.setLanguage(language);
this.parsers.set(lang, parser);
return parser;
} catch {
return null;
}
}
/**
* 获取语言定义
*/
private async getLanguage(lang: string): Promise<TreeSitterLanguage | null> {
if (this.languages.has(lang)) {
return this.languages.get(lang)!;
}
try {
// 尝试加载 WASM 语言文件
const wasmPath = this.getWasmPath(lang);
const language = await ParserClass.Language.load(wasmPath);
this.languages.set(lang, language);
return language;
} catch {
return null;
}
}
/**
* 获取 WASM 文件路径
*/
private getWasmPath(lang: string): string {
// tree-sitter WASM 文件的标准位置
const wasmName = `tree-sitter-${lang}.wasm`;
// 尝试多个可能的位置
const possiblePaths = [
path.join(process.cwd(), 'node_modules', 'tree-sitter-wasms', 'out', wasmName),
path.join(process.cwd(), 'node_modules', `tree-sitter-${lang}`, wasmName),
path.join(__dirname, '..', '..', '..', 'wasm', wasmName),
];
// 返回第一个路径(实际使用时需要检查存在性)
return possiblePaths[0];
}
/**
* 获取语言查询
*/
private async getQuery(lang: string): Promise<TreeSitterQuery | null> {
if (this.queries.has(lang)) {
return this.queries.get(lang)!;
}
try {
const language = await this.getLanguage(lang);
if (!language) return null;
const queryPath = path.join(this.queriesDir, `${lang}-tags.scm`);
const queryText = await fs.readFile(queryPath, 'utf-8');
const query = language.query(queryText);
this.queries.set(lang, query);
return query;
} catch {
return null;
}
}
/**
* 使用正则表达式提取 tags (回退方案)
*/
private extractWithRegex(
code: string,
fname: string,
relFname: string,
lang: string
): Tag[] {
const tags: Tag[] = [];
const lines = code.split('\n');
// 根据语言选择正则模式
const patterns = this.getRegexPatterns(lang);
lines.forEach((line, lineNum) => {
for (const pattern of patterns.definitions) {
const match = line.match(pattern.regex);
if (match && match[pattern.nameGroup]) {
tags.push({
relFname,
fname,
name: match[pattern.nameGroup],
kind: 'def',
line: lineNum,
});
}
}
});
// 提取引用
tags.push(...this.extractRefsWithRegex(code, fname, relFname));
return tags;
}
/**
* 使用正则提取引用
*/
private extractRefsWithRegex(code: string, fname: string, relFname: string): Tag[] {
const tags: Tag[] = [];
// 简单的标识符匹配 - 排除关键字
const keywords = new Set([
'if',
'else',
'for',
'while',
'do',
'switch',
'case',
'break',
'continue',
'return',
'function',
'class',
'const',
'let',
'var',
'import',
'export',
'from',
'default',
'async',
'await',
'try',
'catch',
'finally',
'throw',
'new',
'this',
'super',
'extends',
'implements',
'interface',
'type',
'enum',
'public',
'private',
'protected',
'static',
'readonly',
'abstract',
'true',
'false',
'null',
'undefined',
'void',
'never',
'any',
'unknown',
'string',
'number',
'boolean',
'object',
'symbol',
'bigint',
'def',
'class',
'self',
'None',
'True',
'False',
'and',
'or',
'not',
'in',
'is',
'lambda',
'with',
'as',
'pass',
'raise',
'yield',
'global',
'nonlocal',
'assert',
'del',
]);
// 匹配 PascalCase 或 snake_case 标识符(更可能是用户定义的)
const identRegex = /\b([A-Z][a-zA-Z0-9]*|[a-z][a-zA-Z0-9]*_[a-zA-Z0-9_]*)\b/g;
let match;
while ((match = identRegex.exec(code)) !== null) {
const name = match[1];
if (!keywords.has(name) && name.length >= 2) {
tags.push({
relFname,
fname,
name,
kind: 'ref',
line: -1, // 行号未知
});
}
}
return tags;
}
/**
* 获取语言的正则模式
*/
private getRegexPatterns(lang: string): {
definitions: Array<{ regex: RegExp; nameGroup: number }>;
} {
switch (lang) {
case 'typescript':
case 'javascript':
return {
definitions: [
// function name(
{ regex: /(?:export\s+)?(?:async\s+)?function\s+(\w+)\s*[<(]/, nameGroup: 1 },
// class Name
{ regex: /(?:export\s+)?(?:abstract\s+)?class\s+(\w+)/, nameGroup: 1 },
// interface Name
{ regex: /(?:export\s+)?interface\s+(\w+)/, nameGroup: 1 },
// type Name =
{ regex: /(?:export\s+)?type\s+(\w+)\s*[<=]/, nameGroup: 1 },
// const/let/var name = (arrow function or function)
{
regex: /(?:export\s+)?(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?(?:\([^)]*\)|[^=])\s*=>/,
nameGroup: 1,
},
// enum Name
{ regex: /(?:export\s+)?enum\s+(\w+)/, nameGroup: 1 },
// method name(
{ regex: /^\s*(?:async\s+)?(\w+)\s*\([^)]*\)\s*[:{]/, nameGroup: 1 },
],
};
case 'python':
return {
definitions: [
// def name(
{ regex: /^\s*(?:async\s+)?def\s+(\w+)\s*\(/, nameGroup: 1 },
// class Name
{ regex: /^\s*class\s+(\w+)/, nameGroup: 1 },
],
};
default:
return {
definitions: [
// 通用:function/def/class 后面的标识符
{ regex: /(?:function|def|class)\s+(\w+)/, nameGroup: 1 },
],
};
}
}
}
/**
* 创建 Tag 提取器实例
*/
export function createTagExtractor(): TagExtractor {
return new TagExtractor();
}
+5
View File
@@ -0,0 +1,5 @@
/**
* Tags 模块导出
*/
export { TagExtractor, createTagExtractor } from './extractor.js';
@@ -0,0 +1,65 @@
;; JavaScript/JSX 标签查询
;; 定义使用 @name.definition.* 前缀
;; 引用使用 @name.reference.* 前缀
;; ==================== 定义 ====================
;; 函数声明
(function_declaration
name: (identifier) @name.definition.function) @definition.function
;; 生成器函数
(generator_function_declaration
name: (identifier) @name.definition.function) @definition.function
;; 箭头函数赋值 (const/let)
(lexical_declaration
(variable_declarator
name: (identifier) @name.definition.function
value: [(arrow_function) (function)])) @definition.function
;; 箭头函数赋值 (var)
(variable_declaration
(variable_declarator
name: (identifier) @name.definition.function
value: [(arrow_function) (function)])) @definition.function
;; 赋值表达式中的函数
(assignment_expression
left: (identifier) @name.definition.function
right: [(arrow_function) (function)]) @definition.function
(assignment_expression
left: (member_expression
property: (property_identifier) @name.definition.function)
right: [(arrow_function) (function)]) @definition.function
;; 对象方法
(pair
key: (property_identifier) @name.definition.function
value: [(arrow_function) (function)]) @definition.function
;; 方法定义
(method_definition
name: (property_identifier) @name.definition.method) @definition.method
;; 类定义
(class_declaration
name: (identifier) @name.definition.class) @definition.class
(class
name: (identifier) @name.definition.class) @definition.class
;; ==================== 引用 ====================
;; 函数调用
(call_expression
function: (identifier) @name.reference.call) @reference.call
(call_expression
function: (member_expression
property: (property_identifier) @name.reference.call)) @reference.call
;; new 表达式
(new_expression
constructor: (identifier) @name.reference.class) @reference.class
+37
View File
@@ -0,0 +1,37 @@
;; Python 标签查询
;; 定义使用 @name.definition.* 前缀
;; 引用使用 @name.reference.* 前缀
;; ==================== 定义 ====================
;; 函数定义
(function_definition
name: (identifier) @name.definition.function) @definition.function
;; 类定义
(class_definition
name: (identifier) @name.definition.class) @definition.class
;; 装饰器函数 (通常也是定义)
(decorated_definition
(function_definition
name: (identifier) @name.definition.function)) @definition.function
(decorated_definition
(class_definition
name: (identifier) @name.definition.class)) @definition.class
;; ==================== 引用 ====================
;; 函数调用
(call
function: (identifier) @name.reference.call) @reference.call
(call
function: (attribute
attribute: (identifier) @name.reference.call)) @reference.call
;; 类继承
(class_definition
superclasses: (argument_list
(identifier) @name.reference.class)) @reference.class
@@ -0,0 +1,90 @@
;; TypeScript/TSX 标签查询
;; 定义使用 @name.definition.* 前缀
;; 引用使用 @name.reference.* 前缀
;; ==================== 定义 ====================
;; 函数定义
(function_declaration
name: (identifier) @name.definition.function) @definition.function
(function_signature
name: (identifier) @name.definition.function) @definition.function
;; 箭头函数赋值
(lexical_declaration
(variable_declarator
name: (identifier) @name.definition.function
value: (arrow_function))) @definition.function
(variable_declaration
(variable_declarator
name: (identifier) @name.definition.function
value: (arrow_function))) @definition.function
;; 方法定义
(method_definition
name: (property_identifier) @name.definition.method) @definition.method
(method_signature
name: (property_identifier) @name.definition.method) @definition.method
(abstract_method_signature
name: (property_identifier) @name.definition.method) @definition.method
;; 类定义
(class_declaration
name: (type_identifier) @name.definition.class) @definition.class
(abstract_class_declaration
name: (type_identifier) @name.definition.class) @definition.class
;; 接口定义
(interface_declaration
name: (type_identifier) @name.definition.interface) @definition.interface
;; 类型别名
(type_alias_declaration
name: (type_identifier) @name.definition.type) @definition.type
;; 枚举
(enum_declaration
name: (identifier) @name.definition.enum) @definition.enum
;; 模块/命名空间
(module
name: (identifier) @name.definition.module) @definition.module
(internal_module
name: (identifier) @name.definition.module) @definition.module
;; ==================== 引用 ====================
;; 类型注解引用
(type_annotation
(type_identifier) @name.reference.type) @reference.type
;; 类型参数中的引用
(type_arguments
(type_identifier) @name.reference.type) @reference.type
;; extends/implements
(class_heritage
(extends_clause
value: (identifier) @name.reference.class)) @reference.class
(class_heritage
(implements_clause
(type_identifier) @name.reference.interface)) @reference.interface
;; new 表达式
(new_expression
constructor: (identifier) @name.reference.class) @reference.class
;; 函数调用
(call_expression
function: (identifier) @name.reference.call) @reference.call
(call_expression
function: (member_expression
property: (property_identifier) @name.reference.call)) @reference.call
+142
View File
@@ -0,0 +1,142 @@
/**
* RepoMap 类型定义
* 基于 Aider 的 RepoMap 实现
*/
/**
* 代码标签 - 对应 Aider 的 Tag
*/
export interface Tag {
/** 相对路径 */
relFname: string;
/** 绝对路径 */
fname: string;
/** 行号 (0-indexed, -1 表示未知) */
line: number;
/** 符号名称 */
name: string;
/** 定义或引用 */
kind: 'def' | 'ref';
}
/**
* 文件缓存条目
*/
export interface TagCacheEntry {
/** 文件修改时间 (ms) */
mtime: number;
/** 标签数据 */
data: Tag[];
}
/**
* RepoMap 配置
*/
export interface RepoMapConfig {
/** 目标 token 数量 */
mapTokens: number;
/** 无聊天文件时的乘数 */
mapMulNoFiles: number;
/** 最大上下文窗口 */
maxContextWindow: number;
/** 刷新策略 */
refresh: 'auto' | 'always' | 'files' | 'manual';
/** 缓存目录 */
cacheDir: string;
/** 是否详细输出 */
verbose: boolean;
/** 排除的文件模式 */
exclude: string[];
/** 包含的文件模式 */
include: string[];
}
/**
* 默认配置
*/
export const DEFAULT_REPOMAP_CONFIG: RepoMapConfig = {
mapTokens: 2048,
mapMulNoFiles: 8,
maxContextWindow: 128000,
refresh: 'auto',
cacheDir: '.ai-assist/tags-cache',
verbose: false,
exclude: [
'node_modules/**',
'dist/**',
'build/**',
'.git/**',
'*.test.*',
'*.spec.*',
'**/*.d.ts',
],
include: ['**/*.ts', '**/*.tsx', '**/*.js', '**/*.jsx', '**/*.py'],
};
/**
* 图边
*/
export interface GraphEdge {
/** 引用文件 */
from: string;
/** 定义文件 */
to: string;
/** 边权重 */
weight: number;
/** 符号名 */
ident: string;
}
/**
* 排序后的定义
*/
export interface RankedDefinition {
fname: string;
ident: string;
rank: number;
tags: Tag[];
}
/**
* PageRank 算法配置
*/
export interface PageRankOptions {
/** 阻尼系数 (默认 0.85) */
damping?: number;
/** 最大迭代次数 (默认 100) */
maxIterations?: number;
/** 收敛容差 (默认 1e-6) */
tolerance?: number;
/** 个性化向量 */
personalization?: Map<string, number>;
}
/**
* 支持的语言映射
*/
export const LANGUAGE_MAP: Record<string, string> = {
'.ts': 'typescript',
'.tsx': 'typescript',
'.js': 'javascript',
'.jsx': 'javascript',
'.mjs': 'javascript',
'.cjs': 'javascript',
'.py': 'python',
'.go': 'go',
'.rs': 'rust',
'.java': 'java',
'.rb': 'ruby',
'.cpp': 'cpp',
'.cc': 'cpp',
'.c': 'c',
'.h': 'c',
'.hpp': 'cpp',
};
/**
* 获取文件语言
*/
export function getLanguageFromFilename(filename: string): string | null {
const ext = filename.substring(filename.lastIndexOf('.')).toLowerCase();
return LANGUAGE_MAP[ext] || null;
}
+20
View File
@@ -0,0 +1,20 @@
Generate a repository map showing the most relevant code symbols (functions, classes, methods) based on AST analysis and PageRank ranking.
This tool analyzes the codebase structure using Tree-sitter AST parsing and ranks symbols by their relevance using the PageRank algorithm. It helps AI understand the code structure by showing:
- Important function and class definitions
- Their relationships through references
- Prioritized by relevance to the current context
Parameters:
- directory: Directory to analyze (default: current working directory)
- chat_files: Files currently being discussed (excluded from map output)
- mentioned_files: File names mentioned in conversation (boost relevance)
- mentioned_identifiers: Symbol names mentioned (boost relevance)
- max_tokens: Maximum output token count (default: 1024)
The output shows relevant code symbols organized by file, helping to understand:
- What functions/classes exist in the codebase
- How different parts of the code relate to each other
- Which symbols are most important based on usage patterns
Use this tool when you need to understand the overall structure of a codebase or find relevant code locations for a task.
+6
View File
@@ -46,6 +46,9 @@ import {
gitStashTool, gitStashTool,
} from './git/index.js'; } from './git/index.js';
// RepoMap 工具
import { repoMapTool } from './repomap/index.js';
// 所有工具列表(用于注册) // 所有工具列表(用于注册)
const allToolsWithMetadata: ToolWithMetadata[] = [ const allToolsWithMetadata: ToolWithMetadata[] = [
// 核心工具 (deferLoading: false) // 核心工具 (deferLoading: false)
@@ -88,6 +91,9 @@ const allToolsWithMetadata: ToolWithMetadata[] = [
gitPullTool, gitPullTool,
gitCheckoutTool, gitCheckoutTool,
gitStashTool, gitStashTool,
// RepoMap 工具 (deferLoading: true)
repoMapTool,
]; ];
// 注册所有工具到 registry // 注册所有工具到 registry
+5
View File
@@ -0,0 +1,5 @@
/**
* RepoMap 工具模块
*/
export { repoMapTool } from './repo_map.js';
+261
View File
@@ -0,0 +1,261 @@
/**
* RepoMap 工具
* 生成代码仓库的上下文地图,帮助 AI 理解代码结构
*/
import * as path from 'path';
import * as fs from 'fs/promises';
import type { ToolResult } from '../../types/index.js';
import type { ToolWithMetadata } from '../types.js';
import { loadDescription } from '../load_description.js';
import { getPermissionManager } from '../../permission/index.js';
import { createRepoMap, type RepoMapConfig } from '../../repomap/index.js';
// 缓存 RepoMap 实例
const repoMapCache = new Map<string, ReturnType<typeof createRepoMap>>();
/**
* 获取或创建 RepoMap 实例
*/
function getRepoMap(root: string, config?: Partial<RepoMapConfig>) {
const key = root;
if (!repoMapCache.has(key)) {
repoMapCache.set(key, createRepoMap(root, config));
}
return repoMapCache.get(key)!;
}
/**
* 支持的文件扩展名
*/
const SUPPORTED_EXTENSIONS = new Set([
'.ts',
'.tsx',
'.js',
'.jsx',
'.mjs',
'.cjs',
'.py',
]);
/**
* 要排除的目录
*/
const EXCLUDED_DIRS = new Set([
'node_modules',
'dist',
'build',
'.git',
'.next',
'__pycache__',
'.pytest_cache',
'coverage',
'.nyc_output',
]);
/**
* 要排除的文件模式
*/
function shouldExcludeFile(filename: string): boolean {
return (
filename.includes('.test.') ||
filename.includes('.spec.') ||
filename.endsWith('.d.ts')
);
}
/**
* 递归获取目录中的文件
*/
async function getFilesRecursive(
directory: string,
maxDepth = 10
): Promise<string[]> {
const files: string[] = [];
async function walk(dir: string, depth: number): Promise<void> {
if (depth > maxDepth) return;
try {
const entries = await fs.readdir(dir, { withFileTypes: true });
for (const entry of entries) {
const fullPath = path.join(dir, entry.name);
if (entry.isDirectory()) {
if (!EXCLUDED_DIRS.has(entry.name) && !entry.name.startsWith('.')) {
await walk(fullPath, depth + 1);
}
} else if (entry.isFile()) {
const ext = path.extname(entry.name).toLowerCase();
if (SUPPORTED_EXTENSIONS.has(ext) && !shouldExcludeFile(entry.name)) {
files.push(fullPath);
}
}
}
} catch {
// 忽略无法访问的目录
}
}
await walk(directory, 0);
return files;
}
export const repoMapTool: ToolWithMetadata = {
name: 'repo_map',
description: loadDescription('repo_map'),
metadata: {
name: 'repo_map',
category: 'core',
description: '生成代码仓库上下文地图,帮助理解代码结构',
keywords: [
'repomap',
'repo',
'map',
'context',
'code',
'structure',
'ast',
'symbol',
'definition',
'代码地图',
'仓库结构',
'代码结构',
'符号',
'定义',
],
deferLoading: true,
},
parameters: {
directory: {
type: 'string',
description: '要分析的目录路径(默认为当前工作目录)',
required: false,
},
chat_files: {
type: 'string',
description:
'当前对话中涉及的文件列表,逗号分隔(这些文件会被排除在地图外)',
required: false,
},
mentioned_files: {
type: 'string',
description: '对话中提到的文件名,逗号分隔(用于提高相关性)',
required: false,
},
mentioned_identifiers: {
type: 'string',
description: '对话中提到的标识符/符号名,逗号分隔(用于提高相关性)',
required: false,
},
max_tokens: {
type: 'number',
description: '最大输出 token 数(默认 1024',
required: false,
},
},
execute: async (params: Record<string, unknown>): Promise<ToolResult> => {
const cwd = process.cwd();
const directory = (params.directory as string) || cwd;
const absolutePath = path.isAbsolute(directory)
? directory
: path.join(cwd, directory);
// 解析逗号分隔的字符串为数组
const chatFilesStr = (params.chat_files as string) || '';
const chatFiles = chatFilesStr
? chatFilesStr.split(',').map((s) => s.trim())
: [];
const mentionedFilesStr = (params.mentioned_files as string) || '';
const mentionedFiles = new Set(
mentionedFilesStr ? mentionedFilesStr.split(',').map((s) => s.trim()) : []
);
const mentionedIdentsStr = (params.mentioned_identifiers as string) || '';
const mentionedIdents = new Set(
mentionedIdentsStr ? mentionedIdentsStr.split(',').map((s) => s.trim()) : []
);
const maxTokens = (params.max_tokens as number) || 1024;
// 权限检查
const permissionManager = getPermissionManager();
const permResult = await permissionManager.checkFilePermission({
operation: 'search',
path: absolutePath,
workdir: cwd,
});
if (!permResult.allowed) {
if (permResult.needsConfirmation) {
return {
success: false,
output: '',
error: `需要用户确认: 分析目录 ${absolutePath}\n原因: ${permResult.reason || '需要权限确认'}`,
};
}
return {
success: false,
output: '',
error: `权限被拒绝: ${permResult.reason || '不允许分析此目录'}`,
};
}
try {
// 检查目录是否存在
const stat = await fs.stat(absolutePath);
if (!stat.isDirectory()) {
return {
success: false,
output: '',
error: `${absolutePath} 不是一个目录`,
};
}
// 获取 RepoMap 实例
const repoMap = getRepoMap(absolutePath, {
mapTokens: maxTokens,
});
// 获取文件列表
const allFiles = await getFilesRecursive(absolutePath);
// 分离聊天文件和其他文件
const chatFilesAbs = chatFiles.map((f) =>
path.isAbsolute(f) ? f : path.join(absolutePath, f)
);
const otherFiles = allFiles.filter((f) => !chatFilesAbs.includes(f));
// 生成 repo map
const mapContent = await repoMap.getRepoMap(
chatFilesAbs,
otherFiles,
mentionedFiles,
mentionedIdents
);
if (!mapContent || mapContent.trim() === '') {
return {
success: true,
output:
'未找到相关的代码符号。可能是因为目录中没有支持的代码文件,或者文件过少。',
};
}
const header = `# Repository Map\n# Directory: ${absolutePath}\n# Files analyzed: ${allFiles.length}\n\n`;
return {
success: true,
output: header + mapContent,
};
} catch (error) {
return {
success: false,
output: '',
error: error instanceof Error ? error.message : String(error),
};
}
},
};
+329
View File
@@ -0,0 +1,329 @@
/**
* RepoMap 测试
*/
import { describe, it, expect, beforeAll } from 'vitest';
import * as path from 'path';
import * as fs from 'fs/promises';
import * as os from 'os';
import { createRepoMap } from '../../src/repomap/repomap.js';
import { TagExtractor } from '../../src/repomap/tags/extractor.js';
import { Graph, pagerank, distributeRanksToDefinitions } from '../../src/repomap/ranking/index.js';
import { DiskCache } from '../../src/repomap/cache/disk-cache.js';
describe('Graph', () => {
it('should add edges and track nodes', () => {
const graph = new Graph();
graph.addEdge({ from: 'a.ts', to: 'b.ts', weight: 1, ident: 'foo' });
graph.addEdge({ from: 'a.ts', to: 'c.ts', weight: 2, ident: 'bar' });
graph.addEdge({ from: 'b.ts', to: 'c.ts', weight: 1, ident: 'baz' });
const nodes = graph.getNodes();
expect(nodes).toContain('a.ts');
expect(nodes).toContain('b.ts');
expect(nodes).toContain('c.ts');
expect(nodes.length).toBe(3);
});
it('should track in and out edges correctly', () => {
const graph = new Graph();
graph.addEdge({ from: 'a.ts', to: 'b.ts', weight: 1, ident: 'foo' });
graph.addEdge({ from: 'a.ts', to: 'c.ts', weight: 2, ident: 'bar' });
const outEdges = graph.getOutEdges('a.ts');
expect(outEdges.length).toBe(2);
const inEdges = graph.getInEdges('b.ts');
expect(inEdges.length).toBe(1);
expect(inEdges[0].from).toBe('a.ts');
});
it('should calculate out degree correctly', () => {
const graph = new Graph();
graph.addEdge({ from: 'a.ts', to: 'b.ts', weight: 1, ident: 'foo' });
graph.addEdge({ from: 'a.ts', to: 'c.ts', weight: 2, ident: 'bar' });
expect(graph.getOutDegree('a.ts')).toBe(3); // sum of weights
expect(graph.getOutDegree('b.ts')).toBe(0);
});
});
describe('PageRank', () => {
it('should compute ranks for a simple graph', () => {
const graph = new Graph();
// 创建一个简单的图: a -> b -> c
graph.addEdge({ from: 'a.ts', to: 'b.ts', weight: 1, ident: 'foo' });
graph.addEdge({ from: 'b.ts', to: 'c.ts', weight: 1, ident: 'bar' });
const ranks = pagerank(graph);
expect(ranks.size).toBe(3);
expect(ranks.get('a.ts')).toBeGreaterThan(0);
expect(ranks.get('b.ts')).toBeGreaterThan(0);
expect(ranks.get('c.ts')).toBeGreaterThan(0);
});
it('should respect personalization vector', () => {
const graph = new Graph();
graph.addEdge({ from: 'a.ts', to: 'b.ts', weight: 1, ident: 'foo' });
graph.addEdge({ from: 'b.ts', to: 'c.ts', weight: 1, ident: 'bar' });
graph.addEdge({ from: 'c.ts', to: 'a.ts', weight: 1, ident: 'baz' });
// 高度偏好 a.ts
const personalization = new Map<string, number>();
personalization.set('a.ts', 100);
personalization.set('b.ts', 1);
personalization.set('c.ts', 1);
const ranks = pagerank(graph, { personalization });
// a.ts 应该有较高的 rank
expect(ranks.get('a.ts')!).toBeGreaterThan(ranks.get('b.ts')!);
});
it('should distribute ranks to definitions', () => {
const graph = new Graph();
graph.addEdge({ from: 'a.ts', to: 'b.ts', weight: 1, ident: 'foo' });
graph.addEdge({ from: 'a.ts', to: 'b.ts', weight: 2, ident: 'bar' });
const nodeRanks = pagerank(graph);
const defRanks = distributeRanksToDefinitions(graph, nodeRanks);
expect(defRanks.has('b.ts:foo')).toBe(true);
expect(defRanks.has('b.ts:bar')).toBe(true);
});
});
describe('DiskCache', () => {
let cacheDir: string;
let cache: DiskCache<{ value: number }>;
beforeAll(async () => {
cacheDir = path.join(os.tmpdir(), `repomap-test-${Date.now()}`);
cache = new DiskCache(cacheDir);
});
it('should set and get values', async () => {
await cache.set('key1', { value: 42 });
const result = await cache.get('key1');
expect(result).not.toBeNull();
expect(result?.value).toBe(42);
});
it('should return null for missing keys', async () => {
const result = await cache.get('nonexistent');
expect(result).toBeNull();
});
it('should delete values', async () => {
await cache.set('key2', { value: 100 });
await cache.delete('key2');
const result = await cache.get('key2');
expect(result).toBeNull();
});
it('should check existence with has', async () => {
await cache.set('key3', { value: 200 });
expect(await cache.has('key3')).toBe(true);
expect(await cache.has('nonexistent2')).toBe(false);
});
});
describe('TagExtractor', () => {
let tempDir: string;
beforeAll(async () => {
tempDir = path.join(os.tmpdir(), `repomap-extractor-test-${Date.now()}`);
await fs.mkdir(tempDir, { recursive: true });
});
it('should extract tags from TypeScript code using regex fallback', async () => {
const testFile = path.join(tempDir, 'test.ts');
const code = `
export function greet(name: string): string {
return \`Hello, \${name}!\`;
}
export class Greeter {
private name: string;
constructor(name: string) {
this.name = name;
}
greet(): string {
return \`Hello, \${this.name}!\`;
}
}
export interface IGreeter {
greet(): string;
}
`;
await fs.writeFile(testFile, code);
const extractor = new TagExtractor();
const tags = await extractor.getTags(testFile, 'test.ts');
// 检查是否提取到了一些 tags
expect(tags.length).toBeGreaterThan(0);
// 检查是否有定义
const defs = tags.filter((t) => t.kind === 'def');
expect(defs.length).toBeGreaterThan(0);
// 检查是否包含函数和类
const defNames = defs.map((t) => t.name);
expect(defNames).toContain('greet');
expect(defNames).toContain('Greeter');
});
it('should extract tags from Python code using regex fallback', async () => {
const testFile = path.join(tempDir, 'test.py');
const code = `
def greet(name: str) -> str:
return f"Hello, {name}!"
class Greeter:
def __init__(self, name: str):
self.name = name
def greet(self) -> str:
return f"Hello, {self.name}!"
async def async_greet(name: str) -> str:
return f"Hello async, {name}!"
`;
await fs.writeFile(testFile, code);
const extractor = new TagExtractor();
const tags = await extractor.getTags(testFile, 'test.py');
// 检查是否提取到了一些 tags
expect(tags.length).toBeGreaterThan(0);
// 检查是否有定义
const defs = tags.filter((t) => t.kind === 'def');
expect(defs.length).toBeGreaterThan(0);
// 检查是否包含函数和类
const defNames = defs.map((t) => t.name);
expect(defNames).toContain('greet');
expect(defNames).toContain('Greeter');
expect(defNames).toContain('async_greet');
});
});
describe('RepoMap', () => {
let tempDir: string;
beforeAll(async () => {
tempDir = path.join(os.tmpdir(), `repomap-test-${Date.now()}`);
await fs.mkdir(tempDir, { recursive: true });
// 创建测试文件
await fs.writeFile(
path.join(tempDir, 'utils.ts'),
`
export function add(a: number, b: number): number {
return a + b;
}
export function multiply(a: number, b: number): number {
return a * b;
}
`
);
await fs.writeFile(
path.join(tempDir, 'calculator.ts'),
`
import { add, multiply } from './utils';
export class Calculator {
add(a: number, b: number): number {
return add(a, b);
}
multiply(a: number, b: number): number {
return multiply(a, b);
}
}
`
);
await fs.writeFile(
path.join(tempDir, 'main.ts'),
`
import { Calculator } from './calculator';
const calc = new Calculator();
console.log(calc.add(1, 2));
console.log(calc.multiply(3, 4));
`
);
});
it('should create a repo map', async () => {
const repoMap = createRepoMap(tempDir, {
mapTokens: 2048,
});
const files = [
path.join(tempDir, 'utils.ts'),
path.join(tempDir, 'calculator.ts'),
path.join(tempDir, 'main.ts'),
];
const map = await repoMap.getRepoMap([], files, new Set(), new Set());
// 应该生成了一些内容
expect(map.length).toBeGreaterThan(0);
});
it('should boost mentioned identifiers', async () => {
const repoMap = createRepoMap(tempDir, {
mapTokens: 2048,
});
const files = [
path.join(tempDir, 'utils.ts'),
path.join(tempDir, 'calculator.ts'),
path.join(tempDir, 'main.ts'),
];
const mentionedIdents = new Set(['Calculator']);
const map = await repoMap.getRepoMap([], files, new Set(), mentionedIdents);
// 应该包含 Calculator 相关内容
expect(map.length).toBeGreaterThan(0);
});
it('should exclude chat files from output', async () => {
const repoMap = createRepoMap(tempDir, {
mapTokens: 2048,
});
const chatFiles = [path.join(tempDir, 'utils.ts')];
const otherFiles = [
path.join(tempDir, 'calculator.ts'),
path.join(tempDir, 'main.ts'),
];
const map = await repoMap.getRepoMap(chatFiles, otherFiles, new Set(), new Set());
// utils.ts 不应该出现在输出中
expect(map).not.toContain('utils.ts:');
});
});