Support downloading one file per result set as separate files of a zip file
authorKarl O. Pinc <kop@karlpinc.com>
Wed, 2 Oct 2024 19:45:19 +0000 (14:45 -0500)
committerKarl O. Pinc <kop@karlpinc.com>
Wed, 2 Oct 2024 19:47:47 +0000 (14:47 -0500)
src/pgwui_sql/views/sql.py

index b5d583d322f2a4e0486f8e2c936a23f4f916c73f..007dfe6f65541f145d54e0a84f1a038b4f9f9ce1 100644 (file)
@@ -31,6 +31,7 @@ import pyramid.response
 import sys
 import tempfile
 import wtforms.fields
+import zipfile
 
 import pgwui_core.core
 import pgwui_core.utils
@@ -152,6 +153,8 @@ class SQLResultsHandler(pgwui_core.core.SessionDBHandler):
     search_path = attrs.field(default=None)     # requested search_path
     db_search_path = attrs.field(default=None)  # search_path of db
     tfile = attrs.field(default=None)
+    zip_fd = attrs.field(default=None)
+    dl_filename = attrs.field(default=None)
 
     def make_form(self):
         return SQLForm().build(self, ip=SQLInitialPost(), fc=SQLWTForm)
@@ -159,8 +162,9 @@ class SQLResultsHandler(pgwui_core.core.SessionDBHandler):
     def read(self):
         super().read()
         self.search_path = self.request.POST.get('search_path')
+        self.dl_filename = self.make_dl_filename()
 
-    def dl_filename(self):
+    def make_dl_filename(self):
         uf = self.uf
         return '_'.join(
             ['sql_results',
@@ -169,6 +173,11 @@ class SQLResultsHandler(pgwui_core.core.SessionDBHandler):
              '_'.join(datetime.datetime.now().isoformat(
                       sep="_", timespec="seconds").split(':'))])
 
