diff --git a/.tox/.pkg/file.lock b/.tox/.pkg/file.lock new file mode 100644 index 0000000..e69de29 diff --git "a/C:\\Users\\litecli\\litecli_test.db" "b/C:\\Users\\litecli\\litecli_test.db" new file mode 100644 index 0000000..e69de29 diff --git a/CHANGELOG.md b/CHANGELOG.md index afdb40f..e429af4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ ## Unreleased +### Bug Fixes + +- Expand `~` in configured log file paths before opening the log. + ### Internal - Add a GitHub Actions workflow to run Codex review on pull requests. diff --git a/litecli/main.py b/litecli/main.py index 8151ffa..fa732c3 100644 --- a/litecli/main.py +++ b/litecli/main.py @@ -12,7 +12,7 @@ from datetime import datetime from io import open from time import time -from typing import Any, Generator, Iterable, cast +from typing import Any, Generator, Iterable, Literal, TextIO, cast import click import sqlparse @@ -75,13 +75,13 @@ def __init__( self, sqlexecute: SQLExecute | None = None, prompt: str | None = None, - logfile: Any | None = None, + logfile: TextIO | None = None, auto_vertical_output: bool = False, warn: bool | None = None, liteclirc: str | None = None, ) -> None: self.sqlexecute = sqlexecute - self.logfile = logfile + self.logfile: TextIO | Literal[False] | None = logfile # Load config. c = self.config = get_config(liteclirc) @@ -249,6 +249,7 @@ def initialize_logging(self) -> None: log_file = self.config["main"]["log_file"] if log_file == "default": log_file = config_location() + "log" + log_file = os.path.expanduser(log_file) try: ensure_dir_exists(log_file) except OSError: @@ -472,7 +473,9 @@ def one_iteration(text: str | None = None) -> None: try: start = time() assert self.sqlexecute is not None - cur = self.sqlexecute.conn and self.sqlexecute.conn.cursor() + conn = self.sqlexecute.conn + assert conn is not None + cur = conn.cursor() context, sql, duration = special.handle_llm(text, cur) if context: click.echo("LLM Reponse:") @@ -534,7 +537,9 @@ def one_iteration(text: str | None = None) -> None: except KeyboardInterrupt: try: # since connection can be sqlite3 or sqlean, it's hard to annotate the type for interrupt. so ignore the type hint warning. - sqlexecute.conn.interrupt() # type: ignore[attr-defined] + conn = sqlexecute.conn + if conn is not None: + conn.interrupt() # type: ignore[attr-defined] except Exception as e: self.echo( "Encountered error while cancelling query: {}".format(e), @@ -791,7 +796,7 @@ def _on_completions_refreshed(self, new_completer: SQLCompleter) -> None: def get_completions(self, text: str, cursor_positition: int) -> Iterable[Completion]: with self._completer_lock: - return cast(Iterable[Completion], self.completer.get_completions(Document(text=text, cursor_position=cursor_positition), None)) + return self.completer.get_completions(Document(text=text, cursor_position=cursor_positition), None) def get_prompt(self, string: str) -> str: self.logger.debug("Getting prompt %r", string) @@ -933,7 +938,7 @@ def cli( database: str, dbname: str, prompt: str | None, - logfile: Any | None, + logfile: TextIO | None, auto_vertical_output: bool, table: bool, csv: bool, diff --git a/litecli/packages/completion_engine.py b/litecli/packages/completion_engine.py index 2a64d78..4083b4d 100644 --- a/litecli/packages/completion_engine.py +++ b/litecli/packages/completion_engine.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, cast import sqlparse from sqlparse.sql import Comparison, Identifier, Where, Token @@ -243,7 +243,7 @@ def is_operand(x: str | None) -> bool: return [{"type": "user"}] elif token_v in ("select", "where", "having"): # Check for a table alias or schema qualification - parent = (identifier and identifier.get_parent_name()) or [] + parent = _get_parent_name(identifier) tables = extract_tables(full_text) if parent: @@ -265,7 +265,7 @@ def is_operand(x: str | None) -> bool: elif (token_v.endswith("join") and isinstance(token, Token) and token.is_keyword) or ( token_v in ("copy", "from", "update", "into", "describe", "truncate", "desc", "explain") ): - schema = (identifier and identifier.get_parent_name()) or [] + schema = _get_parent_name(identifier) # Suggest tables from either the currently-selected schema or the # public schema if no schema has been specified @@ -284,14 +284,14 @@ def is_operand(x: str | None) -> bool: elif token_v in ("table", "view", "function"): # E.g. 'DROP FUNCTION ', 'ALTER TABLE ' rel_type = token_v - schema = (identifier and identifier.get_parent_name()) or [] + schema = _get_parent_name(identifier) if schema: return [{"type": rel_type, "schema": schema}] else: return [{"type": "schema"}, {"type": rel_type, "schema": []}] elif token_v == "on": tables = extract_tables(full_text) # [(schema, table, alias), ...] - parent = (identifier and identifier.get_parent_name()) or [] + parent = _get_parent_name(identifier) if parent: # "ON parent." # parent can be either a schema name or table alias @@ -333,3 +333,11 @@ def is_operand(x: str | None) -> bool: def identifies(id: Any, schema: str | None, table: str, alias: str | None) -> bool: return (id == alias) or (id == table) or (schema is not None and (id == schema + "." + table)) + + +def _get_parent_name(identifier: Identifier | None) -> str | list[str]: + if identifier is None: + return [] + + parent = identifier.get_parent_name() + return cast(str, parent) if parent else [] diff --git a/litecli/packages/special/__init__.py b/litecli/packages/special/__init__.py index 410f25e..5eddb63 100644 --- a/litecli/packages/special/__init__.py +++ b/litecli/packages/special/__init__.py @@ -3,14 +3,16 @@ from __future__ import annotations from types import FunctionType -from typing import Callable, Any +from typing import TypeVar __all__: list[str] = [] +_Exported = TypeVar("_Exported") -def export(defn: Callable[..., Any]) -> Callable[..., Any]: + +def export(defn: _Exported) -> _Exported: """Decorator to explicitly mark functions that are exposed in a lib.""" - # ty, requires explict check for callable of tyep | function type to access __name__ + # ty requires an explicit callable/type check to access __name__. if isinstance(defn, (type, FunctionType)): globals()[defn.__name__] = defn __all__.append(defn.__name__) diff --git a/litecli/packages/special/llm.py b/litecli/packages/special/llm.py index 11e71d5..e4391d7 100644 --- a/litecli/packages/special/llm.py +++ b/litecli/packages/special/llm.py @@ -264,7 +264,7 @@ def ensure_litecli_template(replace: bool = False) -> None: @export -def handle_llm(text: str, cur: DBCursor) -> tuple[str, str | None, float]: +def handle_llm(text: str, cur: DBCursor) -> tuple[str, str, float]: """This function handles the special command `\\llm`. If it deals with a question that results in a SQL query then it will return @@ -375,7 +375,7 @@ def sql_using_llm( cur: DBCursor, question: str | None = None, verbose: bool = False, -) -> tuple[str, str | None, str | None]: +) -> tuple[str, str, str | None]: if cur is None: raise RuntimeError("Connect to a datbase and try again.") schema_query = """ diff --git a/tests/test_main.py b/tests/test_main.py index 0a47c9b..21ad5b6 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,8 +1,10 @@ +import logging import os import shutil from collections import namedtuple from datetime import datetime from textwrap import dedent +from typing import Any, cast from unittest.mock import patch import click @@ -148,9 +150,8 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager): class TestOutput: def get_size(self): - size = namedtuple("Size", "rows columns") - size.columns, size.rows = terminal_size - return size + Size = namedtuple("Size", "rows columns") + return Size(rows=terminal_size[1], columns=terminal_size[0]) class TestExecute: host = "test" @@ -165,7 +166,7 @@ class PromptBuffer(PromptSession): output = TestOutput() m.prompt_app = PromptBuffer() - m.sqlexecute = TestExecute() + m.sqlexecute = cast(Any, TestExecute()) m.explicit_pager = explicit_pager def echo_via_pager(s): @@ -232,18 +233,15 @@ def test_conditional_pager(monkeypatch): SPECIAL_COMMANDS["pager"].handler("") -def test_reserved_space_is_integer(): +def test_reserved_space_is_integer(monkeypatch): """Make sure that reserved space is returned as an integer.""" - def stub_terminal_size(): - return (5, 5) + def stub_terminal_size(fallback=(80, 24)): + return os.terminal_size((5, 5)) - old_func = shutil.get_terminal_size - - shutil.get_terminal_size = stub_terminal_size # type: ignore[assignment] + monkeypatch.setattr(shutil, "get_terminal_size", stub_terminal_size) lc = LiteCli() assert isinstance(lc.get_reserved_space(), int) - shutil.get_terminal_size = old_func @dbtest @@ -278,6 +276,33 @@ def test_startup_commands(executor): # implement tests on executions of the startupcommands +def test_initialize_logging_expands_user_log_file(monkeypatch, tmp_path): + home = tmp_path / "home" + log_file = home / ".cache" / "litecli" / "log" + monkeypatch.setenv("HOME", str(home)) + monkeypatch.setenv("USERPROFILE", str(home)) + + m = cast(Any, object.__new__(LiteCli)) + m.config = {"main": {"log_file": "~/.cache/litecli/log", "log_level": "INFO"}} + echo_messages = [] + m.echo = lambda *args, **kwargs: echo_messages.append((args, kwargs)) + + root_logger = logging.getLogger("litecli") + original_handlers = list(root_logger.handlers) + try: + m.initialize_logging() + + added_handlers = [handler for handler in root_logger.handlers if handler not in original_handlers] + assert log_file.exists() + assert not echo_messages + assert any(isinstance(handler, logging.FileHandler) and handler.baseFilename == str(log_file) for handler in added_handlers) + finally: + for handler in root_logger.handlers[:]: + if handler not in original_handlers: + root_logger.removeHandler(handler) + handler.close() + + @patch("litecli.main.datetime") # Adjust if your module path is different def test_get_prompt(mock_datetime): # We'll freeze time at 2025-01-20 13:37:42 for comedic effect.