Apply the requested search_path before executing the supplied SQL
authorKarl O. Pinc <kop@karlpinc.com>
Sun, 15 Sep 2024 19:42:37 +0000 (14:42 -0500)
committerKarl O. Pinc <kop@karlpinc.com>
Sun, 15 Sep 2024 19:42:37 +0000 (14:42 -0500)
src/pgwui_sql/views/sql.py

index d2a79d4f62a68266bb13f3e42cf16d2b7a515769..dc1389092acbe43958a2ddd11ea69d1c387b9020 100644 (file)
@@ -79,7 +79,8 @@ class SQLResultsHandler(pgwui_core.core.SessionDBHandler):
       cur
     '''
     sql_results = attrs.field(factory=list)
-    search_path = attrs.field(default=None)
+    search_path = attrs.field(default=None)     # requested search_path
+    db_search_path = attrs.field(default=None)  # search_path of db
 
     def make_form(self):
         return pgwui_sql.views.base.SQLForm().build(
@@ -146,6 +147,23 @@ class SQLResultsHandler(pgwui_core.core.SessionDBHandler):
             rows.append(ResultRow().build_data_row(row))
         return rows
 
+    def get_db_search_path(self):
+        self.cur.execute('SHOW search_path;')
+        sp = self.cur.fetchone()[0]
+        self.cur.fetchone()   # Exhaust results, get None
+        return sp
+
+    def set_search_path(self, sql):
+        '''Prepend code to set the requested search path to the supplied sql,
+        if the requested search path is not the default.
+
+        Returns: The adjusted sql
+        '''
+        if (self.search_path is None
+                or self.db_search_path == self.search_path):
+            return sql
+        return f'SET search_path TO {self.search_path};\n{sql}'
+
     def cleanup(self):
         '''
         Execute a series of SQL statements.
@@ -153,6 +171,14 @@ class SQLResultsHandler(pgwui_core.core.SessionDBHandler):
         interleaving errors with output.
         '''
         cur = self.cur
+
+        # Adjust the executed SQL to use the requested search_path
+        # Change the form content so that the user sees the change
+        # We can get away with this because this does not change
+        # the sql in the form that supplies the sql.
+        self.db_search_path = self.get_db_search_path()
+        self.uf['sql'] = self.set_search_path(self.uf['sql'])
+
         self.execute(cur, self.uf['sql'])
         if cur.statusmessage is None:
             raise sql_ex.NoStatementsError(descr='No SQL statements executed')