[Scala][XML]ScalaでXMLストリームパーザを書いてみた

巨大なXMLファイルを処理するとき、プロセスのメモリにDOMを全部置いたりするのは非常に効率が悪い。なので、世の中にはSAXというイベントベースのAPIがあるわけだが、遅延評価を利用してこれをイベントベースの処理をDOMのストリームに変換してみたい。

標準ではsacla.xml.pull.XMLEventReader*1というのが既にあるが、スレッドの扱いがちょっと不満だったので、Scalaの練習がてら、自分で書いてみた。

StAXのイベントストリームから、NodeSeq(子要素を遅延評価できるようにしてある)に変換するのが、下記のXMLStreamParser。

ソース

import scala.io.Source
import scala.concurrent.SyncVar
import scala.xml._
import scala.xml.parsing._
import scala.xml.pull._

object XMLStreamParser {
  def apply(input: Source) = nodeStream(eventStream(input))

  def nodeStream(evs: Stream[XMLEvent]): Stream[Node] = {
    if (evs.isEmpty) Stream.empty
    else {
      evs.head match {
        case EvElemStart(p, l, a, s) => Stream.cons(Elem(p, l, a, s, nodeStream(evs.tail):_*),
                                                    nodeStream(findNextSibling(evs.tail, 0)))
        case EvElemEnd(p, l) => Stream.empty
        case EvText(t) => Stream.cons(Text(t), nodeStream(evs.tail))
        case EvEntityRef(e) => Stream.cons(EntityRef(e), nodeStream(evs.tail))
        case EvComment(c) => Stream.cons(Comment(c), nodeStream(evs.tail))
        case EvProcInstr(t, x) => Stream.cons(ProcInstr(t, x), nodeStream(evs.tail))
      }
    }
  }

  private def findNextSibling(evs: Stream[XMLEvent], depth: Int): Stream[XMLEvent] = {
    if (evs.isEmpty) Stream.empty
    else {
      evs.head match {
        case EvElemStart(_,_,_,_) => findNextSibling(evs.tail, depth + 1)
        case EvElemEnd(_,_) => if (depth < 1) evs.tail
                               else findNextSibling(evs.tail, depth - 1)
        case _ => findNextSibling(evs.tail, depth)
      }
    }
  }

  def eventStream(input: Source): Stream[XMLEvent] = {
    import javax.xml.stream._
    import javax.xml.stream.XMLStreamConstants._
    val factory = XMLInputFactory.newInstance()
    val reader = factory.createXMLStreamReader(new SourceReader(input))
    def st: Stream[XMLEvent] = {
      if (reader.hasNext) {
        reader.next
        reader.getEventType match {
          case START_ELEMENT =>
            Stream.cons(EvElemStart(reader.getPrefix,
                                    reader.getLocalName,
                                    reader2metadata(reader),
                                    reader2namespacebinding(reader)),
                        st)
          case END_ELEMENT =>
            Stream.cons(EvElemEnd(reader.getPrefix, reader.getLocalName), st)
          case CDATA =>
            Stream.cons(EvText(reader.getText), st)
          case CHARACTERS =>
            Stream.cons(EvText(reader.getText), st)
          case COMMENT =>
            Stream.cons(EvComment(reader.getText), st)
          case ENTITY_REFERENCE =>
            Stream.cons(EvEntityRef(reader.getLocalName), st)
          case PROCESSING_INSTRUCTION =>
            Stream.cons(EvProcInstr(reader.getPITarget, reader.getPIData), st)
          case _:Int => st
        }
      } else {
        Stream.empty
      }
    }
    st
  }

  private def reader2metadata(reader: javax.xml.stream.XMLStreamReader) = {
    var md: MetaData = Null
    for (i <- 0 until reader.getAttributeCount()) {
      if (reader.getPrefix() == null) {
        md = new UnprefixedAttribute(reader.getAttributeLocalName(i),
                                     reader.getAttributeValue(i),
                                     md)
      } else {
        md = new PrefixedAttribute(reader.getAttributePrefix(i),
                                   reader.getAttributeLocalName(i),
                                   reader.getAttributeValue(i),
                                   md)
      }
    }
    md
  }

  private def reader2namespacebinding(reader: javax.xml.stream.XMLStreamReader) = {
    var nb: NamespaceBinding = TopScope
    for (i <- 0 until reader.getNamespaceCount()) {
      nb = new NamespaceBinding(reader.getNamespacePrefix(i), reader.getNamespaceURI(i), nb)
    }
    nb
  }
}

class SourceReader(chars: Iterator[Char]) extends java.io.Reader {
  var currentChars = chars

  override def read(buf: Array[Char], off: Int, len: Int): Int = {
    if (!chars.hasNext) return -1
    var i = off
    val end = off + len
    while (i < end) {
      if (chars.hasNext) buf(i) = chars.next
      else return i - off
      i += 1
    }
    i - off
  }

  override def close() = {
    currentChars = Iterator.empty
  }
}

使い方

も帰ってきたストリームの先頭をvalやvarで宣言した変数に入れてはいけない。ストリームの全要素がメモリ上に乗ってしまうので、XMLStreamParserを使う意味がない。

ストリームの先頭は、必要なくなったらすぐに捨てるように書く。

import scala.io._
import scala.xml._
import org.scalatest._
import org.scalatest.matchers._

class XMLTest extends Spec with ShouldMatchers {

  describe("Simple XML") {
    it("001") {
      def xr = XMLStreamParser(Source.fromString("<hoge/>"))
      xr.toList should equal (List(<hoge/>))
    }

    it("002") {
      def xr = XMLStreamParser(Source.fromString("""<?xml version="1.0" encoding="UTF-8" ?>
<hoge><piyo t="0"><moko/></piyo><puyo><ketaketa/></puyo></hoge>"""))
      xr.head.asInstanceOf[Elem].child(0) should equal (<piyo t="0"><moko/></piyo>)
      xr.toList should equal (List(<hoge><piyo t="0"><moko/></piyo><puyo><ketaketa/></puyo></hoge>))
    }

    it("003") {
      def xr = XMLStreamParser(Source.fromString("<hoge>piyopiyo</hoge>"))
      xr.toList should equal (List(<hoge>piyopiyo</hoge>))
    }

    it("004") {
      def xr = XMLStreamParser(Source.fromURL("http://d.hatena.ne.jp/t2ru/rss"))
      xr.filter(_.isInstanceOf[Elem]).map(_.asInstanceOf[Elem]).head.child
      xr.toList
    }
  }
}

object XMLTestRunner {
  def main(args: Array[String]) = {
    (new XMLTest).execute()
  }
}

*1:公式でもまだ不安定なAPIらしいけど。。。