Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added .tox/.pkg/file.lock
Empty file.
Empty file.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
## Unreleased

### Bug Fixes

- Expand `~` in configured log file paths before opening the log.

### Internal

- Add a GitHub Actions workflow to run Codex review on pull requests.
Expand Down
19 changes: 12 additions & 7 deletions litecli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from datetime import datetime
from io import open
from time import time
from typing import Any, Generator, Iterable, cast
from typing import Any, Generator, Iterable, Literal, TextIO, cast

import click
import sqlparse
Expand Down Expand Up @@ -75,13 +75,13 @@ def __init__(
self,
sqlexecute: SQLExecute | None = None,
prompt: str | None = None,
logfile: Any | None = None,
logfile: TextIO | None = None,
auto_vertical_output: bool = False,
warn: bool | None = None,
liteclirc: str | None = None,
) -> None:
self.sqlexecute = sqlexecute
self.logfile = logfile
self.logfile: TextIO | Literal[False] | None = logfile

# Load config.
c = self.config = get_config(liteclirc)
Expand Down Expand Up @@ -249,6 +249,7 @@ def initialize_logging(self) -> None:
log_file = self.config["main"]["log_file"]
if log_file == "default":
log_file = config_location() + "log"
log_file = os.path.expanduser(log_file)
try:
ensure_dir_exists(log_file)
except OSError:
Expand Down Expand Up @@ -472,7 +473,9 @@ def one_iteration(text: str | None = None) -> None:
try:
start = time()
assert self.sqlexecute is not None
cur = self.sqlexecute.conn and self.sqlexecute.conn.cursor()
conn = self.sqlexecute.conn
assert conn is not None
cur = conn.cursor()
context, sql, duration = special.handle_llm(text, cur)
if context:
click.echo("LLM Reponse:")
Expand Down Expand Up @@ -534,7 +537,9 @@ def one_iteration(text: str | None = None) -> None:
except KeyboardInterrupt:
try:
# since connection can be sqlite3 or sqlean, it's hard to annotate the type for interrupt. so ignore the type hint warning.
sqlexecute.conn.interrupt() # type: ignore[attr-defined]
conn = sqlexecute.conn
if conn is not None:
conn.interrupt() # type: ignore[attr-defined]
except Exception as e:
self.echo(
"Encountered error while cancelling query: {}".format(e),
Expand Down Expand Up @@ -791,7 +796,7 @@ def _on_completions_refreshed(self, new_completer: SQLCompleter) -> None:

def get_completions(self, text: str, cursor_positition: int) -> Iterable[Completion]:
with self._completer_lock:
return cast(Iterable[Completion], self.completer.get_completions(Document(text=text, cursor_position=cursor_positition), None))
return self.completer.get_completions(Document(text=text, cursor_position=cursor_positition), None)

def get_prompt(self, string: str) -> str:
self.logger.debug("Getting prompt %r", string)
Expand Down Expand Up @@ -933,7 +938,7 @@ def cli(
database: str,
dbname: str,
prompt: str | None,
logfile: Any | None,
logfile: TextIO | None,
auto_vertical_output: bool,
table: bool,
csv: bool,
Expand Down
18 changes: 13 additions & 5 deletions litecli/packages/completion_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any
from typing import Any, cast

import sqlparse
from sqlparse.sql import Comparison, Identifier, Where, Token
Expand Down Expand Up @@ -243,7 +243,7 @@ def is_operand(x: str | None) -> bool:
return [{"type": "user"}]
elif token_v in ("select", "where", "having"):
# Check for a table alias or schema qualification
parent = (identifier and identifier.get_parent_name()) or []
parent = _get_parent_name(identifier)

tables = extract_tables(full_text)
if parent:
Expand All @@ -265,7 +265,7 @@ def is_operand(x: str | None) -> bool:
elif (token_v.endswith("join") and isinstance(token, Token) and token.is_keyword) or (
token_v in ("copy", "from", "update", "into", "describe", "truncate", "desc", "explain")
):
schema = (identifier and identifier.get_parent_name()) or []
schema = _get_parent_name(identifier)

# Suggest tables from either the currently-selected schema or the
# public schema if no schema has been specified
Expand All @@ -284,14 +284,14 @@ def is_operand(x: str | None) -> bool:
elif token_v in ("table", "view", "function"):
# E.g. 'DROP FUNCTION <funcname>', 'ALTER TABLE <tablname>'
rel_type = token_v
schema = (identifier and identifier.get_parent_name()) or []
schema = _get_parent_name(identifier)
if schema:
return [{"type": rel_type, "schema": schema}]
else:
return [{"type": "schema"}, {"type": rel_type, "schema": []}]
elif token_v == "on":
tables = extract_tables(full_text) # [(schema, table, alias), ...]
parent = (identifier and identifier.get_parent_name()) or []
parent = _get_parent_name(identifier)
if parent:
# "ON parent.<suggestion>"
# parent can be either a schema name or table alias
Expand Down Expand Up @@ -333,3 +333,11 @@ def is_operand(x: str | None) -> bool:

def identifies(id: Any, schema: str | None, table: str, alias: str | None) -> bool:
return (id == alias) or (id == table) or (schema is not None and (id == schema + "." + table))


def _get_parent_name(identifier: Identifier | None) -> str | list[str]:
if identifier is None:
return []

parent = identifier.get_parent_name()
return cast(str, parent) if parent else []
8 changes: 5 additions & 3 deletions litecli/packages/special/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
from __future__ import annotations
from types import FunctionType

from typing import Callable, Any
from typing import TypeVar

__all__: list[str] = []

_Exported = TypeVar("_Exported")

def export(defn: Callable[..., Any]) -> Callable[..., Any]:

def export(defn: _Exported) -> _Exported:
"""Decorator to explicitly mark functions that are exposed in a lib."""
# ty, requires explict check for callable of tyep | function type to access __name__
# ty requires an explicit callable/type check to access __name__.
if isinstance(defn, (type, FunctionType)):
globals()[defn.__name__] = defn
__all__.append(defn.__name__)
Expand Down
4 changes: 2 additions & 2 deletions litecli/packages/special/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def ensure_litecli_template(replace: bool = False) -> None:


@export
def handle_llm(text: str, cur: DBCursor) -> tuple[str, str | None, float]:
def handle_llm(text: str, cur: DBCursor) -> tuple[str, str, float]:
"""This function handles the special command `\\llm`.

If it deals with a question that results in a SQL query then it will return
Expand Down Expand Up @@ -375,7 +375,7 @@ def sql_using_llm(
cur: DBCursor,
question: str | None = None,
verbose: bool = False,
) -> tuple[str, str | None, str | None]:
) -> tuple[str, str, str | None]:
if cur is None:
raise RuntimeError("Connect to a datbase and try again.")
schema_query = """
Expand Down
47 changes: 36 additions & 11 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
import os
import shutil
from collections import namedtuple
from datetime import datetime
from textwrap import dedent
from typing import Any, cast
from unittest.mock import patch

import click
Expand Down Expand Up @@ -148,9 +150,8 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):

class TestOutput:
def get_size(self):
size = namedtuple("Size", "rows columns")
size.columns, size.rows = terminal_size
return size
Size = namedtuple("Size", "rows columns")
return Size(rows=terminal_size[1], columns=terminal_size[0])

class TestExecute:
host = "test"
Expand All @@ -165,7 +166,7 @@ class PromptBuffer(PromptSession):
output = TestOutput()

m.prompt_app = PromptBuffer()
m.sqlexecute = TestExecute()
m.sqlexecute = cast(Any, TestExecute())
m.explicit_pager = explicit_pager

def echo_via_pager(s):
Expand Down Expand Up @@ -232,18 +233,15 @@ def test_conditional_pager(monkeypatch):
SPECIAL_COMMANDS["pager"].handler("")


def test_reserved_space_is_integer():
def test_reserved_space_is_integer(monkeypatch):
"""Make sure that reserved space is returned as an integer."""

def stub_terminal_size():
return (5, 5)
def stub_terminal_size(fallback=(80, 24)):
return os.terminal_size((5, 5))

old_func = shutil.get_terminal_size

shutil.get_terminal_size = stub_terminal_size # type: ignore[assignment]
monkeypatch.setattr(shutil, "get_terminal_size", stub_terminal_size)
lc = LiteCli()
assert isinstance(lc.get_reserved_space(), int)
shutil.get_terminal_size = old_func


@dbtest
Expand Down Expand Up @@ -278,6 +276,33 @@ def test_startup_commands(executor):
# implement tests on executions of the startupcommands


def test_initialize_logging_expands_user_log_file(monkeypatch, tmp_path):
home = tmp_path / "home"
log_file = home / ".cache" / "litecli" / "log"
monkeypatch.setenv("HOME", str(home))
monkeypatch.setenv("USERPROFILE", str(home))

m = cast(Any, object.__new__(LiteCli))
m.config = {"main": {"log_file": "~/.cache/litecli/log", "log_level": "INFO"}}
echo_messages = []
m.echo = lambda *args, **kwargs: echo_messages.append((args, kwargs))

root_logger = logging.getLogger("litecli")
original_handlers = list(root_logger.handlers)
try:
m.initialize_logging()

added_handlers = [handler for handler in root_logger.handlers if handler not in original_handlers]
assert log_file.exists()
assert not echo_messages
assert any(isinstance(handler, logging.FileHandler) and handler.baseFilename == str(log_file) for handler in added_handlers)
finally:
for handler in root_logger.handlers[:]:
if handler not in original_handlers:
root_logger.removeHandler(handler)
handler.close()


@patch("litecli.main.datetime") # Adjust if your module path is different
def test_get_prompt(mock_datetime):
# We'll freeze time at 2025-01-20 13:37:42 for comedic effect.
Expand Down
Loading