+    def zip_at_pathname(self, suffix):
+        '''Return a value suitable for a zipfile.Path() "at" key's value
+        '''
+        return f'{self.dl_filename}/{self.dl_filename}_{suffix}'
+
     def write(self, result, errors):
         '''
         Setup dict to render resulting html form
@@ -186,7 +195,7 @@ class SQLResultsHandler(pgwui_core.core.SessionDBHandler):
         response['report_success'] = (not response['errors']
                                       and self.uf['action'] != '')
         if self.uf['download']:
-            response['dl_filename'] = self.dl_filename()
+            response['dl_filename'] = self.dl_filename
         return response
 
     def val_input(self):
@@ -270,38 +279,85 @@ class SQLResultsHandler(pgwui_core.core.SessionDBHandler):
             return sql
         return f'SET search_path TO {self.search_path};\n{sql}'
 
-    def make_csv_writer(self):
+    def open_tfile(self):
+        '''Open the file to be downloaded and save it in self.tfile
+        '''
+        if self.uf['download_as'] == MANY_FILES_VALUE:
+            self.tfile = tempfile.TemporaryFile(mode='w+b')
+            self.zip_fd = zipfile.ZipFile(self.tfile, mode='a')
+        else:
+            self.tfile = tempfile.TemporaryFile(mode='w+t', newline='')
+
+    def make_csv_writer(self, fd=None):
+        if not fd:
+            fd = self.tfile
         vinfo = sys.version_info
         if self.uf['download_fmt'] == CSV:
             if vinfo.major >= 3 and vinfo.minor >= 12:
                 quoting = csv.QUOTE_STRINGS
             else:
                 quoting = csv.QUOTE_NONNUMERIC
-            return csv.writer(self.tfile, quoting=quoting)
+            return csv.writer(fd, quoting=quoting)
         else:
-            return csv.writer(self.tfile, dialect=csv.excel_tab)
+            return csv.writer(fd, dialect=csv.excel_tab)
+
+    def write_sql(self):
+        if self.uf['include_sql']:
+            if self.uf['download_as'] == MANY_FILES_VALUE:
+                with zipfile.Path(
+                    self.zip_fd,
+                    at=self.zip_at_pathname('statements.txt')
+                ).open(mode='wb') as fd:
+                    fd.write(self.uf['sql'].encode())
+                return None
+            else:
+                writer = self.make_csv_writer()
+                # Strip trailing whitespace from sql because otherwise,
+                # after import to a spreadsheet with cells a single row tall,
+                # only the empty line after the last EOL is shown and the
+                # first cell of the sheet looks empty instead of looking like
+                # it contains the sql.
+                writer.writerow((self.uf['sql'].rstrip(),))
+                return writer
+
+    def write_resultset(self, cur, writer, null_rep):
+        # Rather than report the statusmessage first, which requires
+        # putting all the statement's results in RAM, report it last.
+        if cur.rownumber is not None:
+            writer.writerow(ResultRow().build_heading_row(cur).data)
+            for row in self.get_result_rows(cur, null_rep):
+                writer.writerow(row.data)
+        writer.writerow((ResultRow().build_rowcount_row(cur).data,))
+        writer.writerow(
+            (ResultRow().build_statusmessage_row(cur).data,))
 
     def make_download(self, cur):
         # Optimized to minimize RAM usage
         null_rep = self.uf['null_rep']
-        self.tfile = tempfile.TemporaryFile(mode='w+t', newline='')
-        writer = self.make_csv_writer()
-        if self.uf['include_sql']:
-            writer.writerow((self.uf['sql'].rstrip(),))
+        self.open_tfile()
+        if self.uf['download_as'] == MANY_FILES_VALUE:
+            self.zip_fd.mkdir(self.dl_filename)
+        writer = self.write_sql()
 
         nextset = True
+        download_as = self.uf['download_as']
+        stmt_no = 1
         while nextset is True:
-            # Rather than report the statusmessage first, which requires
-            # putting all the statement's results in RAM, report it last.
-            if cur.rownumber is not None:
-                writer.writerow(ResultRow().build_heading_row(cur).data)
-                for row in self.get_result_rows(cur, null_rep):
-                    writer.writerow(row.data)
-            writer.writerow((ResultRow().build_rowcount_row(cur).data,))
-            writer.writerow((ResultRow().build_statusmessage_row(cur).data,))
+            if download_as == MANY_FILES_VALUE:
+                with zipfile.Path(
+                        self.zip_fd,
+                        at=self.zip_at_pathname(f'stmt{stmt_no}.txt')
+                ).open(mode='w', newline='') as fd:
+                    self.write_resultset(
+                        cur, self.make_csv_writer(fd), null_rep)
+                stmt_no += 1
+            else:
+                self.write_resultset(cur, writer, null_rep)
 
             nextset = cur.nextset()
 
+        if download_as == MANY_FILES_VALUE:
+            self.zip_fd.close()
         self.tfile.seek(0)
 
     def make_sql_results(self, cur):
@@ -341,13 +397,12 @@ class SQLResultsHandler(pgwui_core.core.SessionDBHandler):
 
         try:
             if self.uf['download']:
-                try:
-                    self.make_download(cur)
-                except csv.Error as ex:
-                    raise sql_ex.CSVError(
-                        descr=f'The csv module reports: {ex}')
+                self.make_download(cur)
             else:
                 self.make_sql_results(cur)
+        except csv.Error as ex:
+            raise sql_ex.CSVError(
+                descr=f'The csv module reports: {ex}')
         except MemoryError:
             self.sql_results = []
             gc.collect()
@@ -435,7 +490,10 @@ def sql_view(request):
             'attachment;'
             f' filename={response["dl_filename"]}.{generate_suffix(uh.uf)}')
 
-        pmd_response.app_iter = codecs.iterencode(uh.tfile, 'utf_8')
+        if uh.uf['download_as'] == MANY_FILES_VALUE:
+            pmd_response.app_iter = uh.tfile
+        else:
+            pmd_response.app_iter = codecs.iterencode(uh.tfile, 'utf_8')
 
         log_response(response, uh.uf['download'])
         return pmd_response