csv_loader.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import logging
  2. import csv
  3. from typing import Optional, Dict, List
  4. from langchain.document_loaders import CSVLoader as LCCSVLoader
  5. from langchain.document_loaders.helpers import detect_file_encodings
  6. from langchain.schema import Document
  7. logger = logging.getLogger(__name__)
  8. class CSVLoader(LCCSVLoader):
  9. def __init__(
  10. self,
  11. file_path: str,
  12. source_column: Optional[str] = None,
  13. csv_args: Optional[Dict] = None,
  14. encoding: Optional[str] = None,
  15. autodetect_encoding: bool = True,
  16. ):
  17. self.file_path = file_path
  18. self.source_column = source_column
  19. self.encoding = encoding
  20. self.csv_args = csv_args or {}
  21. self.autodetect_encoding = autodetect_encoding
  22. def load(self) -> List[Document]:
  23. """Load data into document objects."""
  24. try:
  25. with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
  26. docs = self._read_from_file(csvfile)
  27. except UnicodeDecodeError as e:
  28. if self.autodetect_encoding:
  29. detected_encodings = detect_file_encodings(self.file_path)
  30. for encoding in detected_encodings:
  31. logger.debug("Trying encoding: ", encoding.encoding)
  32. try:
  33. with open(self.file_path, newline="", encoding=encoding.encoding) as csvfile:
  34. docs = self._read_from_file(csvfile)
  35. break
  36. except UnicodeDecodeError:
  37. continue
  38. else:
  39. raise RuntimeError(f"Error loading {self.file_path}") from e
  40. return docs
  41. def _read_from_file(self, csvfile):
  42. docs = []
  43. csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
  44. for i, row in enumerate(csv_reader):
  45. content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
  46. try:
  47. source = (
  48. row[self.source_column]
  49. if self.source_column is not None
  50. else ''
  51. )
  52. except KeyError:
  53. raise ValueError(
  54. f"Source column '{self.source_column}' not found in CSV file."
  55. )
  56. metadata = {"source": source, "row": i}
  57. doc = Document(page_content=content, metadata=metadata)
  58. docs.append(doc)
  59. return docs