-
Notifications
You must be signed in to change notification settings - Fork 49
Add Kotlin language support #592
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6442af8
20d093b
692fd06
e5eda92
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,157 @@ | ||
| from pathlib import Path | ||
| from ...entities.entity import Entity | ||
| from ...entities.file import File | ||
| from typing import Optional | ||
| from ..analyzer import AbstractAnalyzer | ||
|
|
||
| from multilspy import SyncLanguageServer | ||
|
|
||
| import tree_sitter_kotlin as tskotlin | ||
| from tree_sitter import Language, Node | ||
|
|
||
| import logging | ||
| logger = logging.getLogger('code_graph') | ||
|
|
||
| class KotlinAnalyzer(AbstractAnalyzer): | ||
| def __init__(self) -> None: | ||
| super().__init__(Language(tskotlin.language())) | ||
|
|
||
| def add_dependencies(self, path: Path, files: list[Path]): | ||
| # For now, we skip dependency resolution for Kotlin | ||
| # In the future, this could parse build.gradle or pom.xml for Kotlin projects | ||
| pass | ||
|
|
||
| def get_entity_label(self, node: Node) -> str: | ||
| if node.type == 'class_declaration': | ||
| # Check if it's an interface by looking for interface keyword | ||
| for child in node.children: | ||
| if child.type == 'interface': | ||
| return "Interface" | ||
| return "Class" | ||
| elif node.type == 'object_declaration': | ||
| return "Object" | ||
| elif node.type == 'function_declaration': | ||
| # Check if this is a method (inside a class) or a top-level function | ||
| parent = node.parent | ||
| if parent and parent.type == 'class_body': | ||
| return "Method" | ||
| return "Function" | ||
| raise ValueError(f"Unknown entity type: {node.type}") | ||
|
|
||
| def get_entity_name(self, node: Node) -> str: | ||
| if node.type in ['class_declaration', 'object_declaration', 'function_declaration']: | ||
| for child in node.children: | ||
| if child.type == 'identifier': | ||
| return child.text.decode('utf-8') | ||
| raise ValueError(f"Cannot extract name from entity type: {node.type}") | ||
|
|
||
| def get_entity_docstring(self, node: Node) -> Optional[str]: | ||
| if node.type in ['class_declaration', 'object_declaration', 'function_declaration']: | ||
| # Check for KDoc comment (/** ... */) before the node | ||
| if node.prev_sibling and node.prev_sibling.type == "multiline_comment": | ||
| comment_text = node.prev_sibling.text.decode('utf-8') | ||
| # Only return if it's a KDoc comment (starts with /**) | ||
| if comment_text.startswith('/**'): | ||
| return comment_text | ||
| return None | ||
| raise ValueError(f"Unknown entity type: {node.type}") | ||
|
|
||
| def get_entity_types(self) -> list[str]: | ||
| return ['class_declaration', 'object_declaration', 'function_declaration'] | ||
|
|
||
| def _get_delegation_types(self, entity: Entity) -> list[tuple]: | ||
| """Extract type identifiers from delegation specifiers in order. | ||
|
|
||
| Returns list of (node, is_constructor_invocation) tuples. | ||
| constructor_invocation indicates a superclass; plain user_type indicates an interface. | ||
| """ | ||
| types = [] | ||
| for child in entity.node.children: | ||
| if child.type == 'delegation_specifiers': | ||
| for spec in child.children: | ||
| if spec.type == 'delegation_specifier': | ||
| for sub in spec.children: | ||
| if sub.type == 'constructor_invocation': | ||
| for s in sub.children: | ||
| if s.type == 'user_type': | ||
| for id_node in s.children: | ||
| if id_node.type == 'identifier': | ||
| types.append((id_node, True)) | ||
| elif sub.type == 'user_type': | ||
| for id_node in sub.children: | ||
| if id_node.type == 'identifier': | ||
| types.append((id_node, False)) | ||
| return types | ||
|
|
||
| def add_symbols(self, entity: Entity) -> None: | ||
| if entity.node.type == 'class_declaration': | ||
| types = self._get_delegation_types(entity) | ||
| for node, is_class in types: | ||
| if is_class: | ||
| entity.add_symbol("base_class", node) | ||
| else: | ||
| entity.add_symbol("implement_interface", node) | ||
|
|
||
| elif entity.node.type == 'object_declaration': | ||
| types = self._get_delegation_types(entity) | ||
| for node, _ in types: | ||
| entity.add_symbol("implement_interface", node) | ||
|
|
||
| elif entity.node.type == 'function_declaration': | ||
| # Find function calls | ||
| captures = self._captures("(call_expression) @reference.call", entity.node) | ||
| if 'reference.call' in captures: | ||
| for caller in captures['reference.call']: | ||
| entity.add_symbol("call", caller) | ||
|
|
||
| # Find parameters with types | ||
| captures = self._captures("(parameter (user_type (identifier) @parameter))", entity.node) | ||
| if 'parameter' in captures: | ||
| for parameter in captures['parameter']: | ||
| entity.add_symbol("parameters", parameter) | ||
|
|
||
| # Find return type | ||
| captures = self._captures("(function_declaration (user_type (identifier) @return_type))", entity.node) | ||
| if 'return_type' in captures: | ||
| for return_type in captures['return_type']: | ||
| entity.add_symbol("return_type", return_type) | ||
|
|
||
| def is_dependency(self, file_path: str) -> bool: | ||
| # Check if file is in a dependency directory (e.g., build, .gradle cache) | ||
| return "build/" in file_path or ".gradle/" in file_path or "/cache/" in file_path | ||
|
|
||
| def resolve_path(self, file_path: str, path: Path) -> str: | ||
| # For Kotlin, just return the file path as-is for now | ||
| return file_path | ||
|
|
||
| def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, node: Node) -> list[Entity]: | ||
| res = [] | ||
| for file, resolved_node in self.resolve(files, lsp, file_path, path, node): | ||
| type_dec = self.find_parent(resolved_node, ['class_declaration', 'object_declaration']) | ||
| if type_dec in file.entities: | ||
| res.append(file.entities[type_dec]) | ||
| return res | ||
|
|
||
| def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, node: Node) -> list[Entity]: | ||
| res = [] | ||
| # For call expressions, we need to extract the function name | ||
| if node.type == 'call_expression': | ||
| # Find the identifier being called | ||
| for child in node.children: | ||
| if child.type in ['identifier', 'navigation_expression']: | ||
| for file, resolved_node in self.resolve(files, lsp, file_path, path, child): | ||
| method_dec = self.find_parent(resolved_node, ['function_declaration', 'class_declaration', 'object_declaration']) | ||
| if method_dec and method_dec.type in ['class_declaration', 'object_declaration']: | ||
| continue | ||
| if method_dec in file.entities: | ||
| res.append(file.entities[method_dec]) | ||
| break | ||
| return res | ||
|
|
||
| def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, key: str, symbol: Node) -> list[Entity]: | ||
| if key in ["implement_interface", "base_class", "parameters", "return_type"]: | ||
| return self.resolve_type(files, lsp, file_path, path, symbol) | ||
| elif key in ["call"]: | ||
| return self.resolve_method(files, lsp, file_path, path, symbol) | ||
| else: | ||
| raise ValueError(f"Unknown key {key}") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,10 +8,11 @@ | |
| from ..graph import Graph | ||
| from .analyzer import AbstractAnalyzer | ||
| # from .c.analyzer import CAnalyzer | ||
| from .java.analyzer import JavaAnalyzer | ||
| from .python.analyzer import PythonAnalyzer | ||
| from .csharp.analyzer import CSharpAnalyzer | ||
| from .java.analyzer import JavaAnalyzer | ||
| from .javascript.analyzer import JavaScriptAnalyzer | ||
| from .kotlin.analyzer import KotlinAnalyzer | ||
| from .python.analyzer import PythonAnalyzer | ||
|
|
||
| from multilspy import SyncLanguageServer | ||
| from multilspy.multilspy_config import MultilspyConfig | ||
|
|
@@ -28,7 +29,9 @@ | |
| '.py': PythonAnalyzer(), | ||
| '.java': JavaAnalyzer(), | ||
| '.cs': CSharpAnalyzer(), | ||
| '.js': JavaScriptAnalyzer()} | ||
| '.js': JavaScriptAnalyzer(), | ||
| '.kt': KotlinAnalyzer(), | ||
| '.kts': KotlinAnalyzer()} | ||
|
|
||
| class NullLanguageServer: | ||
| def start_server(self): | ||
|
|
@@ -145,8 +148,11 @@ def second_pass(self, graph: Graph, files: list[Path], path: Path) -> None: | |
| lsps[".cs"] = SyncLanguageServer.create(config, logger, str(path)) | ||
| else: | ||
| lsps[".cs"] = NullLanguageServer() | ||
| # For now, use NullLanguageServer for Kotlin as kotlin-language-server setup is not yet integrated | ||
| lsps[".kt"] = NullLanguageServer() | ||
| lsps[".kts"] = NullLanguageServer() | ||
| lsps[".js"] = NullLanguageServer() | ||
| with lsps[".java"].start_server(), lsps[".py"].start_server(), lsps[".cs"].start_server(), lsps[".js"].start_server(): | ||
| with lsps[".java"].start_server(), lsps[".py"].start_server(), lsps[".cs"].start_server(), lsps[".js"].start_server(), lsps[".kt"].start_server(), lsps[".kts"].start_server(): | ||
| files_len = len(self.files) | ||
| for i, file_path in enumerate(files): | ||
| if file_path not in self.files: | ||
|
|
@@ -158,31 +164,28 @@ def second_pass(self, graph: Graph, files: list[Path], path: Path) -> None: | |
| logging.info(f'Processing file ({i + 1}/{files_len}): {file_path}') | ||
| for _, entity in file.entities.items(): | ||
| entity.resolved_symbol(lambda key, symbol, fp=file_path: analyzers[fp.suffix].resolve_symbol(self.files, lsps[fp.suffix], fp, path, key, symbol)) | ||
| for key, symbols in entity.symbols.items(): | ||
| for symbol in symbols: | ||
| if len(symbol.resolved_symbol) == 0: | ||
| continue | ||
| resolved_symbol = next(iter(symbol.resolved_symbol)) | ||
| for key, resolved_set in entity.resolved_symbols.items(): | ||
| for resolved in resolved_set: | ||
| if key == "base_class": | ||
| graph.connect_entities("EXTENDS", entity.id, resolved_symbol.id) | ||
| graph.connect_entities("EXTENDS", entity.id, resolved.id) | ||
| elif key == "implement_interface": | ||
| graph.connect_entities("IMPLEMENTS", entity.id, resolved_symbol.id) | ||
| graph.connect_entities("IMPLEMENTS", entity.id, resolved.id) | ||
| elif key == "extend_interface": | ||
| graph.connect_entities("EXTENDS", entity.id, resolved_symbol.id) | ||
| graph.connect_entities("EXTENDS", entity.id, resolved.id) | ||
| elif key == "call": | ||
| graph.connect_entities("CALLS", entity.id, resolved_symbol.id, {"line": symbol.symbol.start_point.row, "text": symbol.symbol.text.decode("utf-8")}) | ||
| graph.connect_entities("CALLS", entity.id, resolved.id) | ||
| elif key == "return_type": | ||
| graph.connect_entities("RETURNS", entity.id, resolved_symbol.id) | ||
| graph.connect_entities("RETURNS", entity.id, resolved.id) | ||
| elif key == "parameters": | ||
| graph.connect_entities("PARAMETERS", entity.id, resolved_symbol.id) | ||
| graph.connect_entities("PARAMETERS", entity.id, resolved.id) | ||
|
|
||
| def analyze_files(self, files: list[Path], path: Path, graph: Graph) -> None: | ||
| self.first_pass(path, files, [], graph) | ||
| self.second_pass(graph, files, path) | ||
|
|
||
| def analyze_sources(self, path: Path, ignore: list[str], graph: Graph) -> None: | ||
| path = path.resolve() | ||
| files = list(path.rglob("*.java")) + list(path.rglob("*.py")) + list(path.rglob("*.cs")) + [f for f in path.rglob("*.js") if "node_modules" not in f.parts] | ||
| files = list(path.rglob("*.java")) + list(path.rglob("*.py")) + list(path.rglob("*.cs")) + [f for f in path.rglob("*.js") if "node_modules" not in f.parts] + list(path.rglob("*.kt")) + list(path.rglob("*.kts")) | ||
| # First pass analysis of the source code | ||
|
Comment on lines
186
to
189
|
||
| self.first_pass(path, files, ignore, graph) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,31 +1,31 @@ | ||
| from typing import Callable, Self | ||
| from tree_sitter import Node | ||
|
|
||
| class Symbol: | ||
| def __init__(self, symbol: Node): | ||
| self.symbol = symbol | ||
| self.resolved_symbol = set() | ||
|
|
||
| def add_resolve_symbol(self, resolved_symbol): | ||
| self.resolved_symbol.add(resolved_symbol) | ||
|
|
||
| class Entity: | ||
| def __init__(self, node: Node): | ||
| self.node = node | ||
| self.symbols: dict[str, list[Symbol]] = {} | ||
| self.symbols: dict[str, list[Node]] = {} | ||
| self.resolved_symbols: dict[str, set[Self]] = {} | ||
| self.children: dict[Node, Self] = {} | ||
|
|
||
| def add_symbol(self, key: str, symbol: Node): | ||
| if key not in self.symbols: | ||
| self.symbols[key] = [] | ||
| self.symbols[key].append(Symbol(symbol)) | ||
| self.symbols[key].append(symbol) | ||
|
|
||
| def add_resolved_symbol(self, key: str, symbol: Self): | ||
| if key not in self.resolved_symbols: | ||
| self.resolved_symbols[key] = set() | ||
| self.resolved_symbols[key].add(symbol) | ||
|
|
||
| def add_child(self, child: Self): | ||
| child.parent = self | ||
| self.children[child.node] = child | ||
|
|
||
| def resolved_symbol(self, f: Callable[[str, Node], list[Self]]): | ||
| for key, symbols in self.symbols.items(): | ||
| self.resolved_symbols[key] = set() | ||
| for symbol in symbols: | ||
| for resolved_symbol in f(key, symbol.symbol): | ||
| symbol.add_resolve_symbol(resolved_symbol) | ||
| for resolved_symbol in f(key, symbol): | ||
| self.resolved_symbols[key].add(resolved_symbol) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| /** | ||
| * A base interface for logging | ||
| */ | ||
| interface Logger { | ||
| fun log(message: String) | ||
| } | ||
|
|
||
| /** | ||
| * Base class for shapes | ||
| */ | ||
| open class Shape(val name: String) { | ||
| open fun area(): Double = 0.0 | ||
| } | ||
|
|
||
| class Circle(val radius: Double) : Shape("circle"), Logger { | ||
| override fun area(): Double { | ||
| return Math.PI * radius * radius | ||
| } | ||
|
|
||
| override fun log(message: String) { | ||
| println(message) | ||
| } | ||
| } | ||
|
|
||
| fun calculateTotal(shapes: List<Shape>): Double { | ||
| var total = 0.0 | ||
| for (shape in shapes) { | ||
| total += shape.area() | ||
| } | ||
| return total | ||
| } | ||
|
|
||
| object AppConfig : Logger { | ||
| val version = "1.0" | ||
|
|
||
| override fun log(message: String) { | ||
| println("[$version] $message") | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CALLSrelationships no longer include call-site properties (previouslylineandtextwere stored). With the newresolved_symbolsset-based structure, this also loses per-call-site granularity when the same callee is invoked multiple times. Consider preserving a mapping from each call node to its resolved entity(ies) so you can keep call-site metadata on edges.